Skip to content

Commit

Permalink
Refactor the reduce ops in order to introduce argmin/argmax. (#212)
Browse files Browse the repository at this point in the history
* Refactor the reduce ops in order to introduce argmin/argmax.

* Clippy fixes.

* Use the newly introduced argmax.

* Fix the strided case.

* Handle the non-contiguous case.
  • Loading branch information
LaurentMazare authored Jul 21, 2023
1 parent c60831a commit 4106545
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 110 deletions.
6 changes: 6 additions & 0 deletions candle-core/src/backprop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,12 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Reduce(_, ReduceOp::ArgMin, _) => {
Err(Error::BackwardNotSupported { op: "argmin" })?
}
Op::Reduce(_, ReduceOp::ArgMax, _) => {
Err(Error::BackwardNotSupported { op: "argmax" })?
}
Op::Softmax(_arg, _) => Err(Error::BackwardNotSupported { op: "softmax" })?,
Op::Reshape(arg) => {
let arg_grad = grad.reshape(arg.dims())?;
Expand Down
221 changes: 181 additions & 40 deletions candle-core/src/cpu_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,26 @@ trait Map1 {
}
}

trait Map1Any {
fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
&self,
vs: &[T],
layout: &Layout,
wrap: W,
) -> Result<CpuStorage>;

fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
match vs {
CpuStorage::U8(vs) => Ok(self.f(vs, layout, CpuStorage::U8)?),
CpuStorage::U32(vs) => Ok(self.f(vs, layout, CpuStorage::U32)?),
CpuStorage::BF16(vs) => Ok(self.f(vs, layout, CpuStorage::BF16)?),
CpuStorage::F16(vs) => Ok(self.f(vs, layout, CpuStorage::F16)?),
CpuStorage::F32(vs) => Ok(self.f(vs, layout, CpuStorage::F32)?),
CpuStorage::F64(vs) => Ok(self.f(vs, layout, CpuStorage::F64)?),
}
}
}

type C = CpuStorage;
trait Map2 {
const OP: &'static str;
Expand Down Expand Up @@ -144,11 +164,118 @@ impl<'a> Map2 for WCond<'a> {
}
}

struct ReduceIndex {
reduce_dim_index: usize,
use_min: bool,
return_index: bool,
}

impl ReduceIndex {
// The value gets replaced if f(s[current_acc], s[i]) returns true.
#[inline(always)]
fn fold_impl<T, U, F, G>(&self, src: &[T], src_l: &Layout, f: F, g: G) -> Result<Vec<U>>
where
T: Clone + Copy,
U: Clone + Copy,
F: Fn(T, T) -> bool,
G: Fn(T, usize) -> U,
{
let reduce_dim_size = src_l.dims()[self.reduce_dim_index];
let reduce_dim_stride = src_l.stride()[self.reduce_dim_index];
let dst_len = src_l.shape().elem_count() / reduce_dim_size;
let mut dst: Vec<U> = Vec::with_capacity(dst_len);
let dst_to_set = dst.spare_capacity_mut();
let dst_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(dst_to_set) };
match src_l.contiguous_offsets() {
Some((o1, o2)) => {
let src = &src[o1..o2];
if reduce_dim_stride == 1 {
for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
let start_src_i = start_src_i * reduce_dim_size;
let src = &src[start_src_i..start_src_i + reduce_dim_size];
let mut acc = 0;
let mut val = src[0];
for (src_i, &s) in src.iter().enumerate() {
if f(val, s) {
acc = src_i;
val = s
}
}
*dst_v = g(val, acc)
}
} else {
for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
let (p, q) = (
start_src_i / reduce_dim_stride,
start_src_i % reduce_dim_stride,
);
// start_src_i = p * reduce_dim_stride + q
let start_src_i = p * reduce_dim_stride * reduce_dim_size + q;
let src = &src[start_src_i..];
let mut acc = 0;
let mut val = src[0];
for src_i in 0..reduce_dim_size {
let s = src[src_i * reduce_dim_stride];
if f(val, s) {
acc = src_i;
val = s
}
}
*dst_v = g(val, acc)
}
}
}
None => {
let l = src_l.narrow(self.reduce_dim_index, 0, 1)?;
for (unstr_index, src_index) in l.strided_index().enumerate() {
let src = &src[src_index..];
let mut acc = 0;
let mut val = src[0];
for src_i in 0..reduce_dim_size {
let s = src[src_i * reduce_dim_stride];
if f(val, s) {
acc = src_i;
val = s
}
}
dst[unstr_index] = g(val, acc)
}
}
}
unsafe { dst.set_len(dst_len) };
Ok(dst)
}
}

impl Map1Any for ReduceIndex {
#[inline(always)]
fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
&self,
src: &[T],
src_l: &Layout,
wrap: W,
) -> Result<CpuStorage> {
if src_l.shape().elem_count() == 0 {
Err(Error::EmptyTensor { op: "reduce" }.bt())?
}
let dst = match (self.return_index, self.use_min) {
(false, true) => wrap(self.fold_impl(src, src_l, |x, y| x > y, |v, _i| v)?),
(false, false) => wrap(self.fold_impl(src, src_l, |x, y| x < y, |v, _i| v)?),
(true, true) => {
CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x > y, |_v, i| i as u32)?)
}
(true, false) => {
CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x < y, |_v, i| i as u32)?)
}
};
Ok(dst)
}
}

