Skip to content

Commit

Permalink
Cuda kernels for IndexAdd/ScatterAdd. (#236)
Browse files Browse the repository at this point in the history
* Skeleton methods for IndexAdd/ScatterAdd.

* Add a Map2InPlace trait.

* Add the glue code for the index-add/scatter-add kernels.

* Tweak the file name: embeddings -> indexing.

* Add the cuda kernel for indexadd.

* And add the scatter-add kernels.
  • Loading branch information
LaurentMazare authored Jul 24, 2023
1 parent 581b104 commit 74a6a76
Show file tree
Hide file tree
Showing 3 changed files with 255 additions and 30 deletions.
182 changes: 153 additions & 29 deletions candle-core/src/cuda_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -398,12 +398,42 @@ trait Map2 {
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
_ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?,
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
};
Ok(out)
}
}

trait Map2InPlace {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
dst_shape: &Shape,
src: &CudaSlice<T>,
src_l: &Layout,
dev: &CudaDevice,
) -> Result<()>;

fn map(
&self,
dst: &mut S,
dst_s: &Shape,
src: &S,
src_l: &Layout,
d: &CudaDevice,
) -> Result<()> {
match (dst, src) {
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
}
}
}

trait Map2Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
Expand Down Expand Up @@ -651,7 +681,7 @@ impl<'a> Map1 for Embedding<'a> {
let cfg = LaunchConfig::for_num_elems(el as u32);
let ds = dev.htod_copy([dims, ids_l.stride()].concat()).w()?;
let rhs = &rhs.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::EMBEDDINGS)?;
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el * h_size) }.w()?;
let params = (el, dims.len(), &ds, ids, rhs, &out, h_size, v_size);
Expand Down Expand Up @@ -696,7 +726,7 @@ impl<'a> Map1 for IndexSelect<'a> {
let left_size: usize = src_l.dims()[..self.2].iter().product();
let right_size: usize = src_l.dims()[self.2 + 1..].iter().product();
let dim_size = src_l.dims()[self.2];
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::EMBEDDINGS)?;
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(ids_el * left_size * right_size) }.w()?;
let params = (
Expand Down Expand Up @@ -752,7 +782,7 @@ impl<'a> Map1 for Gather<'a> {
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
let src_dim_sz = src_l.dims()[dim];
let ids_dim_sz = ids_l.dims()[dim];
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::EMBEDDINGS)?;
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (
Expand All @@ -764,6 +794,97 @@ impl<'a> Map1 for Gather<'a> {
}
}

struct IndexAdd<'a>(&'a CudaStorage, &'a Layout, usize);
impl<'a> Map2InPlace for IndexAdd<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
dst_shape: &Shape,
src: &CudaSlice<T>,
src_l: &Layout,
dev: &CudaDevice,
) -> Result<()> {
let ids = &self.0;
let ids_l = &self.1;
let dim = self.2;
let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() {
Some(o12) => o12,
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
};
let (name, ids) = match &ids.slice {
CudaStorageSlice::U32(slice) => ("ia_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::U8(slice) => ("ia_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
_ => Err(CudaError::UnexpectedDType {
msg: "index-add ids should be u8 or u32",
expected: DType::U32,
got: ids.dtype(),
})?,
};
let src = match src_l.contiguous_offsets() {
Some((o1, o2)) => src.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
};
let left_sz: usize = src_l.dims()[..dim].iter().product();
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
let src_dim_sz = src_l.dims()[dim];
let dst_dim_sz = dst_shape.dims()[dim];
let ids_dim_sz = ids_l.dims()[0];
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
// SAFETY: Set later by running the kernel.
let params = (
ids, ids_dim_sz, &src, dst, left_sz, src_dim_sz, dst_dim_sz, right_sz,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(())
}
}

struct ScatterAdd<'a>(&'a CudaStorage, &'a Layout, usize);
impl<'a> Map2InPlace for ScatterAdd<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
_dst_shape: &Shape,
src: &CudaSlice<T>,
src_l: &Layout,
dev: &CudaDevice,
) -> Result<()> {
let ids = &self.0;
let ids_l = &self.1;
let dim = self.2;
let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() {
Some(o12) => o12,
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
};
let (name, ids) = match &ids.slice {
CudaStorageSlice::U32(slice) => ("sa_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::U8(slice) => ("sa_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
_ => Err(CudaError::UnexpectedDType {
msg: "scatter-add ids should be u8 or u32",
expected: DType::U32,
got: ids.dtype(),
})?,
};
let src = match src_l.contiguous_offsets() {
Some((o1, o2)) => src.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
};
let left_sz: usize = src_l.dims()[..dim].iter().product();
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
let src_dim_sz = src_l.dims()[dim];
let ids_dim_sz = ids_l.dims()[dim];
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
// SAFETY: Set later by running the kernel.
let params = (ids, &src, dst, left_sz, src_dim_sz, ids_dim_sz, right_sz);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(())
}
}

struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
impl<'a> Map2 for Conv1D<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
Expand Down Expand Up @@ -1004,8 +1125,7 @@ fn gemm_config<T>(
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})
.w()?
})?
};
// The b tensor has dims batching, m, k (lhs)
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
Expand All @@ -1017,8 +1137,7 @@ fn gemm_config<T>(
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})
.w()?
})?
};
// The setup below was copied from:
// https://github.com/lebedov/scikit-cuda/blob/7e7300474286019c917a6c8a4bca59405c64fbce/tests/test_cublas.py#L531
Expand All @@ -1043,8 +1162,7 @@ fn gemm_config<T>(
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})
.w()?,
})?,
};
let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
Expand All @@ -1054,8 +1172,7 @@ fn gemm_config<T>(
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})
.w()?,
})?,
};

Ok(StridedBatchedConfig {
Expand Down Expand Up @@ -1281,25 +1398,33 @@ impl BackendStorage for CudaStorage {
}
fn scatter_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
) -> Result<Self> {
Err(CudaError::InternalError("TODO: implement scatter-add").into())
let device = self.device().clone();
let mut acc = device.zeros_impl(l.shape(), self.dtype())?;
self.copy_strided_src(&mut acc, 0, l)?;
ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
Ok(acc)
}
fn index_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
) -> Result<Self> {
Err(CudaError::InternalError("TODO: implement index-add").into())
let device = self.device().clone();
let mut acc = device.zeros_impl(l.shape(), self.dtype())?;
self.copy_strided_src(&mut acc, 0, l)?;
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
Ok(acc)
}

fn matmul(
Expand Down Expand Up @@ -1364,7 +1489,7 @@ impl BackendStorage for CudaStorage {
.w()?;
CudaStorageSlice::F64(out)
}
_ => Err(CudaError::InternalError("dtype mismatch in matmul op")).w()?,
_ => Err(CudaError::InternalError("dtype mismatch in matmul op"))?,
};
let device = dev.clone();
Ok(Self { slice, device })
Expand Down Expand Up @@ -1452,8 +1577,7 @@ impl BackendStorage for CudaStorage {
}
_ => Err(CudaError::InternalError(
"dtype mismatch in copy_strided op",
))
.w()?,
))?,
}
Ok(())
}
Expand Down
Loading

0 comments on commit 74a6a76

Please sign in to comment.