Skip to content

Commit

Permalink
Start adding gather.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Jul 22, 2023
1 parent 6eeea1b commit e7432fb
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 0 deletions.
1 change: 1 addition & 0 deletions candle-core/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pub trait BackendStorage: Sized {
) -> Result<Self>;

fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self>;
fn index_add(
&self,
Expand Down
2 changes: 2 additions & 0 deletions candle-core/src/backprop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ impl Tensor {
}
| Op::CustomOp2(lhs, rhs, _)
| Op::Binary(lhs, rhs, _)
| Op::Gather(lhs, rhs, _)
| Op::IndexSelect(lhs, rhs, _)
| Op::Embedding(lhs, rhs)
| Op::Matmul(lhs, rhs) => {
Expand Down Expand Up @@ -162,6 +163,7 @@ impl Tensor {
*f_sum_grad = f_sum_grad.add(&f_grad)?;
}
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
Op::Gather(..) => Err(Error::BackwardNotSupported { op: "gather" })?,
Op::IndexAdd { .. } => Err(Error::BackwardNotSupported { op: "index-add" })?,
Op::IndexSelect(arg, indexes, dim) => {
let sum_grad = grads.or_insert(arg)?;
Expand Down
5 changes: 5 additions & 0 deletions candle-core/src/cpu_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1593,6 +1593,11 @@ impl BackendStorage for CpuStorage {
IndexSelect { ids, ids_l, dim }.map(self, l)
}

fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
let ids = ids.as_slice::<u32>()?;
IndexSelect { ids, ids_l, dim }.map(self, l)
}

fn index_add(
&self,
l: &Layout,
Expand Down
3 changes: 3 additions & 0 deletions candle-core/src/cuda_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1064,6 +1064,9 @@ impl BackendStorage for CudaStorage {
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(CudaError::InternalError("TODO: implement index-select").into())
}
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
Err(CudaError::InternalError("TODO: implement gather").into())
}
fn index_add(
&self,
_: &Layout,
Expand Down
4 changes: 4 additions & 0 deletions candle-core/src/dummy_cuda_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ impl crate::backend::BackendStorage for CudaStorage {
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}

fn index_add(
&self,
_: &Layout,
Expand Down
1 change: 1 addition & 0 deletions candle-core/src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ pub(crate) enum Op {
Reduce(Tensor, ReduceOp, Vec<usize>),
Matmul(Tensor, Tensor),
Embedding(Tensor, Tensor),
Gather(Tensor, Tensor, usize),
IndexSelect(Tensor, Tensor, usize),
IndexAdd(Tensor, Tensor, Tensor, usize),
WhereCond(Tensor, Tensor, Tensor),
Expand Down
21 changes: 21 additions & 0 deletions candle-core/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,27 @@ impl Storage {
}
}

pub(crate) fn gather(
&self,
l: &Layout,
indexes: &Self,
indexes_l: &Layout,
d: usize,
) -> Result<Self> {
self.same_device(indexes, "index-add")?;
match (self, indexes) {
(Self::Cpu(s), Self::Cpu(indexes)) => {
let storage = s.gather(l, indexes, indexes_l, d)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(s), Self::Cuda(indexes)) => {
let storage = s.gather(l, indexes, indexes_l, d)?;
Ok(Self::Cuda(storage))
}
_ => unreachable!(),
}
}

pub(crate) fn index_add(
&self,
l: &Layout,
Expand Down
13 changes: 13 additions & 0 deletions candle-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,19 @@ impl Tensor {
Ok(from_storage(storage, self.shape(), op, false))
}

pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "gather")?;
let storage =
self.storage()
.gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?;
let op = if indexes.track_op() || self.track_op() {
Some(Op::Gather(self.clone(), indexes.clone(), dim))
} else {
None
};
Ok(from_storage(storage, indexes.shape(), op, false))
}

pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "index-select")?;
let indexes_len = match indexes.dims() {
Expand Down

0 comments on commit e7432fb

Please sign in to comment.