Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into feat/comparison-and…
Browse files Browse the repository at this point in the history
…-equality-operations
  • Loading branch information
LLukas22 committed Oct 30, 2023
2 parents 8671814 + 9699608 commit 885e5c4
Show file tree
Hide file tree
Showing 59 changed files with 2,414 additions and 247 deletions.
Binary file added .github/workflows/maturin.yml
Binary file not shown.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ These online demos run entirely in your browser:
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
- [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
- [Blip](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning): image captioning.
- [BLIP](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning): image captioning.

We also provide a some command line based examples using state of the art models:

Expand Down Expand Up @@ -96,7 +96,8 @@ We also provide a some command line based examples using state of the art models
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">

- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
using self-supervision (can be used for imagenet classification, depth
evaluation, segmentation).
Expand Down
7 changes: 7 additions & 0 deletions candle-core/src/backprop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,13 @@ impl Tensor {
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
let (_, _, k0, k1) = kernel.dims4()?;
let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
} else {
grad_kernel
};
*sum_grad = sum_grad.add(&grad_kernel)?;
}
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
Expand Down
20 changes: 10 additions & 10 deletions candle-core/src/cpu_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -804,11 +804,11 @@ impl<'a, I: IntDType> Map1 for Gather<'a, I> {
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
let ids = match self.ids_l.contiguous_offsets() {
Some((a, b)) => &self.ids[a..b],
None => Err(Error::RequiresContiguous { op: "gather" })?,
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
};
let src = match src_l.contiguous_offsets() {
Some((a, b)) => &src[a..b],
None => Err(Error::RequiresContiguous { op: "gather" })?,
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
};
let dim = self.dim;
let ids_dims = self.ids_l.dims();
Expand Down Expand Up @@ -857,7 +857,7 @@ impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
let src = match layout.contiguous_offsets() {
Some((a, b)) => &src[a..b],
None => Err(Error::RequiresContiguous { op: "index-select" })?,
None => Err(Error::RequiresContiguous { op: "index-select" }.bt())?,
};
let dim = self.dim;
let n_ids = match self.ids_l.dims() {
Expand Down Expand Up @@ -913,7 +913,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
let mut dst = vec![T::zero(); dst_len];
copy_strided_src_(v1, &mut dst, 0, l1);
let src = match src_l.contiguous_offsets() {
None => Err(Error::RequiresContiguous { op: "scatter-add" })?,
None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?,
Some((o1, o2)) => &src[o1..o2],
};

Expand All @@ -929,7 +929,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {

let ids = match self.ids_l.contiguous_offsets() {
Some((a, b)) => &self.ids[a..b],
None => Err(Error::RequiresContiguous { op: "gather" })?,
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
};
for left_i in 0..ids_left_len {
let start_ids_idx = left_i * ids_right_len * ids_dim_len;
Expand Down Expand Up @@ -971,7 +971,7 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
let mut dst = vec![T::zero(); dst_len];
copy_strided_src_(v1, &mut dst, 0, l1);
let src = match src_l.contiguous_offsets() {
None => Err(Error::RequiresContiguous { op: "index-add" })?,
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
Some((o1, o2)) => &src[o1..o2],
};
let dim = self.dim;
Expand Down Expand Up @@ -2539,25 +2539,25 @@ impl BackendStorage for CpuStorage {
Self::U8(ids) => {
let ids = match ids_l.contiguous_offsets() {
Some((a, b)) => &ids[a..b],
None => Err(Error::RequiresContiguous { op: "index-add" })?,
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
};
IndexAdd { ids, dim }.map(self, l, src, src_l)
}
Self::U32(ids) => {
let ids = match ids_l.contiguous_offsets() {
Some((a, b)) => &ids[a..b],
None => Err(Error::RequiresContiguous { op: "index-add" })?,
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
};
IndexAdd { ids, dim }.map(self, l, src, src_l)
}
Self::I64(ids) => {
let ids = match ids_l.contiguous_offsets() {
Some((a, b)) => &ids[a..b],
None => Err(Error::RequiresContiguous { op: "index-add" })?,
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
};
IndexAdd { ids, dim }.map(self, l, src, src_l)
}
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add")),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add").bt()),
}
}

Expand Down
12 changes: 12 additions & 0 deletions candle-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,15 @@ impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
self(xs)
}
}

// A trait defining a module with forward method using a single tensor argument and a flag to
// separate the training and evaluation behaviors.
pub trait ModuleT {
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor>;
}

impl<M: Module> ModuleT for M {
fn forward_t(&self, xs: &Tensor, _train: bool) -> Result<Tensor> {
self.forward(xs)
}
}
35 changes: 34 additions & 1 deletion candle-core/src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,6 @@ unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
unary_op!(Abs, "abs", v, v.abs());
unary_op!(Neg, "neg", v, -v);
unary_op!(Recip, "recip", v, v.recip());
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
Expand Down Expand Up @@ -666,6 +665,40 @@ impl UnaryOpT for Erf {
}
}

impl UnaryOpT for Abs {
const NAME: &'static str = "abs";
const KERNEL: &'static str = "uabs";
const V: Self = Abs;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v.abs()
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v.abs()
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v.abs()
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v.abs()
}
#[inline(always)]
fn u8(v: u8) -> u8 {
v
}
#[inline(always)]
fn u32(v: u32) -> u32 {
v
}
#[inline(always)]
fn i64(v: i64) -> i64 {
v.abs()
}
}

