Skip to content

Commit

Permalink
Use index-add in the backprop.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Jul 21, 2023
1 parent 6e74d30 commit 98d1db0
Showing 1 changed file with 1 addition and 19 deletions.
20 changes: 1 addition & 19 deletions candle-core/src/backprop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,26 +164,8 @@ impl Tensor {
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
Op::IndexAdd { .. } => Err(Error::BackwardNotSupported { op: "index-add" })?,
Op::IndexSelect(arg, indexes, dim) => {
let dim = *dim;
let sum_grad = grads.or_insert(arg)?;
// TODO: This is very very very inefficient, have some dedicated kernel for this.
// https://pytorch.org/docs/stable/generated/torch.Tensor.index_add.html
let indexes = indexes.to_vec1::<u32>()?;
for (dst_index, src_index) in indexes.iter().enumerate() {
let src_index = *src_index as usize;
let dst_grad_for_index = grad.narrow(dim, dst_index, 1)?;
let mut pre_dims = arg.dims().to_vec();
pre_dims[dim] = src_index;
let pre_zeros =
Tensor::zeros(pre_dims, sum_grad.dtype(), sum_grad.device())?;
let mut post_dims = arg.dims().to_vec();
post_dims[dim] = post_dims[dim] - src_index - 1;
let post_zeros =
Tensor::zeros(post_dims, sum_grad.dtype(), sum_grad.device())?;
let src_grad =
Tensor::cat(&[pre_zeros, dst_grad_for_index, post_zeros], dim)?;
*sum_grad = sum_grad.add(&src_grad)?;
}
*sum_grad = sum_grad.index_add(indexes, &grad, *dim)?;
}
Op::Embedding(_lhs, _rhs) => {
Err(Error::BackwardNotSupported { op: "embedding" })?
Expand Down

0 comments on commit 98d1db0

Please sign in to comment.