Skip to content

Commit

Permalink
Add the IntDType trait.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Jul 23, 2023
1 parent 17f8d0f commit f286296
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 30 deletions.
66 changes: 37 additions & 29 deletions candle-core/src/cpu_backend.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{DType, Error, Layout, Result, Shape, WithDType};
use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
use half::{bf16, f16};

// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
Expand Down Expand Up @@ -133,9 +133,9 @@ impl Map2U8 for Cmp {
}
}

struct WCond<'a>(&'a [u32], &'a Layout);
struct WCond<'a, T: IntDType>(&'a [T], &'a Layout);

impl<'a> Map2 for WCond<'a> {
impl<'a, I: IntDType> Map2 for WCond<'a, I> {
const OP: &'static str = "where";
#[inline(always)]
fn f<T: WithDType>(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result<Vec<T>> {
Expand All @@ -150,14 +150,20 @@ impl<'a> Map2 for WCond<'a> {
let f = &f[o_f1..o_f2];
pred.iter()
.zip(t.iter().zip(f.iter()))
.map(|(&p, (&t, &f))| if p > 0 { t } else { f })
.map(|(p, (&t, &f))| if p.is_true() { t } else { f })
.collect::<Vec<_>>()
}
_ => self
.1
.strided_index()
.zip(t_l.strided_index().zip(f_l.strided_index()))
.map(|(i_p, (i_t, i_f))| if self.0[i_p] > 0 { t[i_t] } else { f[i_f] })
.map(|(i_p, (i_t, i_f))| {
if self.0[i_p].is_true() {
t[i_t]
} else {
f[i_f]
}
})
.collect::<Vec<_>>(),
};
Ok(vs)
Expand Down Expand Up @@ -681,13 +687,13 @@ impl<'a> Map1 for Gather<'a> {
}
}

struct IndexSelect<'a> {
ids: &'a [u32],
struct IndexSelect<'a, T: IntDType> {
ids: &'a [T],
ids_l: &'a Layout,
dim: usize,
}

impl<'a> Map1 for IndexSelect<'a> {
impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
let src = match layout.contiguous_offsets() {
Some((a, b)) => &src[a..b],
Expand All @@ -714,7 +720,7 @@ impl<'a> Map1 for IndexSelect<'a> {
let start_src_idx = left_i * right_len * src_dim;
let start_dst_idx = left_i * right_len * n_ids;
for i in 0..n_ids {
let index = self.ids[self.ids_l.start_offset() + stride_ids * i] as usize;
let index = self.ids[self.ids_l.start_offset() + stride_ids * i].as_usize();
if index >= src_dim {
Err(Error::InvalidIndex {
index,
Expand All @@ -733,13 +739,13 @@ impl<'a> Map1 for IndexSelect<'a> {
}
}

struct ScatterAdd<'a> {
ids: &'a [u32],
struct ScatterAdd<'a, I: IntDType> {
ids: &'a [I],
ids_l: &'a Layout,
dim: usize,
}

impl<'a> Map2 for ScatterAdd<'a> {
impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
const OP: &'static str = "scatter-add";
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
let dst_len = l1.shape().elem_count();
Expand Down Expand Up @@ -771,7 +777,7 @@ impl<'a> Map2 for ScatterAdd<'a> {
let start_ids_idx = start_ids_idx + i * ids_right_len;
for right_i in 0..dst_right_len {
let ids_idx = start_ids_idx + right_i;
let index = ids[ids_idx] as usize;
let index = ids[ids_idx].as_usize();
if index >= dst_dim_len {
Err(Error::InvalidIndex {
index,
Expand All @@ -790,12 +796,12 @@ impl<'a> Map2 for ScatterAdd<'a> {
}
}

struct IndexAdd<'a> {
ids: &'a [u32],
struct IndexAdd<'a, I: IntDType> {
ids: &'a [I],
dim: usize,
}

impl<'a> Map2 for IndexAdd<'a> {
impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
const OP: &'static str = "index-add";
// https://pytorch.org/docs/stable/generated/torch.Tensor.index_add_.html#torch.Tensor.index_add_
// v1, l1 -> self
Expand All @@ -811,8 +817,8 @@ impl<'a> Map2 for IndexAdd<'a> {
let max_idx = l1.dims()[dim];
let stride = src_l.stride()[dim];
if dim == 0 {
for (src_idx, &dst_idx) in self.ids.iter().enumerate() {
let dst_idx = dst_idx as usize;
for (src_idx, dst_idx) in self.ids.iter().enumerate() {
let dst_idx = dst_idx.as_usize();
if dst_idx >= max_idx {
Err(Error::InvalidIndex {
index: dst_idx,
Expand All @@ -831,8 +837,8 @@ impl<'a> Map2 for IndexAdd<'a> {
} else {
let pre_dim = src_l.dims()[..dim].iter().product::<usize>();
let post_dim = src_l.dims()[dim + 1..].iter().product::<usize>();
for (src_idx, &dst_idx) in self.ids.iter().enumerate() {
let dst_idx = dst_idx as usize;
for (src_idx, dst_idx) in self.ids.iter().enumerate() {
let dst_idx = dst_idx.as_usize();
if dst_idx >= max_idx {
Err(Error::InvalidIndex {
index: dst_idx,
Expand All @@ -856,14 +862,14 @@ impl<'a> Map2 for IndexAdd<'a> {
}
}

struct Embedding<'a> {
struct Embedding<'a, I: IntDType> {
vocab_size: usize,
hidden_size: usize,
ids: &'a [u32],
ids: &'a [I],
ids_l: &'a Layout,
}

impl<'a> Map1 for Embedding<'a> {
impl<'a, I: IntDType> Map1 for Embedding<'a, I> {
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
if !layout.is_contiguous() {
Err(Error::RequiresContiguous { op: "embedding" })?
Expand All @@ -872,8 +878,8 @@ impl<'a> Map1 for Embedding<'a> {
let mut values = Vec::with_capacity(self.ids_l.shape().elem_count() * self.hidden_size);
match self.ids_l.contiguous_offsets() {
Some((o1, o2)) => {
for &index in self.ids[o1..o2].iter() {
let index = index as usize;
for index in self.ids[o1..o2].iter() {
let index = index.as_usize();
if index >= self.vocab_size {
Err(Error::InvalidIndex {
index,
Expand All @@ -889,7 +895,7 @@ impl<'a> Map1 for Embedding<'a> {
}
None => {
for index in self.ids_l.strided_index() {
let index = self.ids[index].try_into()?;
let index = self.ids[index].as_usize();
if index >= self.vocab_size {
Err(Error::InvalidIndex {
index,
Expand Down Expand Up @@ -1692,9 +1698,11 @@ impl BackendStorage for CpuStorage {
f: &Self,
f_l: &Layout,
) -> Result<Self> {
// TODO: Support types that could be casted to a boolean.
let pred = self.as_slice::<u32>()?;
WCond(pred, layout).map(t, t_l, f, f_l)
match self {
Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")),
}
}

fn conv1d(
Expand Down
23 changes: 23 additions & 0 deletions candle-core/src/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,26 @@ with_dtype!(f16, F16, f16::from_f64, f16::to_f64);
with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);
with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
with_dtype!(f64, F64, |v: f64| v, |v: f64| v);

pub trait IntDType {
fn is_true(&self) -> bool;
fn as_usize(&self) -> usize;
}

impl IntDType for u32 {
fn is_true(&self) -> bool {
*self != 0
}
fn as_usize(&self) -> usize {
*self as usize
}
}

impl IntDType for u8 {
fn is_true(&self) -> bool {
*self != 0
}
fn as_usize(&self) -> usize {
*self as usize
}
}
2 changes: 1 addition & 1 deletion candle-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ mod variable;

pub use cpu_backend::CpuStorage;
pub use device::{Device, DeviceLocation};
pub use dtype::{DType, WithDType};
pub use dtype::{DType, IntDType, WithDType};
pub use error::{Error, Result};
pub use indexer::IndexOp;
pub use layout::Layout;
Expand Down

0 comments on commit f286296

Please sign in to comment.