From 410654525f36e95aebb52462c3ec9bb25826523c Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 21 Jul 2023 12:41:08 +0200 Subject: [PATCH] Refactor the reduce ops in order to introduce argmin/argmax. (#212) * Refactor the reduce ops in order to introduce argmin/argmax. * Clippy fixes. * Use the newly introduced argmax. * Fix the strided case. * Handle the non-contiguous case. --- candle-core/src/backprop.rs | 6 + candle-core/src/cpu_backend.rs | 221 ++++++++++++++---- candle-core/src/cuda_backend.rs | 2 + candle-core/src/error.rs | 3 + candle-core/src/op.rs | 14 ++ candle-core/src/tensor.rs | 76 +++--- .../examples/simple-training/main.rs | 29 +-- 7 files changed, 241 insertions(+), 110 deletions(-) diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 4afaf23b46..62cbc488c8 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -304,6 +304,12 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&arg_grad)? } + Op::Reduce(_, ReduceOp::ArgMin, _) => { + Err(Error::BackwardNotSupported { op: "argmin" })? + } + Op::Reduce(_, ReduceOp::ArgMax, _) => { + Err(Error::BackwardNotSupported { op: "argmax" })? + } Op::Softmax(_arg, _) => Err(Error::BackwardNotSupported { op: "softmax" })?, Op::Reshape(arg) => { let arg_grad = grad.reshape(arg.dims())?; diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index b7060f5031..7901a7dae8 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -33,6 +33,26 @@ trait Map1 { } } +trait Map1Any { + fn f) -> CpuStorage>( + &self, + vs: &[T], + layout: &Layout, + wrap: W, + ) -> Result; + + fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result { + match vs { + CpuStorage::U8(vs) => Ok(self.f(vs, layout, CpuStorage::U8)?), + CpuStorage::U32(vs) => Ok(self.f(vs, layout, CpuStorage::U32)?), + CpuStorage::BF16(vs) => Ok(self.f(vs, layout, CpuStorage::BF16)?), + CpuStorage::F16(vs) => Ok(self.f(vs, layout, CpuStorage::F16)?), + CpuStorage::F32(vs) => Ok(self.f(vs, layout, CpuStorage::F32)?), + CpuStorage::F64(vs) => Ok(self.f(vs, layout, CpuStorage::F64)?), + } + } +} + type C = CpuStorage; trait Map2 { const OP: &'static str; @@ -144,11 +164,118 @@ impl<'a> Map2 for WCond<'a> { } } +struct ReduceIndex { + reduce_dim_index: usize, + use_min: bool, + return_index: bool, +} + +impl ReduceIndex { + // The value gets replaced if f(s[current_acc], s[i]) returns true. + #[inline(always)] + fn fold_impl(&self, src: &[T], src_l: &Layout, f: F, g: G) -> Result> + where + T: Clone + Copy, + U: Clone + Copy, + F: Fn(T, T) -> bool, + G: Fn(T, usize) -> U, + { + let reduce_dim_size = src_l.dims()[self.reduce_dim_index]; + let reduce_dim_stride = src_l.stride()[self.reduce_dim_index]; + let dst_len = src_l.shape().elem_count() / reduce_dim_size; + let mut dst: Vec = Vec::with_capacity(dst_len); + let dst_to_set = dst.spare_capacity_mut(); + let dst_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(dst_to_set) }; + match src_l.contiguous_offsets() { + Some((o1, o2)) => { + let src = &src[o1..o2]; + if reduce_dim_stride == 1 { + for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() { + let start_src_i = start_src_i * reduce_dim_size; + let src = &src[start_src_i..start_src_i + reduce_dim_size]; + let mut acc = 0; + let mut val = src[0]; + for (src_i, &s) in src.iter().enumerate() { + if f(val, s) { + acc = src_i; + val = s + } + } + *dst_v = g(val, acc) + } + } else { + for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() { + let (p, q) = ( + start_src_i / reduce_dim_stride, + start_src_i % reduce_dim_stride, + ); + // start_src_i = p * reduce_dim_stride + q + let start_src_i = p * reduce_dim_stride * reduce_dim_size + q; + let src = &src[start_src_i..]; + let mut acc = 0; + let mut val = src[0]; + for src_i in 0..reduce_dim_size { + let s = src[src_i * reduce_dim_stride]; + if f(val, s) { + acc = src_i; + val = s + } + } + *dst_v = g(val, acc) + } + } + } + None => { + let l = src_l.narrow(self.reduce_dim_index, 0, 1)?; + for (unstr_index, src_index) in l.strided_index().enumerate() { + let src = &src[src_index..]; + let mut acc = 0; + let mut val = src[0]; + for src_i in 0..reduce_dim_size { + let s = src[src_i * reduce_dim_stride]; + if f(val, s) { + acc = src_i; + val = s + } + } + dst[unstr_index] = g(val, acc) + } + } + } + unsafe { dst.set_len(dst_len) }; + Ok(dst) + } +} + +impl Map1Any for ReduceIndex { + #[inline(always)] + fn f) -> CpuStorage>( + &self, + src: &[T], + src_l: &Layout, + wrap: W, + ) -> Result { + if src_l.shape().elem_count() == 0 { + Err(Error::EmptyTensor { op: "reduce" }.bt())? + } + let dst = match (self.return_index, self.use_min) { + (false, true) => wrap(self.fold_impl(src, src_l, |x, y| x > y, |v, _i| v)?), + (false, false) => wrap(self.fold_impl(src, src_l, |x, y| x < y, |v, _i| v)?), + (true, true) => { + CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x > y, |_v, i| i as u32)?) + } + (true, false) => { + CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x < y, |_v, i| i as u32)?) + } + }; + Ok(dst) + } +} + struct Reduce<'a> { dst_shape: &'a Shape, reduce_dims: &'a [usize], reduce_dims_and_stride: Vec<(usize, usize)>, - op: ReduceOp, } impl<'a> Reduce<'a> { @@ -217,25 +344,7 @@ impl<'a> Reduce<'a> { impl<'a> Map1 for Reduce<'a> { #[inline(always)] fn f(&self, src: &[T], src_l: &Layout) -> Result> { - match self.op { - ReduceOp::Min => { - let s = if src_l.shape().elem_count() != 0 { - src[src_l.start_offset()] - } else { - Err(Error::EmptyTensor { op: "min" }.bt())? - }; - self.fold_impl(src, src_l, s, |x, y| if x < y { x } else { y }) - } - ReduceOp::Max => { - let s = if src_l.shape().elem_count() != 0 { - src[src_l.start_offset()] - } else { - Err(Error::EmptyTensor { op: "max" }.bt())? - }; - self.fold_impl(src, src_l, s, |x, y| if x > y { x } else { y }) - } - ReduceOp::Sum => self.fold_impl(src, src_l, T::zero(), |x, y| x + y), - } + self.fold_impl(src, src_l, T::zero(), |x, y| x + y) } } @@ -1144,27 +1253,59 @@ impl BackendStorage for CpuStorage { } fn reduce_op(&self, op: ReduceOp, layout: &Layout, reduce_dims: &[usize]) -> Result { - let src_dims = layout.dims(); - let mut dst_dims = src_dims.to_vec(); - for &dim in reduce_dims.iter() { - dst_dims[dim] = 1; - } - let dst_shape = Shape::from(dst_dims); - let mut reduce_dims = reduce_dims.to_vec(); - // Sort the reduce_dims as they have to be processed from left to right when converting the - // indexes. - reduce_dims.sort(); - let reduce_dims_and_stride: Vec<_> = reduce_dims - .iter() - .map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::())) - .collect(); - Reduce { - dst_shape: &dst_shape, - reduce_dims: &reduce_dims, - reduce_dims_and_stride, - op, + match op { + ReduceOp::Sum => { + let src_dims = layout.dims(); + let mut dst_dims = src_dims.to_vec(); + for &dim in reduce_dims.iter() { + dst_dims[dim] = 1; + } + let dst_shape = Shape::from(dst_dims); + let mut reduce_dims = reduce_dims.to_vec(); + // Sort the reduce_dims as they have to be processed from left to right when converting the + // indexes. + reduce_dims.sort(); + let reduce_dims_and_stride: Vec<_> = reduce_dims + .iter() + .map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::())) + .collect(); + Reduce { + dst_shape: &dst_shape, + reduce_dims: &reduce_dims, + reduce_dims_and_stride, + } + .map(self, layout) + } + ReduceOp::Min | ReduceOp::ArgMin | ReduceOp::Max | ReduceOp::ArgMax => { + let reduce_dim_index = match reduce_dims { + [reduce_dim_index] => *reduce_dim_index, + _ => { + let op = match op { + ReduceOp::Min => "min", + ReduceOp::ArgMin => "argmin", + ReduceOp::Max => "max", + ReduceOp::ArgMax => "argmax", + _ => unreachable!(), + }; + let dims = reduce_dims.to_vec(); + Err(Error::OnlySingleDimension { op, dims })? + } + }; + let (use_min, return_index) = match op { + ReduceOp::Min => (true, false), + ReduceOp::ArgMin => (true, true), + ReduceOp::Max => (false, false), + ReduceOp::ArgMax => (false, true), + _ => unreachable!(), + }; + ReduceIndex { + reduce_dim_index, + use_min, + return_index, + } + .map(self, layout) + } } - .map(self, layout) } fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result { diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index e40f5f7109..cdbfd0c6d7 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -562,6 +562,8 @@ impl<'a> Map1 for FastReduce<'a> { ReduceOp::Sum => "fast_sum", ReduceOp::Min => "fast_min", ReduceOp::Max => "fast_max", + ReduceOp::ArgMin => "fast_argmin", + ReduceOp::ArgMax => "fast_argmax", }; let func = dev.get_or_load_func(&kernel_name::(name), kernels::REDUCE)?; // SAFETY: filled in by the follow up kernel. diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index acbe28d3fb..23f2642d7c 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -79,6 +79,9 @@ pub enum Error { nth_shape: Shape, }, + #[error("{op} can only be performed on a single dimension")] + OnlySingleDimension { op: &'static str, dims: Vec }, + #[error("empty tensor for {op}")] EmptyTensor { op: &'static str }, diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 4686e57ef8..226cff4137 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -17,6 +17,20 @@ pub enum ReduceOp { Sum, Min, Max, + ArgMin, + ArgMax, +} + +impl ReduceOp { + pub(crate) fn name(&self) -> &'static str { + match self { + Self::ArgMax => "argmax", + Self::ArgMin => "argmin", + Self::Min => "min", + Self::Max => "max", + Self::Sum => "sum", + } + } } // These ops return the same type as their input type. diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index f72404dff0..42d660f442 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -628,47 +628,21 @@ impl Tensor { } } - fn max_impl(&self, max_dims: D, keepdim: bool) -> Result { - let max_dims = max_dims.to_indexes(self.shape(), "max")?; - let storage = self - .storage() - .reduce_op(ReduceOp::Max, self.layout(), &max_dims)?; - let mut dims = self.dims().to_vec(); - for &max_dim in max_dims.iter() { - dims[max_dim] = 1 - } - let op = if self.track_op() { - Some(Op::Reduce(self.clone(), ReduceOp::Max, dims.to_vec())) - } else { - None - }; - let max = from_storage(storage, dims, op, false); - if keepdim { - Ok(max) - } else { - max.squeeze_dims(&max_dims) - } - } - - fn min_impl(&self, min_dims: D, keepdim: bool) -> Result { - let min_dims = min_dims.to_indexes(self.shape(), "min")?; - let storage = self - .storage() - .reduce_op(ReduceOp::Min, self.layout(), &min_dims)?; + fn reduce_impl(&self, dim: D, keepdim: bool, op: ReduceOp) -> Result { + let dim = dim.to_index(self.shape(), op.name())?; + let storage = self.storage().reduce_op(op, self.layout(), &[dim])?; let mut dims = self.dims().to_vec(); - for &min_dim in min_dims.iter() { - dims[min_dim] = 1 - } + dims[dim] = 1; let op = if self.track_op() { - Some(Op::Reduce(self.clone(), ReduceOp::Min, dims.to_vec())) + Some(Op::Reduce(self.clone(), op, dims.to_vec())) } else { None }; - let min = from_storage(storage, dims, op, false); + let res = from_storage(storage, dims, op, false); if keepdim { - Ok(min) + Ok(res) } else { - min.squeeze_dims(&min_dims) + res.squeeze_dims(&[dim]) } } @@ -722,30 +696,36 @@ impl Tensor { self.sum_impl(sum_dims, false) } - pub fn max_keepdim(&self, max_dims: D) -> Result { - self.max_impl(max_dims, true) + pub fn max_keepdim(&self, dim: D) -> Result { + self.reduce_impl(dim, true, ReduceOp::Max) } - pub fn max(&self, max_dims: D) -> Result { - self.max_impl(max_dims, false) + pub fn max(&self, dim: D) -> Result { + self.reduce_impl(dim, false, ReduceOp::Max) } - pub fn max_all(&self) -> Result { - let dims: Vec<_> = (0..self.rank()).collect(); - self.max(dims) + pub fn min_keepdim(&self, dim: D) -> Result { + self.reduce_impl(dim, true, ReduceOp::Min) } - pub fn min_keepdim(&self, min_dims: D) -> Result { - self.min_impl(min_dims, true) + pub fn min(&self, dim: D) -> Result { + self.reduce_impl(dim, false, ReduceOp::Min) } - pub fn min(&self, min_dims: D) -> Result { - self.min_impl(min_dims, false) + pub fn argmax_keepdim(&self, dim: D) -> Result { + self.reduce_impl(dim, true, ReduceOp::ArgMax) } - pub fn min_all(&self) -> Result { - let dims: Vec<_> = (0..self.rank()).collect(); - self.min(dims) + pub fn argmax(&self, dim: D) -> Result { + self.reduce_impl(dim, false, ReduceOp::ArgMax) + } + + pub fn argmin_keepdim(&self, dim: D) -> Result { + self.reduce_impl(dim, true, ReduceOp::ArgMin) + } + + pub fn argmin(&self, dim: D) -> Result { + self.reduce_impl(dim, false, ReduceOp::ArgMin) } pub fn cmp(&self, rhs: &Self, op: CmpOp) -> Result { diff --git a/candle-examples/examples/simple-training/main.rs b/candle-examples/examples/simple-training/main.rs index 767266f63d..ea2dc0cd1b 100644 --- a/candle-examples/examples/simple-training/main.rs +++ b/candle-examples/examples/simple-training/main.rs @@ -42,7 +42,7 @@ pub fn main() -> Result<()> { let bs = Var::zeros(LABELS, DType::F32, &dev)?; let sgd = candle_nn::SGD::new(&[&ws, &bs], 1.0); let test_images = m.test_images; - let test_labels = m.test_labels.to_vec1::()?; + let test_labels = m.test_labels.to_dtype(DType::U32)?; for epoch in 1..200 { let logits = train_images.matmul(&ws)?.broadcast_add(&bs)?; let log_sm = log_softmax(&logits, D::Minus1)?; @@ -52,28 +52,13 @@ pub fn main() -> Result<()> { sgd.backward_step(&loss)?; let test_logits = test_images.matmul(&ws)?.broadcast_add(&bs)?; - /* TODO: Add argmax so that the following can be computed within candle. - let test_accuracy = test_logits - .argmax(Some(-1), false) - .eq_tensor(&test_labels) - .to_kind(Kind::Float) - .mean(Kind::Float) - .double_value(&[]); - */ - let test_logits = test_logits.to_vec2::()?; let sum_ok = test_logits - .iter() - .zip(test_labels.iter()) - .map(|(logits, label)| { - let arg_max = logits - .iter() - .enumerate() - .max_by(|(_, v1), (_, v2)| v1.total_cmp(v2)) - .map(|(idx, _)| idx); - f64::from(arg_max == Some(*label as usize)) - }) - .sum::(); - let test_accuracy = sum_ok / test_labels.len() as f64; + .argmax(D::Minus1)? + .eq(&test_labels)? + .to_dtype(DType::F32)? + .sum_all()? + .to_scalar::()?; + let test_accuracy = sum_ok / test_labels.shape().r1()? as f32; println!( "{epoch:4} train loss: {:8.5} test acc: {:5.2}%", loss.to_scalar::()?,