impl UnaryOpT for Ceil {
const NAME: &'static str = "ceil";
const KERNEL: &'static str = "uceil";
Expand Down
5 changes: 0 additions & 5 deletions candle-core/src/quantized/avx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,9 @@ pub(crate) unsafe fn mul_sum_i8_pairs_float(x: __m256i, y: __m256i) -> __m256 {
#[inline(always)]
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
let nb = n / qk;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
if nb % 2 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
}

unsafe {
let mut acc = _mm256_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {
Expand Down
5 changes: 0 additions & 5 deletions candle-core/src/quantized/k_quants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,14 +236,9 @@ impl GgmlType for BlockQ4_0 {

fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
let qk = QK8_0;
let nb = n / qk;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
if nb % 2 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
}

// Generic implementation.
let mut sumf = 0f32;
for (xs, ys) in xs.iter().zip(ys.iter()) {
Expand Down
55 changes: 4 additions & 51 deletions candle-core/src/quantized/neon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,71 +19,46 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
if nb % 2 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
}

unsafe {
let mut sumv0 = vdupq_n_f32(0.0f32);
let mut sumv1 = vdupq_n_f32(0.0f32);
for i in (0..nb).step_by(2) {
for i in 0..nb {
let x0 = &xs[i];
let x1 = &xs[i + 1];
let y0 = &ys[i];
let y1 = &ys[i + 1];

let m4b = vdupq_n_u8(0x0F);
let s8b = vdupq_n_s8(0x8);

let v0_0 = vld1q_u8(x0.qs.as_ptr());
let v0_1 = vld1q_u8(x1.qs.as_ptr());

// 4-bit -> 8-bit
let v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
let v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
let v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
let v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));

// sub 8
let v0_0ls = vsubq_s8(v0_0l, s8b);
let v0_0hs = vsubq_s8(v0_0h, s8b);
let v0_1ls = vsubq_s8(v0_1l, s8b);
let v0_1hs = vsubq_s8(v0_1h, s8b);

// load y
let v1_0l = vld1q_s8(y0.qs.as_ptr());
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
let v1_1l = vld1q_s8(y1.qs.as_ptr());
let v1_1h = vld1q_s8(y1.qs.as_ptr().add(16));

// TODO: Support dotprod when it's available outside of nightly.
let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l));
let pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));

let pl1l = vmull_s8(vget_low_s8(v0_1ls), vget_low_s8(v1_1l));
let pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
let ph1l = vmull_s8(vget_low_s8(v0_1hs), vget_low_s8(v1_1h));
let ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));

let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
let pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
let ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));

sumv0 = vmlaq_n_f32(
sumv0,
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
x0.d.to_f32() * y0.d.to_f32(),
);
sumv1 = vmlaq_n_f32(
sumv1,
vcvtq_f32_s32(vaddq_s32(pl1, ph1)),
x1.d.to_f32() * y1.d.to_f32(),
);
}
Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
Ok(vaddvq_f32(sumv0))
}
}

Expand All @@ -94,57 +69,35 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
}
let nb = n / QK8_0;
if nb % 2 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
}
unsafe {
let mut sumv0 = vdupq_n_f32(0.0f32);
let mut sumv1 = vdupq_n_f32(0.0f32);
for i in (0..nb).step_by(2) {
for i in 0..nb {
let x0 = &xs[i];
let x1 = &xs[i + 1];
let y0 = &ys[i];
let y1 = &ys[i + 1];

let x0_0 = vld1q_s8(x0.qs.as_ptr());
let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16));
let x1_0 = vld1q_s8(x1.qs.as_ptr());
let x1_1 = vld1q_s8(x1.qs.as_ptr().add(16));

// load y
let y0_0 = vld1q_s8(y0.qs.as_ptr());
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
let y1_0 = vld1q_s8(y1.qs.as_ptr());
let y1_1 = vld1q_s8(y1.qs.as_ptr().add(16));

// TODO dotprod once this is the intrinsics are.
let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0));
let p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));

let p1_0 = vmull_s8(vget_low_s8(x1_0), vget_low_s8(y1_0));
let p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
let p1_2 = vmull_s8(vget_low_s8(x1_1), vget_low_s8(y1_1));
let p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));

let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
let p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
let p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));

sumv0 = vmlaq_n_f32(
sumv0,
vcvtq_f32_s32(vaddq_s32(p0, p1)),
x0.d.to_f32() * y0.d.to_f32(),
);
sumv1 = vmlaq_n_f32(
sumv1,
vcvtq_f32_s32(vaddq_s32(p2, p3)),
x1.d.to_f32() * y1.d.to_f32(),
);
}
Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
Ok(vaddvq_f32(sumv0))
}
}

Expand Down
8 changes: 0 additions & 8 deletions candle-core/src/quantized/simd128.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
let nb = n / QK8_0;
if nb % 2 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
}
unsafe {
let mut acc = f32x4_splat(0.0f32);
for (x, y) in xs.iter().zip(ys.iter()) {
Expand Down Expand Up @@ -61,10 +57,6 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
}
let nb = n / QK8_0;
if nb % 2 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
}
unsafe {
let mut acc = f32x4_splat(0.0f32);
for (x, y) in xs.iter().zip(ys.iter()) {
Expand Down
Loading

0 comments on commit 885e5c4

Please sign in to comment.