struct Reduce<'a> {
dst_shape: &'a Shape,
reduce_dims: &'a [usize],
reduce_dims_and_stride: Vec<(usize, usize)>,
op: ReduceOp,
}

impl<'a> Reduce<'a> {
Expand Down Expand Up @@ -217,25 +344,7 @@ impl<'a> Reduce<'a> {
impl<'a> Map1 for Reduce<'a> {
#[inline(always)]
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
match self.op {
ReduceOp::Min => {
let s = if src_l.shape().elem_count() != 0 {
src[src_l.start_offset()]
} else {
Err(Error::EmptyTensor { op: "min" }.bt())?
};
self.fold_impl(src, src_l, s, |x, y| if x < y { x } else { y })
}
ReduceOp::Max => {
let s = if src_l.shape().elem_count() != 0 {
src[src_l.start_offset()]
} else {
Err(Error::EmptyTensor { op: "max" }.bt())?
};
self.fold_impl(src, src_l, s, |x, y| if x > y { x } else { y })
}
ReduceOp::Sum => self.fold_impl(src, src_l, T::zero(), |x, y| x + y),
}
self.fold_impl(src, src_l, T::zero(), |x, y| x + y)
}
}

Expand Down Expand Up @@ -1144,27 +1253,59 @@ impl BackendStorage for CpuStorage {
}

fn reduce_op(&self, op: ReduceOp, layout: &Layout, reduce_dims: &[usize]) -> Result<Self> {
let src_dims = layout.dims();
let mut dst_dims = src_dims.to_vec();
for &dim in reduce_dims.iter() {
dst_dims[dim] = 1;
}
let dst_shape = Shape::from(dst_dims);
let mut reduce_dims = reduce_dims.to_vec();
// Sort the reduce_dims as they have to be processed from left to right when converting the
// indexes.
reduce_dims.sort();
let reduce_dims_and_stride: Vec<_> = reduce_dims
.iter()
.map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::<usize>()))
.collect();
Reduce {
dst_shape: &dst_shape,
reduce_dims: &reduce_dims,
reduce_dims_and_stride,
op,
match op {
ReduceOp::Sum => {
let src_dims = layout.dims();
let mut dst_dims = src_dims.to_vec();
for &dim in reduce_dims.iter() {
dst_dims[dim] = 1;
}
let dst_shape = Shape::from(dst_dims);
let mut reduce_dims = reduce_dims.to_vec();
// Sort the reduce_dims as they have to be processed from left to right when converting the
// indexes.
reduce_dims.sort();
let reduce_dims_and_stride: Vec<_> = reduce_dims
.iter()
.map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::<usize>()))
.collect();
Reduce {
dst_shape: &dst_shape,
reduce_dims: &reduce_dims,
reduce_dims_and_stride,
}
.map(self, layout)
}
ReduceOp::Min | ReduceOp::ArgMin | ReduceOp::Max | ReduceOp::ArgMax => {
let reduce_dim_index = match reduce_dims {
[reduce_dim_index] => *reduce_dim_index,
_ => {
let op = match op {
ReduceOp::Min => "min",
ReduceOp::ArgMin => "argmin",
ReduceOp::Max => "max",
ReduceOp::ArgMax => "argmax",
_ => unreachable!(),
};
let dims = reduce_dims.to_vec();
Err(Error::OnlySingleDimension { op, dims })?
}
};
let (use_min, return_index) = match op {
ReduceOp::Min => (true, false),
ReduceOp::ArgMin => (true, true),
ReduceOp::Max => (false, false),
ReduceOp::ArgMax => (false, true),
_ => unreachable!(),
};
ReduceIndex {
reduce_dim_index,
use_min,
return_index,
}
.map(self, layout)
}
}
.map(self, layout)
}

fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
Expand Down
2 changes: 2 additions & 0 deletions candle-core/src/cuda_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,8 @@ impl<'a> Map1 for FastReduce<'a> {
ReduceOp::Sum => "fast_sum",
ReduceOp::Min => "fast_min",
ReduceOp::Max => "fast_max",
ReduceOp::ArgMin => "fast_argmin",
ReduceOp::ArgMax => "fast_argmax",
};
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::REDUCE)?;
// SAFETY: filled in by the follow up kernel.
Expand Down
3 changes: 3 additions & 0 deletions candle-core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ pub enum Error {
nth_shape: Shape,
},

#[error("{op} can only be performed on a single dimension")]
OnlySingleDimension { op: &'static str, dims: Vec<usize> },

#[error("empty tensor for {op}")]
EmptyTensor { op: &'static str },

Expand Down
14 changes: 14 additions & 0 deletions candle-core/src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,20 @@ pub enum ReduceOp {
Sum,
Min,
Max,
ArgMin,
ArgMax,
}

impl ReduceOp {
pub(crate) fn name(&self) -> &'static str {
match self {
Self::ArgMax => "argmax",
Self::ArgMin => "argmin",
Self::Min => "min",
Self::Max => "max",
Self::Sum => "sum",
}
}
}

// These ops return the same type as their input type.
Expand Down
Loading

0 comments on commit 4106545

Please sign in to comment.