diff --git a/.vscode/settings.json b/.vscode/settings.json index b2dbd68012..280ea2e7f0 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -7,5 +7,13 @@ "candle-pyo3" ], "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true + "python.testing.pytestEnabled": true, + "rust-analyzer.cargo.features": [ + "cuda", + ], + "files.associations": { + "random": "cpp", + "ratio": "cpp", + "cmath": "cpp" + }, } \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index d6cf18614f..6dc5e85cd1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,7 @@ fancy-regex = "0.13.0" gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] } hf-hub = "0.3.0" half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } +float8 = { version = "0.1.0", features = ["num-traits", "rand_distr"], git = "https://github.com/EricLBuehler/float8.git" } hound = "3.5.1" image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] } imageproc = { version = "0.24.0", default-features = false } diff --git a/README.md b/README.md index a351ab667f..173f907d6f 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ [![Documentation](https://docs.rs/candle-core/badge.svg)](https://docs.rs/candle-core) ![License](https://img.shields.io/crates/l/candle-core.svg) +**This is an optimized implmentation by Eric Buehler.** + Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support) and ease of use. Try our online demos: [whisper](https://huggingface.co/spaces/lmz/candle-whisper), diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index cbf8f2007f..6ce7e31e1c 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -18,6 +18,7 @@ metal = { workspace = true, optional = true} cudarc = { workspace = true, optional = true } gemm = { workspace = true } half = { workspace = true } +float8 = { workspace = true } intel-mkl-src = { workspace = true, optional = true } libc = { workspace = true, optional = true } memmap2 = { workspace = true } @@ -39,7 +40,7 @@ criterion = { workspace = true } [features] default = [] -cuda = ["cudarc", "dep:candle-kernels"] +cuda = ["cudarc", "dep:candle-kernels", "float8/cuda"] cudnn = ["cuda", "cudarc/cudnn"] mkl = ["dep:libc", "dep:intel-mkl-src"] accelerate = ["dep:libc", "dep:accelerate-src"] diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs index 579c5f3f0b..d52659045c 100644 --- a/candle-core/benches/benchmarks/mod.rs +++ b/candle-core/benches/benchmarks/mod.rs @@ -6,7 +6,7 @@ pub(crate) mod random; pub(crate) mod unary; pub(crate) mod where_cond; -use candle_core::{Device, Result}; +use candle_core::{cuda::WrapErr, Device, Result}; pub(crate) trait BenchDevice { fn sync(&self) -> Result<()>; @@ -20,7 +20,7 @@ impl BenchDevice for Device { Device::Cpu => Ok(()), Device::Cuda(device) => { #[cfg(feature = "cuda")] - return Ok(device.synchronize()?); + return Ok(device.synchronize().w()?); #[cfg(not(feature = "cuda"))] panic!("Cuda device without cuda feature enabled: {:?}", device) } diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index afe3e40754..655c7894d8 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -89,9 +89,23 @@ pub trait BackendStorage: Sized { _: usize, ) -> Result; - fn matmul( + #[allow(clippy::too_many_arguments)] + fn matmul_with_alpha_beta( + &self, + _: &Self, + _: &mut Self, + _: Option, + _: (usize, usize, usize, usize), + _: &Layout, + _: &Layout, + _: &Layout, + ) -> Result<()>; + + #[allow(clippy::too_many_arguments)] + fn matmul_with_alpha( &self, _: &Self, + _: Option, _: (usize, usize, usize, usize), _: &Layout, _: &Layout, @@ -144,6 +158,7 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone { fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result; fn set_seed(&self, _: u64) -> Result<()>; + fn get_current_seed(&self) -> Result; /// Synchronize should block until all the operations on the device are completed. fn synchronize(&self) -> Result<()>; diff --git a/candle-core/src/convert.rs b/candle-core/src/convert.rs index 5ea5612a7c..173a96d6e6 100644 --- a/candle-core/src/convert.rs +++ b/candle-core/src/convert.rs @@ -1,5 +1,6 @@ //! Implement conversion traits for tensors use crate::{DType, Device, Error, Tensor, WithDType}; +use float8::F8E4M3; use half::{bf16, f16, slice::HalfFloatSliceExt}; use std::convert::TryFrom; @@ -130,6 +131,16 @@ impl Tensor { f.write_u32::(v)? } } + DType::I16 => { + for v in vs.to_vec1::()? { + f.write_i16::(v)? + } + } + DType::I32 => { + for v in vs.to_vec1::()? { + f.write_i32::(v)? + } + } DType::I64 => { for v in vs.to_vec1::()? { f.write_i64::(v)? @@ -139,6 +150,11 @@ impl Tensor { let vs = vs.to_vec1::()?; f.write_all(&vs)?; } + DType::F8E4M3 => { + for v in vs.to_vec1::()? { + f.write_u8(v.to_bits())? + } + } } Ok(()) } diff --git a/candle-core/src/cpu/avx.rs b/candle-core/src/cpu/avx.rs index 9398a3460a..113fc14ced 100644 --- a/candle-core/src/cpu/avx.rs +++ b/candle-core/src/cpu/avx.rs @@ -1,10 +1,10 @@ -use super::{Cpu, CpuF16}; +use super::{Cpu, CpuBF16, CpuF16}; #[cfg(target_arch = "x86")] use core::arch::x86::*; #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*; -use half::f16; +use half::{bf16, f16}; pub struct CurrentCpu {} @@ -146,3 +146,82 @@ impl CpuF16 for CurrentCpuF16 { *y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); } } + +pub struct CurrentCpuBF16 {} +impl CpuBF16 for CurrentCpuBF16 { + type Unit = __m256; + type Array = [__m256; ARR]; + + const STEP: usize = STEP; + const EPR: usize = EPR; + + fn n() -> usize { + ARR + } + + unsafe fn zero() -> Self::Unit { + _mm256_setzero_ps() + } + + unsafe fn zero_array() -> Self::Array { + [Self::zero(); ARR] + } + + unsafe fn from_f32(v: f32) -> Self::Unit { + _mm256_set1_ps(v) + } + + #[cfg(target_feature = "f16c")] + unsafe fn load(mem_addr: *const bf16) -> Self::Unit { + _mm256_cvtph_ps(_mm_loadu_si128(mem_addr as *const __m128i)) + } + + #[cfg(not(target_feature = "f16c"))] + unsafe fn load(mem_addr: *const bf16) -> Self::Unit { + let mut tmp = [0.0f32; 8]; + for i in 0..8 { + tmp[i] = (*mem_addr.add(i)).to_f32(); + } + _mm256_loadu_ps(tmp.as_ptr()) + } + + unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit { + _mm256_add_ps(a, b) + } + + unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit { + _mm256_add_ps(_mm256_mul_ps(b, c), a) + } + + #[cfg(target_feature = "f16c")] + unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) { + _mm_storeu_si128(mem_addr as *mut __m128i, _mm256_cvtps_ph(a, 0)) + } + + #[cfg(not(target_feature = "f16c"))] + unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) { + let mut tmp = [0.0f32; 8]; + _mm256_storeu_ps(tmp.as_mut_ptr(), a); + for i in 0..8 { + *mem_addr.add(i) = bf16::from_f32(tmp[i]); + } + } + + unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) { + let mut offset = ARR >> 1; + for i in 0..offset { + x[i] = _mm256_add_ps(x[i], x[offset + i]); + } + offset >>= 1; + for i in 0..offset { + x[i] = _mm256_add_ps(x[i], x[offset + i]); + } + offset >>= 1; + for i in 0..offset { + x[i] = _mm256_add_ps(x[i], x[offset + i]); + } + let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1)); + let t1 = _mm_hadd_ps(t0, t0); + *y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); + } +} diff --git a/candle-core/src/cpu/kernels.rs b/candle-core/src/cpu/kernels.rs index 527646d62b..f81ad625d3 100644 --- a/candle-core/src/cpu/kernels.rs +++ b/candle-core/src/cpu/kernels.rs @@ -121,6 +121,13 @@ impl VecOps for half::bf16 { fn max(self, other: Self) -> Self { Self::max(self, other) } + + #[inline(always)] + unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) { + let mut res_f32 = 0f32; + super::vec_dot_bf16(lhs, rhs, &mut res_f32, len); + *res = half::bf16::from_f32(res_f32); + } } impl VecOps for u8 { #[inline(always)] @@ -144,6 +151,28 @@ impl VecOps for u32 { ::max(self, other) } } +impl VecOps for i16 { + #[inline(always)] + fn min(self, other: Self) -> Self { + ::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + ::max(self, other) + } +} +impl VecOps for i32 { + #[inline(always)] + fn min(self, other: Self) -> Self { + ::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + ::max(self, other) + } +} impl VecOps for i64 { #[inline(always)] fn min(self, other: Self) -> Self { diff --git a/candle-core/src/cpu/mod.rs b/candle-core/src/cpu/mod.rs index e7d8b6906f..0b77e6ecb7 100644 --- a/candle-core/src/cpu/mod.rs +++ b/candle-core/src/cpu/mod.rs @@ -36,14 +36,33 @@ trait CpuF16 { unsafe fn from_f32(v: f32) -> Self::Unit; unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit); } -use half::f16; + +#[allow(unused)] +trait CpuBF16 { + type Unit; + type Array; + const STEP: usize; + const EPR: usize; + + fn n() -> usize; + unsafe fn zero() -> Self::Unit; + unsafe fn zero_array() -> Self::Array; + unsafe fn load(mem_addr: *const bf16) -> Self::Unit; + unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit; + unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit; + unsafe fn vec_reduce(x: Self::Array, y: *mut f32); + unsafe fn from_f32(v: f32) -> Self::Unit; + unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit); +} + +use half::{bf16, f16}; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] #[cfg(target_feature = "avx")] pub mod avx; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] #[cfg(target_feature = "avx")] -pub use avx::{CurrentCpu, CurrentCpuF16}; +pub use avx::{CurrentCpu, CurrentCpuBF16, CurrentCpuF16}; #[cfg(target_arch = "wasm32")] #[cfg(target_feature = "simd128")] @@ -170,6 +189,34 @@ pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f *c = sumf; } +#[cfg(target_feature = "avx")] +#[inline(always)] +pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) { + let mut sumf = 0.0f32; + let np = k & !(CurrentCpuBF16::STEP - 1); + + let mut sum = CurrentCpuBF16::zero_array(); + let mut ax = CurrentCpuBF16::zero_array(); + let mut ay = CurrentCpuBF16::zero_array(); + + for i in (0..np).step_by(CurrentCpuBF16::STEP) { + for j in 0..CurrentCpuBF16::n() { + ax[j] = CurrentCpuBF16::load(a_row.add(i + j * CurrentCpuBF16::EPR)); + ay[j] = CurrentCpuBF16::load(b_row.add(i + j * CurrentCpuBF16::EPR)); + + sum[j] = CurrentCpuBF16::vec_fma(sum[j], ax[j], ay[j]); + } + } + + CurrentCpuBF16::vec_reduce(sum, &mut sumf); + + // leftovers + for i in np..k { + sumf += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32(); + } + *c = sumf; +} + #[cfg(not(target_feature = "avx"))] #[inline(always)] pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) { @@ -180,3 +227,14 @@ pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f } *c = sum; } + +#[cfg(not(target_feature = "avx"))] +#[inline(always)] +pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) { + // leftovers + let mut sum = 0.0; + for i in 0..k { + sum += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32(); + } + *c = sum; +} diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 58773c8020..6ef74c0725 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -1,12 +1,16 @@ +use std::ops::Deref; + use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType}; +use float8::F8E4M3; use half::{bf16, f16}; use rayon::prelude::*; mod utils; pub use utils::{ - binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2U8, + binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2Alpha, Map2U8, + Map3, }; const USE_IM2COL_CONV1D: bool = true; @@ -19,22 +23,28 @@ const USE_IM2COL_CONV2D: bool = true; pub enum CpuStorage { U8(Vec), U32(Vec), + I16(Vec), + I32(Vec), I64(Vec), BF16(Vec), F16(Vec), F32(Vec), F64(Vec), + F8E4M3(Vec), } #[derive(Debug, Clone)] pub enum CpuStorageRef<'a> { U8(&'a [u8]), U32(&'a [u32]), + I16(&'a [i16]), + I32(&'a [i32]), I64(&'a [i64]), BF16(&'a [bf16]), F16(&'a [f16]), F32(&'a [f32]), F64(&'a [f64]), + F8E4M3(&'a [F8E4M3]), } #[derive(Debug, Clone)] @@ -1529,156 +1539,927 @@ impl Map2 for MatMul { } } -fn elu(v: T, alpha: T) -> T { - if v.is_sign_positive() { - v - } else { - (v.exp() - T::one()) * alpha - } -} +struct MatMulWithBias(MatMul); -impl CpuStorage { - pub fn as_slice(&self) -> Result<&[D]> { - D::cpu_storage_as_slice(self) - } +impl Deref for MatMulWithBias { + type Target = MatMul; - pub fn concat(storages: &[CpuStorage]) -> Result { - let storage0 = &storages[0]; - let s = match storage0 { - Self::U8(_) => { - let storages = storages - .iter() - .map(|s| match s { - Self::U8(s) => Ok(s.as_slice()), - _ => crate::bail!("dtype mismatch"), - }) - .collect::>>()? - .concat(); - Self::U8(storages) - } - Self::U32(_) => { - let storages = storages - .iter() - .map(|s| match s { - Self::U32(s) => Ok(s.as_slice()), - _ => crate::bail!("dtype mismatch"), - }) - .collect::>>()? - .concat(); - Self::U32(storages) - } - Self::I64(_) => { - let storages = storages - .iter() - .map(|s| match s { - Self::I64(s) => Ok(s.as_slice()), - _ => crate::bail!("dtype mismatch"), - }) - .collect::>>()? - .concat(); - Self::I64(storages) - } - Self::BF16(_) => { - let storages = storages - .iter() - .map(|s| match s { - Self::BF16(s) => Ok(s.as_slice()), - _ => crate::bail!("dtype mismatch"), - }) - .collect::>>()? - .concat(); - Self::BF16(storages) - } - Self::F16(_) => { - let storages = storages - .iter() - .map(|s| match s { - Self::F16(s) => Ok(s.as_slice()), - _ => crate::bail!("dtype mismatch"), - }) - .collect::>>()? - .concat(); - Self::F16(storages) - } - Self::F32(_) => { - let storages = storages - .iter() - .map(|s| match s { - Self::F32(s) => Ok(s.as_slice()), - _ => crate::bail!("dtype mismatch"), - }) - .collect::>>()? - .concat(); - Self::F32(storages) - } - Self::F64(_) => { - let storages = storages - .iter() - .map(|s| match s { - Self::F64(s) => Ok(s.as_slice()), - _ => crate::bail!("dtype mismatch"), - }) - .collect::>>()? - .concat(); - Self::F64(storages) - } - }; - Ok(s) + fn deref(&self) -> &Self::Target { + &self.0 } } -impl BackendStorage for CpuStorage { - type Device = CpuDevice; +impl Map3 for MatMulWithBias { + const OP: &'static str = "mat_mul_ac"; - fn dtype(&self) -> DType { - match self { - Self::U8(_) => DType::U8, - Self::U32(_) => DType::U32, - Self::I64(_) => DType::I64, - Self::BF16(_) => DType::BF16, - Self::F16(_) => DType::F16, - Self::F32(_) => DType::F32, - Self::F64(_) => DType::F64, + #[cfg(all(not(feature = "mkl"), not(feature = "accelerate")))] + fn f( + &self, + lhs: &[T], + lhs_l: &Layout, + rhs: &[T], + rhs_l: &Layout, + c: &mut [T], + c_l: &Layout, + s: Option, + ) -> Result<()> { + use gemm::{gemm, Parallelism}; + + match T::DTYPE { + DType::F16 | DType::F32 | DType::F64 => {} + _ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?, } - } - fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { - // TODO: find a way around the quadratic number of cases below. - match (self, dtype) { - (Self::U8(storage), DType::BF16) => { - let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); - Ok(Self::BF16(data)) - } - (Self::U32(storage), DType::BF16) => { - let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); - Ok(Self::BF16(data)) - } - (Self::I64(storage), DType::BF16) => { - let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); - Ok(Self::BF16(data)) - } - (Self::BF16(storage), DType::BF16) => { - let data = unary_map(storage, layout, |v| v); - Ok(Self::BF16(data)) - } - (Self::F16(storage), DType::BF16) => { - let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32())); - Ok(Self::BF16(data)) + let (b, m, n, k) = self.0 .0; + let lhs = &lhs[lhs_l.start_offset()..]; + let rhs = &rhs[rhs_l.start_offset()..]; + + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + let rank = lhs_stride.len(); + let lhs_cs = lhs_stride[rank - 1]; + let lhs_rs = lhs_stride[rank - 2]; + + let rhs_cs = rhs_stride[rank - 1]; + let rhs_rs = rhs_stride[rank - 2]; + + let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?; + let c_skip: usize = m * n; + + let dst_shape: Shape = (m, n).into(); + let dst_strides = dst_shape.stride_contiguous(); + let dst_rs = dst_strides[0]; + let dst_cs = dst_strides[1]; + + let num_threads = crate::utils::get_num_threads(); + let parallelism = if num_threads > 1 { + Parallelism::Rayon(num_threads) + } else { + Parallelism::None + }; + + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + crate::bail!("`c` start offset must be 0"); + } + if o2 != b * m * n { + crate::bail!("`c` end offset must be {}", b * m * n) + } } - (Self::F32(storage), DType::BF16) => { - let data = unary_map(storage, layout, bf16::from_f32); - Ok(Self::BF16(data)) + None => crate::bail!("`c` has to be contiguous"), + }; + + let alpha = T::from_f64(s.unwrap_or(1.0)); + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut c[step * c_skip..]; + unsafe { + gemm( + /* m: usize = */ m, + /* n: usize = */ n, + /* k: usize = */ k, + /* dst: *mut T = */ dst_p.as_mut_ptr(), + /* dst_cs: isize = */ dst_cs as isize, + /* dst_rs: isize = */ dst_rs as isize, + /* read_dst: bool = */ true, + /* lhs: *const T = */ lhs_p.as_ptr(), + /* lhs_cs: isize = */ lhs_cs as isize, + /* lhs_rs: isize = */ lhs_rs as isize, + /* rhs: *const T = */ rhs_p.as_ptr(), + /* rhs_cs: isize = */ rhs_cs as isize, + /* rhs_rs: isize = */ rhs_rs as isize, + /* alpha: T = */ T::one(), + /* beta: T = */ alpha, + /* conj_dst: bool = */ false, + /* conj_lhs: bool = */ false, + /* conj_rhs: bool = */ false, + parallelism, + ) } - (Self::F64(storage), DType::BF16) => { - let data = unary_map(storage, layout, bf16::from_f64); - Ok(Self::BF16(data)) + } + Ok(()) + } + + #[cfg(feature = "accelerate")] + fn f( + &self, + lhs: &[T], + lhs_l: &Layout, + rhs: &[T], + rhs_l: &Layout, + c: &mut [T], + c_l: &Layout, + s: Option, + ) -> Result<()> { + let (b, m, n, k) = self.0 .0; + let lhs = &lhs[lhs_l.start_offset()..]; + let rhs = &rhs[rhs_l.start_offset()..]; + + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + + let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?; + let c_skip: usize = m * n; + + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + + let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { + (n as i32, b'N') + } else if rhs_m1 == k && rhs_m2 == 1 { + (k as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))? + }; + // The b tensor has dims batching, m, k (lhs) + let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { + (k as i32, b'N') + } else if lhs_m1 == m && lhs_m2 == 1 { + (m as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))? + }; + + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + crate::bail!("`c` start offset must be 0"); + } + if o2 != b * m * n { + crate::bail!("`c` end offset must be {}", b * m * n) + } } - (Self::U8(storage), DType::F16) => { - let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); - Ok(Self::F16(data)) + None => crate::bail!("`c` has to be contiguous"), + }; + + match T::DTYPE { + DType::F16 => { + crate::bail!("the accelerate backend does not support f16 matmul") } - (Self::U32(storage), DType::F16) => { - let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); + DType::F32 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut c[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f32; + let b = lhs_p.as_ptr() as *const f32; + let c = dst_p.as_mut_ptr() as *mut f32; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::accelerate::sgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ s.unwrap_or(1.) as f32, + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ 1., + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + DType::F64 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut c[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f64; + let b = lhs_p.as_ptr() as *const f64; + let c = dst_p.as_mut_ptr() as *mut f64; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::accelerate::dgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ s.unwrap_or(1.) as f64, + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ 1., + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?, + } + Ok(()) + } + + #[cfg(feature = "mkl")] + fn f( + &self, + lhs: &[T], + lhs_l: &Layout, + rhs: &[T], + rhs_l: &Layout, + c: &mut [T], + c_l: &Layout, + s: Option, + ) -> Result<()> { + let (b, m, n, k) = self.0 .0; + let lhs = &lhs[lhs_l.start_offset()..]; + let rhs = &rhs[rhs_l.start_offset()..]; + + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + + let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?; + let c_skip: usize = m * n; + + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + + let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { + (n as i32, b'N') + } else if rhs_m1 == k && rhs_m2 == 1 { + (k as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))? + }; + // The b tensor has dims batching, m, k (lhs) + let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { + (k as i32, b'N') + } else if lhs_m1 == m && lhs_m2 == 1 { + (m as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))? + }; + + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + crate::bail!("`c` start offset must be 0"); + } + if o2 != b * m * n { + crate::bail!("`c` end offset must be {}", b * m * n) + } + } + None => crate::bail!("`c` has to be contiguous"), + }; + + match T::DTYPE { + DType::F16 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut c[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f16; + let b = lhs_p.as_ptr() as *const f16; + let c = dst_p.as_mut_ptr() as *mut f16; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::mkl::hgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ f16::from_f64(s.unwrap_or(1.)), + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ f16::ONE, + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + DType::F32 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut c[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f32; + let b = lhs_p.as_ptr() as *const f32; + let c = dst_p.as_mut_ptr() as *mut f32; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::mkl::sgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ s.unwrap_or(1.) as f32, + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ 0., + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + DType::F64 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut c[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f64; + let b = lhs_p.as_ptr() as *const f64; + let c = dst_p.as_mut_ptr() as *mut f64; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::mkl::dgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ s.unwrap_or(1.), + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ 0., + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?, + } + Ok(()) + } +} + +struct MatMulWithAlpha(MatMul); + +impl Deref for MatMulWithAlpha { + type Target = MatMul; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Map2Alpha for MatMulWithAlpha { + const OP: &'static str = "mat_mul_a"; + + #[cfg(all(not(feature = "mkl"), not(feature = "accelerate")))] + fn f( + &self, + lhs: &[T], + lhs_l: &Layout, + rhs: &[T], + rhs_l: &Layout, + s: Option, + ) -> Result> { + use gemm::{gemm, Parallelism}; + + match T::DTYPE { + DType::F16 | DType::F32 | DType::F64 => {} + _ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?, + } + + let (b, m, n, k) = self.0 .0; + let lhs = &lhs[lhs_l.start_offset()..]; + let rhs = &rhs[rhs_l.start_offset()..]; + + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + let rank = lhs_stride.len(); + let lhs_cs = lhs_stride[rank - 1]; + let lhs_rs = lhs_stride[rank - 2]; + + let rhs_cs = rhs_stride[rank - 1]; + let rhs_rs = rhs_stride[rank - 2]; + + let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?; + let c_skip: usize = m * n; + + let dst_shape: Shape = (m, n).into(); + let dst_strides = dst_shape.stride_contiguous(); + let dst_rs = dst_strides[0]; + let dst_cs = dst_strides[1]; + + let mut dst = vec![T::zero(); b * m * n]; + let num_threads = crate::utils::get_num_threads(); + let parallelism = if num_threads > 1 { + Parallelism::Rayon(num_threads) + } else { + Parallelism::None + }; + + let alpha = T::from_f64(s.unwrap_or(1.0)); + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + gemm( + /* m: usize = */ m, + /* n: usize = */ n, + /* k: usize = */ k, + /* dst: *mut T = */ dst_p.as_mut_ptr(), + /* dst_cs: isize = */ dst_cs as isize, + /* dst_rs: isize = */ dst_rs as isize, + /* read_dst: bool = */ true, + /* lhs: *const T = */ lhs_p.as_ptr(), + /* lhs_cs: isize = */ lhs_cs as isize, + /* lhs_rs: isize = */ lhs_rs as isize, + /* rhs: *const T = */ rhs_p.as_ptr(), + /* rhs_cs: isize = */ rhs_cs as isize, + /* rhs_rs: isize = */ rhs_rs as isize, + /* alpha: T = */ T::one(), + /* beta: T = */ alpha, + /* conj_dst: bool = */ false, + /* conj_lhs: bool = */ false, + /* conj_rhs: bool = */ false, + parallelism, + ) + } + } + Ok(dst) + } + + #[cfg(feature = "accelerate")] + fn f( + &self, + lhs: &[T], + lhs_l: &Layout, + rhs: &[T], + rhs_l: &Layout, + s: Option, + ) -> Result> { + let (b, m, n, k) = self.0 .0; + let lhs = &lhs[lhs_l.start_offset()..]; + let rhs = &rhs[rhs_l.start_offset()..]; + + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + + let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?; + let c_skip: usize = m * n; + + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + + let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { + (n as i32, b'N') + } else if rhs_m1 == k && rhs_m2 == 1 { + (k as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))? + }; + // The b tensor has dims batching, m, k (lhs) + let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { + (k as i32, b'N') + } else if lhs_m1 == m && lhs_m2 == 1 { + (m as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))? + }; + + let mut dst = vec![T::zero(); b * m * n]; + match T::DTYPE { + DType::F16 => { + crate::bail!("the accelerate backend does not support f16 matmul") + } + DType::F32 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f32; + let b = lhs_p.as_ptr() as *const f32; + let c = dst_p.as_mut_ptr() as *mut f32; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::accelerate::sgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ s.unwrap_or(1.) as f32, + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ 1., + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + DType::F64 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f64; + let b = lhs_p.as_ptr() as *const f64; + let c = dst_p.as_mut_ptr() as *mut f64; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::accelerate::dgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ s.unwrap_or(1.), + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ 1., + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?, + } + Ok(dst) + } + + #[cfg(feature = "mkl")] + fn f( + &self, + lhs: &[T], + lhs_l: &Layout, + rhs: &[T], + rhs_l: &Layout, + s: Option, + ) -> Result> { + let (b, m, n, k) = self.0 .0; + let lhs = &lhs[lhs_l.start_offset()..]; + let rhs = &rhs[rhs_l.start_offset()..]; + + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + + let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?; + let c_skip: usize = m * n; + + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + + let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { + (n as i32, b'N') + } else if rhs_m1 == k && rhs_m2 == 1 { + (k as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))? + }; + // The b tensor has dims batching, m, k (lhs) + let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { + (k as i32, b'N') + } else if lhs_m1 == m && lhs_m2 == 1 { + (m as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))? + }; + + let mut dst = vec![T::zero(); b * m * n]; + match T::DTYPE { + DType::F16 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f16; + let b = lhs_p.as_ptr() as *const f16; + let c = dst_p.as_mut_ptr() as *mut f16; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::mkl::hgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ f16::from_f64(s.unwrap_or(1.)), + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ f16::ONE, + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + DType::F32 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f32; + let b = lhs_p.as_ptr() as *const f32; + let c = dst_p.as_mut_ptr() as *mut f32; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::mkl::sgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ s.unwrap_or(1.) as f32, + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ 0., + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + DType::F64 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f64; + let b = lhs_p.as_ptr() as *const f64; + let c = dst_p.as_mut_ptr() as *mut f64; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::mkl::dgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ s.unwrap_or(1.), + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ 0., + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?, + } + Ok(dst) + } +} + +fn elu(v: T, alpha: T) -> T { + if v.is_sign_positive() { + v + } else { + (v.exp() - T::one()) * alpha + } +} + +impl CpuStorage { + pub fn as_slice(&self) -> Result<&[D]> { + D::cpu_storage_as_slice(self) + } + + pub fn concat(storages: &[CpuStorage]) -> Result { + let storage0 = &storages[0]; + let s = match storage0 { + Self::U8(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::U8(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::U8(storages) + } + Self::U32(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::U32(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::U32(storages) + } + Self::I16(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::I16(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::I16(storages) + } + Self::I32(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::I32(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::I32(storages) + } + Self::I64(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::I64(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::I64(storages) + } + Self::BF16(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::BF16(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::BF16(storages) + } + Self::F16(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::F16(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::F16(storages) + } + Self::F32(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::F32(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::F32(storages) + } + Self::F64(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::F64(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::F64(storages) + } + Self::F8E4M3(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::F8E4M3(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::F8E4M3(storages) + } + }; + Ok(s) + } +} + +impl BackendStorage for CpuStorage { + type Device = CpuDevice; + + fn dtype(&self) -> DType { + match self { + Self::U8(_) => DType::U8, + Self::U32(_) => DType::U32, + Self::I16(_) => DType::I16, + Self::I32(_) => DType::I32, + Self::I64(_) => DType::I64, + Self::BF16(_) => DType::BF16, + Self::F16(_) => DType::F16, + Self::F32(_) => DType::F32, + Self::F64(_) => DType::F64, + Self::F8E4M3(_) => DType::F8E4M3, + } + } + + fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { + // TODO: find a way around the quadratic number of cases below. + match (self, dtype) { + (Self::U8(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); + Ok(Self::BF16(data)) + } + (Self::U32(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); + Ok(Self::BF16(data)) + } + (Self::I16(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); + Ok(Self::BF16(data)) + } + (Self::I32(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); + Ok(Self::BF16(data)) + } + (Self::I64(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); + Ok(Self::BF16(data)) + } + (Self::BF16(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| v); + Ok(Self::BF16(data)) + } + (Self::F16(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32())); + Ok(Self::BF16(data)) + } + (Self::F32(storage), DType::BF16) => { + let data = unary_map(storage, layout, bf16::from_f32); + Ok(Self::BF16(data)) + } + (Self::F64(storage), DType::BF16) => { + let data = unary_map(storage, layout, bf16::from_f64); + Ok(Self::BF16(data)) + } + (Self::F8E4M3(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32())); + Ok(Self::BF16(data)) + } + (Self::U8(storage), DType::F16) => { + let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); + Ok(Self::F16(data)) + } + (Self::U32(storage), DType::F16) => { + let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); + Ok(Self::F16(data)) + } + (Self::I16(storage), DType::F16) => { + let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); + Ok(Self::F16(data)) + } + (Self::I32(storage), DType::F16) => { + let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); Ok(Self::F16(data)) } (Self::I64(storage), DType::F16) => { @@ -1701,6 +2482,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, f16::from_f64); Ok(Self::F16(data)) } + (Self::F8E4M3(storage), DType::F16) => { + let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32())); + Ok(Self::F16(data)) + } (Self::U8(storage), DType::F32) => { let data = unary_map(storage, layout, |v| v as f32); Ok(Self::F32(data)) @@ -1709,6 +2494,14 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as f32); Ok(Self::F32(data)) } + (Self::I16(storage), DType::F32) => { + let data = unary_map(storage, layout, |v| v as f32); + Ok(Self::F32(data)) + } + (Self::I32(storage), DType::F32) => { + let data = unary_map(storage, layout, |v| v as f32); + Ok(Self::F32(data)) + } (Self::I64(storage), DType::F32) => { let data = unary_map(storage, layout, |v| v as f32); Ok(Self::F32(data)) @@ -1729,6 +2522,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as f32); Ok(Self::F32(data)) } + (Self::F8E4M3(storage), DType::F32) => { + let data = unary_map(storage, layout, |v| v.to_f32()); + Ok(Self::F32(data)) + } (Self::U8(storage), DType::U8) => { let data = unary_map(storage, layout, |v| v); Ok(Self::U8(data)) @@ -1753,10 +2550,22 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as u8); Ok(Self::U8(data)) } + (Self::I16(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v as u8); + Ok(Self::U8(data)) + } + (Self::I32(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v as u8); + Ok(Self::U8(data)) + } (Self::I64(storage), DType::U8) => { let data = unary_map(storage, layout, |v| v as u8); Ok(Self::U8(data)) } + (Self::F8E4M3(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v.to_f32() as u8); + Ok(Self::U8(data)) + } (Self::U8(storage), DType::U32) => { let data = unary_map(storage, layout, |v| v as u32); Ok(Self::U32(data)) @@ -1765,6 +2574,14 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v); Ok(Self::U32(data)) } + (Self::I16(storage), DType::U32) => { + let data = unary_map(storage, layout, |v| v as u32); + Ok(Self::U32(data)) + } + (Self::I32(storage), DType::U32) => { + let data = unary_map(storage, layout, |v| v as u32); + Ok(Self::U32(data)) + } (Self::I64(storage), DType::U32) => { let data = unary_map(storage, layout, |v| v as u32); Ok(Self::U32(data)) @@ -1785,6 +2602,90 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as u32); Ok(Self::U32(data)) } + (Self::F8E4M3(storage), DType::U32) => { + let data = unary_map(storage, layout, |v| v.to_f32() as u32); + Ok(Self::U32(data)) + } + (Self::U8(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::U32(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::I16(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v); + Ok(Self::I16(data)) + } + (Self::I32(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i16); + Ok(Self::I16(data)) + } + (Self::I64(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i16); + Ok(Self::I16(data)) + } + (Self::BF16(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i16); + Ok(Self::I16(data)) + } + (Self::F16(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i16); + Ok(Self::I16(data)) + } + (Self::F32(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i16); + Ok(Self::I16(data)) + } + (Self::F64(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i16); + Ok(Self::I16(data)) + } + (Self::F8E4M3(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i16); + Ok(Self::I16(data)) + } + (Self::U8(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::U32(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::I16(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i32); + Ok(Self::I32(data)) + } + (Self::I32(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v); + Ok(Self::I32(data)) + } + (Self::I64(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i32); + Ok(Self::I32(data)) + } + (Self::BF16(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i32); + Ok(Self::I32(data)) + } + (Self::F16(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i32); + Ok(Self::I32(data)) + } + (Self::F32(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i32); + Ok(Self::I32(data)) + } + (Self::F64(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i32); + Ok(Self::I32(data)) + } + (Self::F8E4M3(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i32); + Ok(Self::I32(data)) + } (Self::U8(storage), DType::I64) => { let data = unary_map(storage, layout, |v| v as i64); Ok(Self::I64(data)) @@ -1793,6 +2694,14 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as i64); Ok(Self::I64(data)) } + (Self::I16(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::I32(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } (Self::I64(storage), DType::I64) => { let data = unary_map(storage, layout, |v| v); Ok(Self::I64(data)) @@ -1813,6 +2722,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as i64); Ok(Self::I64(data)) } + (Self::F8E4M3(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i64); + Ok(Self::I64(data)) + } (Self::U8(storage), DType::F64) => { let data = unary_map(storage, layout, |v| v as f64); Ok(Self::F64(data)) @@ -1821,6 +2734,14 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as f64); Ok(Self::F64(data)) } + (Self::I16(storage), DType::F64) => { + let data = unary_map(storage, layout, |v| v as f64); + Ok(Self::F64(data)) + } + (Self::I32(storage), DType::F64) => { + let data = unary_map(storage, layout, |v| v as f64); + Ok(Self::F64(data)) + } (Self::I64(storage), DType::F64) => { let data = unary_map(storage, layout, |v| v as f64); Ok(Self::F64(data)) @@ -1841,6 +2762,50 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v); Ok(Self::F64(data)) } + (Self::F8E4M3(storage), DType::F64) => { + let data = unary_map(storage, layout, |v| v.to_f64()); + Ok(Self::F64(data)) + } + (Self::U8(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32)); + Ok(Self::F8E4M3(data)) + } + (Self::U32(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32)); + Ok(Self::F8E4M3(data)) + } + (Self::I16(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32)); + Ok(Self::F8E4M3(data)) + } + (Self::I32(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32)); + Ok(Self::F8E4M3(data)) + } + (Self::I64(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32)); + Ok(Self::F8E4M3(data)) + } + (Self::BF16(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from(v.to_f32())); + Ok(Self::F8E4M3(data)) + } + (Self::F16(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v.to_f32())); + Ok(Self::F8E4M3(data)) + } + (Self::F32(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, F8E4M3::from_f32); + Ok(Self::F8E4M3(data)) + } + (Self::F64(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, F8E4M3::from_f64); + Ok(Self::F8E4M3(data)) + } + (Self::F8E4M3(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| v); + Ok(Self::F8E4M3(data)) + } } } @@ -1954,8 +2919,14 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v.powf(e)); Ok(Self::F64(data)) } + Self::F8E4M3(storage) => { + let data = unary_map(storage, layout, |v| v.powf(F8E4M3::from_f64(e))); + Ok(Self::F8E4M3(data)) + } Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()), Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()), + Self::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "elu").bt()), + Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I32, "elu").bt()), Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()), } } @@ -1979,8 +2950,14 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| elu(v, alpha)); Ok(Self::F64(data)) } + Self::F8E4M3(storage) => { + let data = unary_map(storage, layout, |v| elu(v, F8E4M3::from_f64(alpha))); + Ok(Self::F8E4M3(data)) + } Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()), Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()), + Self::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "elu").bt()), + Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I32, "elu").bt()), Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()), } } @@ -2023,6 +3000,15 @@ impl BackendStorage for CpuStorage { Ok(Self::F64(data)) } } + Self::F8E4M3(storage) => { + if B::F8E4M3_VEC { + let data = unary_map_vec(storage, layout, B::f8e4m3, B::f8e4m3_vec); + Ok(Self::F8E4M3(data)) + } else { + let data = unary_map(storage, layout, B::f8e4m3); + Ok(Self::F8E4M3(data)) + } + } Self::U8(storage) => { let data = unary_map(storage, layout, B::u8); Ok(Self::U8(data)) @@ -2031,6 +3017,14 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, B::u32); Ok(Self::U32(data)) } + Self::I16(storage) => { + let data = unary_map(storage, layout, B::i16); + Ok(Self::I16(data)) + } + Self::I32(storage) => { + let data = unary_map(storage, layout, B::i32); + Ok(Self::I32(data)) + } Self::I64(storage) => { let data = unary_map(storage, layout, B::i64); Ok(Self::I64(data)) @@ -2085,6 +3079,22 @@ impl BackendStorage for CpuStorage { }; Ok(Self::U32(data)) } + (Self::I16(lhs), Self::I16(rhs)) => { + let data = if B::I16_VEC { + binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i16, B::i16_vec) + } else { + binary_map(lhs_l, rhs_l, lhs, rhs, B::i16) + }; + Ok(Self::I16(data)) + } + (Self::I32(lhs), Self::I32(rhs)) => { + let data = if B::I32_VEC { + binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i32, B::i32_vec) + } else { + binary_map(lhs_l, rhs_l, lhs, rhs, B::i32) + }; + Ok(Self::I32(data)) + } (Self::I64(lhs), Self::I64(rhs)) => { let data = if B::I64_VEC { binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i64, B::i64_vec) @@ -2128,6 +3138,12 @@ impl BackendStorage for CpuStorage { (Self::U32(src), Self::U32(dst)) => { copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) } + (Self::I16(src), Self::I16(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } + (Self::I32(src), Self::I32(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } (Self::I64(src), Self::I64(dst)) => { copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) } @@ -2159,6 +3175,8 @@ impl BackendStorage for CpuStorage { match (self, dst) { (Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::I16(src), Self::I16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::I32(src), Self::I32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::I64(src), Self::I64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), @@ -2188,6 +3206,8 @@ impl BackendStorage for CpuStorage { match self { Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l), Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l), + Self::I16(pred) => WCond(pred, layout).map(t, t_l, f, f_l), + Self::I32(pred) => WCond(pred, layout).map(t, t_l, f, f_l), Self::I64(pred) => WCond(pred, layout).map(t, t_l, f, f_l), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")), } @@ -2220,7 +3240,7 @@ impl BackendStorage for CpuStorage { let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? } else { // Make the kernel contiguous if not already the case. let mut kernel_c = unsafe { @@ -2231,7 +3251,7 @@ impl BackendStorage for CpuStorage { let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? }; let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?; let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? }; @@ -2272,8 +3292,9 @@ impl BackendStorage for CpuStorage { vec![0, k_size * c_out, 1], kernel_l.start_offset(), ); - self.matmul( + self.matmul_with_alpha( kernel, + None, ( b_size, /* m */ l_in, @@ -2322,7 +3343,7 @@ impl BackendStorage for CpuStorage { let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? } else { // Make the kernel contiguous if not already the case. let mut kernel_c = unsafe { @@ -2333,7 +3354,7 @@ impl BackendStorage for CpuStorage { let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? }; let res_l = Layout::contiguous((b, h_out, w_out, params.c_out)) .transpose(1, 2)? @@ -2357,6 +3378,8 @@ impl BackendStorage for CpuStorage { match ids { Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), + Self::I16(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), + Self::I32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select").bt()), } @@ -2366,6 +3389,8 @@ impl BackendStorage for CpuStorage { match ids { Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l), Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l), + Self::I16(ids) => Gather { ids, ids_l, dim }.map(self, l), + Self::I32(ids) => Gather { ids, ids_l, dim }.map(self, l), Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather").bt()), } @@ -2383,6 +3408,8 @@ impl BackendStorage for CpuStorage { match ids { Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), + Self::I16(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), + Self::I32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()), } @@ -2412,6 +3439,20 @@ impl BackendStorage for CpuStorage { }; IndexAdd { ids, dim }.map(self, l, src, src_l) } + Self::I16(ids) => { + let ids = match ids_l.contiguous_offsets() { + Some((a, b)) => &ids[a..b], + None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?, + }; + IndexAdd { ids, dim }.map(self, l, src, src_l) + } + Self::I32(ids) => { + let ids = match ids_l.contiguous_offsets() { + Some((a, b)) => &ids[a..b], + None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?, + }; + IndexAdd { ids, dim }.map(self, l, src, src_l) + } Self::I64(ids) => { let ids = match ids_l.contiguous_offsets() { Some((a, b)) => &ids[a..b], @@ -2423,14 +3464,28 @@ impl BackendStorage for CpuStorage { } } - fn matmul( + fn matmul_with_alpha_beta( + &self, + rhs: &Self, + c: &mut Self, + s: Option, + bmnk: (usize, usize, usize, usize), + lhs_l: &Layout, + rhs_l: &Layout, + c_l: &Layout, + ) -> Result<()> { + MatMulWithBias(MatMul(bmnk)).map(self, lhs_l, rhs, rhs_l, c, c_l, s) + } + + fn matmul_with_alpha( &self, rhs: &Self, + s: Option, bmnk: (usize, usize, usize, usize), lhs_l: &Layout, rhs_l: &Layout, ) -> Result { - MatMul(bmnk).map(self, lhs_l, rhs, rhs_l) + MatMulWithAlpha(MatMul(bmnk)).map(self, lhs_l, rhs, rhs_l, s) } fn device(&self) -> &Self::Device { @@ -2477,13 +3532,17 @@ impl BackendDevice for CpuDevice { crate::bail!("cannot seed the CPU rng with set_seed") } + fn get_current_seed(&self) -> Result { + crate::bail!("cannot get the CPU rng seed with get_current_seed") + } + fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result { use rand::prelude::*; let elem_count = shape.elem_count(); let mut rng = rand::thread_rng(); match dtype { - DType::U8 | DType::U32 | DType::I64 => { + DType::U8 | DType::U32 | DType::I16 | DType::I32 | DType::I64 => { Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()) } DType::BF16 => { @@ -2504,6 +3563,15 @@ impl BackendDevice for CpuDevice { } Ok(CpuStorage::F16(data)) } + DType::F8E4M3 => { + let mut data = Vec::with_capacity(elem_count); + let uniform = + rand::distributions::Uniform::new(F8E4M3::from_f64(min), F8E4M3::from_f64(max)); + for _i in 0..elem_count { + data.push(rng.sample::(uniform)) + } + Ok(CpuStorage::F8E4M3(data)) + } DType::F32 => { let mut data = Vec::with_capacity(elem_count); let uniform = rand::distributions::Uniform::new(min as f32, max as f32); @@ -2529,7 +3597,7 @@ impl BackendDevice for CpuDevice { let elem_count = shape.elem_count(); let mut rng = rand::thread_rng(); match dtype { - DType::U8 | DType::U32 | DType::I64 => { + DType::U8 | DType::U32 | DType::I16 | DType::I32 | DType::I64 => { Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()) } DType::BF16 => { @@ -2550,6 +3618,15 @@ impl BackendDevice for CpuDevice { } Ok(CpuStorage::F16(data)) } + DType::F8E4M3 => { + let mut data = Vec::with_capacity(elem_count); + let normal = rand_distr::Normal::new(F8E4M3::from_f64(mean), F8E4M3::from_f64(std)) + .map_err(Error::wrap)?; + for _i in 0..elem_count { + data.push(normal.sample(&mut rng)) + } + Ok(CpuStorage::F8E4M3(data)) + } DType::F32 => { let mut data = Vec::with_capacity(elem_count); let normal = @@ -2588,6 +3665,16 @@ impl BackendDevice for CpuDevice { v.set_len(elem_count); CpuStorage::U32(v) } + DType::I16 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::I16(v) + } + DType::I32 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::I32(v) + } DType::I64 => { let mut v = Vec::with_capacity(elem_count); v.set_len(elem_count); @@ -2613,6 +3700,11 @@ impl BackendDevice for CpuDevice { v.set_len(elem_count); CpuStorage::F64(v) } + DType::F8E4M3 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::F8E4M3(v) + } }; Ok(storage) } @@ -2622,9 +3714,12 @@ impl BackendDevice for CpuDevice { let storage = match dtype { DType::U8 => CpuStorage::U8(vec![1u8; elem_count]), DType::U32 => CpuStorage::U32(vec![1u32; elem_count]), + DType::I16 => CpuStorage::I16(vec![1i16; elem_count]), + DType::I32 => CpuStorage::I32(vec![1i32; elem_count]), DType::I64 => CpuStorage::I64(vec![1i64; elem_count]), DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]), DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]), + DType::F8E4M3 => CpuStorage::F8E4M3(vec![F8E4M3::ONE; elem_count]), DType::F32 => CpuStorage::F32(vec![1f32; elem_count]), DType::F64 => CpuStorage::F64(vec![1f64; elem_count]), }; @@ -2636,9 +3731,12 @@ impl BackendDevice for CpuDevice { let storage = match dtype { DType::U8 => CpuStorage::U8(vec![0u8; elem_count]), DType::U32 => CpuStorage::U32(vec![0u32; elem_count]), + DType::I16 => CpuStorage::I16(vec![0i16; elem_count]), + DType::I32 => CpuStorage::I32(vec![0i32; elem_count]), DType::I64 => CpuStorage::I64(vec![0i64; elem_count]), DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]), DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]), + DType::F8E4M3 => CpuStorage::F8E4M3(vec![F8E4M3::ZERO; elem_count]), DType::F32 => CpuStorage::F32(vec![0f32; elem_count]), DType::F64 => CpuStorage::F64(vec![0f64; elem_count]), }; diff --git a/candle-core/src/cpu_backend/utils.rs b/candle-core/src/cpu_backend/utils.rs index 3e0c69b4f7..495fcd660b 100644 --- a/candle-core/src/cpu_backend/utils.rs +++ b/candle-core/src/cpu_backend/utils.rs @@ -10,11 +10,14 @@ pub trait Map1 { match vs { C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)), C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)), + C::I16(vs) => Ok(C::I16(self.f(vs, layout)?)), + C::I32(vs) => Ok(C::I32(self.f(vs, layout)?)), C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)), C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)), C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)), C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)), C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)), + C::F8E4M3(vs) => Ok(C::F8E4M3(self.f(vs, layout)?)), } } } @@ -26,11 +29,14 @@ pub trait Map1Any { match vs { C::U8(vs) => Ok(self.f(vs, layout, C::U8)?), C::U32(vs) => Ok(self.f(vs, layout, C::U32)?), + C::I16(vs) => Ok(self.f(vs, layout, C::I16)?), + C::I32(vs) => Ok(self.f(vs, layout, C::I32)?), C::I64(vs) => Ok(self.f(vs, layout, C::I64)?), C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?), C::F16(vs) => Ok(self.f(vs, layout, C::F16)?), C::F32(vs) => Ok(self.f(vs, layout, C::F32)?), C::F64(vs) => Ok(self.f(vs, layout, C::F64)?), + C::F8E4M3(vs) => Ok(self.f(vs, layout, C::F8E4M3)?), } } } @@ -48,6 +54,86 @@ pub trait Map2 { (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)), (C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)), (C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)), + (C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::F8E4M3(self.f(v1, l1, v2, l2)?)), + _ => Err(Error::DTypeMismatchBinaryOp { + lhs: v1.dtype(), + rhs: v2.dtype(), + op: Self::OP, + } + .bt()), + } + } +} + +pub trait Map3 { + const OP: &'static str; + #[allow(clippy::too_many_arguments)] + fn f( + &self, + v1: &[T], + l1: &Layout, + v2: &[T], + l2: &Layout, + v3: &mut [T], + l3: &Layout, + s: Option, + ) -> Result<()>; + + #[allow(clippy::too_many_arguments)] + fn map( + &self, + v1: &C, + l1: &Layout, + v2: &C, + l2: &Layout, + v3: &mut C, + l3: &Layout, + s: Option, + ) -> Result<()> { + let v3d = v3.dtype(); + match (v1, v2, v3) { + (C::U8(v1), C::U8(v2), C::U8(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::U32(v1), C::U32(v2), C::U32(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::I64(v1), C::I64(v2), C::I64(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::BF16(v1), C::BF16(v2), C::BF16(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::F16(v1), C::F16(v2), C::F16(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::F32(v1), C::F32(v2), C::F32(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::F64(v1), C::F64(v2), C::F64(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::F8E4M3(v1), C::F8E4M3(v2), C::F8E4M3(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + _ => Err(Error::DTypeMismatchBinaryOp3 { + lhs: v1.dtype(), + rhs: v2.dtype(), + c: v3d, + op: Self::OP, + } + .bt()), + } + } +} + +pub trait Map2Alpha { + const OP: &'static str; + #[allow(clippy::too_many_arguments)] + fn f( + &self, + v1: &[T], + l1: &Layout, + v2: &[T], + l2: &Layout, + s: Option, + ) -> Result>; + + #[allow(clippy::too_many_arguments)] + fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout, s: Option) -> Result { + match (v1, v2) { + (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2, s)?)), + (C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2, s)?)), + (C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2, s)?)), + (C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2, s)?)), + (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2, s)?)), + (C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2, s)?)), + (C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2, s)?)), + (C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::F8E4M3(self.f(v1, l1, v2, l2, s)?)), _ => Err(Error::DTypeMismatchBinaryOp { lhs: v1.dtype(), rhs: v2.dtype(), @@ -71,6 +157,7 @@ pub trait Map2U8 { (C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), _ => Err(Error::DTypeMismatchBinaryOp { lhs: v1.dtype(), rhs: v2.dtype(), diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs index 0aa58cacde..1d8cb7d34b 100644 --- a/candle-core/src/cuda_backend/device.rs +++ b/candle-core/src/cuda_backend/device.rs @@ -3,8 +3,9 @@ use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape}; pub use candle_kernels as kernels; pub use cudarc; use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig}; +use float8::F8E4M3; use half::{bf16, f16}; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, RwLock}; use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr}; @@ -30,6 +31,7 @@ pub struct CudaDevice { device: Arc, pub(crate) blas: Arc, curand: Arc>, + seed_value: Arc>, } impl std::fmt::Debug for CudaDevice { @@ -47,6 +49,10 @@ impl std::ops::Deref for CudaDevice { } impl CudaDevice { + pub fn cublas_handle(&self) -> &cudarc::cublas::CudaBlas { + &*self.blas + } + pub fn cuda_device(&self) -> Arc { self.device.clone() } @@ -75,6 +81,22 @@ impl CudaDevice { unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::U32(data) } + DType::I16 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_i16", kernels::FILL)?; + let params = (&data, v as i16, elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::I16(data) + } + DType::I32 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_i32", kernels::FILL)?; + let params = (&data, v as i32, elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::I32(data) + } DType::I64 => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::(elem_count) }.w()?; @@ -115,6 +137,14 @@ impl CudaDevice { unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::F64(data) } + DType::F8E4M3 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_f8_e4m3", kernels::FILL)?; + let params = (&data, v, elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F8E4M3(data) + } }; Ok(CudaStorage { slice, @@ -156,6 +186,7 @@ impl BackendDevice for CudaDevice { device, blas: Arc::new(blas), curand: Arc::new(Mutex::new(CudaRng(curand))), + seed_value: Arc::new(RwLock::new(299792458)), }) } @@ -164,9 +195,14 @@ impl BackendDevice for CudaDevice { // state will be identical and the same random numbers will be generated. let mut curand = self.curand.lock().unwrap(); curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?; + *self.seed_value.write().unwrap() = seed; Ok(()) } + fn get_current_seed(&self) -> Result { + Ok(*self.seed_value.read().unwrap()) + } + fn location(&self) -> crate::DeviceLocation { crate::DeviceLocation::Cuda { gpu_id: self.device.ordinal(), @@ -188,6 +224,14 @@ impl BackendDevice for CudaDevice { let data = self.alloc_zeros::(elem_count).w()?; CudaStorageSlice::U32(data) } + DType::I16 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::I16(data) + } + DType::I32 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::I32(data) + } DType::I64 => { let data = self.alloc_zeros::(elem_count).w()?; CudaStorageSlice::I64(data) @@ -208,6 +252,10 @@ impl BackendDevice for CudaDevice { let data = self.alloc_zeros::(elem_count).w()?; CudaStorageSlice::F64(data) } + DType::F8E4M3 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::F64(data) + } }; Ok(CudaStorage { slice, @@ -221,13 +269,18 @@ impl BackendDevice for CudaDevice { let slice = match dtype { // TODO: Add support for F16 and BF16 though this is likely to require some upstream // cudarc changes. - DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => { - Err(CudaError::UnsupportedDtype { - dtype, - op: "rand_uniform", - }) - .w()? - } + DType::U8 + | DType::U32 + | DType::I64 + | DType::I32 + | DType::I16 + | DType::F16 + | DType::BF16 + | DType::F8E4M3 => Err(CudaError::UnsupportedDtype { + dtype, + op: "rand_uniform", + }) + .w()?, DType::F32 => { let mut data = unsafe { self.alloc::(elem_count) }.w()?; curand.0.fill_with_uniform(&mut data).w()?; @@ -265,13 +318,18 @@ impl BackendDevice for CudaDevice { elem_count }; let slice = match dtype { - DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => { - Err(CudaError::UnsupportedDtype { - dtype, - op: "rand_normal", - }) - .w()? - } + DType::U8 + | DType::U32 + | DType::I16 + | DType::I32 + | DType::I64 + | DType::F16 + | DType::BF16 + | DType::F8E4M3 => Err(CudaError::UnsupportedDtype { + dtype, + op: "rand_normal", + }) + .w()?, DType::F32 => { let mut data = unsafe { self.alloc::(elem_count_round) }.w()?; curand @@ -307,6 +365,14 @@ impl BackendDevice for CudaDevice { let data = self.alloc::(elem_count).w()?; CudaStorageSlice::U32(data) } + DType::I16 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::I16(data) + } + DType::I32 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::I32(data) + } DType::I64 => { let data = self.alloc::(elem_count).w()?; CudaStorageSlice::I64(data) @@ -327,6 +393,10 @@ impl BackendDevice for CudaDevice { let data = self.alloc::(elem_count).w()?; CudaStorageSlice::F64(data) } + DType::F8E4M3 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::F8E4M3(data) + } }; Ok(CudaStorage { slice, @@ -344,6 +414,14 @@ impl BackendDevice for CudaDevice { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::U32(data) } + CpuStorageRef::I16(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::I16(data) + } + CpuStorageRef::I32(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::I32(data) + } CpuStorageRef::I64(storage) => { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::I64(data) @@ -364,6 +442,10 @@ impl BackendDevice for CudaDevice { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::F64(data) } + CpuStorageRef::F8E4M3(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::F8E4M3(data) + } }; Ok(CudaStorage { slice, @@ -381,6 +463,14 @@ impl BackendDevice for CudaDevice { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::U32(data) } + CpuStorage::I16(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::I16(data) + } + CpuStorage::I32(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::I32(data) + } CpuStorage::I64(storage) => { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::I64(data) @@ -401,6 +491,10 @@ impl BackendDevice for CudaDevice { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::F64(data) } + CpuStorage::F8E4M3(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::F8E4M3(data) + } }; Ok(CudaStorage { slice, @@ -418,6 +512,14 @@ impl BackendDevice for CudaDevice { let data = self.htod_copy(storage).w()?; CudaStorageSlice::U32(data) } + CpuStorage::I16(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::I16(data) + } + CpuStorage::I32(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::I32(data) + } CpuStorage::I64(storage) => { let data = self.htod_copy(storage).w()?; CudaStorageSlice::I64(data) @@ -438,6 +540,10 @@ impl BackendDevice for CudaDevice { let data = self.htod_copy(storage).w()?; CudaStorageSlice::F64(data) } + CpuStorage::F8E4M3(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::F8E4M3(data) + } }; Ok(CudaStorage { slice, diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 07bb1785dd..9e8e099b3c 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -7,6 +7,7 @@ use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; use cudarc::driver::{ CudaSlice, DevicePtr, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits, }; +use float8::F8E4M3; use half::{bf16, f16}; #[cfg(feature = "cudnn")] @@ -47,11 +48,14 @@ impl SlicePtrOrNull { pub enum CudaStorageSlice { U8(CudaSlice), U32(CudaSlice), + I16(CudaSlice), + I32(CudaSlice), I64(CudaSlice), BF16(CudaSlice), F16(CudaSlice), F32(CudaSlice), F64(CudaSlice), + F8E4M3(CudaSlice), } struct Clone; @@ -363,11 +367,17 @@ impl<'a> Map1 for IndexSelect<'a> { CudaStorageSlice::U8(slice) => { ("is_u8", *slice.slice(ids_l.start_offset()..).device_ptr()) } + CudaStorageSlice::I16(slice) => { + ("is_i16", *slice.slice(ids_l.start_offset()..).device_ptr()) + } + CudaStorageSlice::I32(slice) => { + ("is_i32", *slice.slice(ids_l.start_offset()..).device_ptr()) + } CudaStorageSlice::I64(slice) => { ("is_i64", *slice.slice(ids_l.start_offset()..).device_ptr()) } _ => Err(CudaError::UnexpectedDType { - msg: "index_select ids should be u8 or u32", + msg: "index_select ids should be u8/u32/i16/i32/i64", expected: DType::U32, got: self.0.dtype(), }) @@ -427,11 +437,17 @@ impl<'a> Map1 for Gather<'a> { ("gather_u32", *slice.slice(ids_o1..ids_o2).device_ptr()) } CudaStorageSlice::U8(slice) => ("gather_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::I16(slice) => { + ("gather_i16", *slice.slice(ids_o1..ids_o2).device_ptr()) + } + CudaStorageSlice::I32(slice) => { + ("gather_i32", *slice.slice(ids_o1..ids_o2).device_ptr()) + } CudaStorageSlice::I64(slice) => { ("gather_i64", *slice.slice(ids_o1..ids_o2).device_ptr()) } _ => Err(CudaError::UnexpectedDType { - msg: "gather ids should be u8/u32/i64", + msg: "gather ids should be u8/u32/i16/i32/i64", expected: DType::U32, got: ids.dtype(), })?, @@ -477,10 +493,12 @@ impl<'a> Map2InPlace for IndexAdd<'a> { }; let (name, ids) = match &ids.slice { CudaStorageSlice::U32(slice) => ("ia_u32", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::I16(slice) => ("ia_i16", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::I32(slice) => ("ia_i32", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::I64(slice) => ("ia_i64", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::U8(slice) => ("ia_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), _ => Err(CudaError::UnexpectedDType { - msg: "index-add ids should be u8/u32/i64", + msg: "index-add ids should be u8/u32/i16/i32/i64", expected: DType::U32, got: ids.dtype(), })?, @@ -525,10 +543,12 @@ impl<'a> Map2InPlace for ScatterAdd<'a> { }; let (name, ids) = match &ids.slice { CudaStorageSlice::U32(slice) => ("sa_u32", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::I16(slice) => ("sa_i16", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::I32(slice) => ("sa_i32", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::I64(slice) => ("sa_i64", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::U8(slice) => ("sa_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), _ => Err(CudaError::UnexpectedDType { - msg: "scatter-add ids should be u8/u32/i64", + msg: "scatter-add ids should be u8/u32/i16/i32/i64", expected: DType::U32, got: ids.dtype(), })?, @@ -867,12 +887,20 @@ impl<'a> Map2 for WhereCond<'a> { let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); (ptr, "where_u32") } + CudaStorageSlice::I16(slice) => { + let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); + (ptr, "where_i16") + } + CudaStorageSlice::I32(slice) => { + let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); + (ptr, "where_i32") + } CudaStorageSlice::I64(slice) => { let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); (ptr, "where_i64") } _ => Err(CudaError::UnexpectedDType { - msg: "where conditions should be u8/u32/i64", + msg: "where conditions should be u8/u32/i16/i32/i64", expected: DType::U32, got: self.0.dtype(), }) @@ -1026,6 +1054,8 @@ macro_rules! cuda_dtype { } cuda_dtype!(u8, U8); cuda_dtype!(u32, U32); +cuda_dtype!(i16, I16); +cuda_dtype!(i32, I32); cuda_dtype!(i64, I64); cuda_dtype!(f16, F16); cuda_dtype!(bf16, BF16); @@ -1148,11 +1178,14 @@ impl BackendStorage for CudaStorage { match self.slice { CudaStorageSlice::U8(_) => DType::U8, CudaStorageSlice::U32(_) => DType::U32, + CudaStorageSlice::I16(_) => DType::I16, + CudaStorageSlice::I32(_) => DType::I32, CudaStorageSlice::I64(_) => DType::I64, CudaStorageSlice::BF16(_) => DType::BF16, CudaStorageSlice::F16(_) => DType::F16, CudaStorageSlice::F32(_) => DType::F32, CudaStorageSlice::F64(_) => DType::F64, + CudaStorageSlice::F8E4M3(_) => DType::F8E4M3, } } @@ -1174,11 +1207,14 @@ impl BackendStorage for CudaStorage { let inp = match &self.slice { CudaStorageSlice::U8(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::I16(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::I32(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::I64(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::F32(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::F64(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::F8E4M3(inp) => *inp.slice(start_o..).device_ptr(), }; let inp = &inp; @@ -1197,6 +1233,18 @@ impl BackendStorage for CudaStorage { unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::U32(out) } + DType::I16 => { + let out = unsafe { dev.alloc::(el) }.w()?; + let params = (el, dims.len(), &ds, *inp, &out); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::I16(out) + } + DType::I32 => { + let out = unsafe { dev.alloc::(el) }.w()?; + let params = (el, dims.len(), &ds, *inp, &out); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::I32(out) + } DType::I64 => { let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, *inp, &out); @@ -1227,6 +1275,12 @@ impl BackendStorage for CudaStorage { unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::F64(out) } + DType::F8E4M3 => { + let out = unsafe { dev.alloc::(el) }.w()?; + let params = (el, dims.len(), &ds, *inp, &out); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F8E4M3(out) + } }; Ok(Self { slice, @@ -1293,6 +1347,16 @@ impl BackendStorage for CudaStorage { let cpu_storage = dev.dtoh_sync_copy(slice).w()?; Ok(CpuStorage::U32(cpu_storage)) } + CudaStorageSlice::I16(slice) => { + let dev = slice.device(); + let cpu_storage = dev.dtoh_sync_copy(slice).w()?; + Ok(CpuStorage::I16(cpu_storage)) + } + CudaStorageSlice::I32(slice) => { + let dev = slice.device(); + let cpu_storage = dev.dtoh_sync_copy(slice).w()?; + Ok(CpuStorage::I32(cpu_storage)) + } CudaStorageSlice::I64(slice) => { let dev = slice.device(); let cpu_storage = dev.dtoh_sync_copy(slice).w()?; @@ -1318,6 +1382,11 @@ impl BackendStorage for CudaStorage { let cpu_storage = dev.dtoh_sync_copy(slice).w()?; Ok(CpuStorage::F64(cpu_storage)) } + CudaStorageSlice::F8E4M3(slice) => { + let dev = slice.device(); + let cpu_storage = dev.dtoh_sync_copy(slice).w()?; + Ok(CpuStorage::F8E4M3(cpu_storage)) + } } } @@ -1367,7 +1436,7 @@ impl BackendStorage for CudaStorage { let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? } else { // Make the kernel contiguous if not already the case. let mut kernel_c = unsafe { @@ -1378,7 +1447,7 @@ impl BackendStorage for CudaStorage { let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? }; let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?; let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? }; @@ -1422,8 +1491,9 @@ impl BackendStorage for CudaStorage { vec![0, k_size * c_out, 1], kernel_l.start_offset(), ); - self.matmul( + self.matmul_with_alpha( kernel, + None, ( b_size, /* m */ l_in, @@ -1481,7 +1551,7 @@ impl BackendStorage for CudaStorage { let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? } else { // Make the kernel contiguous if not already the case. let mut kernel_c = unsafe { @@ -1492,7 +1562,7 @@ impl BackendStorage for CudaStorage { let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? }; let res_l = Layout::contiguous((b, h_out, w_out, n)) .transpose(1, 2)? @@ -1559,6 +1629,8 @@ impl BackendStorage for CudaStorage { S::F64(out) } (S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv2d does not support u32"))?, + (S::I16(_), S::I16(_)) => Err(CudaError::InternalError("conv2d does not support i16"))?, + (S::I32(_), S::I32(_)) => Err(CudaError::InternalError("conv2d does not support i32"))?, (S::I64(_), S::I64(_)) => Err(CudaError::InternalError("conv2d does not support i64"))?, _ => Err(CudaError::InternalError("dtype mismatch in conv2d"))?, }; @@ -1655,9 +1727,80 @@ impl BackendStorage for CudaStorage { Ok(acc) } - fn matmul( + fn matmul_with_alpha_beta( + &self, + rhs: &Self, + c: &mut Self, + s: Option, + (b, m, n, k): (usize, usize, usize, usize), + lhs_l: &Layout, + rhs_l: &Layout, + c_l: &Layout, + ) -> Result<()> { + let elem_count = b * m * n; + + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + crate::bail!("`c` start offset must be 0"); + } + if o2 != elem_count { + crate::bail!("`c` end offset must be {}", elem_count) + } + } + None => crate::bail!("`c` has to be contiguous"), + }; + + match (&self.slice, &rhs.slice, &mut c.slice) { + ( + CudaStorageSlice::BF16(lhs), + CudaStorageSlice::BF16(rhs), + CudaStorageSlice::BF16(c), + ) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config( + bf16::from_f64(s.unwrap_or(1.0)), + bf16::ONE, + (b, m, n, k), + lhs_l, + rhs_l, + )?; + unsafe { gemm_strided_batched_bf16(&self.device.blas, cfg, rhs, lhs, c) }.w()?; + } + (CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs), CudaStorageSlice::F16(c)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config( + f16::from_f64(s.unwrap_or(1.0)), + f16::ONE, + (b, m, n, k), + lhs_l, + rhs_l, + )?; + unsafe { gemm_strided_batched_f16(&self.device.blas, cfg, rhs, lhs, c) }.w()?; + } + (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs), CudaStorageSlice::F32(c)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config(s.unwrap_or(1.0) as f32, 1., (b, m, n, k), lhs_l, rhs_l)?; + unsafe { gemm_strided_batched_f32(&self.device.blas, cfg, rhs, lhs, c) }.w()?; + } + (CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs), CudaStorageSlice::F64(c)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config(s.unwrap_or(1.0), 1., (b, m, n, k), lhs_l, rhs_l)?; + unsafe { self.device.blas.gemm_strided_batched(cfg, rhs, lhs, c) }.w()?; + } + _ => Err(CudaError::InternalError("dtype mismatch in matmul op"))?, + }; + Ok(()) + } + + fn matmul_with_alpha( &self, rhs: &Self, + scale: Option, (b, m, n, k): (usize, usize, usize, usize), lhs_l: &Layout, rhs_l: &Layout, @@ -1668,7 +1811,13 @@ impl BackendStorage for CudaStorage { (CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => { let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); - let cfg = gemm_config(bf16::ONE, bf16::ZERO, (b, m, n, k), lhs_l, rhs_l)?; + let cfg = gemm_config( + bf16::from_f64(scale.unwrap_or(1.)), + bf16::ZERO, + (b, m, n, k), + lhs_l, + rhs_l, + )?; let mut out = unsafe { dev.alloc::(elem_count) }.w()?; unsafe { gemm_strided_batched_bf16(&self.device.blas, cfg, rhs, lhs, &mut out) } .w()?; @@ -1677,7 +1826,13 @@ impl BackendStorage for CudaStorage { (CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => { let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); - let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_l, rhs_l)?; + let cfg = gemm_config( + f16::from_f64(scale.unwrap_or(1.)), + f16::ZERO, + (b, m, n, k), + lhs_l, + rhs_l, + )?; let mut out = unsafe { dev.alloc::(elem_count) }.w()?; unsafe { gemm_strided_batched_f16(&self.device.blas, cfg, rhs, lhs, &mut out) } .w()?; @@ -1686,7 +1841,7 @@ impl BackendStorage for CudaStorage { (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => { let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); - let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?; + let cfg = gemm_config(scale.unwrap_or(1.) as f32, 0., (b, m, n, k), lhs_l, rhs_l)?; let mut out = unsafe { dev.alloc::(elem_count) }.w()?; unsafe { gemm_strided_batched_f32(&self.device.blas, cfg, rhs, lhs, &mut out) } .w()?; @@ -1695,7 +1850,7 @@ impl BackendStorage for CudaStorage { (CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => { let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); - let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?; + let cfg = gemm_config(scale.unwrap_or(1.), 0., (b, m, n, k), lhs_l, rhs_l)?; let mut out = unsafe { dev.alloc::(elem_count) }.w()?; unsafe { self.device @@ -1742,6 +1897,16 @@ impl BackendStorage for CudaStorage { *d.slice(dst_o..).device_ptr(), "copy2d_u32", ), + (S::I16(s), S::I16(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_i16", + ), + (S::I32(s), S::I32(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_i32", + ), (S::I64(s), S::I64(d)) => ( *s.slice(src_o..).device_ptr(), *d.slice(dst_o..).device_ptr(), @@ -1848,6 +2013,30 @@ impl BackendStorage for CudaStorage { unsafe { func.launch(cfg, params) }.w()? } } + (CudaStorageSlice::I16(src), CudaStorageSlice::I16(dst)) => { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { + dev.dtod_copy(&src, &mut dst).w()? + } else { + let func = dev.get_or_load_func("ucopy_i16", kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, &src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()? + } + } + (CudaStorageSlice::I32(src), CudaStorageSlice::I32(dst)) => { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { + dev.dtod_copy(&src, &mut dst).w()? + } else { + let func = dev.get_or_load_func("ucopy_i32", kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, &src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()? + } + } (CudaStorageSlice::I64(src), CudaStorageSlice::I64(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { @@ -2079,3 +2268,163 @@ unsafe fn gemm_strided_batched_bf16( sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP, ) } + +pub struct KVConcat { + pub concat_dim: usize, +} +impl crate::CustomOp2 for KVConcat { + fn name(&self) -> &'static str { + "kvconcat" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + crate::bail!("no cpu support for kvconcat") + } + + fn cuda_fwd( + &self, + ltensor: &CudaStorage, + ltensor_l: &Layout, + rtensor: &CudaStorage, + rtensor_l: &Layout, + ) -> Result<(CudaStorage, Shape)> { + assert!(self.concat_dim == 2 || self.concat_dim == 0); //must be in the dim of sequence len + let dev = <ensor.device; + let elem_count = ltensor_l.shape().elem_count() + rtensor_l.shape().elem_count(); + let dims_l = ltensor_l.shape().dims(); + let dims_r = rtensor_l.shape().dims(); + let dim_size = dims_l.len(); + let cfg = LaunchConfig::for_num_elems(elem_count as u32); + + let chunk_l = if dim_size > 3 { + dims_l[0] * dims_l[1] + } else { + dims_l[0] + }; + let chunk_r = if dim_size > 3 { + dims_r[0] * dims_r[1] + } else { + dims_r[0] + }; + let lstride = if dim_size > 3 { + dims_l[2] * dims_l[3] + } else { + dims_l[1] * dims_l[2] + }; + let rstride = if dim_size > 3 { + dims_r[2] * dims_r[3] + } else { + dims_r[1] * dims_r[2] + }; + + let slice = match (<ensor.slice, &rtensor.slice) { + (CudaStorageSlice::BF16(left_), CudaStorageSlice::BF16(right_)) => { + let out = unsafe { dev.alloc::(elem_count).w()? }; + let func = dev.get_or_load_func("kvconcat_bf16", kernels::KVCONCAT)?; + let params = ( + left_, + right_, + &out, + self.concat_dim, + chunk_l, + chunk_r, + lstride, + rstride, + ); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::BF16(out) + } + (CudaStorageSlice::F32(left_), CudaStorageSlice::F32(right_)) => { + let out = unsafe { dev.alloc::(elem_count).w()? }; + let func = dev.get_or_load_func("kvconcat_f32", kernels::KVCONCAT)?; + let params = ( + left_, + right_, + &out, + self.concat_dim, + chunk_l, + chunk_r, + lstride, + rstride, + ); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F32(out) + } + (CudaStorageSlice::F16(left_), CudaStorageSlice::F16(right_)) => { + let out = unsafe { dev.alloc::(elem_count).w()? }; + let func = dev.get_or_load_func("kvconcat_f16", kernels::KVCONCAT)?; + let params = ( + left_, + right_, + &out, + self.concat_dim, + chunk_l, + chunk_r, + lstride, + rstride, + ); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F16(out) + } + (CudaStorageSlice::F64(left_), CudaStorageSlice::F64(right_)) => { + let out = unsafe { dev.alloc::(elem_count).w()? }; + let func = dev.get_or_load_func("kvconcat_f64", kernels::KVCONCAT)?; + let params = ( + left_, + right_, + &out, + self.concat_dim, + chunk_l, + chunk_r, + lstride, + rstride, + ); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F64(out) + } + (CudaStorageSlice::U8(left_), CudaStorageSlice::U8(right_)) => { + let out = unsafe { dev.alloc::(elem_count).w()? }; + let func = dev.get_or_load_func("kvconcat_u8", kernels::KVCONCAT)?; + let params = ( + left_, + right_, + &out, + self.concat_dim, + chunk_l, + chunk_r, + lstride, + rstride, + ); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::U8(out) + } + _ => Err(CudaError::InternalError("dtype mismatch in kvconcat op"))?, + }; + + let mut lshape: Vec = ltensor_l.shape().dims().to_vec(); + if self.concat_dim == 0 { + lshape[0] += rtensor_l.shape().dims()[0]; + } else { + if dim_size > 3 { + lshape[2] += rtensor_l.shape().dims()[2]; + } else { + lshape[1] += rtensor_l.shape().dims()[1]; + } + } + + let device = dev.clone(); + Ok(( + CudaStorage { + slice: slice, + device, + }, + lshape.into(), + )) + } +} diff --git a/candle-core/src/cuda_backend/utils.rs b/candle-core/src/cuda_backend/utils.rs index c1210727ad..581d687aac 100644 --- a/candle-core/src/cuda_backend/utils.rs +++ b/candle-core/src/cuda_backend/utils.rs @@ -19,11 +19,14 @@ pub trait Map1 { let out = match s { S::U8(s) => S::U8(self.f(s, d, l)?), S::U32(s) => S::U32(self.f(s, d, l)?), + S::I16(s) => S::I16(self.f(s, d, l)?), + S::I32(s) => S::I32(self.f(s, d, l)?), S::I64(s) => S::I64(self.f(s, d, l)?), S::BF16(s) => S::BF16(self.f(s, d, l)?), S::F16(s) => S::F16(self.f(s, d, l)?), S::F32(s) => S::F32(self.f(s, d, l)?), S::F64(s) => S::F64(self.f(s, d, l)?), + S::F8E4M3(s) => S::F8E4M3(self.f(s, d, l)?), }; Ok(out) } @@ -48,6 +51,7 @@ pub trait Map2 { (S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?), (S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?), (S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?), + (S::F8E4M3(s1), S::F8E4M3(s2)) => S::F8E4M3(self.f(s1, l1, s2, l2, d)?), _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, }; Ok(out) @@ -86,6 +90,9 @@ pub trait Map3 { (S::F16(s1), S::F16(s2), S::F16(s3)) => S::F16(self.f(s1, l1, s2, l2, s3, l3, d)?), (S::F32(s1), S::F32(s2), S::F32(s3)) => S::F32(self.f(s1, l1, s2, l2, s3, l3, d)?), (S::F64(s1), S::F64(s2), S::F64(s3)) => S::F64(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::F8E4M3(s1), S::F8E4M3(s2), S::F8E4M3(s3)) => { + S::F8E4M3(self.f(s1, l1, s2, l2, s3, l3, d)?) + } _ => Err(CudaError::InternalError("dtype mismatch in ternary op"))?, }; Ok(out) @@ -118,6 +125,7 @@ pub trait Map2InPlace { (S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d), (S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d), (S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d), + (S::F8E4M3(dst), S::F8E4M3(src)) => self.f(dst, dst_s, src, src_l, d), _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, } } @@ -136,11 +144,14 @@ pub trait Map1Any { let out = match s { S::U8(s) => self.f(s, d, l, S::U8)?, S::U32(s) => self.f(s, d, l, S::U32)?, + S::I16(s) => self.f(s, d, l, S::I16)?, + S::I32(s) => self.f(s, d, l, S::I32)?, S::I64(s) => self.f(s, d, l, S::I64)?, S::BF16(s) => self.f(s, d, l, S::BF16)?, S::F16(s) => self.f(s, d, l, S::F16)?, S::F32(s) => self.f(s, d, l, S::F32)?, S::F64(s) => self.f(s, d, l, S::F64)?, + S::F8E4M3(s) => self.f(s, d, l, S::F8E4M3)?, }; Ok(out) } @@ -165,6 +176,7 @@ pub trait Map2Any { (S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?, (S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?, (S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?, + (S::F8E4M3(s1), S::F8E4M3(s2)) => self.f(s1, l1, s2, l2, d)?, _ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?, }; Ok(out) diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 91e569372d..22721ce98e 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -142,6 +142,15 @@ impl Device { } } + /// Get the current seed for the device RNG. + pub fn get_current_seed(&self) -> Result { + match self { + Self::Cpu => CpuDevice.get_current_seed(), + Self::Cuda(c) => c.get_current_seed(), + Self::Metal(m) => m.get_current_seed(), + } + } + pub fn same_device(&self, rhs: &Self) -> bool { match (self, rhs) { (Self::Cpu, Self::Cpu) => true, @@ -341,12 +350,12 @@ impl Device { Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())), Device::Cuda(device) => { let storage = array.to_cpu_storage(); - let storage = device.storage_from_cpu_storage_owned(storage)?; + let storage = device.storage_from_cpu_storage(&storage)?; Ok(Storage::Cuda(storage)) } Device::Metal(device) => { let storage = array.to_cpu_storage(); - let storage = device.storage_from_cpu_storage_owned(storage)?; + let storage = device.storage_from_cpu_storage(&storage)?; Ok(Storage::Metal(storage)) } } @@ -357,12 +366,12 @@ impl Device { Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))), Device::Cuda(device) => { let storage = S::to_cpu_storage_owned(data); - let storage = device.storage_from_cpu_storage_owned(storage)?; + let storage = device.storage_from_cpu_storage(&storage)?; Ok(Storage::Cuda(storage)) } Device::Metal(device) => { let storage = S::to_cpu_storage_owned(data); - let storage = device.storage_from_cpu_storage_owned(storage)?; + let storage = device.storage_from_cpu_storage(&storage)?; Ok(Storage::Metal(storage)) } } diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs index 7e6e3cf8f1..c975440aa9 100644 --- a/candle-core/src/display.rs +++ b/candle-core/src/display.rs @@ -2,6 +2,7 @@ /// This implementation should be in line with the PyTorch version. /// https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py use crate::{DType, Result, Tensor, WithDType}; +use float8::F8E4M3; use half::{bf16, f16}; impl Tensor { @@ -55,11 +56,14 @@ impl std::fmt::Debug for Tensor { match self.dtype() { DType::U8 => self.fmt_dt::(f), DType::U32 => self.fmt_dt::(f), + DType::I16 => self.fmt_dt::(f), + DType::I32 => self.fmt_dt::(f), DType::I64 => self.fmt_dt::(f), DType::BF16 => self.fmt_dt::(f), DType::F16 => self.fmt_dt::(f), DType::F32 => self.fmt_dt::(f), DType::F64 => self.fmt_dt::(f), + DType::F8E4M3 => self.fmt_dt::(f), } } } @@ -463,6 +467,18 @@ impl std::fmt::Display for Tensor { tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; writeln!(f)?; } + DType::I16 => { + let tf: IntFormatter = IntFormatter::new(); + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + DType::I32 => { + let tf: IntFormatter = IntFormatter::new(); + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } DType::I64 => { let tf: IntFormatter = IntFormatter::new(); let max_w = tf.max_width(&to_display); @@ -497,6 +513,9 @@ impl std::fmt::Display for Tensor { writeln!(f)?; } } + DType::F8E4M3 => { + return write!(f, "F8E4M3 does not support display."); + } }; let device_str = match self.device().location() { diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index de6cddc3a3..f40ec3f7e1 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -1,15 +1,22 @@ //! Types for elements that can be stored and manipulated using tensors. #![allow(clippy::redundant_closure_call)] use crate::backend::BackendStorage; +use crate::cpu::kernels::VecOps; use crate::{CpuStorage, CpuStorageRef, Error, Result}; /// The different types of elements allowed in tensors. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum DType { + // Floating-point 8 bits integer (4-bit exponent, 3-bit mantissa). + F8E4M3, // Unsigned 8 bits integer. U8, // Unsigned 32 bits integer. U32, + // Signed 16 bits integer. + I16, + // Signed 32 bits integer. + I32, // Signed 64 bits integer. I64, // Brain floating-point using half precision (16 bits). @@ -39,11 +46,14 @@ impl std::str::FromStr for DType { match s { "u8" => Ok(Self::U8), "u32" => Ok(Self::U32), + "i16" => Ok(Self::I16), + "i32" => Ok(Self::I32), "i64" => Ok(Self::I64), "bf16" => Ok(Self::BF16), "f16" => Ok(Self::F16), "f32" => Ok(Self::F32), "f64" => Ok(Self::F64), + "f8_e4m3" => Ok(Self::F8E4M3), _ => Err(DTypeParseError(s.to_string())), } } @@ -55,11 +65,14 @@ impl DType { match self { Self::U8 => "u8", Self::U32 => "u32", + Self::I16 => "i16", + Self::I32 => "i32", Self::I64 => "i64", Self::BF16 => "bf16", Self::F16 => "f16", Self::F32 => "f32", Self::F64 => "f64", + Self::F8E4M3 => "f8_e4m3", } } @@ -67,7 +80,10 @@ impl DType { pub fn size_in_bytes(&self) -> usize { match self { Self::U8 => 1, + Self::F8E4M3 => 1, Self::U32 => 4, + Self::I16 => 2, + Self::I32 => 4, Self::I64 => 8, Self::BF16 => 2, Self::F16 => 2, @@ -78,15 +94,15 @@ impl DType { pub fn is_int(&self) -> bool { match self { - Self::U8 | Self::U32 | Self::I64 => true, - Self::BF16 | Self::F16 | Self::F32 | Self::F64 => false, + Self::U8 | Self::U32 | Self::I16 | Self::I32 | Self::I64 => true, + Self::BF16 | Self::F16 | Self::F32 | Self::F64 | Self::F8E4M3 => false, } } pub fn is_float(&self) -> bool { match self { - Self::U8 | Self::U32 | Self::I64 => false, - Self::BF16 | Self::F16 | Self::F32 | Self::F64 => true, + Self::U8 | Self::U32 | Self::I16 | Self::I32 | Self::I64 => false, + Self::BF16 | Self::F16 | Self::F32 | Self::F64 | Self::F8E4M3 => true, } } } @@ -165,21 +181,53 @@ macro_rules! with_dtype { } }; } +use float8::F8E4M3; use half::{bf16, f16}; with_dtype!(u8, U8, |v: f64| v as u8, |v: u8| v as f64); with_dtype!(u32, U32, |v: f64| v as u32, |v: u32| v as f64); +with_dtype!(i16, I16, |v: f64| v as i16, |v: i16| v as f64); +with_dtype!(i32, I32, |v: f64| v as i32, |v: i32| v as f64); with_dtype!(i64, I64, |v: f64| v as i64, |v: i64| v as f64); with_dtype!(f16, F16, f16::from_f64, f16::to_f64); with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64); with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64); with_dtype!(f64, F64, |v: f64| v, |v: f64| v); +with_dtype!(F8E4M3, F8E4M3, |v: f64| F8E4M3::from_f64(v), |v: F8E4M3| v + .to_f64()); + +impl VecOps for F8E4M3 { + fn max(self, rhs: Self) -> Self { + F8E4M3::max(self, rhs) + } + fn min(self, rhs: Self) -> Self { + F8E4M3::min(self, rhs) + } +} pub trait IntDType: WithDType { fn is_true(&self) -> bool; fn as_usize(&self) -> usize; } +impl IntDType for i16 { + fn is_true(&self) -> bool { + *self != 0 + } + fn as_usize(&self) -> usize { + *self as usize + } +} + +impl IntDType for i32 { + fn is_true(&self) -> bool { + *self != 0 + } + fn as_usize(&self) -> usize { + *self as usize + } +} + impl IntDType for i64 { fn is_true(&self) -> bool { *self != 0 @@ -213,3 +261,4 @@ impl FloatDType for f16 {} impl FloatDType for bf16 {} impl FloatDType for f32 {} impl FloatDType for f64 {} +impl FloatDType for F8E4M3 {} diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 68eef1efed..9fa1970b00 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -140,9 +140,23 @@ impl crate::backend::BackendStorage for CudaStorage { Err(Error::NotCompiledWithCudaSupport) } - fn matmul( + fn matmul_with_alpha_beta( &self, _: &Self, + _: &mut Self, + _: Option, + _: (usize, usize, usize, usize), + _: &Layout, + _: &Layout, + _: &Layout, + ) -> Result<()> { + Err(Error::NotCompiledWithCudaSupport) + } + + fn matmul_with_alpha( + &self, + _: &Self, + _: Option, _: (usize, usize, usize, usize), _: &Layout, _: &Layout, @@ -194,6 +208,10 @@ impl crate::backend::BackendDevice for CudaDevice { Err(Error::NotCompiledWithCudaSupport) } + fn get_current_seed(&self) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + fn location(&self) -> crate::DeviceLocation { fail!() } diff --git a/candle-core/src/dummy_metal_backend.rs b/candle-core/src/dummy_metal_backend.rs index a1c2394d49..2a3ea93c03 100644 --- a/candle-core/src/dummy_metal_backend.rs +++ b/candle-core/src/dummy_metal_backend.rs @@ -152,9 +152,23 @@ impl crate::backend::BackendStorage for MetalStorage { Err(Error::NotCompiledWithMetalSupport) } - fn matmul( + fn matmul_with_alpha_beta( &self, _: &Self, + _: &mut Self, + _: Option, + _: (usize, usize, usize, usize), + _: &Layout, + _: &Layout, + _: &Layout, + ) -> Result<()> { + Err(Error::NotCompiledWithMetalSupport) + } + + fn matmul_with_alpha( + &self, + _: &Self, + _: Option, _: (usize, usize, usize, usize), _: &Layout, _: &Layout, @@ -206,6 +220,10 @@ impl crate::backend::BackendDevice for MetalDevice { Err(Error::NotCompiledWithMetalSupport) } + fn get_current_seed(&self) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + fn location(&self) -> crate::DeviceLocation { fail!() } diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index e7112e2e61..66f9fd4175 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -1,3 +1,8 @@ +use std::{ + convert::Infallible, + fmt::{Debug, Display}, +}; + use crate::{DType, DeviceLocation, Layout, MetalError, Shape}; #[derive(Debug, Clone)] @@ -26,6 +31,14 @@ pub enum Error { op: &'static str, }, + #[error("dtype mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}, c: {rhs:?}")] + DTypeMismatchBinaryOp3 { + lhs: DType, + rhs: DType, + c: DType, + op: &'static str, + }, + #[error("unsupported dtype {0:?} for op {1}")] UnsupportedDTypeForOp(DType, &'static str), @@ -100,6 +113,14 @@ pub enum Error { op: &'static str, }, + #[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}, c: {c:?}")] + DeviceMismatchBinaryOp3 { + lhs: DeviceLocation, + rhs: DeviceLocation, + c: DeviceLocation, + op: &'static str, + }, + // === Op Specific Errors === #[error("narrow invalid args {msg}: {shape:?}, dim: {dim}, start: {start}, len:{len}")] NarrowInvalidArgs { @@ -194,6 +215,13 @@ pub enum Error { #[error(transparent)] Wrapped(Box), + /// Arbitrary errors wrapping with context. + #[error("{wrapped:?}\n{context:?}")] + WrappedContext { + wrapped: Box, + context: String, + }, + /// Adding path information to an error. #[error("path: {path:?} {inner}")] WithPath { @@ -215,14 +243,21 @@ pub enum Error { pub type Result = std::result::Result; impl Error { + /// Create a new error by wrapping another. pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self { Self::Wrapped(Box::new(err)).bt() } - pub fn msg(err: impl std::error::Error) -> Self { - Self::Msg(err.to_string()).bt() + /// Create a new error based on a printable error message. + /// + /// If the message implements `std::error::Error`, prefer using [`Error::wrap`] instead. + pub fn msg(msg: M) -> Self { + Self::Msg(msg.to_string()).bt() } + /// Create a new error based on a debuggable error message. + /// + /// If the message implements `std::error::Error`, prefer using [`Error::wrap`] instead. pub fn debug(err: impl std::fmt::Debug) -> Self { Self::Msg(format!("{err:?}")).bt() } @@ -267,3 +302,86 @@ pub fn zip(r1: Result, r2: Result) -> Result<(T, U)> { (_, Err(e)) => Err(e), } } + +pub(crate) mod private { + pub trait Sealed {} + + impl Sealed for std::result::Result where E: std::error::Error {} + impl Sealed for Option {} +} + +/// Attach more context to an error. +/// +/// Inspired by [`anyhow::Context`]. +pub trait Context: private::Sealed { + /// Wrap the error value with additional context. + fn context(self, context: C) -> std::result::Result + where + C: Display + Send + Sync + 'static; + + /// Wrap the error value with additional context that is evaluated lazily + /// only once an error does occur. + fn with_context(self, f: F) -> std::result::Result + where + C: Display + Send + Sync + 'static, + F: FnOnce() -> C; +} + +impl Context for std::result::Result +where + E: std::error::Error + Send + Sync + 'static, +{ + fn context(self, context: C) -> std::result::Result + where + C: Display + Send + Sync + 'static, + { + // Not using map_err to save 2 useless frames off the captured backtrace + // in ext_context. + match self { + Ok(ok) => Ok(ok), + Err(error) => Err(Error::WrappedContext { + wrapped: Box::new(error), + context: context.to_string(), + }), + } + } + + fn with_context(self, context: F) -> std::result::Result + where + C: Display + Send + Sync + 'static, + F: FnOnce() -> C, + { + match self { + Ok(ok) => Ok(ok), + Err(error) => Err(Error::WrappedContext { + wrapped: Box::new(error), + context: context().to_string(), + }), + } + } +} + +impl Context for Option { + fn context(self, context: C) -> std::result::Result + where + C: Display + Send + Sync + 'static, + { + // Not using ok_or_else to save 2 useless frames off the captured + // backtrace. + match self { + Some(ok) => Ok(ok), + None => Err(Error::msg(context)), + } + } + + fn with_context(self, context: F) -> std::result::Result + where + C: Display + Send + Sync + 'static, + F: FnOnce() -> C, + { + match self { + Some(ok) => Ok(ok), + None => Err(Error::msg(context())), + } + } +} diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index d8d6253213..7edb81f2b5 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -69,6 +69,7 @@ pub mod streaming; mod strided_index; mod tensor; mod tensor_cat; +mod tensor_indexing; pub mod test_utils; pub mod utils; mod variable; @@ -80,14 +81,14 @@ pub use cpu_backend::{CpuStorage, CpuStorageRef}; pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3}; pub use device::{Device, DeviceLocation, NdArray}; pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType}; -pub use error::{Error, Result}; +pub use error::{Context, Error, Result}; pub use indexer::{IndexOp, TensorIndexer}; pub use layout::Layout; pub use shape::{Shape, D}; pub use storage::Storage; pub use streaming::{StreamTensor, StreamingBinOp, StreamingModule}; pub use strided_index::{StridedBlocks, StridedIndex}; -pub use tensor::{Tensor, TensorId}; +pub use tensor::{from_storage_no_op, Tensor, TensorId}; pub use variable::Var; #[cfg(feature = "cuda")] diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs index 29b8995bc9..f46b4201a6 100644 --- a/candle-core/src/metal_backend/device.rs +++ b/candle-core/src/metal_backend/device.rs @@ -123,6 +123,8 @@ pub struct MetalDevice { pub(crate) seed: Arc>, /// Whether to use the MLX matmul kernels instead of the MFA ones. pub(crate) use_mlx_mm: bool, + /// Value of the current seed + pub(crate) seed_value: Arc>, } impl std::fmt::Debug for MetalDevice { diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 6f560c02ee..7fad400eae 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -96,6 +96,8 @@ impl BackendStorage for MetalStorage { match self.dtype { DType::U8 => Ok(CpuStorage::U8(self.to_cpu()?)), DType::U32 => Ok(CpuStorage::U32(self.to_cpu()?)), + DType::I16 => Ok(CpuStorage::I16(self.to_cpu()?)), + DType::I32 => Ok(CpuStorage::I32(self.to_cpu()?)), DType::I64 => Ok(CpuStorage::I64(self.to_cpu()?)), DType::F16 => Ok(CpuStorage::F16(self.to_cpu()?)), DType::BF16 => Ok(CpuStorage::BF16(self.to_cpu()?)), @@ -304,6 +306,16 @@ impl BackendStorage for MetalStorage { (ReduceOp::Max, DType::BF16) => ("fast_max_bf16_strided", true, false), (ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16_strided", true, true), (ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16_strided", true, true), + (ReduceOp::Sum, DType::I16) => ("fast_sum_i16_strided", false, false), + (ReduceOp::Min, DType::I16) => ("fast_min_i16_strided", true, false), + (ReduceOp::Max, DType::I16) => ("fast_max_i16_strided", true, false), + (ReduceOp::ArgMin, DType::I16) => ("fast_argmin_i16_strided", true, true), + (ReduceOp::ArgMax, DType::I16) => ("fast_argmax_i16_strided", true, true), + (ReduceOp::Sum, DType::I32) => ("fast_sum_i32_strided", false, false), + (ReduceOp::Min, DType::I32) => ("fast_min_i32_strided", true, false), + (ReduceOp::Max, DType::I32) => ("fast_max_i32_strided", true, false), + (ReduceOp::ArgMin, DType::I32) => ("fast_argmin_i32_strided", true, true), + (ReduceOp::ArgMax, DType::I32) => ("fast_argmax_i32_strided", true, true), (ReduceOp::Sum, DType::I64) => ("fast_sum_i64_strided", false, false), (ReduceOp::Min, DType::I64) => ("fast_min_i64_strided", true, false), (ReduceOp::Max, DType::I64) => ("fast_max_i64_strided", true, false), @@ -363,21 +375,39 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::BF16) => "cast_u32_bf16", (DType::U32, DType::F16) => "cast_u32_f16", (DType::U32, DType::F32) => "cast_u32_f32", + (DType::U32, DType::I16) => "cast_u32_i16", + (DType::U32, DType::I32) => "cast_u32_i32", (DType::U32, DType::I64) => "cast_u32_i64", (DType::U32, DType::U8) => "cast_u32_u8", (DType::U8, DType::BF16) => "cast_u8_bf16", (DType::U8, DType::F16) => "cast_u8_f16", (DType::U8, DType::F32) => "cast_u8_f32", + (DType::U8, DType::I16) => "cast_u8_i16", + (DType::U8, DType::I32) => "cast_u8_i32", (DType::U8, DType::I64) => "cast_u8_i64", (DType::U8, DType::U32) => "cast_u8_u32", (DType::F32, DType::BF16) => "cast_f32_bf16", (DType::F32, DType::F16) => "cast_f32_f16", + (DType::F32, DType::I16) => "cast_f32_i16", + (DType::F32, DType::I32) => "cast_f32_i32", (DType::F32, DType::I64) => "cast_f32_i64", (DType::F32, DType::U32) => "cast_f32_u32", (DType::F32, DType::U8) => "cast_f32_u8", + (DType::I16, DType::BF16) => "cast_i16_bf16", + (DType::I16, DType::F16) => "cast_i16_f16", + (DType::I16, DType::F32) => "cast_i16_f32", + (DType::I16, DType::U32) => "cast_i16_u32", + (DType::I16, DType::U8) => "cast_i16_u8", + + (DType::I32, DType::BF16) => "cast_i32_bf16", + (DType::I32, DType::F16) => "cast_i32_f16", + (DType::I32, DType::F32) => "cast_i32_f32", + (DType::I32, DType::U32) => "cast_i32_u32", + (DType::I32, DType::U8) => "cast_i32_u8", + (DType::I64, DType::BF16) => "cast_i64_bf16", (DType::I64, DType::F16) => "cast_i64_f16", (DType::I64, DType::F32) => "cast_i64_f32", @@ -386,12 +416,16 @@ impl BackendStorage for MetalStorage { (DType::F16, DType::BF16) => "cast_f16_bf16", (DType::F16, DType::F32) => "cast_f16_f32", + (DType::F16, DType::I16) => "cast_f16_i16", + (DType::F16, DType::I32) => "cast_f16_i32", (DType::F16, DType::I64) => "cast_f16_i64", (DType::F16, DType::U32) => "cast_f16_u32", (DType::F16, DType::U8) => "cast_f16_u8", (DType::BF16, DType::F16) => "cast_bf16_f16", (DType::BF16, DType::F32) => "cast_bf16_f32", + (DType::BF16, DType::I16) => "cast_bf16_i16", + (DType::BF16, DType::I32) => "cast_bf16_i32", (DType::BF16, DType::I64) => "cast_bf16_i64", (DType::BF16, DType::U32) => "cast_bf16_u32", (DType::BF16, DType::U8) => "cast_bf16_u8", @@ -439,15 +473,23 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::BF16) => "cast_u32_bf16_strided", (DType::U32, DType::F16) => "cast_u32_f16_strided", (DType::U32, DType::F32) => "cast_u32_f32_strided", + (DType::U32, DType::U8) => "cast_u32_u8_strided", + (DType::U32, DType::I16) => "cast_u32_i16_strided", + (DType::U32, DType::I32) => "cast_u32_i32_strided", (DType::U32, DType::I64) => "cast_u32_i64_strided", (DType::U32, DType::U8) => "cast_u32_u8_strided", (DType::U8, DType::BF16) => "cast_u8_bf16_strided", (DType::U8, DType::F16) => "cast_u8_f16_strided", (DType::U8, DType::F32) => "cast_u8_f32_strided", + (DType::U8, DType::I16) => "cast_u8_i16_strided", + (DType::U8, DType::I32) => "cast_u8_i32_strided", (DType::U8, DType::I64) => "cast_u8_i64_strided", (DType::U8, DType::U32) => "cast_u8_u32_strided", + (DType::I16, DType::F32) => "cast_i16_f32_strided", + (DType::I32, DType::F32) => "cast_i32_f32_strided", + (left, right) => { crate::bail!("Metal strided to_dtype {left:?} {right:?} not implemented") } @@ -539,6 +581,8 @@ impl BackendStorage for MetalStorage { ("usign", DType::F16) => contiguous_tiled::sign::HALF, ("usign", DType::F32) => contiguous_tiled::sign::FLOAT, ("usign", DType::BF16) => contiguous_tiled::sign::BFLOAT, + ("usign", DType::I16) => contiguous_tiled::sign::I16, + ("usign", DType::I32) => contiguous_tiled::sign::I32, ("usign", DType::I64) => contiguous_tiled::sign::I64, (name, dtype) => { crate::bail!( @@ -617,6 +661,8 @@ impl BackendStorage for MetalStorage { ("usign", DType::F16) => contiguous::sign::HALF, ("usign", DType::F32) => contiguous::sign::FLOAT, ("usign", DType::BF16) => contiguous::sign::BFLOAT, + ("usign", DType::I16) => contiguous::sign::I16, + ("usign", DType::I32) => contiguous::sign::I32, ("usign", DType::I64) => contiguous::sign::I64, (name, dtype) => { crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented") @@ -748,6 +794,8 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::F32) => "where_u32_f32", (DType::U8, DType::BF16) => "where_u8_bf16", (DType::U8, DType::F16) => "where_u8_f16", + (DType::U8, DType::I16) => "where_u8_i16", + (DType::U8, DType::I32) => "where_u8_i32", (DType::U8, DType::I64) => "where_u8_i64", (DType::U8, DType::U32) => "where_u8_u32", (DType::U8, DType::U8) => "where_u8_u8", @@ -829,7 +877,7 @@ impl BackendStorage for MetalStorage { let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? } else { // Make the kernel contiguous if not already the case. let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; @@ -837,7 +885,7 @@ impl BackendStorage for MetalStorage { let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? }; let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?; let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; @@ -888,8 +936,9 @@ impl BackendStorage for MetalStorage { vec![0, k_size * c_out, 1], k_layout.start_offset(), ); - self.matmul( + self.matmul_with_alpha( k, + None, (b_size, l_in, c_out * k_size, c_in), &layout.transpose(1, 2)?, &kernel_l_mm, @@ -1020,7 +1069,7 @@ impl BackendStorage for MetalStorage { let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? } else { // Make the kernel contiguous if not already the case. let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; @@ -1028,7 +1077,7 @@ impl BackendStorage for MetalStorage { let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? }; let res_l = Layout::contiguous((b, h_out, w_out, n)) .transpose(1, 2)? @@ -1284,6 +1333,12 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::F32) => "sa_u32_f32", (DType::U32, DType::F16) => "sa_u32_f16", (DType::U32, DType::BF16) => "sa_u32_bf16", + (DType::I16, DType::F32) => "sa_i16_f32", + (DType::I16, DType::F16) => "sa_i16_f16", + (DType::I16, DType::BF16) => "sa_i16_bf16", + (DType::I32, DType::F32) => "sa_i32_f32", + (DType::I32, DType::F16) => "sa_i32_f16", + (DType::I32, DType::BF16) => "sa_i32_bf16", (DType::I64, DType::F32) => "sa_i64_f32", (DType::I64, DType::F16) => "sa_i64_f16", (DType::I64, DType::BF16) => "sa_i64_bf16", @@ -1332,6 +1387,14 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::F16) => "is_u32_f16", (DType::U32, DType::BF16) => "is_u32_bf16", + (DType::I16, DType::F32) => "is_i16_f32", + (DType::I16, DType::F16) => "is_i16_f16", + (DType::I16, DType::BF16) => "is_i16_bf16", + + (DType::I32, DType::F32) => "is_i32_f32", + (DType::I32, DType::F16) => "is_i32_f16", + (DType::I32, DType::BF16) => "is_i32_bf16", + (DType::I64, DType::F32) => "is_i64_f32", (DType::I64, DType::F16) => "is_i64_f16", (DType::I64, DType::BF16) => "is_i64_bf16", @@ -1377,9 +1440,27 @@ impl BackendStorage for MetalStorage { return Err(crate::Error::RequiresContiguous { op: "index-add" }.bt()); }; let name = match (ids.dtype, self.dtype) { + (DType::I16, DType::BF16) => "ia_i16_bf16", + (DType::I16, DType::F16) => "ia_i16_f16", + (DType::I16, DType::F32) => "ia_i16_f32", + (DType::I16, DType::I32) => "ia_i16_i32", + (DType::I16, DType::I64) => "ia_i16_i64", + (DType::I16, DType::U32) => "ia_i16_u32", + (DType::I16, DType::U8) => "ia_i16_u8", + + (DType::I32, DType::BF16) => "ia_i32_bf16", + (DType::I32, DType::F16) => "ia_i32_f16", + (DType::I32, DType::F32) => "ia_i32_f32", + (DType::I32, DType::I32) => "ia_i32_i32", + (DType::I32, DType::I64) => "ia_i32_i64", + (DType::I32, DType::U32) => "ia_i32_u32", + (DType::I32, DType::U8) => "ia_i32_u8", + (DType::I64, DType::BF16) => "ia_i64_bf16", (DType::I64, DType::F16) => "ia_i64_f16", (DType::I64, DType::F32) => "ia_i64_f32", + (DType::I64, DType::I16) => "ia_i64_i16", + (DType::I64, DType::I32) => "ia_i64_i32", (DType::I64, DType::I64) => "ia_i64_i64", (DType::I64, DType::U32) => "ia_i64_u32", (DType::I64, DType::U8) => "ia_i64_u8", @@ -1387,6 +1468,8 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::BF16) => "ia_u32_bf16", (DType::U32, DType::F16) => "ia_u32_f16", (DType::U32, DType::F32) => "ia_u32_f32", + (DType::U32, DType::I16) => "ia_u32_i16", + (DType::U32, DType::I32) => "ia_u32_i32", (DType::U32, DType::I64) => "ia_u32_i64", (DType::U32, DType::U32) => "ia_u32_u32", (DType::U32, DType::U8) => "ia_u32_u8", @@ -1394,6 +1477,8 @@ impl BackendStorage for MetalStorage { (DType::U8, DType::BF16) => "ia_u8_bf16", (DType::U8, DType::F16) => "ia_u8_f16", (DType::U8, DType::F32) => "ia_u8_f32", + (DType::U8, DType::I16) => "ia_u8_i16", + (DType::U8, DType::I32) => "ia_u8_i32", (DType::U8, DType::I64) => "ia_u8_i64", (DType::U8, DType::U32) => "ia_u8_u32", (DType::U8, DType::U8) => "ia_u8_u8", @@ -1424,9 +1509,65 @@ impl BackendStorage for MetalStorage { Ok(acc) } - fn matmul( + fn matmul_with_alpha_beta( &self, rhs: &Self, + c: &mut Self, + s: Option, + (b, m, n, k): (usize, usize, usize, usize), + lhs_l: &Layout, + rhs_l: &Layout, + c_l: &Layout, + ) -> Result<()> { + let name = match self.dtype { + DType::F32 => "sgemm", + DType::F16 => "hgemm", + DType::BF16 => "bgemm", + dtype => { + return Err(MetalError::Message(format!("matmul doesn't support {dtype:?}")).into()) + } + }; + + let elem_count = b * m * n; + + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + crate::bail!("`c` start offset must be 0"); + } + if o2 != elem_count { + crate::bail!("`c` end offset must be {}", elem_count) + } + } + None => crate::bail!("`c` has to be contiguous"), + }; + + let command_buffer = self.device.command_buffer()?; + command_buffer.set_label("matmul"); + candle_metal_kernels::call_gemm( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + (b, m, n, k), + lhs_l.stride(), + lhs_l.start_offset() * self.dtype.size_in_bytes(), + &self.buffer, + rhs_l.stride(), + rhs_l.start_offset() * rhs.dtype.size_in_bytes(), + &rhs.buffer, + &c.buffer, + s.unwrap_or(1.) as f32, + 1., + ) + .map_err(MetalError::from)?; + Ok(()) + } + + fn matmul_with_alpha( + &self, + rhs: &Self, + s: Option, (b, m, n, k): (usize, usize, usize, usize), lhs_l: &Layout, rhs_l: &Layout, @@ -1435,6 +1576,11 @@ impl BackendStorage for MetalStorage { let command_buffer = self.device.command_buffer()?; command_buffer.set_label("matmul"); if self.dtype == DType::BF16 { + if s.unwrap_or(1.) != 1. { + return Err( + MetalError::Message(format!("mlx matmul doesn't support alpha {s:?}")).into(), + ); + } candle_metal_kernels::call_mlx_gemm( &self.device.device, &command_buffer, @@ -1451,6 +1597,11 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } else if self.device.use_mlx_mm { + if s.unwrap_or(1.) != 1. { + return Err( + MetalError::Message(format!("mlx matmul doesn't support alpha {s:?}")).into(), + ); + } let dtype = match self.dtype { DType::F32 => candle_metal_kernels::GemmDType::F32, DType::F16 => candle_metal_kernels::GemmDType::F16, @@ -1501,6 +1652,8 @@ impl BackendStorage for MetalStorage { rhs_l.start_offset() * rhs.dtype.size_in_bytes(), &rhs.buffer, &buffer, + s.unwrap_or(1.) as f32, + 0., ) .map_err(MetalError::from)?; } @@ -1548,6 +1701,8 @@ impl BackendStorage for MetalStorage { DType::F32 => candle_metal_kernels::copy2d::FLOAT, DType::F16 => candle_metal_kernels::copy2d::HALF, DType::BF16 => candle_metal_kernels::copy2d::BFLOAT, + DType::I16 => candle_metal_kernels::copy2d::I16, + DType::I32 => candle_metal_kernels::copy2d::I32, DType::I64 => candle_metal_kernels::copy2d::I64, DType::U32 => candle_metal_kernels::copy2d::U32, DType::U8 => candle_metal_kernels::copy2d::U8, @@ -1594,6 +1749,8 @@ impl BackendStorage for MetalStorage { DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT, + DType::I16 => candle_metal_kernels::unary::strided::copy::I16, + DType::I32 => candle_metal_kernels::unary::strided::copy::I32, DType::I64 => candle_metal_kernels::unary::strided::copy::I64, DType::U32 => candle_metal_kernels::unary::strided::copy::U32, DType::U8 => candle_metal_kernels::unary::strided::copy::U8, @@ -1685,6 +1842,28 @@ impl MetalStorage { ("ge", DType::BF16) => (contiguous::ge::BFLOAT, DType::U8), ("gt", DType::BF16) => (contiguous::gt::BFLOAT, DType::U8), + ("add", DType::I16) => (contiguous::add::I16, self.dtype), + ("sub", DType::I16) => (contiguous::sub::I16, self.dtype), + ("mul", DType::I16) => (contiguous::mul::I16, self.dtype), + ("div", DType::I16) => (contiguous::div::I16, self.dtype), + ("eq", DType::I16) => (contiguous::eq::I16, DType::U8), + ("ne", DType::I16) => (contiguous::ne::I16, DType::U8), + ("le", DType::I16) => (contiguous::le::I16, DType::U8), + ("lt", DType::I16) => (contiguous::lt::I16, DType::U8), + ("ge", DType::I16) => (contiguous::ge::I16, DType::U8), + ("gt", DType::I16) => (contiguous::gt::I16, DType::U8), + + ("add", DType::I32) => (contiguous::add::I32, self.dtype), + ("sub", DType::I32) => (contiguous::sub::I32, self.dtype), + ("mul", DType::I32) => (contiguous::mul::I32, self.dtype), + ("div", DType::I32) => (contiguous::div::I32, self.dtype), + ("eq", DType::I32) => (contiguous::eq::I32, DType::U8), + ("ne", DType::I32) => (contiguous::ne::I32, DType::U8), + ("le", DType::I32) => (contiguous::le::I32, DType::U8), + ("lt", DType::I32) => (contiguous::lt::I32, DType::U8), + ("ge", DType::I32) => (contiguous::ge::I32, DType::U8), + ("gt", DType::I32) => (contiguous::gt::I32, DType::U8), + ("add", DType::I64) => (contiguous::add::I64, self.dtype), ("sub", DType::I64) => (contiguous::sub::I64, self.dtype), ("mul", DType::I64) => (contiguous::mul::I64, self.dtype), @@ -1778,6 +1957,32 @@ impl MetalStorage { ("ge", DType::BF16) => (strided::ge::BFLOAT, DType::U8), ("gt", DType::BF16) => (strided::gt::BFLOAT, DType::U8), + ("badd", DType::I16) => (strided::add::I16, self.dtype), + ("bsub", DType::I16) => (strided::sub::I16, self.dtype), + ("bmul", DType::I16) => (strided::mul::I16, self.dtype), + ("bdiv", DType::I16) => (strided::div::I16, self.dtype), + ("bminimum", DType::I16) => (strided::min::I16, self.dtype), + ("bmaximum", DType::I16) => (strided::max::I16, self.dtype), + ("eq", DType::I16) => (strided::eq::I16, DType::U8), + ("ne", DType::I16) => (strided::ne::I16, DType::U8), + ("le", DType::I16) => (strided::le::I16, DType::U8), + ("lt", DType::I16) => (strided::lt::I16, DType::U8), + ("ge", DType::I16) => (strided::ge::I16, DType::U8), + ("gt", DType::I16) => (strided::gt::I16, DType::U8), + + ("badd", DType::I32) => (strided::add::I32, self.dtype), + ("bsub", DType::I32) => (strided::sub::I32, self.dtype), + ("bmul", DType::I32) => (strided::mul::I32, self.dtype), + ("bdiv", DType::I32) => (strided::div::I32, self.dtype), + ("bminimum", DType::I32) => (strided::min::I32, self.dtype), + ("bmaximum", DType::I32) => (strided::max::I32, self.dtype), + ("eq", DType::I32) => (strided::eq::I32, DType::U8), + ("ne", DType::I32) => (strided::ne::I32, DType::U8), + ("le", DType::I32) => (strided::le::I32, DType::U8), + ("lt", DType::I32) => (strided::lt::I32, DType::U8), + ("ge", DType::I32) => (strided::ge::I32, DType::U8), + ("gt", DType::I32) => (strided::gt::I32, DType::U8), + ("badd", DType::I64) => (strided::add::I64, self.dtype), ("bsub", DType::I64) => (strided::sub::I64, self.dtype), ("bmul", DType::I64) => (strided::mul::I64, self.dtype), @@ -1882,6 +2087,7 @@ impl BackendDevice for MetalDevice { buffers: Arc::new(RwLock::new(HashMap::new())), kernels, seed, + seed_value: Arc::new(RwLock::new(299792458)), use_mlx_mm, }) } @@ -1925,6 +2131,8 @@ impl BackendDevice for MetalDevice { DType::F16 => "fill_f16", DType::BF16 => "fill_bf16", DType::F32 => "fill_f32", + DType::I32 => "fill_i32", + DType::I16 => "fill_i16", DType::F64 => { let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?; return self.storage_from_cpu_storage(&cpu_storage); @@ -1955,6 +2163,8 @@ impl BackendDevice for MetalDevice { let (count, buffer) = match T::cpu_storage_ref(s) { CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::I16(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::I32(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)), @@ -1968,6 +2178,8 @@ impl BackendDevice for MetalDevice { let (count, buffer) = match storage { CpuStorage::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::I16(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::I32(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)), @@ -2070,9 +2282,15 @@ impl BackendDevice for MetalDevice { } seed_buffer.did_modify_range(metal::NSRange::new(0, 4)); + *self.seed_value.write().unwrap() = seed as u64; + Ok(()) } + fn get_current_seed(&self) -> Result { + Ok(*self.seed_value.read().unwrap()) + } + fn synchronize(&self) -> Result<()> { self.wait_until_completed() } diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs index 83e4f6527f..28d5a63e90 100644 --- a/candle-core/src/npy.rs +++ b/candle-core/src/npy.rs @@ -27,11 +27,13 @@ //! ``` use crate::{DType, Device, Error, Result, Shape, Tensor}; use byteorder::{LittleEndian, ReadBytesExt}; +use float8::F8E4M3; use half::{bf16, f16, slice::HalfFloatSliceExt}; use std::collections::HashMap; use std::fs::File; use std::io::{BufReader, Read, Write}; use std::path::Path; +use std::slice; const NPY_MAGIC_STRING: &[u8] = b"\x93NUMPY"; const NPY_SUFFIX: &str = ".npy"; @@ -85,9 +87,12 @@ impl Header { DType::F16 => "f2", DType::F32 => "f4", DType::F64 => "f8", + DType::I16 => "i2", + DType::I32 => "i4", DType::I64 => "i8", DType::U32 => "u4", DType::U8 => "u1", + DType::F8E4M3 => Err(Error::Npy("f8e4m3 is not supported".into()))?, }; if !shape.is_empty() { shape.push(',') @@ -234,11 +239,28 @@ impl Tensor { reader.read_u32_into::(&mut data_t)?; Tensor::from_vec(data_t, shape, &Device::Cpu) } + DType::I16 => { + let mut data_t = vec![0i16; elem_count]; + reader.read_i16_into::(&mut data_t)?; + Tensor::from_vec(data_t, shape, &Device::Cpu) + } + DType::I32 => { + let mut data_t = vec![0i32; elem_count]; + reader.read_i32_into::(&mut data_t)?; + Tensor::from_vec(data_t, shape, &Device::Cpu) + } DType::I64 => { let mut data_t = vec![0i64; elem_count]; reader.read_i64_into::(&mut data_t)?; Tensor::from_vec(data_t, shape, &Device::Cpu) } + DType::F8E4M3 => { + let mut data_t = vec![F8E4M3::ZERO; elem_count]; + let ptr = data_t.as_mut_ptr().cast::(); + let len = data_t.len(); + reader.read_i8_into(unsafe { slice::from_raw_parts_mut(ptr, len) })?; + Tensor::from_vec(data_t, shape, &Device::Cpu) + } } } diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 49ba44be89..208977913a 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -1,5 +1,6 @@ #![allow(clippy::redundant_closure_call)] use crate::Tensor; +use float8::F8E4M3; use half::{bf16, f16}; use num_traits::float::Float; @@ -187,8 +188,11 @@ pub trait UnaryOpT { fn f16(v1: f16) -> f16; fn f32(v1: f32) -> f32; fn f64(v1: f64) -> f64; + fn f8e4m3(v1: F8E4M3) -> F8E4M3; fn u8(v1: u8) -> u8; fn u32(v1: u32) -> u32; + fn i16(v1: i16) -> i16; + fn i32(v1: i32) -> i32; fn i64(v1: i64) -> i64; // There is no very good way to represent optional function in traits so we go for an explicit @@ -197,6 +201,8 @@ pub trait UnaryOpT { fn bf16_vec(_xs: &[bf16], _ys: &mut [bf16]) {} const F16_VEC: bool = false; fn f16_vec(_xs: &[f16], _ys: &mut [f16]) {} + const F8E4M3_VEC: bool = false; + fn f8e4m3_vec(_xs: &[F8E4M3], _ys: &mut [F8E4M3]) {} const F32_VEC: bool = false; fn f32_vec(_xs: &[f32], _ys: &mut [f32]) {} const F64_VEC: bool = false; @@ -211,8 +217,11 @@ pub trait BinaryOpT { fn f16(v1: f16, v2: f16) -> f16; fn f32(v1: f32, v2: f32) -> f32; fn f64(v1: f64, v2: f64) -> f64; + fn f8e4m3(v1: F8E4M3, v2: F8E4M3) -> F8E4M3; fn u8(v1: u8, v2: u8) -> u8; fn u32(v1: u32, v2: u32) -> u32; + fn i16(v1: i16, v2: i16) -> i16; + fn i32(v1: i32, v2: i32) -> i32; fn i64(v1: i64, v2: i64) -> i64; const BF16_VEC: bool = false; @@ -223,12 +232,18 @@ pub trait BinaryOpT { fn f32_vec(_xs1: &[f32], _xs2: &[f32], _ys: &mut [f32]) {} const F64_VEC: bool = false; fn f64_vec(_xs1: &[f64], _xs2: &[f64], _ys: &mut [f64]) {} + const F8E4M3_VEC: bool = false; + fn f8e4m3_vec(_xs1: &[F8E4M3], __xs2: &[F8E4M3], _ys: &mut [F8E4M3]) {} const U8_VEC: bool = false; fn u8_vec(_xs1: &[u8], _xs2: &[u8], _ys: &mut [u8]) {} const U32_VEC: bool = false; fn u32_vec(_xs1: &[u32], _xs2: &[u32], _ys: &mut [u32]) {} const I64_VEC: bool = false; fn i64_vec(_xs1: &[i64], _xs2: &[i64], _ys: &mut [i64]) {} + const I32_VEC: bool = false; + fn i32_vec(_xs1: &[i32], _xs2: &[i32], _ys: &mut [i32]) {} + const I16_VEC: bool = false; + fn i16_vec(_xs1: &[i16], _xs2: &[i16], _ys: &mut [i16]) {} } pub(crate) struct Add; @@ -280,6 +295,10 @@ macro_rules! bin_op { $e(v1, v2) } #[inline(always)] + fn f8e4m3(v1: F8E4M3, v2: F8E4M3) -> F8E4M3 { + $e(v1, v2) + } + #[inline(always)] fn u8(v1: u8, v2: u8) -> u8 { $e(v1, v2) } @@ -288,6 +307,14 @@ macro_rules! bin_op { $e(v1, v2) } #[inline(always)] + fn i16(v1: i16, v2: i16) -> i16 { + $e(v1, v2) + } + #[inline(always)] + fn i32(v1: i32, v2: i32) -> i32 { + $e(v1, v2) + } + #[inline(always)] fn i64(v1: i64, v2: i64) -> i64 { $e(v1, v2) } @@ -360,6 +387,10 @@ macro_rules! unary_op { $e } #[inline(always)] + fn f8e4m3($a: F8E4M3) -> F8E4M3 { + $e + } + #[inline(always)] fn f32($a: f32) -> f32 { $e } @@ -379,6 +410,14 @@ macro_rules! unary_op { fn i64(_: i64) -> i64 { todo!("no unary function for i64") } + #[inline(always)] + fn i32(_: i32) -> i32 { + todo!("no unary function for i32") + } + #[inline(always)] + fn i16(_: i16) -> i16 { + todo!("no unary function for i16") + } } }; @@ -404,6 +443,10 @@ macro_rules! unary_op { $e } #[inline(always)] + fn f8e4m3($a: F8E4M3) -> F8E4M3 { + $e + } + #[inline(always)] fn u8(_: u8) -> u8 { todo!("no unary function for u8") } @@ -415,6 +458,14 @@ macro_rules! unary_op { fn i64(_: i64) -> i64 { todo!("no unary function for i64") } + #[inline(always)] + fn i32(_: i32) -> i32 { + todo!("no unary function for i32") + } + #[inline(always)] + fn i16(_: i16) -> i16 { + todo!("no unary function for i16") + } #[cfg(feature = "mkl")] const F32_VEC: bool = true; @@ -495,6 +546,17 @@ impl UnaryOpT for Gelu { )) } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + F8E4M3::from_f32(0.5) + * v + * (F8E4M3::ONE + + F8E4M3::tanh( + F8E4M3::from_f32(SQRT_TWO_OVER_PI_F32) + * v + * (F8E4M3::ONE + F8E4M3::from_f32(0.044715) * v * v), + )) + } + #[inline(always)] fn f32(v: f32) -> f32 { 0.5 * v * (1.0 + f32::tanh(SQRT_TWO_OVER_PI_F32 * v * (1.0 + 0.044715 * v * v))) } @@ -514,6 +576,14 @@ impl UnaryOpT for Gelu { fn i64(_: i64) -> i64 { 0 } + #[inline(always)] + fn i32(_: i32) -> i32 { + 0 + } + #[inline(always)] + fn i16(_: i16) -> i16 { + 0 + } const KERNEL: &'static str = "ugelu"; #[cfg(feature = "mkl")] @@ -568,6 +638,10 @@ impl UnaryOpT for Erf { f16::from_f64(Self::f64(v.to_f64())) } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + F8E4M3::from_f64(Self::f64(v.to_f64())) + } + #[inline(always)] fn f32(v: f32) -> f32 { Self::f64(v as f64) as f32 } @@ -587,6 +661,14 @@ impl UnaryOpT for Erf { fn i64(_: i64) -> i64 { 0 } + #[inline(always)] + fn i32(_: i32) -> i32 { + 0 + } + #[inline(always)] + fn i16(_: i16) -> i16 { + 0 + } } /// Silu operation @@ -602,6 +684,10 @@ impl UnaryOpT for Silu { v / (f16::ONE + (-v).exp()) } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v / (F8E4M3::ONE + (-v).exp()) + } + #[inline(always)] fn f32(v: f32) -> f32 { v / (1.0 + (-v).exp()) } @@ -621,6 +707,14 @@ impl UnaryOpT for Silu { fn i64(_: i64) -> i64 { 0 } + #[inline(always)] + fn i32(_: i32) -> i32 { + 0 + } + #[inline(always)] + fn i16(_: i16) -> i16 { + 0 + } const KERNEL: &'static str = "usilu"; #[cfg(feature = "mkl")] @@ -673,6 +767,10 @@ impl UnaryOpT for Abs { v.abs() } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v.abs() + } + #[inline(always)] fn f32(v: f32) -> f32 { v.abs() } @@ -692,6 +790,14 @@ impl UnaryOpT for Abs { fn i64(v: i64) -> i64 { v.abs() } + #[inline(always)] + fn i32(v: i32) -> i32 { + v.abs() + } + #[inline(always)] + fn i16(v: i16) -> i16 { + v.abs() + } } impl UnaryOpT for Ceil { @@ -707,6 +813,10 @@ impl UnaryOpT for Ceil { v.ceil() } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v.ceil() + } + #[inline(always)] fn f32(v: f32) -> f32 { v.ceil() } @@ -726,6 +836,14 @@ impl UnaryOpT for Ceil { fn i64(v: i64) -> i64 { v } + #[inline(always)] + fn i32(v: i32) -> i32 { + v + } + #[inline(always)] + fn i16(v: i16) -> i16 { + v + } } impl UnaryOpT for Floor { @@ -741,6 +859,10 @@ impl UnaryOpT for Floor { v.floor() } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v.floor() + } + #[inline(always)] fn f32(v: f32) -> f32 { v.floor() } @@ -760,6 +882,14 @@ impl UnaryOpT for Floor { fn i64(v: i64) -> i64 { v } + #[inline(always)] + fn i32(v: i32) -> i32 { + v + } + #[inline(always)] + fn i16(v: i16) -> i16 { + v + } } impl UnaryOpT for Round { @@ -775,6 +905,10 @@ impl UnaryOpT for Round { v.round() } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v.round() + } + #[inline(always)] fn f32(v: f32) -> f32 { v.round() } @@ -794,6 +928,14 @@ impl UnaryOpT for Round { fn i64(v: i64) -> i64 { v } + #[inline(always)] + fn i32(v: i32) -> i32 { + v + } + #[inline(always)] + fn i16(v: i16) -> i16 { + v + } } impl UnaryOpT for GeluErf { @@ -809,6 +951,10 @@ impl UnaryOpT for GeluErf { f16::from_f64(Self::f64(v.to_f64())) } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + F8E4M3::from_f64(Self::f64(v.to_f64())) + } + #[inline(always)] fn f32(v: f32) -> f32 { Self::f64(v as f64) as f32 } @@ -828,6 +974,14 @@ impl UnaryOpT for GeluErf { fn i64(_: i64) -> i64 { 0 } + #[inline(always)] + fn i32(_: i32) -> i32 { + 0 + } + #[inline(always)] + fn i16(_: i16) -> i16 { + 0 + } } impl UnaryOpT for Relu { @@ -843,6 +997,10 @@ impl UnaryOpT for Relu { v.max(f16::ZERO) } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v.max(F8E4M3::ZERO) + } + #[inline(always)] fn f32(v: f32) -> f32 { v.max(0f32) } @@ -862,6 +1020,14 @@ impl UnaryOpT for Relu { fn i64(v: i64) -> i64 { v } + #[inline(always)] + fn i32(v: i32) -> i32 { + v + } + #[inline(always)] + fn i16(v: i16) -> i16 { + v + } } /// `BackpropOp` is a wrapper around `Option`. The main goal is to ensure that dependencies are @@ -941,6 +1107,11 @@ impl UnaryOpT for Sign { f16::from((v > f16::ZERO) as i8) - f16::from((v < f16::ZERO) as i8) } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + F8E4M3::from((v > F8E4M3::ZERO) as i8 as f32) + - F8E4M3::from((v < F8E4M3::ZERO) as i8 as f32) + } + #[inline(always)] fn f32(v: f32) -> f32 { f32::from(v > 0.) - f32::from(v < 0.) } @@ -960,4 +1131,12 @@ impl UnaryOpT for Sign { fn i64(v: i64) -> i64 { (v > 0) as i64 - (v < 0) as i64 } + #[inline(always)] + fn i32(v: i32) -> i32 { + (v > 0) as i32 - (v < 0) as i32 + } + #[inline(always)] + fn i16(v: i16) -> i16 { + (v > 0) as i16 - (v < 0) as i16 + } } diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index 3c24c0e546..7e1ca83835 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -423,6 +423,7 @@ impl QCudaStorage { match self.dtype { GgmlDType::F32 => deq::(&buffer, block_len, &mut out)?, GgmlDType::F16 => deq::(&buffer, block_len, &mut out)?, + GgmlDType::BF16 => deq::(&buffer, block_len, &mut out)?, GgmlDType::Q4_0 => deq::(&buffer, block_len, &mut out)?, GgmlDType::Q4_1 => deq::(&buffer, block_len, &mut out)?, GgmlDType::Q5_0 => deq::(&buffer, block_len, &mut out)?, @@ -471,6 +472,31 @@ impl QCudaStorage { Ok(()) } + pub fn quantize_onto(&mut self, src: &crate::CpuStorage) -> Result<()> { + // Run the quantization on cpu. + let src_len = src.as_slice::()?.len(); + let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?; + + if let QStorage::Cpu(storage) = &mut qcpu_storage { + storage.from_float(src.as_slice::()?)?; + } else { + unreachable!() + } + + let data = qcpu_storage.data()?; + let padded_len = + data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size(); + let mut inner = unsafe { self.device.alloc::(padded_len).w()? }; + self.device + .htod_sync_copy_into(data.as_ref(), &mut inner.slice_mut(..data.len())) + .w()?; + self.data = PaddedCudaSlice { + inner, + len: data.len(), + }; + Ok(()) + } + pub fn storage_size_in_bytes(&self) -> usize { self.data.len } @@ -497,6 +523,12 @@ impl QCudaStorage { self.dequantize_matmul(self_shape, storage, layout) } } + + pub fn data(&self) -> Result> { + self.device + .dtoh_sync_copy(&self.data.inner.slice(..self.data.len)) + .w() + } } impl QCudaStorage { @@ -560,7 +592,7 @@ impl QCudaStorage { let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) { let data_f32 = self.dequantize(n * k)?; let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?; - storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)? + storage.matmul_with_alpha(&data_f32, None, (b, m, n, k), layout, &rhs_l)? } else { let storage = storage.as_cuda_slice::()?; let storage = match layout.contiguous_offsets() { diff --git a/candle-core/src/quantized/dummy_cuda.rs b/candle-core/src/quantized/dummy_cuda.rs index ca7b812084..23a9e05bc2 100644 --- a/candle-core/src/quantized/dummy_cuda.rs +++ b/candle-core/src/quantized/dummy_cuda.rs @@ -32,6 +32,10 @@ impl QCudaStorage { Err(Error::NotCompiledWithCudaSupport) } + pub fn quantize_onto(&mut self, _src: &crate::CpuStorage) -> Result<()> { + Err(Error::NotCompiledWithCudaSupport) + } + pub fn storage_size_in_bytes(&self) -> usize { 0 } @@ -44,6 +48,10 @@ impl QCudaStorage { ) -> Result<(CudaStorage, crate::Shape)> { Err(Error::NotCompiledWithCudaSupport) } + + pub fn data(&self) -> Result> { + Err(Error::NotCompiledWithCudaSupport) + } } pub fn load_quantized( diff --git a/candle-core/src/quantized/dummy_metal.rs b/candle-core/src/quantized/dummy_metal.rs index 520d0ed49a..c5c8db9282 100644 --- a/candle-core/src/quantized/dummy_metal.rs +++ b/candle-core/src/quantized/dummy_metal.rs @@ -28,6 +28,10 @@ impl QMetalStorage { Err(Error::NotCompiledWithMetalSupport) } + pub fn quantize_onto(&mut self, _src: &crate::CpuStorage) -> Result<()> { + Err(Error::NotCompiledWithCudaSupport) + } + pub fn storage_size_in_bytes(&self) -> usize { 0 } @@ -40,6 +44,10 @@ impl QMetalStorage { ) -> Result<(MetalStorage, crate::Shape)> { Err(Error::NotCompiledWithMetalSupport) } + + pub fn data(&self) -> Result> { + Err(Error::NotCompiledWithMetalSupport) + } } pub fn load_quantized( diff --git a/candle-core/src/quantized/ggml_file.rs b/candle-core/src/quantized/ggml_file.rs index 99200bbd06..ea5ec02578 100644 --- a/candle-core/src/quantized/ggml_file.rs +++ b/candle-core/src/quantized/ggml_file.rs @@ -153,6 +153,7 @@ pub fn qtensor_from_ggml( match ggml_dtype { GgmlDType::F32 => from_raw_data::(raw_data, size_in_bytes, dims, device), GgmlDType::F16 => from_raw_data::(raw_data, size_in_bytes, dims, device), + GgmlDType::BF16 => from_raw_data::(raw_data, size_in_bytes, dims, device), GgmlDType::Q4_0 => { from_raw_data::(raw_data, size_in_bytes, dims, device) } diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 6210ac1e9f..2e92921954 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -5,7 +5,7 @@ use super::utils::{ use super::GgmlDType; use crate::Result; use byteorder::{ByteOrder, LittleEndian}; -use half::f16; +use half::{bf16, f16}; use rayon::prelude::*; // Default to QK_K 256 rather than 64. @@ -1963,3 +1963,47 @@ impl GgmlType for f16 { Ok(()) } } + +impl GgmlType for bf16 { + const DTYPE: GgmlDType = GgmlDType::BF16; + const BLCK_SIZE: usize = 1; + type VecDotType = bf16; + + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + if xs.len() < n { + crate::bail!("size mismatch {} < {n}", xs.len()) + } + if ys.len() < n { + crate::bail!("size mismatch {} < {n}", ys.len()) + } + let mut res = 0f32; + unsafe { crate::cpu::vec_dot_bf16(xs.as_ptr(), ys.as_ptr(), &mut res, n) }; + Ok(res) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + if xs.len() != ys.len() { + crate::bail!("size mismatch {} {}", xs.len(), ys.len()); + } + // TODO: vectorize + for (x, y) in xs.iter().zip(ys.iter_mut()) { + *y = bf16::from_f32(*x) + } + Ok(()) + } + + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + if xs.len() != ys.len() { + crate::bail!("size mismatch {} {}", xs.len(), ys.len()); + } + // TODO: vectorize + for (x, y) in xs.iter().zip(ys.iter_mut()) { + *y = x.to_f32() + } + Ok(()) + } +} diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index f7f5b68ac2..038ba7b531 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -55,6 +55,10 @@ impl QMetalStorage { let vec: Vec = read_to_vec(&buffer, block_len); half::f16::to_float(&vec, &mut out)?; } + GgmlDType::BF16 => { + let vec: Vec = read_to_vec(&buffer, block_len); + half::bf16::to_float(&vec, &mut out)?; + } GgmlDType::Q4_0 => { let vec: Vec = read_to_vec(&buffer, block_len); crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?; @@ -126,6 +130,22 @@ impl QMetalStorage { Ok(()) } + pub fn quantize_onto(&mut self, src: &crate::CpuStorage) -> Result<()> { + // Quantization only happens on CPU for now. + let elem_count = src.as_slice::()?.len(); + let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?; + + if let QStorage::Cpu(storage) = &mut qcpu_storage { + storage.from_float(src.as_slice::()?)?; + } else { + unreachable!() + } + + let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?; + self.buffer = buffer; + Ok(()) + } + pub fn storage_size_in_bytes(&self) -> usize { self.buffer.length() as usize } @@ -186,6 +206,22 @@ impl QMetalStorage { let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32); Ok((dst_storage, dst_shape)) } + + pub fn data(&self) -> Result> { + use metal::NSUInteger; + + let buffer = self.device.new_buffer_managed(self.buffer.length())?; + { + let command_buffer = self.device.command_buffer()?; + command_buffer.set_label("to_cpu"); + let blit = command_buffer.new_blit_command_encoder(); + blit.set_label("blit_to_cpu"); + blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); + blit.end_encoding(); + } + self.device.wait_until_completed()?; + Ok(read_to_vec::(&buffer, self.buffer.length() as usize)) + } } pub fn load_quantized( @@ -225,6 +261,7 @@ impl From for candle_metal_kernels::GgmlDType { GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K, GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16, GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32, + GgmlDType::BF16 => candle_metal_kernels::GgmlDType::F16, } } } diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index d852d50410..7f8dbfcf2a 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -27,7 +27,7 @@ pub mod neon; #[cfg(target_feature = "simd128")] pub mod simd128; pub mod utils; -use half::f16; +use half::{bf16, f16}; pub use k_quants::GgmlType; @@ -101,7 +101,19 @@ impl QStorage { } (QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?, (QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?, - _ => crate::bail!("Invalid dequantize storage locations do not match"), + _ => crate::bail!("Invalid quantize storage locations do not match"), + } + Ok(()) + } + + fn quantize_onto(&mut self, src: &Storage) -> Result<()> { + match (self, src) { + (QStorage::Cpu(storage), Storage::Cpu(src)) => { + storage.from_float(src.as_slice::()?)?; + } + (QStorage::Metal(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?, + (QStorage::Cuda(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?, + _ => crate::bail!("Invalid quantize source storage locations: not on cpu"), } Ok(()) } @@ -122,9 +134,8 @@ impl QStorage { let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) }; Ok(Cow::from(data)) } - QStorage::Metal(_) | QStorage::Cuda(_) => { - crate::bail!("not implemented"); - } + QStorage::Cuda(storage) => Ok(Cow::from(storage.data()?)), + QStorage::Metal(storage) => Ok(Cow::from(storage.data()?)), } } } @@ -133,6 +144,7 @@ impl QStorage { pub enum GgmlDType { F32, F16, + BF16, Q4_0, Q4_1, Q5_0, @@ -164,6 +176,8 @@ impl GgmlDType { 13 => Self::Q5K, 14 => Self::Q6K, 15 => Self::Q8K, + // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389 + 30 => Self::BF16, _ => crate::bail!("unknown dtype for tensor {u}"), }; Ok(dtype) @@ -185,6 +199,8 @@ impl GgmlDType { Self::Q5K => 13, Self::Q6K => 14, Self::Q8K => 15, + // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389 + Self::BF16 => 30, } } @@ -205,6 +221,7 @@ impl GgmlDType { Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]), Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]), Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]), + Self::BF16 => Box::new(vec![bf16::zeros(); elem_count]), } } /// The type size for blocks in bytes. @@ -212,7 +229,7 @@ impl GgmlDType { use k_quants::*; match self { Self::F32 => 4, - Self::F16 => 2, + Self::F16 | Self::BF16 => 2, Self::Q4_0 => std::mem::size_of::(), Self::Q4_1 => std::mem::size_of::(), Self::Q5_0 => std::mem::size_of::(), @@ -233,7 +250,7 @@ impl GgmlDType { pub fn block_size(&self) -> usize { match self { Self::F32 => 1, - Self::F16 => 1, + Self::F16 | Self::BF16 => 1, Self::Q4_0 => k_quants::QK4_0, Self::Q4_1 => k_quants::QK4_1, Self::Q5_0 => k_quants::QK5_0, @@ -341,6 +358,34 @@ impl QTensor { }) } + /// Quantize `src` (currently on the CPU) to a QTensor on `dev` + pub fn quantize_onto(src: &Tensor, dtype: GgmlDType, dev: &Device) -> Result { + if !src.device().is_cpu() { + crate::bail!( + "`quantize_onto` expects a `src` to be on the cpu, got {:?}.", + src.device() + ) + } + let shape = src.shape(); + let block_size = dtype.block_size(); + check_shape(shape, block_size)?; + let src = src.to_dtype(crate::DType::F32)?.flatten_all()?; + let elem_count = shape.elem_count(); + if elem_count % block_size != 0 { + crate::bail!( + "tensor size ({shape:?}) is not divisible by block size {}", + block_size + ) + } + // storage is on the `dev`, src is on `cpu` + let mut storage = dev.qzeros(elem_count, dtype)?; + storage.quantize_onto(&src.storage())?; + Ok(Self { + storage, + shape: shape.clone(), + }) + } + pub fn dtype(&self) -> GgmlDType { self.storage.dtype() } @@ -421,7 +466,7 @@ thread_local! { impl QMatMul { pub fn from_arc(qtensor: std::sync::Arc) -> Result { let dequantize = match qtensor.dtype() { - GgmlDType::F32 | GgmlDType::F16 => true, + GgmlDType::F32 | GgmlDType::F16 | GgmlDType::BF16 => true, _ => DEQUANTIZE_ALL.with(|b| *b), }; let t = if dequantize { diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index 5ea1f192b3..52df166313 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -1,4 +1,5 @@ use crate::{DType, Device, Error, Result, Tensor, WithDType}; +use float8::F8E4M3; use safetensors::tensor as st; use safetensors::tensor::SafeTensors; use std::borrow::Cow; @@ -11,10 +12,13 @@ impl From for st::Dtype { DType::U8 => st::Dtype::U8, DType::U32 => st::Dtype::U32, DType::I64 => st::Dtype::I64, + DType::I16 => st::Dtype::I16, + DType::I32 => st::Dtype::I32, DType::BF16 => st::Dtype::BF16, DType::F16 => st::Dtype::F16, DType::F32 => st::Dtype::F32, DType::F64 => st::Dtype::F64, + DType::F8E4M3 => st::Dtype::F8_E4M3, } } } @@ -30,6 +34,7 @@ impl TryFrom for DType { st::Dtype::F16 => Ok(DType::F16), st::Dtype::F32 => Ok(DType::F32), st::Dtype::F64 => Ok(DType::F64), + st::Dtype::F8_E4M3 => Ok(DType::F8E4M3), dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)), } } @@ -187,11 +192,14 @@ impl Tensor { match dtype { DType::U8 => convert_slice::(data, shape, device), DType::U32 => convert_slice::(data, shape, device), + DType::I16 => convert_slice::(data, shape, device), + DType::I32 => convert_slice::(data, shape, device), DType::I64 => convert_slice::(data, shape, device), DType::BF16 => convert_slice::(data, shape, device), DType::F16 => convert_slice::(data, shape, device), DType::F32 => convert_slice::(data, shape, device), DType::F64 => convert_slice::(data, shape, device), + DType::F8E4M3 => convert_slice::(data, shape, device), } } } @@ -204,10 +212,8 @@ fn convert(view: &st::TensorView<'_>, device: &Device) -> Result { convert_with_cast_::(view, device, conv) } st::Dtype::U32 => convert_::(view, device), - st::Dtype::I32 => { - let conv = |x| Ok(i64::from(x)); - convert_with_cast_::(view, device, conv) - } + st::Dtype::I16 => convert_::(view, device), + st::Dtype::I32 => convert_::(view, device), st::Dtype::I64 => convert_::(view, device), st::Dtype::BF16 => convert_::(view, device), st::Dtype::F16 => convert_::(view, device), @@ -223,11 +229,14 @@ fn convert_back(tensor: &Tensor) -> Result> { match tensor.dtype() { DType::U8 => Ok(convert_back_::(tensor.to_vec1()?)), DType::U32 => Ok(convert_back_::(tensor.to_vec1()?)), + DType::I16 => Ok(convert_back_::(tensor.to_vec1()?)), + DType::I32 => Ok(convert_back_::(tensor.to_vec1()?)), DType::I64 => Ok(convert_back_::(tensor.to_vec1()?)), DType::F16 => Ok(convert_back_::(tensor.to_vec1()?)), DType::BF16 => Ok(convert_back_::(tensor.to_vec1()?)), DType::F32 => Ok(convert_back_::(tensor.to_vec1()?)), DType::F64 => Ok(convert_back_::(tensor.to_vec1()?)), + DType::F8E4M3 => Ok(convert_back_::(tensor.to_vec1()?)), } } diff --git a/candle-core/src/sort.rs b/candle-core/src/sort.rs index 614a37fe65..c7236e7f5f 100644 --- a/candle-core/src/sort.rs +++ b/candle-core/src/sort.rs @@ -65,11 +65,14 @@ impl crate::CustomOp1 for ArgSort { let sort_indexes = match storage { crate::CpuStorage::U8(vs) => self.asort(vs, layout), crate::CpuStorage::U32(vs) => self.asort(vs, layout), + crate::CpuStorage::I16(vs) => self.asort(vs, layout), + crate::CpuStorage::I32(vs) => self.asort(vs, layout), crate::CpuStorage::I64(vs) => self.asort(vs, layout), crate::CpuStorage::BF16(vs) => self.asort(vs, layout), crate::CpuStorage::F16(vs) => self.asort(vs, layout), crate::CpuStorage::F32(vs) => self.asort(vs, layout), crate::CpuStorage::F64(vs) => self.asort(vs, layout), + crate::CpuStorage::F8E4M3(vs) => self.asort(vs, layout), }; let sort_indexes = crate::CpuStorage::U32(sort_indexes); Ok((sort_indexes, layout.shape().into())) @@ -149,6 +152,8 @@ impl crate::CustomOp1 for ArgSort { DType::U8 => "asort_asc_u8", DType::U32 => "asort_asc_u32", DType::I64 => "asort_asc_i64", + DType::I32 => "asort_asc_i32", + DType::I16 => "asort_asc_i16", } } else { match storage.dtype() { @@ -159,6 +164,8 @@ impl crate::CustomOp1 for ArgSort { DType::U8 => "asort_desc_u8", DType::U32 => "asort_desc_u32", DType::I64 => "asort_desc_i64", + DType::I32 => "asort_desc_i32", + DType::I16 => "asort_desc_i16", } } }; diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 8a0637e304..8ff1cbf82a 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -705,26 +705,62 @@ impl Storage { } } - pub(crate) fn matmul( + #[allow(clippy::too_many_arguments)] + pub(crate) fn matmul_with_alpha_beta( + &self, + rhs: &Self, + c: &mut Self, + s: Option, + bmnk: (usize, usize, usize, usize), + lhs_layout: &Layout, + rhs_layout: &Layout, + c_layout: &Layout, + ) -> Result<()> { + self.same_device(rhs, "matmul_with_alpha_beta")?; + self.same_dtype(rhs, "matmul_with_alpha_beta")?; + self.same_device(c, "matmul_with_alpha_beta")?; + self.same_dtype(c, "matmul_with_alpha_beta")?; + match (self, rhs, c) { + (Self::Cpu(lhs), Self::Cpu(rhs), Self::Cpu(c)) => { + lhs.matmul_with_alpha_beta(rhs, c, s, bmnk, lhs_layout, rhs_layout, c_layout) + } + (Self::Cuda(lhs), Self::Cuda(rhs), Self::Cuda(c)) => { + lhs.matmul_with_alpha_beta(rhs, c, s, bmnk, lhs_layout, rhs_layout, c_layout) + } + (Self::Metal(lhs), Self::Metal(rhs), Self::Metal(c)) => { + lhs.matmul_with_alpha_beta(rhs, c, s, bmnk, lhs_layout, rhs_layout, c_layout) + } + (lhs, rhs, c) => Err(Error::DeviceMismatchBinaryOp3 { + lhs: lhs.device().location(), + rhs: rhs.device().location(), + c: c.device().location(), + op: "matmul_with_alpha_beta", + } + .bt()), + } + } + + pub(crate) fn matmul_with_alpha( &self, rhs: &Self, + s: Option, bmnk: (usize, usize, usize, usize), lhs_layout: &Layout, rhs_layout: &Layout, ) -> Result { - self.same_device(rhs, "matmul")?; - self.same_dtype(rhs, "matmul")?; + self.same_device(rhs, "matmul_with_alpha")?; + self.same_dtype(rhs, "matmul_with_alpha")?; match (self, rhs) { (Self::Cpu(lhs), Self::Cpu(rhs)) => { - let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?; + let storage = lhs.matmul_with_alpha(rhs, s, bmnk, lhs_layout, rhs_layout)?; Ok(Self::Cpu(storage)) } (Self::Cuda(lhs), Self::Cuda(rhs)) => { - let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?; + let storage = lhs.matmul_with_alpha(rhs, s, bmnk, lhs_layout, rhs_layout)?; Ok(Self::Cuda(storage)) } (Self::Metal(lhs), Self::Metal(rhs)) => { - let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?; + let storage = lhs.matmul_with_alpha(rhs, s, bmnk, lhs_layout, rhs_layout)?; Ok(Self::Metal(storage)) } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 7dd24abf9b..37f23dca27 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -176,6 +176,22 @@ pub(crate) fn from_storage>( Tensor(Arc::new(tensor_)) } +/// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides. This has a BackpropOp:none(). +pub fn from_storage_no_op>(storage: Storage, shape: S, is_variable: bool) -> Tensor { + let dtype = storage.dtype(); + let device = storage.device(); + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: Arc::new(RwLock::new(storage)), + layout: Layout::contiguous(shape), + op: BackpropOp::none(), + is_variable, + dtype, + device, + }; + Tensor(Arc::new(tensor_)) +} + impl Tensor { pub(crate) fn ones_impl>( shape: S, @@ -256,6 +272,51 @@ impl Tensor { Tensor::zeros(self.shape(), self.dtype(), self.device()) } + // Do not expose outside of the crate, the `is_variable=true` case should only be accessed from + // the variable module. + pub(crate) unsafe fn empty_impl>( + shape: S, + dtype: DType, + device: &Device, + is_variable: bool, + ) -> Result { + let none = BackpropOp::none(); + let shape = shape.into(); + let storage = device.alloc_uninit(&shape, dtype)?; + Ok(from_storage(storage, shape, none, is_variable)) + } + + /// Creates a new tensor filled with uninitialized memory. + /// + /// # Safety + /// This returns uninitialized memory. + /// + /// ```rust + /// use candle_core::{Tensor, DType, Device}; + /// let a = unsafe { Tensor::empty((2, 3), DType::F32, &Device::Cpu)? }; + /// // a == b + /// # Ok::<(), candle_core::Error>(()) + /// ``` + pub unsafe fn empty>(shape: S, dtype: DType, device: &Device) -> Result { + Self::empty_impl(shape, dtype, device, false) + } + + /// Creates a new tensor filled with uninitialized memory of the same shape, dtype, and device as the other + /// tensor. + /// + /// # Safety + /// This returns uninitialized memory. + /// + /// ```rust + /// use candle_core::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = unsafe { a.empty_like()? }; + /// # Ok::<(), candle_core::Error>(()) + /// ``` + pub unsafe fn empty_like(&self) -> Result { + Tensor::empty(self.shape(), self.dtype(), self.device()) + } + pub(crate) fn rand_impl, T: crate::FloatDType>( lo: T, up: T, @@ -1269,8 +1330,9 @@ impl Tensor { .bt())? } - let storage = self.storage().matmul( + let storage = self.storage().matmul_with_alpha( &rhs.storage(), + None, (batching, m, n, k), self.layout(), rhs.layout(), @@ -1301,6 +1363,172 @@ impl Tensor { } } + /// Returns the matrix-multiplication of the input tensor with the other provided tensor. The result is scaled + /// and then added to the output tensor, the bias tensor `c`. + /// + /// If `scale` is None, then the output is as follows: + /// `c := c + axb` + /// + /// Else: + /// `c := c + scale * (axb)` + /// + /// This function is faster than a matmul followed by some scaling multiply because the scaling is fused in the GEMM kernel. + /// This is incompatible with gradient tracking. No gradients will be tracked on this operation. However, this also means + /// there is an allocation saved as the output is in `c`. + /// + /// # Arguments + /// + /// * `self` - A tensor with dimensions `b1, b2, ..., bi, m, k`. + /// * `rhs` - A tensor with dimensions `b1, b2, ..., bi, k, n`. + /// * `c` - A tensor with dimensions `b1, b2, ..., bi, m, n`, into which the result is accumulated and added to. + /// * `scale` - Factor to multiply `self` x `rhs` by + pub fn matmul_with_alpha_beta( + &self, + rhs: &Self, + c: &mut Self, + scale: Option, + ) -> Result<()> { + let a_dims = self.shape().dims(); + let b_dims = rhs.shape().dims(); + + let dim = a_dims.len(); + + if dim < 2 || b_dims.len() != dim { + Err(Error::ShapeMismatchBinaryOp { + lhs: self.shape().clone(), + rhs: rhs.shape().clone(), + op: "matmul", + } + .bt())? + } + + let m = a_dims[dim - 2]; + let k = a_dims[dim - 1]; + let k2 = b_dims[dim - 2]; + let n = b_dims[dim - 1]; + + let exp_c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]); + if exp_c_shape.elem_count() == 0 || k == 0 { + bail!("Expected `c` to have more than one element, got 0."); + } + if exp_c_shape != c.shape().clone() { + Err(Error::UnexpectedShape { + msg: "`c` has an unexpected shape.".to_string(), + expected: exp_c_shape, + got: c.shape().clone(), + })? + } + + let batching: usize = a_dims[..dim - 2].iter().product(); + let batching_b: usize = b_dims[..dim - 2].iter().product(); + if k != k2 || batching != batching_b { + Err(Error::ShapeMismatchBinaryOp { + lhs: self.shape().clone(), + rhs: rhs.shape().clone(), + op: "matmul_with_alpha_beta", + } + .bt())? + } + + self.storage().matmul_with_alpha_beta( + &rhs.storage(), + &mut c.storage_mut(), + scale, + (batching, m, n, k), + self.layout(), + rhs.layout(), + c.layout(), + ) + } + + /// Returns the matrix-multiplication of the input tensor with the other provided tensor. The result is scaled. + /// + /// This function is faster than a matmul followed by some scaling multiply because the scaling is fused in the GEMM kernel. + /// + /// The output is as follows: + /// `scale * (axb)` + /// + /// + /// This is incompatible with gradient tracking. No gradients will be tracked on this operation. + /// + /// # Arguments + /// + /// * `self` - A tensor with dimensions `b1, b2, ..., bi, m, k`. + /// * `rhs` - A tensor with dimensions `b1, b2, ..., bi, k, n`. + /// * `scale` - Factor to multiply `self` x `rhs` by. + pub fn matmul_with_alpha(&self, rhs: &Self, scale: Option) -> Result { + let a_dims = self.shape().dims(); + let b_dims = rhs.shape().dims(); + + let dim = a_dims.len(); + + if dim < 2 || b_dims.len() != dim { + Err(Error::ShapeMismatchBinaryOp { + lhs: self.shape().clone(), + rhs: rhs.shape().clone(), + op: "matmul", + } + .bt())? + } + + let m = a_dims[dim - 2]; + let k = a_dims[dim - 1]; + let k2 = b_dims[dim - 2]; + let n = b_dims[dim - 1]; + + let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]); + if c_shape.elem_count() == 0 || k == 0 { + return Tensor::zeros(c_shape, self.dtype(), self.device()); + } + let batching: usize = a_dims[..dim - 2].iter().product(); + let batching_b: usize = b_dims[..dim - 2].iter().product(); + if k != k2 || batching != batching_b { + Err(Error::ShapeMismatchBinaryOp { + lhs: self.shape().clone(), + rhs: rhs.shape().clone(), + op: "matmul_with_alpha", + } + .bt())? + } + + let storage = self.storage().matmul_with_alpha( + &rhs.storage(), + scale, + (batching, m, n, k), + self.layout(), + rhs.layout(), + )?; + let op = BackpropOp::new2(self, rhs, Op::Matmul); + Ok(from_storage(storage, c_shape, op, false)) + } + + /// Matrix-multiplication with broadcasting support and fused scaling. + /// + /// Compared to `matmul` the two matrixes are allowed to have different dimensions as long as + /// they are compatible for broadcast. E.g. if `self` has shape `(j, 1, n, k)` and `rhs` has + /// shape `(l, k, m)`, the output will have shape `(j, l, n, m)`. + pub fn broadcast_matmul_with_alpha(&self, rhs: &Self, scale: Option) -> Result { + let lhs = self; + let (l_shape, r_shape) = lhs.shape().broadcast_shape_matmul(rhs.shape())?; + let l_broadcast = l_shape != *lhs.shape(); + let r_broadcast = r_shape != *rhs.shape(); + // TODO: Avoid concretising the broadcasted matrixes via contiguous. + match (l_broadcast, r_broadcast) { + (true, true) => lhs + .broadcast_as(&l_shape)? + .contiguous()? + .matmul_with_alpha(&rhs.broadcast_as(&r_shape)?.contiguous()?, scale), + (false, true) => { + lhs.matmul_with_alpha(&rhs.broadcast_as(&r_shape)?.contiguous()?, scale) + } + (true, false) => lhs + .broadcast_as(&l_shape)? + .contiguous()? + .matmul_with_alpha(rhs, scale), + (false, false) => lhs.matmul_with_alpha(rhs, scale), + } + } + /// Returns a tensor with the same shape as the input tensor, the values are taken from /// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the /// input tensor is equal to zero. @@ -1349,244 +1577,6 @@ impl Tensor { self.index_select(ids, 0) } - pub fn scatter_add(&self, indexes: &Self, source: &Self, dim: D) -> Result { - let dim = dim.to_index(self.shape(), "scatter-add")?; - let source_dims = source.dims(); - let self_dims = self.dims(); - let mismatch = if source_dims.len() != self_dims.len() { - true - } else { - let mut mismatch = false; - for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() { - if i != dim && d1 != d2 { - mismatch = true; - break; - } - } - mismatch - }; - if mismatch { - Err(Error::ShapeMismatchBinaryOp { - op: "scatter-add (self, src)", - lhs: self.shape().clone(), - rhs: source.shape().clone(), - } - .bt())? - } - if indexes.dims() != source.dims() { - Err(Error::ShapeMismatchBinaryOp { - op: "scatter-add (indexes, src)", - lhs: indexes.shape().clone(), - rhs: source.shape().clone(), - } - .bt())? - } - let storage = self.storage().scatter_add( - self.layout(), - &indexes.storage(), - indexes.layout(), - &source.storage(), - source.layout(), - dim, - )?; - let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| { - Op::ScatterAdd(t1, t2, t3, dim) - }); - Ok(from_storage(storage, self.shape(), op, false)) - } - - /// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension. - pub fn slice_scatter(&self, src: &Self, dim: D, start: usize) -> Result { - let dim = dim.to_index(self.shape(), "slice-scatter")?; - if dim == 0 { - self.slice_scatter0(src, start) - } else { - // TODO: Maybe we want to add a more efficient implementation at some point. - self.transpose(0, dim)? - .slice_scatter0(&src.transpose(0, dim)?, start)? - .transpose(0, dim) - } - } - - /// Embeds the values of the `src` tensor into the `self` tensor on the first dimension. - pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result { - if self.dtype() != src.dtype() { - Err(Error::DTypeMismatchBinaryOp { - lhs: self.dtype(), - rhs: src.dtype(), - op: "slice-scatter", - } - .bt())? - } - if self.device().location() != src.device.location() { - Err(Error::DeviceMismatchBinaryOp { - lhs: self.device().location(), - rhs: src.device().location(), - op: "slice-scatter", - } - .bt())? - } - if self.rank() != src.rank() { - Err(Error::UnexpectedNumberOfDims { - expected: self.rank(), - got: src.rank(), - shape: src.shape().clone(), - } - .bt())? - } - let shape_ok = - self.dims() - .iter() - .zip(src.dims().iter()) - .enumerate() - .all(|(dim_idx, (&d1, &d2))| { - if 0 == dim_idx { - d2 + start <= d1 - } else { - d1 == d2 - } - }); - if !shape_ok { - Err(Error::ShapeMismatchBinaryOp { - op: "slice-scatter (self, src)", - lhs: self.shape().clone(), - rhs: src.shape().clone(), - } - .bt())? - } - let mut storage = unsafe { self.device().alloc_uninit(self.shape(), self.dtype())? }; - self.storage() - .copy_strided_src(&mut storage, 0, self.layout())?; - let offset = start * src.dims()[1..].iter().product::(); - src.storage() - .copy_strided_src(&mut storage, offset, src.layout())?; - let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start)); - Ok(from_storage(storage, self.shape(), op, false)) - } - - /// Accumulate element from `source` at indexes `indexes` and add them to `self`. - pub fn index_add(&self, indexes: &Self, source: &Self, dim: D) -> Result { - let dim = dim.to_index(self.shape(), "index-add")?; - let source_dims = source.dims(); - let self_dims = self.dims(); - let mismatch = if source_dims.len() != self_dims.len() { - true - } else { - let mut mismatch = false; - for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() { - if i != dim && d1 != d2 { - mismatch = true; - break; - } - } - mismatch - }; - if mismatch { - Err(Error::ShapeMismatchBinaryOp { - op: "index-add (self, source)", - lhs: self.shape().clone(), - rhs: source.shape().clone(), - } - .bt())? - } - // The number of element in indexes must match the dimension on which the add is - // performed on the source tensor (and the index values from `indexes` are taken from - // the target tensor self) - let indexes_len = indexes.dims1()?; - if source_dims[dim] != indexes_len { - Err(Error::ShapeMismatchBinaryOp { - op: "index-add (ids, source))", - lhs: indexes.shape().clone(), - rhs: source.shape().clone(), - } - .bt())? - } - let storage = self.storage().index_add( - self.layout(), - &indexes.storage(), - indexes.layout(), - &source.storage(), - source.layout(), - dim, - )?; - let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| { - Op::IndexAdd(t1, t2, t3, dim) - }); - Ok(from_storage(storage, self.shape(), op, false)) - } - - /// Gather values across the target dimension. - /// - /// # Arguments - /// - /// * `self` - The input tensor. - /// * `indexes` - The indices of elements to gather, this should have the same shape as `self` - /// but can have a different number of elements on the target dimension. - /// * `dim` - the target dimension. - /// - /// The resulting tensor has the same shape as `indexes` and use values from `self` indexed on - /// dimension `dim` by the values in `indexes`. - pub fn gather(&self, indexes: &Self, dim: D) -> Result { - let dim = dim.to_index(self.shape(), "gather")?; - let self_dims = self.dims(); - let indexes_dims = indexes.dims(); - let mismatch = if indexes_dims.len() != self_dims.len() { - true - } else { - let mut mismatch = false; - for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() { - if i != dim && d1 != d2 { - mismatch = true; - break; - } - } - mismatch - }; - if mismatch { - Err(Error::ShapeMismatchBinaryOp { - op: "gather", - lhs: self.shape().clone(), - rhs: indexes.shape().clone(), - } - .bt())? - } - let storage = - self.storage() - .gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?; - let op = BackpropOp::new2(self, indexes, |t1, t2| Op::Gather(t1, t2, dim)); - Ok(from_storage(storage, indexes.shape(), op, false)) - } - - /// Select values for the input tensor at the target indexes across the specified dimension. - /// - /// The `indexes` is argument is an int tensor with a single dimension. - /// The output has the same number of dimension as the `self` input. The target dimension of - /// the output has length the length of `indexes` and the values are taken from `self` using - /// the index from `indexes`. Other dimensions have the same number of elements as the input - /// tensor. - pub fn index_select(&self, indexes: &Self, dim: D) -> Result { - let dim = dim.to_index(self.shape(), "index-select")?; - let indexes_len = match indexes.dims() { - [l] => *l, - _ => Err(Error::ShapeMismatchBinaryOp { - lhs: self.shape().clone(), - rhs: indexes.shape().clone(), - op: "index-select", - } - .bt())?, - }; - let storage = self.storage().index_select( - &indexes.storage(), - self.layout(), - indexes.layout(), - dim, - )?; - let mut dims = self.dims().to_vec(); - dims[dim] = indexes_len; - let op = BackpropOp::new2(self, indexes, |t1, t2| Op::IndexSelect(t1, t2, dim)); - Ok(from_storage(storage, dims, op, false)) - } - /// Returns an iterator over position of the elements in the storage when ranging over the /// index tuples in lexicographic order. pub fn strided_index(&self) -> crate::StridedIndex { @@ -2461,62 +2451,6 @@ impl Tensor { } } - /// Returns a copy of `self` where the values within `ranges` have been replaced with the - /// content of `src`. - pub fn slice_assign>( - &self, - ranges: &[D], - src: &Tensor, - ) -> Result { - let src_dims = src.dims(); - let self_dims = self.dims(); - if self_dims.len() != src_dims.len() { - bail!( - "slice-assign requires input with the same rank {} <> {}", - self_dims.len(), - src_dims.len() - ) - } - if self_dims.len() != ranges.len() { - bail!( - "slice-assign requires input with the same rank as there are ranges {} <> {}", - self_dims.len(), - ranges.len() - ) - } - let mut src = src.clone(); - let mut mask = Self::ones(src.shape(), DType::U8, src.device())?; - for (i, range) in ranges.iter().enumerate() { - let start_included = match range.start_bound() { - std::ops::Bound::Unbounded => 0, - std::ops::Bound::Included(v) => *v, - std::ops::Bound::Excluded(v) => *v + 1, - }; - let end_excluded = match range.end_bound() { - std::ops::Bound::Unbounded => self_dims[i], - std::ops::Bound::Included(v) => *v + 1, - std::ops::Bound::Excluded(v) => *v, - }; - if end_excluded <= start_included { - bail!("slice-assign: empty range for dim {i}, {start_included} {end_excluded}") - } - if self_dims[i] < end_excluded { - bail!( - "slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}", - self_dims[i] - ) - } - if end_excluded - start_included != src_dims[i] { - bail!( - "slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i] - ) - } - src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?; - mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)? - } - mask.where_cond(/* on_true= */ &src, /* on_false= */ self) - } - /// Returns log(sum(exp(tensor), dim)). pub fn log_sum_exp(&self, sum_dims: D) -> Result { let sum_dims = sum_dims.to_indexes(self.shape(), "log-sum-exp")?; @@ -2543,6 +2477,49 @@ impl Tensor { pub fn broadcast_pow(&self, rhs: &Tensor) -> Result { rhs.broadcast_mul(&self.log()?)?.exp() } + + /// Returns a view of which contains all slices of size `size` from self tensor in the dimension + /// `dim` and stepped by `step`. + pub fn unfold(&self, dim: D, size: usize, step: usize) -> Result { + // https://github.com/pytorch/pytorch/blob/75b0720a97ac5d82e8a7a1a6ae7c5f7a87d7183d/aten/src/ATen/native/TensorShape.cpp#L3785-L3804 + let mut sizes = self.dims().to_vec(); + let mut strides = self.stride().to_vec(); + + let dim = dim.to_index(self.shape(), "unfold")?; + + let max_len = if self.dims().is_empty() { + 1 + } else { + sizes[dim] + }; + if size > max_len { + bail!( + "unsqueeze: maximum size for tensor at dimension {dim} is {max_len} but size is {size}" + ) + } + sizes.push(size); + strides.push(if self.dims().is_empty() { + 1 + } else { + strides[dim] + }); + + if !self.dims().is_empty() { + sizes[dim] = ((sizes[dim] as f32 - size as f32) / step as f32 + 1.) as usize; + strides[dim] *= step; + } + + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + layout: Layout::new(sizes.into(), strides, self.layout.start_offset()), + op: BackpropOp::new1(self, Op::Reshape), + is_variable: false, + dtype: self.dtype, + device: self.device.clone(), + }; + Ok(Tensor(Arc::new(tensor_))) + } } macro_rules! bin_trait { diff --git a/candle-core/src/tensor_indexing.rs b/candle-core/src/tensor_indexing.rs new file mode 100644 index 0000000000..140876456b --- /dev/null +++ b/candle-core/src/tensor_indexing.rs @@ -0,0 +1,379 @@ +use std::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive}; + +use crate::{ + bail, + op::{BackpropOp, Op}, + shape::Dim, + tensor::from_storage, + DType, Error, Result, Tensor, +}; + +/// Specialization of `std::ops::RangeBounds` for `usize` to allow trait objects. +pub trait RangeBound { + fn start_bound(&self) -> std::ops::Bound; + fn end_bound(&self) -> std::ops::Bound; +} + +macro_rules! range_bound { + ($name:ident) => { + impl RangeBound for $name { + fn end_bound(&self) -> std::ops::Bound { + >::end_bound(&self).cloned() + } + fn start_bound(&self) -> std::ops::Bound { + >::start_bound(&self).cloned() + } + } + }; + // Use the marker to designate no generics + ($name:ident, $marker:expr) => { + impl RangeBound for $name { + fn end_bound(&self) -> std::ops::Bound { + >::end_bound(&self).cloned() + } + fn start_bound(&self) -> std::ops::Bound { + >::start_bound(&self).cloned() + } + } + }; + // Use the marker to designate no generics + ($name:ty) => { + impl RangeBound for $name { + fn end_bound(&self) -> std::ops::Bound { + >::end_bound(&self).cloned() + } + fn start_bound(&self) -> std::ops::Bound { + >::start_bound(&self).cloned() + } + } + }; +} + +range_bound!(Range); +range_bound!(RangeFrom); +range_bound!(RangeFull, ()); +range_bound!(RangeInclusive); +range_bound!(RangeTo); +range_bound!(RangeToInclusive); +range_bound!((std::ops::Bound, std::ops::Bound)); + +impl RangeBound for usize { + fn end_bound(&self) -> std::ops::Bound { + std::ops::Bound::Excluded(self + 1) + } + fn start_bound(&self) -> std::ops::Bound { + std::ops::Bound::Included(*self) + } +} + +impl Tensor { + /// Returns a copy of `self` where the values within `ranges` have been replaced with the + /// content of `src`. This is analogous to slice asignment in `torch`. + /// + /// # Example + /// ```rust + /// use candle_core::{Device, Tensor}; + /// + /// let dev = Device::Cpu; + /// let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?; + /// let src = Tensor::arange(100u32, (2 * 3) + 100, &dev)?.reshape((3, 2))?; + /// let out = tensor.slice_assign(&[&(..3), &(3..5)], &src)?; + /// assert_eq!( + /// out.to_vec2::()?, + /// &[ + /// [0, 1, 2, 100, 101], + /// [5, 6, 7, 102, 103], + /// [10, 11, 12, 104, 105], + /// [15, 16, 17, 18, 19] + /// ] + /// ); + /// # Ok::<(), candle_core::Error>(()) + /// ``` + pub fn slice_assign(&self, ranges: &[&dyn RangeBound], src: &Tensor) -> Result { + let src_dims = src.dims(); + let self_dims = self.dims(); + if self_dims.len() != src_dims.len() { + bail!( + "slice-assign requires input with the same rank {} <> {}", + self_dims.len(), + src_dims.len() + ) + } + if self_dims.len() != ranges.len() { + bail!( + "slice-assign requires input with the same rank as there are ranges {} <> {}", + self_dims.len(), + ranges.len() + ) + } + let mut src = src.clone(); + let mut mask = Self::ones(src.shape(), DType::U8, src.device())?; + for (i, range) in ranges.iter().enumerate() { + let start_included = match range.start_bound() { + std::ops::Bound::Unbounded => 0, + std::ops::Bound::Included(v) => v, + std::ops::Bound::Excluded(v) => v + 1, + }; + let end_excluded = match range.end_bound() { + std::ops::Bound::Unbounded => self_dims[i], + std::ops::Bound::Included(v) => v + 1, + std::ops::Bound::Excluded(v) => v, + }; + if end_excluded <= start_included { + bail!("slice-assign: empty range for dim {i}, {start_included} {end_excluded}") + } + if self_dims[i] < end_excluded { + bail!( + "slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}", + self_dims[i] + ) + } + if end_excluded - start_included != src_dims[i] { + bail!( + "slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i] + ) + } + src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?; + mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)? + } + mask.where_cond(/* on_true= */ &src, /* on_false= */ self) + } + + pub fn scatter_add(&self, indexes: &Self, source: &Self, dim: D) -> Result { + let dim = dim.to_index(self.shape(), "scatter-add")?; + let source_dims = source.dims(); + let self_dims = self.dims(); + let mismatch = if source_dims.len() != self_dims.len() { + true + } else { + let mut mismatch = false; + for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() { + if i != dim && d1 != d2 { + mismatch = true; + break; + } + } + mismatch + }; + if mismatch { + Err(Error::ShapeMismatchBinaryOp { + op: "scatter-add (self, src)", + lhs: self.shape().clone(), + rhs: source.shape().clone(), + } + .bt())? + } + if indexes.dims() != source.dims() { + Err(Error::ShapeMismatchBinaryOp { + op: "scatter-add (indexes, src)", + lhs: indexes.shape().clone(), + rhs: source.shape().clone(), + } + .bt())? + } + let storage = self.storage().scatter_add( + self.layout(), + &indexes.storage(), + indexes.layout(), + &source.storage(), + source.layout(), + dim, + )?; + let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| { + Op::ScatterAdd(t1, t2, t3, dim) + }); + Ok(from_storage(storage, self.shape(), op, false)) + } + + /// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension. + pub fn slice_scatter(&self, src: &Self, dim: D, start: usize) -> Result { + let dim = dim.to_index(self.shape(), "slice-scatter")?; + if dim == 0 { + self.slice_scatter0(src, start) + } else { + // TODO: Maybe we want to add a more efficient implementation at some point. + self.transpose(0, dim)? + .slice_scatter0(&src.transpose(0, dim)?, start)? + .transpose(0, dim) + } + } + + /// Embeds the values of the `src` tensor into the `self` tensor on the first dimension. + pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result { + if self.dtype() != src.dtype() { + Err(Error::DTypeMismatchBinaryOp { + lhs: self.dtype(), + rhs: src.dtype(), + op: "slice-scatter", + } + .bt())? + } + if self.device().location() != src.device().location() { + Err(Error::DeviceMismatchBinaryOp { + lhs: self.device().location(), + rhs: src.device().location(), + op: "slice-scatter", + } + .bt())? + } + if self.rank() != src.rank() { + Err(Error::UnexpectedNumberOfDims { + expected: self.rank(), + got: src.rank(), + shape: src.shape().clone(), + } + .bt())? + } + let shape_ok = + self.dims() + .iter() + .zip(src.dims().iter()) + .enumerate() + .all(|(dim_idx, (&d1, &d2))| { + if 0 == dim_idx { + d2 + start <= d1 + } else { + d1 == d2 + } + }); + if !shape_ok { + Err(Error::ShapeMismatchBinaryOp { + op: "slice-scatter (self, src)", + lhs: self.shape().clone(), + rhs: src.shape().clone(), + } + .bt())? + } + let mut storage = unsafe { self.device().alloc_uninit(self.shape(), self.dtype())? }; + self.storage() + .copy_strided_src(&mut storage, 0, self.layout())?; + let offset = start * src.dims()[1..].iter().product::(); + src.storage() + .copy_strided_src(&mut storage, offset, src.layout())?; + let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start)); + Ok(from_storage(storage, self.shape(), op, false)) + } + + /// Accumulate element from `source` at indexes `indexes` and add them to `self`. + pub fn index_add(&self, indexes: &Self, source: &Self, dim: D) -> Result { + let dim = dim.to_index(self.shape(), "index-add")?; + let source_dims = source.dims(); + let self_dims = self.dims(); + let mismatch = if source_dims.len() != self_dims.len() { + true + } else { + let mut mismatch = false; + for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() { + if i != dim && d1 != d2 { + mismatch = true; + break; + } + } + mismatch + }; + if mismatch { + Err(Error::ShapeMismatchBinaryOp { + op: "index-add (self, source)", + lhs: self.shape().clone(), + rhs: source.shape().clone(), + } + .bt())? + } + // The number of element in indexes must match the dimension on which the add is + // performed on the source tensor (and the index values from `indexes` are taken from + // the target tensor self) + let indexes_len = indexes.dims1()?; + if source_dims[dim] != indexes_len { + Err(Error::ShapeMismatchBinaryOp { + op: "index-add (ids, source))", + lhs: indexes.shape().clone(), + rhs: source.shape().clone(), + } + .bt())? + } + let storage = self.storage().index_add( + self.layout(), + &indexes.storage(), + indexes.layout(), + &source.storage(), + source.layout(), + dim, + )?; + let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| { + Op::IndexAdd(t1, t2, t3, dim) + }); + Ok(from_storage(storage, self.shape(), op, false)) + } + + /// Gather values across the target dimension. + /// + /// # Arguments + /// + /// * `self` - The input tensor. + /// * `indexes` - The indices of elements to gather, this should have the same shape as `self` + /// but can have a different number of elements on the target dimension. + /// * `dim` - the target dimension. + /// + /// The resulting tensor has the same shape as `indexes` and use values from `self` indexed on + /// dimension `dim` by the values in `indexes`. + pub fn gather(&self, indexes: &Self, dim: D) -> Result { + let dim = dim.to_index(self.shape(), "gather")?; + let self_dims = self.dims(); + let indexes_dims = indexes.dims(); + let mismatch = if indexes_dims.len() != self_dims.len() { + true + } else { + let mut mismatch = false; + for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() { + if i != dim && d1 != d2 { + mismatch = true; + break; + } + } + mismatch + }; + if mismatch { + Err(Error::ShapeMismatchBinaryOp { + op: "gather", + lhs: self.shape().clone(), + rhs: indexes.shape().clone(), + } + .bt())? + } + let storage = + self.storage() + .gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?; + let op = BackpropOp::new2(self, indexes, |t1, t2| Op::Gather(t1, t2, dim)); + Ok(from_storage(storage, indexes.shape(), op, false)) + } + + /// Select values for the input tensor at the target indexes across the specified dimension. + /// + /// The `indexes` is argument is an int tensor with a single dimension. + /// The output has the same number of dimension as the `self` input. The target dimension of + /// the output has length the length of `indexes` and the values are taken from `self` using + /// the index from `indexes`. Other dimensions have the same number of elements as the input + /// tensor. + pub fn index_select(&self, indexes: &Self, dim: D) -> Result { + let dim = dim.to_index(self.shape(), "index-select")?; + let indexes_len = match indexes.dims() { + [l] => *l, + _ => Err(Error::ShapeMismatchBinaryOp { + lhs: self.shape().clone(), + rhs: indexes.shape().clone(), + op: "index-select", + } + .bt())?, + }; + let storage = self.storage().index_select( + &indexes.storage(), + self.layout(), + indexes.layout(), + dim, + )?; + let mut dims = self.dims().to_vec(); + dims[dim] = indexes_len; + let op = BackpropOp::new2(self, indexes, |t1, t2| Op::IndexSelect(t1, t2, dim)); + Ok(from_storage(storage, dims, op, false)) + } +} diff --git a/candle-core/tests/indexing_tests.rs b/candle-core/tests/indexing_tests.rs index 047205a31f..417d54a41f 100644 --- a/candle-core/tests/indexing_tests.rs +++ b/candle-core/tests/indexing_tests.rs @@ -93,28 +93,123 @@ fn index_3d() -> Result<()> { } #[test] -fn slice_assign() -> Result<()> { +fn slice_assign_range() -> Result<()> { let dev = Device::Cpu; let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?; - let src = Tensor::arange(0u32, 2 * 3, &dev)?.reshape((3, 2))?; - let out = tensor.slice_assign(&[1..4, 3..5], &src)?; + let src = Tensor::arange(100u32, (2 * 3) + 100, &dev)?.reshape((3, 2))?; + let out = tensor.slice_assign(&[&(1..4), &(3..5)], &src)?; assert_eq!( out.to_vec2::()?, &[ [0, 1, 2, 3, 4], - [5, 6, 7, 0, 1], - [10, 11, 12, 2, 3], - [15, 16, 17, 4, 5] + [5, 6, 7, 100, 101], + [10, 11, 12, 102, 103], + [15, 16, 17, 104, 105] ] ); - let out = tensor.slice_assign(&[0..3, 0..2], &src)?; + let out = tensor.slice_assign(&[&(0..3), &(0..2)], &src)?; + assert_eq!( + out.to_vec2::()?, + &[ + [100, 101, 2, 3, 4], + [102, 103, 7, 8, 9], + [104, 105, 12, 13, 14], + [15, 16, 17, 18, 19] + ] + ); + Ok(()) +} + +#[test] +fn slice_assign_to() -> Result<()> { + let dev = Device::Cpu; + + let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?; + let src = Tensor::arange(100u32, (2 * 3) + 100, &dev)?.reshape((3, 2))?; + let out = tensor.slice_assign(&[&(..3), &(3..5)], &src)?; + assert_eq!( + out.to_vec2::()?, + &[ + [0, 1, 2, 100, 101], + [5, 6, 7, 102, 103], + [10, 11, 12, 104, 105], + [15, 16, 17, 18, 19] + ] + ); + Ok(()) +} + +#[test] +fn slice_assign_from() -> Result<()> { + let dev = Device::Cpu; + + let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?; + let src = Tensor::arange(100u32, (2 * 3) + 100, &dev)?.reshape((3, 2))?; + let out = tensor.slice_assign(&[&(1..), &(0..2)], &src)?; assert_eq!( out.to_vec2::()?, &[ [0, 1, 2, 3, 4], - [2, 3, 7, 8, 9], - [4, 5, 12, 13, 14], + [100, 101, 7, 8, 9], + [102, 103, 12, 13, 14], + [104, 105, 17, 18, 19] + ] + ); + Ok(()) +} + +#[test] +fn slice_assign_to_incl() -> Result<()> { + let dev = Device::Cpu; + + let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?; + let src = Tensor::arange(100u32, (2 * 3) + 100, &dev)?.reshape((3, 2))?; + let out = tensor.slice_assign(&[&(..=2), &(1..3)], &src)?; + assert_eq!( + out.to_vec2::()?, + &[ + [0, 100, 101, 3, 4], + [5, 102, 103, 8, 9], + [10, 104, 105, 13, 14], + [15, 16, 17, 18, 19] + ] + ); + Ok(()) +} + +#[test] +fn slice_assign_full() -> Result<()> { + let dev = Device::Cpu; + + let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?; + let src = Tensor::arange(100u32, (2 * 4) + 100, &dev)?.reshape((4, 2))?; + let out = tensor.slice_assign(&[&(..), &(3..5)], &src)?; + assert_eq!( + out.to_vec2::()?, + &[ + [0, 1, 2, 100, 101], + [5, 6, 7, 102, 103], + [10, 11, 12, 104, 105], + [15, 16, 17, 106, 107] + ] + ); + Ok(()) +} + +#[test] +fn slice_assign_exact() -> Result<()> { + let dev = Device::Cpu; + + let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?; + let src = Tensor::arange(100u32, 2 + 100, &dev)?.reshape((1, 2))?; + let out = tensor.slice_assign(&[&0, &(3..5)], &src)?; + assert_eq!( + out.to_vec2::()?, + &[ + [0, 1, 2, 100, 101], + [5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], [15, 16, 17, 18, 19] ] ); diff --git a/candle-core/tests/matmul_tests.rs b/candle-core/tests/matmul_tests.rs index c1c16401a8..edca8e1561 100644 --- a/candle-core/tests/matmul_tests.rs +++ b/candle-core/tests/matmul_tests.rs @@ -109,7 +109,53 @@ fn mm_layout(device: &Device) -> Result<()> { Ok(()) } +fn matmul_alpha_beta(device: &Device) -> Result<()> { + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let a = Tensor::from_slice(&data, (2, 2), device)?; + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let b = Tensor::from_slice(&data, (2, 2), device)?; + let data = vec![1.0f32, 1.0, 1.0, 1.0]; + let mut c = Tensor::from_slice(&data, (2, 2), device)?; + + a.matmul_with_alpha_beta(&b, &mut c, None)?; + assert_eq!(c.to_vec2::()?, &[[8.0f32, 11.0], [16.0, 23.0]]); + + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let a = Tensor::from_slice(&data, (2, 2), device)?; + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let b = Tensor::from_slice(&data, (2, 2), device)?; + let data = vec![1.0f32, 1.0, 1.0, 1.0]; + let mut c = Tensor::from_slice(&data, (2, 2), device)?; + + a.matmul_with_alpha_beta(&b, &mut c, Some(2.))?; + assert_eq!(c.to_vec2::()?, &[[15.0f32, 21.0], [31.0, 45.0]]); + Ok(()) +} + +fn matmul_alpha(device: &Device) -> Result<()> { + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let a = Tensor::from_slice(&data, (2, 2), device)?; + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let b = Tensor::from_slice(&data, (2, 2), device)?; + + let c = a.matmul_with_alpha(&b, Some(2.))?; + assert_eq!(c.to_vec2::()?, &[[14.0f32, 20.0], [30.0, 44.0]]); + Ok(()) +} + test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal); +test_device!( + matmul_alpha_beta, + matmul_alpha_beta_cpu, + matmul_alpha_beta_gpu, + matmul_alpha_beta_metal +); +test_device!( + matmul_alpha, + matmul_alpha_cpu, + matmul_alpha_gpu, + matmul_alpha_metal +); test_device!( matmul_bf16, matmul_bf16_cpu, diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index e0cea15c61..ce7ae14720 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -17,6 +17,14 @@ fn ones(device: &Device) -> Result<()> { Tensor::ones((2, 3), DType::U32, device)?.to_vec2::()?, [[1, 1, 1], [1, 1, 1]], ); + assert_eq!( + Tensor::ones((2, 3), DType::I16, device)?.to_vec2::()?, + [[1, 1, 1], [1, 1, 1]], + ); + assert_eq!( + Tensor::ones((2, 3), DType::I32, device)?.to_vec2::()?, + [[1, 1, 1], [1, 1, 1]], + ); assert_eq!( Tensor::ones((2, 3), DType::I64, device)?.to_vec2::()?, [[1, 1, 1], [1, 1, 1]], @@ -848,7 +856,7 @@ fn index_select(device: &Device) -> Result<()> { [9.0, 10.0, 11.0] ] ); - for dtype in [DType::U8, DType::U32, DType::I64] { + for dtype in [DType::U8, DType::U32, DType::I16, DType::I32, DType::I64] { let ids = ids.to_dtype(dtype)?; let hs = t.index_select(&ids, 1)?; assert_eq!( @@ -1406,3 +1414,15 @@ fn pow() -> Result<()> { ); Ok(()) } + +#[test] +fn unfold() -> Result<()> { + let x = Tensor::arange(0i64, 3 * 2, &Device::Cpu)?.reshape((3, 2))?; + let unfolded = x.unfold(0, 2, 1)?; + dbg!(&unfolded); + assert_eq!( + unfolded.to_vec3::()?, + vec![[[0i64, 2], [1, 3]], [[2, 4], [3, 5]]] + ); + Ok(()) +} diff --git a/candle-examples/examples/mamba-minimal/model.rs b/candle-examples/examples/mamba-minimal/model.rs index 4a0a345d17..b8fa01a51a 100644 --- a/candle-examples/examples/mamba-minimal/model.rs +++ b/candle-examples/examples/mamba-minimal/model.rs @@ -2,7 +2,7 @@ /// https://github.com/johnma2006/mamba-minimal/blob/master/model.py /// Simple, minimal implementation of Mamba in one file of PyTorch. use candle::{IndexOp, Module, Result, Tensor, D}; -use candle_nn::{RmsNorm, VarBuilder}; +use candle_nn::{layer_norm::RmsNormNonQuantized, RmsNorm, VarBuilder}; use candle_transformers::models::with_tracing::{linear, linear_no_bias, Linear}; @@ -144,12 +144,12 @@ impl Module for MambaBlock { #[derive(Clone, Debug)] pub struct ResidualBlock { mixer: MambaBlock, - norm: RmsNorm, + norm: RmsNorm, } impl ResidualBlock { pub fn new(cfg: &Config, vb: VarBuilder) -> Result { - let norm = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm"))?; + let norm = candle_nn::rms_norm_non_quant(cfg.d_model, 1e-5, vb.pp("norm"))?; let mixer = MambaBlock::new(cfg, vb.pp("mixer"))?; Ok(Self { mixer, norm }) } @@ -166,7 +166,7 @@ impl Module for ResidualBlock { pub struct Model { embedding: candle_nn::Embedding, layers: Vec, - norm_f: RmsNorm, + norm_f: RmsNorm, lm_head: Linear, } @@ -179,7 +179,7 @@ impl Model { let layer = ResidualBlock::new(cfg, vb_l.pp(layer_idx))?; layers.push(layer) } - let norm_f = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm_f"))?; + let norm_f = candle_nn::rms_norm_non_quant(cfg.d_model, 1e-5, vb.pp("norm_f"))?; let lm_head = Linear::from_weights(embedding.embeddings().clone(), None); Ok(Self { embedding, diff --git a/candle-examples/examples/mobileclip/main.rs b/candle-examples/examples/mobileclip/main.rs index d9615c43b8..d505fc7c48 100644 --- a/candle-examples/examples/mobileclip/main.rs +++ b/candle-examples/examples/mobileclip/main.rs @@ -60,6 +60,7 @@ fn load_images>( image_size: usize, ) -> anyhow::Result { let mut images = vec![]; + for path in paths { let tensor = candle_examples::imagenet::load_image_with_std_mean( path, @@ -69,7 +70,9 @@ fn load_images>( )?; images.push(tensor); } + let images = Tensor::stack(&images, 0)?; + Ok(images) } @@ -77,17 +80,24 @@ pub fn main() -> anyhow::Result<()> { let args = Args::parse(); let model_name = args.which.model_name(); + let api = hf_hub::api::sync::Api::new()?; let api = api.model(model_name); + let model_file = if args.use_pth { api.get("open_clip_pytorch_model.bin")? } else { api.get("open_clip_model.safetensors")? }; + let tokenizer = api.get("tokenizer.json")?; + let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?; + let config = &args.which.config(); + let device = candle_examples::device(args.cpu)?; + let vec_imgs = match args.images { Some(imgs) => imgs, None => vec![ @@ -95,7 +105,9 @@ pub fn main() -> anyhow::Result<()> { "candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(), ], }; + let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?; + let vb = if args.use_pth { VarBuilder::from_pth(&model_file, DType::F32, &device)? } else { @@ -103,15 +115,22 @@ pub fn main() -> anyhow::Result<()> { }; let model = mobileclip::MobileClipModel::new(vb, config)?; + let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?; + let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?; + let softmax_image = softmax(&logits_per_image, 1)?; + let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::()?; + println!("softmax_image_vec: {:?}", softmax_image_vec); + let probability_vec = softmax_image_vec .iter() .map(|v| v * 100.0) .collect::>(); + let probability_per_image = probability_vec.len() / vec_imgs.len(); for (i, img) in vec_imgs.iter().enumerate() { @@ -152,6 +171,7 @@ pub fn tokenize_sequences( }; let mut tokens = vec![]; + for seq in vec_seq.clone() { let encoding = tokenizer.encode(seq, true).map_err(E::msg)?; tokens.push(encoding.get_ids().to_vec()); @@ -165,6 +185,8 @@ pub fn tokenize_sequences( token_vec.extend(vec![pad_id; len_diff]); } } + let input_ids = Tensor::new(tokens, device)?; + Ok((input_ids, vec_seq)) } diff --git a/candle-examples/examples/quantized-phi/main.rs b/candle-examples/examples/quantized-phi/main.rs index f567ce2d36..9ab024c20f 100644 --- a/candle-examples/examples/quantized-phi/main.rs +++ b/candle-examples/examples/quantized-phi/main.rs @@ -15,7 +15,6 @@ use candle_transformers::generation::{LogitsProcessor, Sampling}; use candle_examples::token_output_stream::TokenOutputStream; use candle_transformers::models::quantized_llama::ModelWeights as Phi3b; use candle_transformers::models::quantized_phi::ModelWeights as Phi2; -use candle_transformers::models::quantized_phi3::ModelWeights as Phi3; const DEFAULT_PROMPT: &str = "Write a function to count prime numbers up to N. "; @@ -23,8 +22,6 @@ const DEFAULT_PROMPT: &str = "Write a function to count prime numbers up to N. " enum Which { #[value(name = "phi-2")] Phi2, - #[value(name = "phi-3")] - Phi3, /// Alternative implementation of phi-3, based on llama. #[value(name = "phi-3b")] Phi3b, @@ -103,7 +100,7 @@ impl Args { let api = hf_hub::api::sync::Api::new()?; let repo = match self.which { Which::Phi2 => "microsoft/phi-2", - Which::Phi3 | Which::Phi3b => "microsoft/Phi-3-mini-4k-instruct", + Which::Phi3b => "microsoft/Phi-3-mini-4k-instruct", }; let api = api.model(repo.to_string()); api.get("tokenizer.json")? @@ -118,11 +115,6 @@ impl Args { None => { let (repo, filename, revision) = match self.which { Which::Phi2 => ("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf", "main"), - Which::Phi3 => ( - "microsoft/Phi-3-mini-4k-instruct-gguf", - "Phi-3-mini-4k-instruct-q4.gguf", - "main", - ), Which::Phi3b => ( "microsoft/Phi-3-mini-4k-instruct-gguf", "Phi-3-mini-4k-instruct-q4.gguf", @@ -156,7 +148,6 @@ fn format_size(size_in_bytes: usize) -> String { enum Model { Phi2(Phi2), - Phi3(Phi3), Phi3b(Phi3b), } @@ -164,7 +155,6 @@ impl Model { fn forward(&mut self, xs: &Tensor, pos: usize) -> candle::Result { match self { Self::Phi2(m) => m.forward(xs, pos), - Self::Phi3(m) => m.forward(xs, pos), Self::Phi3b(m) => m.forward(xs, pos), } } @@ -216,12 +206,6 @@ fn main() -> anyhow::Result<()> { ); match args.which { Which::Phi2 => Model::Phi2(Phi2::from_gguf(model, &mut file, &device)?), - Which::Phi3 => Model::Phi3(Phi3::from_gguf( - args.use_flash_attn, - model, - &mut file, - &device, - )?), Which::Phi3b => Model::Phi3b(Phi3b::from_gguf(model, &mut file, &device)?), } }; diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs index 53fec5deab..7edd81fb48 100644 --- a/candle-flash-attn/build.rs +++ b/candle-flash-attn/build.rs @@ -4,6 +4,8 @@ use anyhow::{Context, Result}; use std::path::PathBuf; +const CUDA_NVCC_FLAGS: Option<&'static str> = option_env!("CUDA_NVCC_FLAGS"); + const KERNEL_FILES: [&str; 33] = [ "kernels/flash_api.cu", "kernels/flash_fwd_hdim128_fp16_sm80.cu", @@ -72,7 +74,7 @@ fn main() -> Result<()> { }; let kernels = KERNEL_FILES.iter().collect(); - let builder = bindgen_cuda::Builder::default() + let mut builder = bindgen_cuda::Builder::default() .kernel_paths(kernels) .out_dir(build_dir.clone()) .arg("-std=c++17") @@ -87,13 +89,30 @@ fn main() -> Result<()> { .arg("--use_fast_math") .arg("--verbose"); + // https://github.com/EricLBuehler/mistral.rs/issues/286 + // https://github.com/huggingface/candle-flash-attn-v1/pull/2 + if let Some(cuda_nvcc_flags_env) = CUDA_NVCC_FLAGS { + builder = builder.arg("--compiler-options"); + builder = builder.arg(cuda_nvcc_flags_env); + } + let out_file = build_dir.join("libflashattention.a"); builder.build_lib(out_file); println!("cargo:rustc-link-search={}", build_dir.display()); println!("cargo:rustc-link-lib=flashattention"); println!("cargo:rustc-link-lib=dylib=cudart"); - println!("cargo:rustc-link-lib=dylib=stdc++"); + // https://github.com/denoland/rusty_v8/blob/20b2989186d1ecdf4c291d0706ff9eb1baaf2cfd/build.rs#L602 + let target = std::env::var("TARGET").unwrap(); + if target.contains("msvc") { + // nothing to link to + } else if target.contains("apple") || target.contains("freebsd") || target.contains("openbsd") { + println!("cargo:rustc-link-lib=dylib=c++"); + } else if target.contains("android") { + println!("cargo:rustc-link-lib=dylib=c++_shared"); + } else { + println!("cargo:rustc-link-lib=dylib=stdc++"); + } Ok(()) } diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu index 4ca41b0a16..ca5f2b255d 100644 --- a/candle-flash-attn/kernels/flash_api.cu +++ b/candle-flash-attn/kernels/flash_api.cu @@ -45,6 +45,7 @@ extern "C" void run_mha( uint32_t d, uint32_t d_rounded, float softmax_scale, + float softcap, uint32_t seqlen_q, uint32_t seqlen_k, @@ -99,8 +100,16 @@ extern "C" void run_mha( params.d_rounded = d_rounded; // Set the different scale values. - params.scale_softmax = softmax_scale; - params.scale_softmax_log2 = softmax_scale * M_LOG2E; + if (softcap > 0.0) { + params.softcap = softmax_scale / softcap; + params.scale_softmax = softcap; + params.scale_softmax_log2 = softcap * M_LOG2E; + }else{ + // Remove potential NaN + params.softcap = 0.0; + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + } params.p_dropout = 1.; // probability to keep params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); diff --git a/candle-flash-attn/kernels/flash_fwd_launch_template.h b/candle-flash-attn/kernels/flash_fwd_launch_template.h index 9e5449d736..29918c87c9 100644 --- a/candle-flash-attn/kernels/flash_fwd_launch_template.h +++ b/candle-flash-attn/kernels/flash_fwd_launch_template.h @@ -172,7 +172,11 @@ template void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 32; DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + if constexpr(!Is_dropout) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } }); } diff --git a/candle-flash-attn/src/ffi.rs b/candle-flash-attn/src/ffi.rs index ca65520be5..fe565beae6 100644 --- a/candle-flash-attn/src/ffi.rs +++ b/candle-flash-attn/src/ffi.rs @@ -34,6 +34,7 @@ extern "C" { d: u32, d_rounded: u32, softmax_scale: f32, + softcap: f32, seqlen_q: u32, seqlen_k: u32, diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index f171a9868f..5d991f0075 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -8,6 +8,7 @@ use half::{bf16, f16}; pub struct FlashAttn { pub softmax_scale: f32, + pub softcap: Option, pub alibi_slopes: Option, pub window_size_left: Option, pub window_size_right: Option, @@ -193,6 +194,7 @@ impl FlashAttn { /* d */ head_size as u32, /* d_rounded */ head_size_rounded as u32, /* softmax_scale*/ self.softmax_scale, + /* softcap */ self.softcap.unwrap_or(0.0), /* seqlen_q */ seqlen_q as u32, /* seqlen_k */ seqlen_k as u32, /* seqlen_q_rounded */ seqlen_q_rounded as u32, @@ -262,12 +264,25 @@ pub fn flash_attn( v: &Tensor, softmax_scale: f32, causal: bool, +) -> Result { + flash_attn_softcap(q, k, v, softmax_scale, None, causal) +} + +/// Equivalent to [`flash_attn`], but with softcap support +pub fn flash_attn_softcap( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + softcap: Option, + causal: bool, ) -> Result { let window_size_left = None; let window_size_right = if causal { Some(0) } else { None }; let op = FlashAttn { softmax_scale, + softcap, alibi_slopes: None, window_size_left, window_size_right, @@ -302,9 +317,31 @@ pub fn flash_attn_windowed( softmax_scale: f32, window_size_left: Option, window_size_right: Option, +) -> Result { + flash_attn_windowed_softcap( + q, + k, + v, + softmax_scale, + None, + window_size_left, + window_size_right, + ) +} + +/// Equivalent to [`flash_attn_windowed`], but with softcap support. +pub fn flash_attn_windowed_softcap( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + softcap: Option, + window_size_left: Option, + window_size_right: Option, ) -> Result { let op = FlashAttn { softmax_scale, + softcap, alibi_slopes: None, window_size_left, window_size_right, @@ -333,12 +370,26 @@ pub fn flash_attn_alibi( alibi_slopes: &Tensor, softmax_scale: f32, causal: bool, +) -> Result { + flash_attn_alibi_softcap(q, k, v, alibi_slopes, softmax_scale, None, causal) +} + +/// Equivalent to [`flash_attn_alibi`], but with softcap support. +pub fn flash_attn_alibi_softcap( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + softmax_scale: f32, + softcap: Option, + causal: bool, ) -> Result { let window_size_left = None; let window_size_right = if causal { Some(0) } else { None }; let op = FlashAttn { softmax_scale, + softcap, alibi_slopes: Some(alibi_slopes.clone()), window_size_left, window_size_right, @@ -378,6 +429,7 @@ pub fn flash_attn_alibi_windowed( ) -> Result { let op = FlashAttn { softmax_scale, + softcap: None, alibi_slopes: Some(alibi_slopes.clone()), window_size_left, window_size_right, @@ -387,6 +439,7 @@ pub fn flash_attn_alibi_windowed( struct FlashAttnVarLen { pub softmax_scale: f32, + pub softcap: Option, pub max_seqlen_q: usize, pub max_seqlen_k: usize, pub seqlens_q: Tensor, @@ -434,9 +487,9 @@ impl FlashAttnVarLen { None => candle::bail!("seqlens_k has to be contiguous"), }; - let q = q.as_cuda_slice::()?; - let k = k.as_cuda_slice::()?; - let v = v.as_cuda_slice::()?; + let q = q.as_cuda_slice::()?; + let k = k.as_cuda_slice::()?; + let v = v.as_cuda_slice::()?; let q = q.slice(q_l.start_offset()..); let k = k.slice(k_l.start_offset()..); let v = v.slice(v_l.start_offset()..); @@ -548,7 +601,7 @@ impl FlashAttnVarLen { let seqlen_k_rounded = round_multiple(self.max_seqlen_k, 128); let elem_count = out_shape.elem_count(); - let dst = unsafe { dev.alloc::(elem_count) }.w()?; + let dst = unsafe { dev.alloc::(elem_count) }.w()?; let softmax_lse = dev .alloc_zeros::(batch_size * num_heads * self.max_seqlen_q) .w()?; @@ -605,6 +658,7 @@ impl FlashAttnVarLen { /* d */ head_size as u32, /* d_rounded */ head_size_rounded as u32, /* softmax_scale*/ self.softmax_scale, + /* softcap */ self.softcap.unwrap_or(0.0), /* seqlen_q */ self.max_seqlen_q as u32, /* seqlen_k */ self.max_seqlen_k as u32, /* seqlen_q_rounded */ seqlen_q_rounded as u32, @@ -686,12 +740,40 @@ pub fn flash_attn_varlen( max_seqlen_k: usize, softmax_scale: f32, causal: bool, +) -> Result { + flash_attn_varlen_softcap( + q, + k, + v, + seqlens_q, + seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + None, + causal, + ) +} + +/// Equivalent to [`flash_attn_varlen`], but with softcap support. +pub fn flash_attn_varlen_softcap( + q: &Tensor, + k: &Tensor, + v: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + softcap: Option, + causal: bool, ) -> Result { let window_size_left = None; let window_size_right = if causal { Some(0) } else { None }; let op = FlashAttnVarLen { softmax_scale, + softcap, max_seqlen_q, max_seqlen_k, seqlens_q: seqlens_q.clone(), @@ -742,9 +824,39 @@ pub fn flash_attn_varlen_windowed( softmax_scale: f32, window_size_left: Option, window_size_right: Option, +) -> Result { + flash_attn_varlen_windowed_softcap( + q, + k, + v, + seqlens_q, + seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + None, + window_size_left, + window_size_right, + ) +} + +/// Equivalent to [`flash_attn_varlen_windowed`], but with softcap support. +pub fn flash_attn_varlen_windowed_softcap( + q: &Tensor, + k: &Tensor, + v: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + softcap: Option, + window_size_left: Option, + window_size_right: Option, ) -> Result { let op = FlashAttnVarLen { softmax_scale, + softcap, max_seqlen_q, max_seqlen_k, seqlens_q: seqlens_q.clone(), @@ -789,12 +901,42 @@ pub fn flash_attn_varlen_alibi( max_seqlen_k: usize, softmax_scale: f32, causal: bool, +) -> Result { + flash_attn_varlen_alibi_softcap( + q, + k, + v, + alibi_slopes, + seqlens_q, + seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + None, + causal, + ) +} + +/// Equivalent to [`flash_attn_varlen_alibi`], but with softcap support +pub fn flash_attn_varlen_alibi_softcap( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + softcap: Option, + causal: bool, ) -> Result { let window_size_left = None; let window_size_right = if causal { Some(0) } else { None }; let op = FlashAttnVarLen { softmax_scale, + softcap, max_seqlen_q, max_seqlen_k, seqlens_q: seqlens_q.clone(), @@ -847,9 +989,41 @@ pub fn flash_attn_varlen_alibi_windowed( softmax_scale: f32, window_size_left: Option, window_size_right: Option, +) -> Result { + flash_attn_varlen_alibi_windowed_softcap( + q, + k, + v, + alibi_slopes, + seqlens_q, + seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + None, + window_size_left, + window_size_right, + ) +} + +/// Equivalent to [`flash_attn_varlen_alibi_windowed`], but with softcap support. +pub fn flash_attn_varlen_alibi_windowed_softcap( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + softcap: Option, + window_size_left: Option, + window_size_right: Option, ) -> Result { let op = FlashAttnVarLen { softmax_scale, + softcap, max_seqlen_q, max_seqlen_k, seqlens_q: seqlens_q.clone(), diff --git a/candle-flash-attn/tests/flash_attn_tests.rs b/candle-flash-attn/tests/flash_attn_tests.rs index 250added04..fd51152ee6 100644 --- a/candle-flash-attn/tests/flash_attn_tests.rs +++ b/candle-flash-attn/tests/flash_attn_tests.rs @@ -15,12 +15,23 @@ fn to_vec3_round(t: Tensor, digits: i32) -> Result>>> { Ok(t) } -fn fa_acausal(q: &Tensor, k: &Tensor, v: &Tensor, softmax_scale: f32) -> Result { +fn fa_acausal( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + softcap: Option, +) -> Result { let in_dtype = q.dtype(); let q = q.to_dtype(DType::F32)?; let k = k.to_dtype(DType::F32)?; let v = v.to_dtype(DType::F32)?; - let att = (q.matmul(&k.t()?)? * softmax_scale as f64)?; + let mut att = (q.matmul(&k.t()?)? * softmax_scale as f64)?; + if let Some(softcap) = softcap { + att = (att / softcap as f64)?; + att = att.tanh()?; + att = (att * softcap as f64)?; + } let att = candle_nn::ops::softmax(&att, D::Minus1)?; // Convert to contiguous as matmul doesn't support strided vs for now. let output = att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?; @@ -37,7 +48,7 @@ fn flash_attn_acausal() -> Result<()> { let v = (&q / 50.)?; let q = (&q / 30.)?; - let ys1 = fa_acausal(&q, &k, &v, 0.5)?; + let ys1 = fa_acausal(&q, &k, &v, 0.5, None)?; let ys1 = ys1.i(0)?.to_dtype(DType::F32)?; let ys2 = { let q = q.transpose(1, 2)?; @@ -133,3 +144,84 @@ fn flash_attn_varlen() -> Result<()> { ); Ok(()) } + +#[test] +fn flash_attn_acausal_softcap() -> Result<()> { + let device = Device::new_cuda(0)?; + let q = Tensor::arange(0u32, 48, &device)? + .to_dtype(DType::F16)? + .reshape((1, 3, 2, 8))?; + let k = (&q / 40.)?; + let v = (&q / 50.)?; + let q = (&q / 30.)?; + + let ys1 = fa_acausal(&q, &k, &v, 0.5, Some(30.))?; + let ys1 = ys1.i(0)?.to_dtype(DType::F32)?; + let ys2 = { + let q = q.transpose(1, 2)?; + let k = k.transpose(1, 2)?; + let v = v.transpose(1, 2)?; + candle_flash_attn::flash_attn_softcap(&q, &k, &v, 0.5, Some(30.), false)?.transpose(1, 2)? + }; + let ys2 = ys2.i(0)?.to_dtype(DType::F32)?; + let diff = ys1.sub(&ys2)?.abs()?.flatten_all()?.max(0)?; + + assert_eq!(ys1.dims(), &[3, 2, 8]); + assert_eq!(ys2.dims(), &[3, 2, 8]); + assert!(diff.to_vec0::()?.abs() < 1e-5); + Ok(()) +} + +#[test] +fn flash_attn_varlen_softcap() -> Result<()> { + let device = Device::new_cuda(0)?; + let q = Tensor::arange(0u32, 48, &device)? + .to_dtype(DType::F16)? + .reshape((3, 2, 8))?; + let k = (&q / 40.)?; + let v = (&q / 50.)?; + let q = (&q / 30.)?; + + let seqlens_q = Tensor::new(&[0u32, 2u32], &device)?; + let seqlens_k = Tensor::new(&[0u32, 2u32], &device)?; + + let ys = { + let q = q.transpose(0, 1)?; + let k = k.transpose(0, 1)?; + let v = v.transpose(0, 1)?; + candle_flash_attn::flash_attn_varlen_softcap( + &q, + &k, + &v, + &seqlens_q, + &seqlens_k, + 32, + 32, + 0.5, + Some(30.), + false, + )? + .transpose(0, 1)? + }; + let ys = ys.to_dtype(DType::F32)?; + + assert_eq!(ys.dims(), &[3, 2, 8]); + assert_eq!( + to_vec3_round(ys, 4)?, + &[ + [ + [0.0837, 0.1038, 0.1238, 0.1438, 0.1637, 0.1837, 0.2037, 0.2238], + [0.0922, 0.1122, 0.1322, 0.1522, 0.1721, 0.1921, 0.2122, 0.2322] + ], + [ + [0.4204, 0.4404, 0.4604, 0.4805, 0.5005, 0.5205, 0.5405, 0.5605], + [0.428, 0.448, 0.468, 0.488, 0.5078, 0.5278, 0.5479, 0.5679] + ], + [ + [0.7549, 0.7749, 0.7949, 0.8149, 0.835, 0.855, 0.875, 0.895], + [0.7607, 0.7808, 0.8008, 0.8208, 0.8408, 0.8608, 0.8809, 0.9009] + ] + ] + ); + Ok(()) +} diff --git a/candle-kernels/src/affine.cu b/candle-kernels/src/affine.cu index 540d0819f5..ef75dffd36 100644 --- a/candle-kernels/src/affine.cu +++ b/candle-kernels/src/affine.cu @@ -1,7 +1,7 @@ #include "cuda_utils.cuh" #include -#define AFFINE_OP(TYPENAME, FN_NAME) \ +#define AFFINE_OP(TYPENAME, FN_NAME, AFFINE) \ extern "C" __global__ void FN_NAME( \ const size_t numel, \ const size_t num_dims, \ @@ -16,28 +16,34 @@ extern "C" __global__ void FN_NAME( \ if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \ for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ TYPENAME x = inp ? inp[i] : out[i]; \ - out[i] = x * mul + add; \ + out[i] = AFFINE; \ } \ } \ else { \ for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \ TYPENAME x = inp ? inp[strided_i] : out[i]; \ - out[i] = x * mul + add; \ + out[i] = AFFINE; \ } \ } \ } \ #if __CUDA_ARCH__ >= 800 -AFFINE_OP(__nv_bfloat16, affine_bf16) +AFFINE_OP(__nv_bfloat16, affine_bf16, x * mul + add) + +#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) + +AFFINE_OP(__nv_fp8_e4m3, affine_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) * F8E4M3_TO_FLOAT(mul) + F8E4M3_TO_FLOAT(add))) #endif #if __CUDA_ARCH__ >= 530 -AFFINE_OP(__half, affine_f16) +AFFINE_OP(__half, affine_f16, x * mul + add) #endif -AFFINE_OP(float, affine_f32) -AFFINE_OP(double, affine_f64) -AFFINE_OP(uint8_t, affine_u8) -AFFINE_OP(uint32_t, affine_u32) -AFFINE_OP(int64_t, affine_i64) +AFFINE_OP(float, affine_f32, x * mul + add) +AFFINE_OP(double, affine_f64, x * mul + add) +AFFINE_OP(uint8_t, affine_u8, x * mul + add) +AFFINE_OP(uint32_t, affine_u32, x * mul + add) +AFFINE_OP(int16_t, affine_i16, x * mul + add) +AFFINE_OP(int32_t, affine_i32, x * mul + add) +AFFINE_OP(int64_t, affine_i64, x * mul + add) diff --git a/candle-kernels/src/binary.cu b/candle-kernels/src/binary.cu index d44e3b20ee..7bda3e463e 100644 --- a/candle-kernels/src/binary.cu +++ b/candle-kernels/src/binary.cu @@ -14,6 +14,21 @@ BINARY_OP_OUT(__nv_bfloat16, uint8_t, lt_bf16, x < y) BINARY_OP_OUT(__nv_bfloat16, uint8_t, le_bf16, x <= y) BINARY_OP_OUT(__nv_bfloat16, uint8_t, gt_bf16, x > y) BINARY_OP_OUT(__nv_bfloat16, uint8_t, ge_bf16, x >= y) + +#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) + +BINARY_OP(__nv_fp8_e4m3, badd_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) + F8E4M3_TO_FLOAT(y))) +BINARY_OP(__nv_fp8_e4m3, bdiv_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) / F8E4M3_TO_FLOAT(y))) +BINARY_OP(__nv_fp8_e4m3, bmul_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) * F8E4M3_TO_FLOAT(y))) +BINARY_OP(__nv_fp8_e4m3, bsub_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) - F8E4M3_TO_FLOAT(y))) +BINARY_OP(__nv_fp8_e4m3, bmaximum_f8_e4m3, maxg(x, y)) +BINARY_OP(__nv_fp8_e4m3, bminimum_f8_e4m3, ming(x, y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, eq_f8_e4m3, F8E4M3_TO_FLOAT(x) == F8E4M3_TO_FLOAT(y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, ne_f8_e4m3, F8E4M3_TO_FLOAT(x) != F8E4M3_TO_FLOAT(y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, lt_f8_e4m3, F8E4M3_TO_FLOAT(x) < F8E4M3_TO_FLOAT(y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, le_f8_e4m3, F8E4M3_TO_FLOAT(x) <= F8E4M3_TO_FLOAT(y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, gt_f8_e4m3, F8E4M3_TO_FLOAT(x) > F8E4M3_TO_FLOAT(y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, ge_f8_e4m3, F8E4M3_TO_FLOAT(x) >= F8E4M3_TO_FLOAT(y)) #endif #if __CUDA_ARCH__ >= 530 @@ -35,65 +50,89 @@ BINARY_OP(float, badd_f32, x + y) BINARY_OP(double, badd_f64, x + y); BINARY_OP(uint8_t, badd_u8, x + y); BINARY_OP(uint32_t, badd_u32, x + y); +BINARY_OP(int16_t, badd_i16, x + y); +BINARY_OP(int32_t, badd_i32, x + y); BINARY_OP(int64_t, badd_i64, x + y); BINARY_OP(float, bdiv_f32, x / y) BINARY_OP(double, bdiv_f64, x / y); BINARY_OP(uint8_t, bdiv_u8, x / y); BINARY_OP(uint32_t, bdiv_u32, x / y); +BINARY_OP(int16_t, bdiv_i16, x / y); +BINARY_OP(int32_t, bdiv_i32, x / y); BINARY_OP(int64_t, bdiv_i64, x / y); BINARY_OP(float, bmul_f32, x * y) BINARY_OP(double, bmul_f64, x * y); BINARY_OP(uint8_t, bmul_u8, x * y); BINARY_OP(uint32_t, bmul_u32, x * y); +BINARY_OP(int16_t, bmul_i16, x * y); +BINARY_OP(int32_t, bmul_i32, x * y); BINARY_OP(int64_t, bmul_i64, x * y); BINARY_OP(float, bsub_f32, x - y) BINARY_OP(double, bsub_f64, x - y); BINARY_OP(uint8_t, bsub_u8, x - y); BINARY_OP(uint32_t, bsub_u32, x - y); +BINARY_OP(int16_t, bsub_i16, x - y); +BINARY_OP(int32_t, bsub_i32, x - y); BINARY_OP(int64_t, bsub_i64, x - y); BINARY_OP(float, bminimum_f32, ming(x, y)); BINARY_OP(double, bminimum_f64, ming(x, y)); BINARY_OP(uint8_t, bminimum_u8, ming(x, y)); BINARY_OP(uint32_t, bminimum_u32, ming(x, y)); +BINARY_OP(int16_t, bminimum_i16, ming(x, y)); +BINARY_OP(int32_t, bminimum_i32, ming(x, y)); BINARY_OP(int64_t, bminimum_i64, ming(x, y)); BINARY_OP(float, bmaximum_f32, maxg(x, y)); BINARY_OP(double, bmaximum_f64, maxg(x, y)); BINARY_OP(uint8_t, bmaximum_u8, maxg(x, y)); BINARY_OP(uint32_t, bmaximum_u32, maxg(x, y)); +BINARY_OP(int16_t, bmaximum_i16, maxg(x, y)); +BINARY_OP(int32_t, bmaximum_i32, maxg(x, y)); BINARY_OP(int64_t, bmaximum_i64, maxg(x, y)); BINARY_OP_OUT(float, uint8_t, eq_f32, x == y) BINARY_OP_OUT(double, uint8_t, eq_f64, x == y) BINARY_OP_OUT(uint8_t, uint8_t, eq_u8, x == y) BINARY_OP_OUT(uint32_t, uint8_t, eq_u32, x == y) +BINARY_OP_OUT(int16_t, uint8_t, eq_i16, x == y) +BINARY_OP_OUT(int32_t, uint8_t, eq_i32, x == y) BINARY_OP_OUT(int64_t, uint8_t, eq_i64, x == y) BINARY_OP_OUT(float, uint8_t, ne_f32, x != y) BINARY_OP_OUT(double, uint8_t, ne_f64, x != y) BINARY_OP_OUT(uint8_t, uint8_t, ne_u8, x != y) BINARY_OP_OUT(uint32_t, uint8_t, ne_u32, x != y) +BINARY_OP_OUT(int16_t, uint8_t, ne_i16, x != y) +BINARY_OP_OUT(int32_t, uint8_t, ne_i32, x != y) BINARY_OP_OUT(int64_t, uint8_t, ne_i64, x != y) BINARY_OP_OUT(float, uint8_t, lt_f32, x < y) BINARY_OP_OUT(double, uint8_t, lt_f64, x < y) BINARY_OP_OUT(uint8_t, uint8_t, lt_u8, x < y) BINARY_OP_OUT(uint32_t, uint8_t, lt_u32, x < y) +BINARY_OP_OUT(int16_t, uint8_t, lt_i16, x < y) +BINARY_OP_OUT(int32_t, uint8_t, lt_i32, x < y) BINARY_OP_OUT(int64_t, uint8_t, lt_i64, x < y) BINARY_OP_OUT(float, uint8_t, le_f32, x <= y) BINARY_OP_OUT(double, uint8_t, le_f64, x <= y) BINARY_OP_OUT(uint8_t, uint8_t, le_u8, x <= y) BINARY_OP_OUT(uint32_t, uint8_t, le_u32, x <= y) +BINARY_OP_OUT(int16_t, uint8_t, le_i16, x <= y) +BINARY_OP_OUT(int32_t, uint8_t, le_i32, x <= y) BINARY_OP_OUT(int64_t, uint8_t, le_i64, x <= y) BINARY_OP_OUT(float, uint8_t, gt_f32, x > y) BINARY_OP_OUT(double, uint8_t, gt_f64, x > y) BINARY_OP_OUT(uint8_t, uint8_t, gt_u8, x > y) BINARY_OP_OUT(uint32_t, uint8_t, gt_u32, x > y) +BINARY_OP_OUT(int16_t, uint8_t, gt_i16, x > y) +BINARY_OP_OUT(int32_t, uint8_t, gt_i32, x > y) BINARY_OP_OUT(int64_t, uint8_t, gt_i64, x > y) BINARY_OP_OUT(float, uint8_t, ge_f32, x >= y) BINARY_OP_OUT(double, uint8_t, ge_f64, x >= y) BINARY_OP_OUT(uint8_t, uint8_t, ge_u8, x >= y) BINARY_OP_OUT(uint32_t, uint8_t, ge_u32, x >= y) +BINARY_OP_OUT(int16_t, uint8_t, ge_i16, x >= y) +BINARY_OP_OUT(int32_t, uint8_t, ge_i32, x >= y) BINARY_OP_OUT(int64_t, uint8_t, ge_i64, x >= y) diff --git a/candle-kernels/src/cast.cu b/candle-kernels/src/cast.cu index 90f5e7ba48..7176825b8d 100644 --- a/candle-kernels/src/cast.cu +++ b/candle-kernels/src/cast.cu @@ -83,6 +83,8 @@ CAST_OP(double, __nv_bfloat16, cast_f64_bf16) CAST_THROUGH_OP(__nv_bfloat16, uint8_t, float, cast_bf16_u8) CAST_THROUGH_OP(__nv_bfloat16, __half, float, cast_bf16_f16) CAST_THROUGH_OP(__half, __nv_bfloat16, float, cast_f16_bf16) +CAST_THROUGH_OP(int32_t, __nv_bfloat16, float, cast_i32_bf16) +CAST_THROUGH_OP(__nv_bfloat16, int32_t, float, cast_bf16_i32) #else #include #if CUDA_VERSION >= 11000 @@ -94,6 +96,22 @@ CAST_THROUGH_OP(__nv_bfloat16, double, float, cast_bf16_f64) CAST_THROUGH_OP(__half, __nv_bfloat16, float, cast_f16_bf16) CAST_THROUGH_OP(double, __nv_bfloat16, float, cast_f64_bf16) CAST_THROUGH_OP(uint8_t, __nv_bfloat16, float, cast_u8_bf16) +CAST_THROUGH_OP(int32_t, __nv_bfloat16, float, cast_i32_bf16) +CAST_THROUGH_OP(__nv_bfloat16, int32_t, float, cast_bf16_i32) +CAST_THROUGH_OP(__nv_bfloat16, __nv_fp8_e4m3, float, cast_bf16_f8_e4m3) + +CAST_OP(__nv_fp8_e4m3, float, cast_f8_e4m3_f32) +CAST_OP(float, __nv_fp8_e4m3, cast_f32_f8_e4m3) +CAST_THROUGH_OP(__nv_fp8_e4m3, uint8_t, float, cast_f8_e4m3_u8) +CAST_THROUGH_OP(__nv_fp8_e4m3, __half, float, cast_f8_e4m3_f16) +CAST_THROUGH_OP(__nv_fp8_e4m3, double, float, cast_f8_e4m3_f64) +CAST_THROUGH_OP(__half, __nv_fp8_e4m3, float, cast_f16_f8_e4m3) +CAST_THROUGH_OP(double, __nv_fp8_e4m3, float, cast_f64_f8_e4m3) +CAST_THROUGH_OP(uint8_t, __nv_fp8_e4m3, float, cast_u8_f8_e4m3) +CAST_THROUGH_OP(int32_t, __nv_fp8_e4m3, float, cast_i32_f8_e4m3) +CAST_THROUGH_OP(__nv_fp8_e4m3, int32_t, float, cast_f8_e4m3_i32) +CAST_THROUGH_OP(__nv_fp8_e4m3, __nv_bfloat16, float, cast_f8_e4m3_bf16) +CAST_THROUGH_OP(__nv_bfloat16, __nv_fp8_e4m3, float, cast_bf16_f8_e4m3) #endif #endif @@ -108,34 +126,62 @@ CAST_OP(uint8_t, __half, cast_u8_f16 ) CAST_OP(uint32_t, __half, cast_u32_f16) CAST_OP(float, __half, cast_f32_f16) CAST_OP(double, __half, cast_f64_f16) +CAST_OP(int32_t, __half, cast_i32_f16 ) +CAST_THROUGH_OP(__half, int32_t, float, cast_f16_i32) #endif CAST_OP(uint32_t, uint32_t, cast_u32_u32) CAST_OP(uint32_t, uint8_t, cast_u32_u8 ) CAST_OP(uint32_t, int64_t, cast_u32_i64 ) +CAST_OP(uint32_t, int32_t, cast_u32_i32 ) +CAST_OP(uint32_t, int16_t, cast_u32_i16 ) CAST_OP(uint32_t, float, cast_u32_f32) CAST_OP(uint32_t, double, cast_u32_f64) CAST_OP(uint8_t, uint32_t, cast_u8_u32) CAST_OP(uint8_t, uint8_t, cast_u8_u8 ) +CAST_OP(uint8_t, int16_t, cast_u8_i16 ) +CAST_OP(uint8_t, int32_t, cast_u8_i32 ) CAST_OP(uint8_t, int64_t, cast_u8_i64 ) CAST_OP(uint8_t, float, cast_u8_f32) CAST_OP(uint8_t, double, cast_u8_f64) CAST_OP(int64_t, uint32_t, cast_i64_u32) CAST_OP(int64_t, uint8_t, cast_i64_u8 ) +CAST_OP(int64_t, int16_t, cast_i64_i16 ) +CAST_OP(int64_t, int32_t, cast_i64_i32 ) CAST_OP(int64_t, int64_t, cast_i64_i64 ) CAST_OP(int64_t, float, cast_i64_f32) CAST_OP(int64_t, double, cast_i64_f64) +CAST_OP(int32_t, uint32_t, cast_i32_u32) +CAST_OP(int32_t, uint8_t, cast_i32_u8 ) +CAST_OP(int32_t, int64_t, cast_i32_i64 ) +CAST_OP(int32_t, int32_t, cast_i32_i32 ) +CAST_OP(int32_t, int16_t, cast_i32_i16 ) +CAST_OP(int32_t, float, cast_i32_f32) +CAST_OP(int32_t, double, cast_i32_f64) + +CAST_OP(int16_t, uint32_t, cast_i16_u32) +CAST_OP(int16_t, uint8_t, cast_i16_u8 ) +CAST_OP(int16_t, int64_t, cast_i16_i64 ) +CAST_OP(int16_t, int32_t, cast_i16_i32 ) +CAST_OP(int16_t, int16_t, cast_i16_i16 ) +CAST_OP(int16_t, float, cast_i16_f32) +CAST_OP(int16_t, double, cast_i16_f64) + CAST_OP(float, uint8_t, cast_f32_u8 ) CAST_OP(float, uint32_t, cast_f32_u32) +CAST_OP(float, int16_t, cast_f32_i16 ) +CAST_OP(float, int32_t, cast_f32_i32 ) CAST_OP(float, int64_t, cast_f32_i64 ) CAST_OP(float, float, cast_f32_f32) CAST_OP(float, double, cast_f32_f64) CAST_OP(double, uint8_t, cast_f64_u8 ) CAST_OP(double, uint32_t, cast_f64_u32) +CAST_OP(double, int16_t, cast_f64_i16 ) +CAST_OP(double, int32_t, cast_f64_i32 ) CAST_OP(double, int64_t, cast_f64_i64 ) CAST_OP(double, float, cast_f64_f32) CAST_OP(double, double, cast_f64_f64) diff --git a/candle-kernels/src/compatibility.cuh b/candle-kernels/src/compatibility.cuh index d0791749bb..1e4cf215c1 100644 --- a/candle-kernels/src/compatibility.cuh +++ b/candle-kernels/src/compatibility.cuh @@ -1,5 +1,6 @@ #include "cuda_fp16.h" #include "cuda_bf16.h" +#include "cuda_fp8.h" // Table showing which features are supported on which compute capability // https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu index fa834faa3a..6ca6fd7c2b 100644 --- a/candle-kernels/src/conv.cu +++ b/candle-kernels/src/conv.cu @@ -702,6 +702,18 @@ UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16) IM2COL_OP(__nv_bfloat16, im2col_bf16) IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16) COL2IM1D_OP(__nv_bfloat16, col2im1d_bf16) + +// NOTE: No conv ops for f8 +// CONV1D_OP(__nv_bfloat16, float, conv1d_f8_e5m) +// CONV2D_OP(__nv_fp8_e4m3, float, conv2d_f8_e5m) +// CONVT1D_OP(__nv_fp8_e4m3, float, conv_transpose1d_f8_e5m) +// CONVT2D_OP(__nv_fp8_e4m3, float, conv_transpose2d_f8_e5m) +// AVG_POOL2D_OP(__nv_fp8_e4m3, float, avg_pool2d_f8_e5m) +// MAX_POOL2D_OP(__nv_fp8_e4m3, max_pool2d_f8_e5m) +// UPSAMPLE_NEAREST2D_OP(__nv_fp8_e4m3, upsample_nearest2d_f8_e5m) +// IM2COL_OP(__nv_fp8_e4m3, im2col_f8_e5m) +// IM2COL1D_OP(__nv_fp8_e4m3, im2col1d_f8_e5m) +// COL2IM1D_OP(__nv_fp8_e4m3, col2im1d_f8_e5m) #endif #if __CUDA_ARCH__ >= 530 diff --git a/candle-kernels/src/cuda_utils.cuh b/candle-kernels/src/cuda_utils.cuh index 2673b8aaf1..da8a1fe1c1 100644 --- a/candle-kernels/src/cuda_utils.cuh +++ b/candle-kernels/src/cuda_utils.cuh @@ -115,6 +115,35 @@ __device__ void chunk_sum( } } +__device__ __forceinline__ int GetBlockNum(void) { + return (gridDim.x * gridDim.y * gridDim.z); +} + +__device__ __forceinline__ int GetBlockIdx(void) { + return (blockIdx.z * (gridDim.x * gridDim.y) + blockIdx.y * gridDim.x + + blockIdx.x); +} + +__device__ __forceinline__ int GetThreadNumEachBlock(void) { + return (blockDim.x * blockDim.y * blockDim.z); +} + +__device__ __forceinline__ int GetThreadNum(void) { + return GetBlockNum() * GetThreadNumEachBlock(); +} + +__device__ __forceinline__ int GetThreadIdxInBlock(void) { + return threadIdx.z * (blockDim.x * blockDim.y) + + threadIdx.y * blockDim.x + threadIdx.x; +} + +__device__ __forceinline__ int GetThreadIdx(void) { + int blockIdx = GetBlockIdx(); + int threadNumEachBlock = GetThreadNumEachBlock(); + + return blockIdx * threadNumEachBlock + GetThreadIdxInBlock(); +} + __device__ __forceinline__ bool isnang(float a) { return isnan(a); } __device__ __forceinline__ bool isnang(double a) { return isnan(a); } __device__ __forceinline__ float recipg(float a) { return 1.0 / a; } @@ -152,6 +181,10 @@ __device__ __forceinline__ double absg(double a) { return fabs(a); } __device__ __forceinline__ float copysigng(float a, float b) { return copysignf(a, b); } __device__ __forceinline__ double copysigng(double a, double b) { return copysign(a, b); } +__device__ __forceinline__ int16_t ming(int16_t a, int16_t b) { return min(a, b); } +__device__ __forceinline__ int16_t maxg(int16_t a, int16_t b) { return max(a, b); } +__device__ __forceinline__ int32_t ming(int32_t a, int32_t b) { return min(a, b); } +__device__ __forceinline__ int32_t maxg(int32_t a, int32_t b) { return max(a, b); } __device__ __forceinline__ int64_t ming(int64_t a, int64_t b) { return min(a, b); } __device__ __forceinline__ int64_t maxg(int64_t a, int64_t b) { return max(a, b); } __device__ __forceinline__ uint32_t ming(uint32_t a, uint32_t b) { return min(a, b); } @@ -198,4 +231,27 @@ __device__ __forceinline__ __nv_bfloat16 logg(__nv_bfloat16 a) { return hlog(a); __device__ __forceinline__ __nv_bfloat16 expg(__nv_bfloat16 a) { return hexp(a); } __device__ __forceinline__ __nv_bfloat16 absg(__nv_bfloat16 a) { return __habs(a); } __device__ __forceinline__ __nv_bfloat16 copysigng(__nv_bfloat16 a, __nv_bfloat16 b) { return __float2bfloat16(copysignf(__bfloat162float(a), __bfloat162float(b))); } + +#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) + +__device__ __forceinline__ __nv_fp8_e4m3 powg(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b) { return __nv_fp8_e4m3(powf(F8E4M3_TO_FLOAT(a), F8E4M3_TO_FLOAT(b))); } +__device__ __forceinline__ bool isnang(__nv_fp8_e4m3 a) { return isnanf(F8E4M3_TO_FLOAT(a)); } +__device__ __forceinline__ __nv_fp8_e4m3 sqrtg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(sqrtf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 cosg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(cosf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 sing(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(sinf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 recipg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(1. / F8E4M3_TO_FLOAT(a)); } +__device__ __forceinline__ __nv_fp8_e4m3 maxg(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b) { return __nv_fp8_e4m3(fmaxf(F8E4M3_TO_FLOAT(a), F8E4M3_TO_FLOAT(b))); } +__device__ __forceinline__ __nv_fp8_e4m3 tanhg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(tanhf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 erfg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(erff(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 ceilg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(ceilf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 floorg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(floorf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 roundg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(roundf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 normcdfg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(normcdff(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 ming(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b) { return __nv_fp8_e4m3(fminf(F8E4M3_TO_FLOAT(a), F8E4M3_TO_FLOAT(b))); } +__device__ __forceinline__ __nv_fp8_e4m3 logg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(logf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 expg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(expf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 absg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(fabsf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 copysigng(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b) { return __nv_fp8_e4m3(copysignf(F8E4M3_TO_FLOAT(a), F8E4M3_TO_FLOAT(b))); } + + #endif diff --git a/candle-kernels/src/fill.cu b/candle-kernels/src/fill.cu index ca448d989f..eeea8d4cd4 100644 --- a/candle-kernels/src/fill.cu +++ b/candle-kernels/src/fill.cu @@ -9,6 +9,8 @@ __device__ void fill_with(T *buf, T value, const size_t numel) { } extern "C" __global__ void fill_u8(uint8_t *buf, uint8_t value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_u32(uint32_t *buf, uint32_t value, const size_t numel) { fill_with(buf, value, numel); } +extern "C" __global__ void fill_i16(int16_t *buf, int16_t value, const size_t numel) { fill_with(buf, value, numel); } +extern "C" __global__ void fill_i32(int32_t *buf, int32_t value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_i64(int64_t *buf, int64_t value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_f64(double *buf, double value, const size_t numel) { fill_with(buf, value, numel); } @@ -34,6 +36,8 @@ COPY2D_OP(float, copy2d_f32) COPY2D_OP(double, copy2d_f64) COPY2D_OP(uint8_t, copy2d_u8) COPY2D_OP(uint32_t, copy2d_u32) +COPY2D_OP(int16_t, copy2d_i16) +COPY2D_OP(int32_t, copy2d_i32) COPY2D_OP(int64_t, copy2d_i64) #if __CUDA_ARCH__ >= 530 @@ -43,6 +47,11 @@ COPY2D_OP(__half, copy2d_f16) #if __CUDA_ARCH__ >= 800 #include +#include + extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); } COPY2D_OP(__nv_bfloat16, copy2d_bf16) + +extern "C" __global__ void fill_f8_e4m3(__nv_fp8_e4m3 *buf, __nv_fp8_e4m3 value, const size_t numel) { fill_with(buf, value, numel); } +COPY2D_OP(__nv_fp8_e4m3, copy2d_f8_e4m3) #endif diff --git a/candle-kernels/src/fused_rms_norm.cu b/candle-kernels/src/fused_rms_norm.cu new file mode 100644 index 0000000000..f012e002ad --- /dev/null +++ b/candle-kernels/src/fused_rms_norm.cu @@ -0,0 +1,82 @@ +#include "cuda_fp16.h" +#include + +#define WARP_SIZE 32 + +#ifndef USE_ROCM + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) +#else + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) +#endif + +template +__inline__ __device__ T warpReduceSum(T val) { +#pragma unroll + for (int mask = WARP_SIZE/2; mask > 0; mask >>= 1) + val += VLLM_SHFL_XOR_SYNC(val, mask); + return val; +} + +__inline__ __device__ constexpr int _calculateLaneMask(int warp_size) { + return warp_size - 1; +} + +__inline__ __device__ constexpr int _calculateWidShift(int warp_size) { + return 5 + (warp_size >> 6); +} + +/* Calculate the sum of all elements in a block */ +template +__inline__ __device__ T blockReduceSum(T val) { + static __shared__ T shared[WARP_SIZE]; + constexpr auto LANE_MASK = _calculateLaneMask(WARP_SIZE); + constexpr auto WID_SHIFT = _calculateWidShift(WARP_SIZE); + int lane = threadIdx.x & LANE_MASK; + int wid = threadIdx.x >> WID_SHIFT; + + val = warpReduceSum(val); + + if (lane == 0) + shared[wid] = val; + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (threadIdx.x < (blockDim.x / (WARP_SIZE * 1.0f))) ? shared[lane] : (T)(0.0f); + val = warpReduceSum(val); + return val; +} + +#define RMS_NORM_OP(FN_NAME, TYPENAME)\ +extern "C" __global__ void FN_NAME(\ + TYPENAME* __restrict__ out,\ + const TYPENAME* __restrict__ input,\ + const TYPENAME* __restrict__ weight,\ + const float epsilon,\ + const int num_tokens,\ + const int hidden_size) {\ + __shared__ float s_variance;\ + float variance = 0.0f;\ + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {\ + const float x = (float) input[blockIdx.x * hidden_size + idx];\ + variance += x * x;\ + }\ + variance = blockReduceSum(variance);\ + if (threadIdx.x == 0) {\ + s_variance = rsqrtf(variance / hidden_size + epsilon);\ + }\ + __syncthreads();\ + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {\ + float x = (float) input[blockIdx.x * hidden_size + idx];\ + out[blockIdx.x * hidden_size + idx] = ((TYPENAME) (x * s_variance)) * weight[idx];\ + }\ +}\ + +RMS_NORM_OP(rms_norm_f32, float) +RMS_NORM_OP(rms_norm_f16, __half) + +#if __CUDA_ARCH__ >= 800 +#include +RMS_NORM_OP(rms_norm_bf16, __nv_bfloat16) +#endif \ No newline at end of file diff --git a/candle-kernels/src/fused_rope.cu b/candle-kernels/src/fused_rope.cu new file mode 100644 index 0000000000..9f7873cca7 --- /dev/null +++ b/candle-kernels/src/fused_rope.cu @@ -0,0 +1,231 @@ +#include "cuda_fp16.h" + +#ifndef USE_ROCM + #define LDG(arg) __ldg(arg) +#else + #define LDG(arg) *arg +#endif + +template +inline __device__ void apply_token_rotary_embedding( + scalar_t* __restrict__ arr, + const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, + int rot_offset, + int embed_dim) +{ + int x_index, y_index; + scalar_t cos, sin; + if (IS_NEOX) { + // GPT-NeoX style rotary embedding. + x_index = rot_offset; + y_index = embed_dim + rot_offset; + cos = LDG(cos_ptr + x_index); + sin = LDG(sin_ptr + x_index); + } else { + // GPT-J style rotary embedding. + x_index = 2 * rot_offset; + y_index = 2 * rot_offset + 1; + cos = LDG(cos_ptr + x_index / 2); + sin = LDG(sin_ptr + x_index / 2); + } + + const scalar_t x = arr[x_index]; + const scalar_t y = arr[y_index]; + arr[x_index] = x * cos - y * sin; + arr[y_index] = y * cos + x * sin; +} + +template +inline __device__ void apply_rotary_embedding( + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const scalar_t* cache_ptr, + const int head_size, + const int num_heads, + const int num_kv_heads, + const int rot_dim, + const int token_idx, + const int64_t query_stride, + const int64_t key_stride) +{ + const int embed_dim = rot_dim / 2; + const scalar_t* cos_ptr = cache_ptr; + const scalar_t* sin_ptr = cache_ptr + embed_dim; + + const int nq = num_heads * embed_dim; + for (int i = threadIdx.x; i < nq; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * query_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_token_rotary_embedding(query + token_head, cos_ptr, + sin_ptr, rot_offset, embed_dim); + } + + const int nk = num_kv_heads * embed_dim; + for (int i = threadIdx.x; i < nk; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_token_rotary_embedding(key + token_head, cos_ptr, + sin_ptr, rot_offset, embed_dim); + } +} + +extern "C" __global__ void rotary_embedding_kernel_f32( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + float* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + float* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const float* __restrict__ cos_sin_cache, // [max_position, rot_dim] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const float* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); +} + +extern "C" __global__ void rotary_embedding_kernel_f16( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + __half* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + __half* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const __half* __restrict__ cos_sin_cache, // [max_position, rot_dim] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const __half* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding<__half, false>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); +} + +extern "C" __global__ void rotary_embedding_kernel_f64( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + double* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + double* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const double* __restrict__ cos_sin_cache, // [max_position, rot_dim] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const double* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); +} + + + + +extern "C" __global__ void rotary_embedding_kernel_neox_f32( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + float* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + float* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const float* __restrict__ cos_sin_cache, // [max_position, rot_dim] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const float* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); +} + +extern "C" __global__ void rotary_embedding_kernel_neox_f16( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + __half* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + __half* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const __half* __restrict__ cos_sin_cache, // [max_position, rot_dim] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const __half* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding<__half, true>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); +} + +extern "C" __global__ void rotary_embedding_kernel_neox_f64( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + double* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + double* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const double* __restrict__ cos_sin_cache, // [max_position, rot_dim] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const double* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); +} + +#if __CUDA_ARCH__ >= 800 +#include +extern "C" __global__ void rotary_embedding_kernel_bf16( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + __nv_bfloat16* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + __nv_bfloat16* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const __nv_bfloat16* __restrict__ cos_sin_cache, // [max_position, rot_dim] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const __nv_bfloat16* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding<__nv_bfloat16, false>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); +} + +extern "C" __global__ void rotary_embedding_kernel_neox_bf16( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + __nv_bfloat16* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + __nv_bfloat16* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const __nv_bfloat16* __restrict__ cos_sin_cache, // [max_position, rot_dim] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const __nv_bfloat16* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding<__nv_bfloat16, true>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); +} +#endif \ No newline at end of file diff --git a/candle-kernels/src/indexing.cu b/candle-kernels/src/indexing.cu index 8af2954d13..52846a04bf 100644 --- a/candle-kernels/src/indexing.cu +++ b/candle-kernels/src/indexing.cu @@ -99,6 +99,57 @@ __device__ void index_add( } } +#if __CUDA_ARCH__ >= 800 +#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) + +template +__device__ void scatter_add_f8( + const I *ids, + const __nv_fp8_e4m3 *inp, + __nv_fp8_e4m3 *out, + const size_t left_size, + const size_t src_dim_size, + const size_t dst_dim_size, + const size_t right_size +) { + const size_t numel = left_size * right_size; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + const size_t pre = i / right_size; + const size_t post = i % right_size; + for (unsigned int j = 0; j < src_dim_size; ++j) { + const size_t src_i = (pre * src_dim_size + j) * right_size + post; + const size_t idx = ids[src_i]; + const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post; + out[dst_i] = __nv_fp8_e4m3(F8E4M3_TO_FLOAT(out[dst_i]) + F8E4M3_TO_FLOAT(inp[src_i])); + } + } +} + +template +__device__ void index_add_f8( + const I *ids, + const size_t ids_dim_size, + const __nv_fp8_e4m3 *inp, + __nv_fp8_e4m3 *out, + const size_t left_size, + const size_t src_dim_size, + const size_t dst_dim_size, + const size_t right_size +) { + const size_t numel = left_size * right_size; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + const size_t pre = i / right_size; + const size_t post = i % right_size; + for (unsigned int j = 0; j < ids_dim_size; ++j) { + const size_t idx = ids[j]; + const size_t src_i = (pre * ids_dim_size + j) * right_size + post; + const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post; + out[dst_i] = __nv_fp8_e4m3(F8E4M3_TO_FLOAT(out[dst_i]) + F8E4M3_TO_FLOAT(inp[src_i])); + } + } +} +#endif + #define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \ extern "C" __global__ void FN_NAME( \ const INDEX_TYPENAME *ids, \ @@ -111,6 +162,18 @@ extern "C" __global__ void FN_NAME( \ const size_t right_size \ ) { index_add(ids, ids_dim_size, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \ +#define IA_OP_F8(TYPENAME, INDEX_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const INDEX_TYPENAME *ids, \ + const size_t ids_dim_size, \ + const TYPENAME *inp, \ + TYPENAME *out, \ + const size_t left_size, \ + const size_t src_dim_size, \ + const size_t dst_dim_size, \ + const size_t right_size \ +) { index_add_f8(ids, ids_dim_size, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \ + template __device__ void scatter_add( const I *ids, @@ -145,46 +208,114 @@ extern "C" __global__ void FN_NAME( \ const size_t right_size \ ) { scatter_add(ids, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \ +#define SA_OP_F8(TYPENAME, INDEX_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const INDEX_TYPENAME *ids, \ + const TYPENAME *inp, \ + TYPENAME *out, \ + const size_t left_size, \ + const size_t src_dim_size, \ + const size_t dst_dim_size, \ + const size_t right_size \ +) { scatter_add_f8(ids, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \ + #if __CUDA_ARCH__ >= 800 +IS_OP(__nv_bfloat16, int16_t, is_i16_bf16) +IS_OP(__nv_bfloat16, int32_t, is_i32_bf16) IS_OP(__nv_bfloat16, int64_t, is_i64_bf16) IS_OP(__nv_bfloat16, uint32_t, is_u32_bf16) IS_OP(__nv_bfloat16, uint8_t, is_u8_bf16) +GATHER_OP(__nv_bfloat16, int16_t, gather_i16_bf16) +GATHER_OP(__nv_bfloat16, int32_t, gather_i32_bf16) GATHER_OP(__nv_bfloat16, int64_t, gather_i64_bf16) GATHER_OP(__nv_bfloat16, uint32_t, gather_u32_bf16) GATHER_OP(__nv_bfloat16, uint8_t, gather_u8_bf16) +IA_OP(__nv_bfloat16, int16_t, ia_i16_bf16) +IA_OP(__nv_bfloat16, int32_t, ia_i32_bf16) IA_OP(__nv_bfloat16, int64_t, ia_i64_bf16) IA_OP(__nv_bfloat16, uint32_t, ia_u32_bf16) IA_OP(__nv_bfloat16, uint8_t, ia_u8_bf16) +SA_OP(__nv_bfloat16, int16_t, sa_i16_bf16) +SA_OP(__nv_bfloat16, int32_t, sa_i32_bf16) SA_OP(__nv_bfloat16, int64_t, sa_i64_bf16) SA_OP(__nv_bfloat16, uint32_t, sa_u32_bf16) SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16) + +IS_OP(__nv_fp8_e4m3, int16_t, is_i16_f8_e4m3) +IS_OP(__nv_fp8_e4m3, int32_t, is_i32_f8_e4m3) +IS_OP(__nv_fp8_e4m3, int64_t, is_i64_f8_e4m3) +IS_OP(__nv_fp8_e4m3, uint32_t, is_u32_f8_e4m3) +IS_OP(__nv_fp8_e4m3, uint8_t, is_u8_f8_e4m3) +GATHER_OP(__nv_fp8_e4m3, int16_t, gather_i16_f8_e4m3) +GATHER_OP(__nv_fp8_e4m3, int32_t, gather_i32_f8_e4m3) +GATHER_OP(__nv_fp8_e4m3, int64_t, gather_i64_f8_e4m3) +GATHER_OP(__nv_fp8_e4m3, uint32_t, gather_u32_f8_e4m3) +GATHER_OP(__nv_fp8_e4m3, uint8_t, gather_u8_f8_e4m3) +IA_OP_F8(__nv_fp8_e4m3, int16_t, ia_i16_f8_e4m3) +IA_OP_F8(__nv_fp8_e4m3, int32_t, ia_i32_f8_e4m3) +IA_OP_F8(__nv_fp8_e4m3, int64_t, ia_i64_f8_e4m3) +IA_OP_F8(__nv_fp8_e4m3, uint32_t, ia_u32_f8_e4m3) +IA_OP_F8(__nv_fp8_e4m3, uint8_t, ia_u8_f8_e4m3) +SA_OP_F8(__nv_fp8_e4m3, int16_t, sa_i16_f8_e4m3) +SA_OP_F8(__nv_fp8_e4m3, int32_t, sa_i32_f8_e4m3) +SA_OP_F8(__nv_fp8_e4m3, int64_t, sa_i64_f8_e4m3) +SA_OP_F8(__nv_fp8_e4m3, uint32_t, sa_u32_f8_e4m3) +SA_OP_F8(__nv_fp8_e4m3, uint8_t, sa_u8_f8_e4m3) #endif #if __CUDA_ARCH__ >= 530 +IS_OP(__half, int16_t, is_i16_f16) +IS_OP(__half, int32_t, is_i32_f16) IS_OP(__half, int64_t, is_i64_f16) IS_OP(__half, uint32_t, is_u32_f16) IS_OP(__half, uint8_t, is_u8_f16) +GATHER_OP(__half, int16_t, gather_i16_f16) +GATHER_OP(__half, int32_t, gather_i32_f16) GATHER_OP(__half, int64_t, gather_i64_f16) GATHER_OP(__half, uint32_t, gather_u32_f16) GATHER_OP(__half, uint8_t, gather_u8_f16) +IA_OP(__half, int16_t, ia_i16_f16) +IA_OP(__half, int32_t, ia_i32_f16) IA_OP(__half, int64_t, ia_i64_f16) IA_OP(__half, uint32_t, ia_u32_f16) IA_OP(__half, uint8_t, ia_u8_f16) +SA_OP(__half, int16_t, sa_i16_f16) +SA_OP(__half, int32_t, sa_i32_f16) SA_OP(__half, int64_t, sa_i64_f16) SA_OP(__half, uint32_t, sa_u32_f16) SA_OP(__half, uint8_t, sa_u8_f16) #endif +IS_OP(float, int16_t, is_i16_f32) +IS_OP(double, int16_t, is_i16_f64) +IS_OP(uint8_t, int16_t, is_i16_u8) +IS_OP(uint32_t, int16_t, is_i16_u32) +IS_OP(int16_t, int16_t, is_i16_i16) +IS_OP(int32_t, int16_t, is_i16_i32) +IS_OP(int64_t, int16_t, is_i16_i64) + +IS_OP(float, int32_t, is_i32_f32) +IS_OP(double, int32_t, is_i32_f64) +IS_OP(uint8_t, int32_t, is_i32_u8) +IS_OP(uint32_t, int32_t, is_i32_u32) +IS_OP(int16_t, int32_t, is_i32_i16) +IS_OP(int32_t, int32_t, is_i32_i32) +IS_OP(int64_t, int32_t, is_i32_i64) + IS_OP(float, int64_t, is_i64_f32) IS_OP(double, int64_t, is_i64_f64) IS_OP(uint8_t, int64_t, is_i64_u8) IS_OP(uint32_t, int64_t, is_i64_u32) IS_OP(int64_t, int64_t, is_i64_i64) +IS_OP(int32_t, int64_t, is_i64_i32) +IS_OP(int16_t, int64_t, is_i64_i16) IS_OP(float, uint32_t, is_u32_f32) IS_OP(double, uint32_t, is_u32_f64) IS_OP(uint8_t, uint32_t, is_u32_u8) +IS_OP(int16_t, uint32_t, is_u32_i16) +IS_OP(int32_t, uint32_t, is_u32_i32) IS_OP(int64_t, uint32_t, is_u32_i64) IS_OP(uint32_t, uint32_t, is_u32_u32) @@ -192,17 +323,39 @@ IS_OP(float, uint8_t, is_u8_f32) IS_OP(double, uint8_t, is_u8_f64) IS_OP(uint8_t, uint8_t, is_u8_u8) IS_OP(uint32_t, uint8_t, is_u8_u32) +IS_OP(int16_t, uint8_t, is_u8_i16) +IS_OP(int32_t, uint8_t, is_u8_i32) IS_OP(int64_t, uint8_t, is_u8_i64) +GATHER_OP(float, int16_t, gather_i16_f32) +GATHER_OP(double, int16_t, gather_i16_f64) +GATHER_OP(uint8_t, int16_t, gather_i16_u8) +GATHER_OP(uint32_t, int16_t, gather_i16_u32) +GATHER_OP(int16_t, int16_t, gather_i16_i16) +GATHER_OP(int32_t, int16_t, gather_i16_i32) +GATHER_OP(int64_t, int16_t, gather_i16_i64) + +GATHER_OP(float, int32_t, gather_i32_f32) +GATHER_OP(double, int32_t, gather_i32_f64) +GATHER_OP(uint8_t, int32_t, gather_i32_u8) +GATHER_OP(uint32_t, int32_t, gather_i32_u32) +GATHER_OP(int16_t, int32_t, gather_i32_i16) +GATHER_OP(int32_t, int32_t, gather_i32_i32) +GATHER_OP(int64_t, int32_t, gather_i32_i64) + GATHER_OP(float, int64_t, gather_i64_f32) GATHER_OP(double, int64_t, gather_i64_f64) GATHER_OP(uint8_t, int64_t, gather_i64_u8) GATHER_OP(uint32_t, int64_t, gather_i64_u32) GATHER_OP(int64_t, int64_t, gather_i64_i64) +GATHER_OP(int32_t, int64_t, gather_i64_i32) +GATHER_OP(int16_t, int64_t, gather_i64_i16) GATHER_OP(float, uint32_t, gather_u32_f32) GATHER_OP(double, uint32_t, gather_u32_f64) GATHER_OP(uint8_t, uint32_t, gather_u32_u8) +GATHER_OP(int16_t, uint32_t, gather_u32_i16) +GATHER_OP(int32_t, uint32_t, gather_u32_i32) GATHER_OP(int64_t, uint32_t, gather_u32_i64) GATHER_OP(uint32_t, uint32_t, gather_u32_u32) @@ -210,17 +363,35 @@ GATHER_OP(float, uint8_t, gather_u8_f32) GATHER_OP(double, uint8_t, gather_u8_f64) GATHER_OP(uint8_t, uint8_t, gather_u8_u8) GATHER_OP(uint32_t, uint8_t, gather_u8_u32) +GATHER_OP(int16_t, uint8_t, gather_u8_i16) +GATHER_OP(int32_t, uint8_t, gather_u8_i32) GATHER_OP(int64_t, uint8_t, gather_u8_i64) +IA_OP(float, int16_t, ia_i16_f32) +IA_OP(double, int16_t, ia_i16_f64) +IA_OP(uint8_t, int16_t, ia_i16_u8) +IA_OP(int16_t, int16_t, ia_i16_i16) +IA_OP(uint16_t, int16_t, ia_i16_u16) + +IA_OP(float, int32_t, ia_i32_f32) +IA_OP(double, int32_t, ia_i32_f64) +IA_OP(uint8_t, int32_t, ia_i32_u8) +IA_OP(int32_t, int32_t, ia_i32_i32) +IA_OP(uint32_t, int32_t, ia_i32_u32) + IA_OP(float, int64_t, ia_i64_f32) IA_OP(double, int64_t, ia_i64_f64) IA_OP(uint8_t, int64_t, ia_i64_u8) IA_OP(int64_t, int64_t, ia_i64_i64) IA_OP(uint32_t, int64_t, ia_i64_u32) +IA_OP(int32_t, int64_t, ia_i64_i32) +IA_OP(int16_t, int64_t, ia_i64_i16) IA_OP(float, uint32_t, ia_u32_f32) IA_OP(double, uint32_t, ia_u32_f64) IA_OP(uint8_t, uint32_t, ia_u32_u8) +IA_OP(int16_t, uint32_t, ia_u32_i16) +IA_OP(int32_t, uint32_t, ia_u32_i32) IA_OP(int64_t, uint32_t, ia_u32_i64) IA_OP(uint32_t, uint32_t, ia_u32_u32) @@ -228,17 +399,37 @@ IA_OP(float, uint8_t, ia_u8_f32) IA_OP(double, uint8_t, ia_u8_f64) IA_OP(uint8_t, uint8_t, ia_u8_u8) IA_OP(uint32_t, uint8_t, ia_u8_u32) +IA_OP(int16_t, uint8_t, ia_u8_i16) +IA_OP(int32_t, uint8_t, ia_u8_i32) IA_OP(int64_t, uint8_t, ia_u8_i64) +SA_OP(float, int16_t, sa_i16_f32) +SA_OP(double, int16_t, sa_i16_f64) +SA_OP(uint8_t, int16_t, sa_i16_u8) +SA_OP(int16_t, int16_t, sa_i16_i16) +SA_OP(int32_t, int16_t, sa_i16_i32) +SA_OP(uint32_t, int16_t, sa_i16_u32) + +SA_OP(float, int32_t, sa_i32_f32) +SA_OP(double, int32_t, sa_i32_f64) +SA_OP(uint8_t, int32_t, sa_i32_u8) +SA_OP(int16_t, int32_t, sa_i32_i16) +SA_OP(int32_t, int32_t, sa_i32_i32) +SA_OP(uint32_t, int32_t, sa_i32_u32) + SA_OP(float, int64_t, sa_i64_f32) SA_OP(double, int64_t, sa_i64_f64) SA_OP(uint8_t, int64_t, sa_i64_u8) +SA_OP(int16_t, int64_t, sa_i64_i16) +SA_OP(int32_t, int64_t, sa_i64_i32) SA_OP(int64_t, int64_t, sa_i64_i64) SA_OP(uint32_t, int64_t, sa_i64_u32) SA_OP(float, uint32_t, sa_u32_f32) SA_OP(double, uint32_t, sa_u32_f64) SA_OP(uint8_t, uint32_t, sa_u32_u8) +SA_OP(int16_t, uint32_t, sa_u32_i16) +SA_OP(int32_t, uint32_t, sa_u32_i32) SA_OP(int64_t, uint32_t, sa_u32_i64) SA_OP(uint32_t, uint32_t, sa_u32_u32) @@ -246,4 +437,6 @@ SA_OP(float, uint8_t, sa_u8_f32) SA_OP(double, uint8_t, sa_u8_f64) SA_OP(uint8_t, uint8_t, sa_u8_u8) SA_OP(uint32_t, uint8_t, sa_u8_u32) +SA_OP(int16_t, uint8_t, sa_u8_i16) +SA_OP(int32_t, uint8_t, sa_u8_i32) SA_OP(int64_t, uint8_t, sa_u8_i64) diff --git a/candle-kernels/src/kvconcat.cu b/candle-kernels/src/kvconcat.cu new file mode 100644 index 0000000000..2bbd6c53a0 --- /dev/null +++ b/candle-kernels/src/kvconcat.cu @@ -0,0 +1,54 @@ +#include "cuda_utils.cuh" +#include + +template +__device__ __forceinline__ void kvconcat_dim0_kernel(T *ltensor, T* rtensor, T *out, + const size_t chunk_l, const size_t chunk_r, const size_t lstride, const size_t rstride) { + size_t idx = GetThreadIdx(); + if (idx < chunk_l * lstride) { + out[idx] = ltensor[idx]; + } else { + out[idx] = rtensor[idx - chunk_l * lstride]; + } +} +template +__device__ __forceinline__ void kvconcat_dim2_kernel(T *ltensor, T* rtensor, T *out, + const size_t chunk_l, const size_t chunk_r, const size_t lstride, const size_t rstride) { + int thread_id = GetThreadIdx(); + int out_stride = lstride + rstride; + int idx = thread_id / out_stride; + int j = thread_id % out_stride; + T* pLeft = ltensor + idx * lstride; + T* pRight = rtensor + idx * rstride; + T* pOut = out + idx * out_stride; + if (idx < chunk_l) { + if (j < lstride) + pOut[j] = pLeft[j]; + else + pOut[j] = pRight[j - lstride]; + } +} + +#define KVCONCAT_OP(TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME(TYPENAME *ltensor, TYPENAME* rtensor, TYPENAME *out, const size_t concat_dim,\ + const size_t chunk_l, const size_t chunk_r, const size_t lstride, const size_t rstride) {\ + if (concat_dim == 2)\ + kvconcat_dim2_kernel(ltensor, rtensor, out, chunk_l, chunk_r, lstride, rstride);\ + else if (concat_dim == 0) {\ + if (blockIdx.x == 0 && threadIdx.x ==0) \ + kvconcat_dim0_kernel(ltensor, rtensor, out, chunk_l, chunk_r, lstride, rstride);\ + }\ +}\ + +KVCONCAT_OP(uint8_t, kvconcat_u8) +KVCONCAT_OP(double, kvconcat_f64) +KVCONCAT_OP(float, kvconcat_f32) + +#if __CUDA_ARCH__ >= 530 +KVCONCAT_OP(__half, kvconcat_f16) +#endif + +#if __CUDA_ARCH__ >= 800 +KVCONCAT_OP(__nv_bfloat16, kvconcat_bf16) +KVCONCAT_OP(__nv_fp8_e4m3, kvconcat_f8_e4m3) +#endif \ No newline at end of file diff --git a/candle-kernels/src/lib.rs b/candle-kernels/src/lib.rs index 1c73d6b774..0bb490ca1c 100644 --- a/candle-kernels/src/lib.rs +++ b/candle-kernels/src/lib.rs @@ -3,7 +3,10 @@ pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx")); pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx")); pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx")); pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx")); +pub const FUSED_RMS_NORM: &str = include_str!(concat!(env!("OUT_DIR"), "/fused_rms_norm.ptx")); +pub const FUSED_ROPE: &str = include_str!(concat!(env!("OUT_DIR"), "/fused_rope.ptx")); pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx")); +pub const KVCONCAT: &str = include_str!(concat!(env!("OUT_DIR"), "/kvconcat.ptx")); pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx")); pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx")); pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx")); diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index aaac24a146..f42cad471e 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -580,6 +580,14 @@ LAYERNORM_OP(__nv_bfloat16, layernorm_bf16) ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16) SUM_OP(__nv_bfloat16, sum_bf16) FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16) + +// NOTE: No reduce ops for f8 +// SUM_OP(__nv_fp8_e4m3, sum_fp8_e4m3) +// SOFTMAX_OP(__nv_fp8_e4m3, float, softmax_fp8_e4m3) +// RMSNORM_OP(__nv_fp8_e4m3, rmsnorm_fp8_e4m3) +// LAYERNORM_OP(__nv_fp8_e4m3, layernorm_fp8_e4m3) +// ROPE_OP(__nv_fp8_e4m3, rope_fp8_e4m3, rope_i_fp8_e4m3, rope_thd_fp8_e4m3) +// FAST_OP(__nv_fp8_e4m3, fast_min_fp8_e4m3, fast_max_fp8_e4m3, fast_argmin_fp8_e4m3, fast_argmax_fp8_e4m3, fast_sum_fp8_e4m3) #endif #if __CUDA_ARCH__ >= 530 @@ -606,5 +614,7 @@ ROPE_OP(double, rope_f64, rope_i_f64, rope_thd_f64) FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32) FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64) FAST_OP(uint32_t, fast_min_u32, fast_max_u32, fast_argmin_u32, fast_argmax_u32, fast_sum_u32) +FAST_OP(int16_t, fast_min_i16, fast_max_i16, fast_argmin_i16, fast_argmax_i16, fast_sum_i16) +FAST_OP(int32_t, fast_min_i32, fast_max_i32, fast_argmin_i32, fast_argmax_i32, fast_sum_i32) FAST_OP(int64_t, fast_min_i64, fast_max_i64, fast_argmin_i64, fast_argmax_i64, fast_sum_i64) FAST_OP(uint8_t, fast_min_u8, fast_max_u8, fast_argmin_u8, fast_argmax_u8, fast_sum_u8) diff --git a/candle-kernels/src/sort.cu b/candle-kernels/src/sort.cu index 08f1f9fc29..7db1b20ec5 100644 --- a/candle-kernels/src/sort.cu +++ b/candle-kernels/src/sort.cu @@ -75,6 +75,9 @@ extern "C" __global__ void asort_desc_##RUST_NAME( \ #if __CUDA_ARCH__ >= 800 ASORT_OP(__nv_bfloat16, bf16) + +// NOTE: No sort ops for f8 +// ASORT_OP(__nv_fp8_e4m3, fp8_e4m3) #endif #if __CUDA_ARCH__ >= 530 @@ -85,4 +88,6 @@ ASORT_OP(float, f32) ASORT_OP(double, f64) ASORT_OP(uint8_t, u8) ASORT_OP(uint32_t, u32) +ASORT_OP(int16_t, i16) +ASORT_OP(int32_t, i32) ASORT_OP(int64_t, i64) diff --git a/candle-kernels/src/ternary.cu b/candle-kernels/src/ternary.cu index aaa8a881fb..c426640b39 100644 --- a/candle-kernels/src/ternary.cu +++ b/candle-kernels/src/ternary.cu @@ -33,17 +33,41 @@ extern "C" __global__ void FN_NAME( \ } \ #if __CUDA_ARCH__ >= 800 +WHERE_OP(__nv_bfloat16, int16_t, where_i16_bf16) +WHERE_OP(__nv_bfloat16, int32_t, where_i32_bf16) WHERE_OP(__nv_bfloat16, int64_t, where_i64_bf16) WHERE_OP(__nv_bfloat16, uint32_t, where_u32_bf16) WHERE_OP(__nv_bfloat16, uint8_t, where_u8_bf16) + +WHERE_OP(__nv_fp8_e4m3, int16_t, where_i16_fp8_e4m3) +WHERE_OP(__nv_fp8_e4m3, int32_t, where_i32_fp8_e4m3) +WHERE_OP(__nv_fp8_e4m3, int64_t, where_i64_fp8_e4m3) +WHERE_OP(__nv_fp8_e4m3, uint32_t, where_u32_fp8_e4m3) +WHERE_OP(__nv_fp8_e4m3, uint8_t, where_u8_fp8_e4m3) #endif #if __CUDA_ARCH__ >= 530 +WHERE_OP(__half, int16_t, where_i16_f16) +WHERE_OP(__half, int32_t, where_i32_f16) WHERE_OP(__half, int64_t, where_i64_f16) WHERE_OP(__half, uint32_t, where_u32_f16) WHERE_OP(__half, uint8_t, where_u8_f16) #endif +WHERE_OP(float, int16_t, where_i16_f32) +WHERE_OP(double, int16_t, where_i16_f64) +WHERE_OP(uint8_t, int16_t, where_i16_u8) +WHERE_OP(uint32_t, int16_t, where_i16_u32) +WHERE_OP(int16_t, int16_t, where_i16_i16) +WHERE_OP(int32_t, int16_t, where_i16_i32) +WHERE_OP(int64_t, int16_t, where_i16_i64) + +WHERE_OP(float, int32_t, where_i32_f32) +WHERE_OP(double, int32_t, where_i32_f64) +WHERE_OP(uint8_t, int32_t, where_i32_u8) +WHERE_OP(uint32_t, int32_t, where_i32_u32) +WHERE_OP(int32_t, int32_t, where_i32_i64) + WHERE_OP(float, int64_t, where_i64_f32) WHERE_OP(double, int64_t, where_i64_f64) WHERE_OP(uint8_t, int64_t, where_i64_u8) @@ -54,10 +78,14 @@ WHERE_OP(float, uint32_t, where_u32_f32) WHERE_OP(double, uint32_t, where_u32_f64) WHERE_OP(uint8_t, uint32_t, where_u32_u8) WHERE_OP(uint32_t, uint32_t, where_u32_u32) +WHERE_OP(int16_t, uint32_t, where_u32_i16) +WHERE_OP(int32_t, uint32_t, where_u32_i32) WHERE_OP(int64_t, uint32_t, where_u32_i64) WHERE_OP(float, uint8_t, where_u8_f32) WHERE_OP(double, uint8_t, where_u8_f64) WHERE_OP(uint8_t, uint8_t, where_u8_u8) WHERE_OP(uint32_t, uint8_t, where_u8_u32) +WHERE_OP(int16_t, uint8_t, where_u8_i16) +WHERE_OP(int32_t, uint8_t, where_u8_i32) WHERE_OP(int64_t, uint8_t, where_u8_i64) diff --git a/candle-kernels/src/unary.cu b/candle-kernels/src/unary.cu index c82a88375d..ba899e643c 100644 --- a/candle-kernels/src/unary.cu +++ b/candle-kernels/src/unary.cu @@ -122,6 +122,33 @@ UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x)) UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param)) UNARY_OP(__nv_bfloat16, usign_bf16, sign_(x)) UNARY_OP(__nv_bfloat16, usigmoid_bf16, sigmoid_fwd(x)) + +#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) + +UNARY_OP(__nv_fp8_e4m3, ucopy_fp8_e4m3, x) +UNARY_OP(__nv_fp8_e4m3, uneg_fp8_e4m3, __nv_fp8_e4m3(-F8E4M3_TO_FLOAT(x))) +UNARY_OP(__nv_fp8_e4m3, urecip_fp8_e4m3, recipg(x)) +UNARY_OP(__nv_fp8_e4m3, uexp_fp8_e4m3, expg(x)) +UNARY_OP(__nv_fp8_e4m3, ulog_fp8_e4m3, logg(x)) +UNARY_OP(__nv_fp8_e4m3, usin_fp8_e4m3, sing(x)) +UNARY_OP(__nv_fp8_e4m3, ucos_fp8_e4m3, cosg(x)) +UNARY_OP(__nv_fp8_e4m3, utanh_fp8_e4m3, tanhg(x)) +UNARY_OP(__nv_fp8_e4m3, uerf_fp8_e4m3, erfg(x)) +UNARY_OP(__nv_fp8_e4m3, uceil_fp8_e4m3, ceilg(x)) +UNARY_OP(__nv_fp8_e4m3, ufloor_fp8_e4m3, floorg(x)) +UNARY_OP(__nv_fp8_e4m3, uround_fp8_e4m3, roundg(x)) +UNARY_OP(__nv_fp8_e4m3, unormcdf_fp8_e4m3, normcdfg(x)) +UNARY_OP(__nv_fp8_e4m3, uabs_fp8_e4m3, absg(x)) +UNARY_OP(__nv_fp8_e4m3, usqr_fp8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x)*F8E4M3_TO_FLOAT(x))) +UNARY_OP(__nv_fp8_e4m3, usqrt_fp8_e4m3, sqrtg(x)) +UNARY_OP(__nv_fp8_e4m3, ugelu_fp8_e4m3, __nv_fp8_e4m3(gelu_fwd(F8E4M3_TO_FLOAT(x)))) +UNARY_OP(__nv_fp8_e4m3, ugelu_erf_fp8_e4m3, __nv_fp8_e4m3(gelu_erf_fwd(F8E4M3_TO_FLOAT(x)))) +UNARY_OP(__nv_fp8_e4m3, urelu_fp8_e4m3, __nv_fp8_e4m3(relu_fwd(F8E4M3_TO_FLOAT(x)))) +UNARY_OP1(__nv_fp8_e4m3, uelu_fp8_e4m3, __nv_fp8_e4m3(elu_fwd(F8E4M3_TO_FLOAT(x), F8E4M3_TO_FLOAT(param)))) +UNARY_OP(__nv_fp8_e4m3, usilu_fp8_e4m3, __nv_fp8_e4m3(silu_fwd(F8E4M3_TO_FLOAT(x)))) +UNARY_OP1(__nv_fp8_e4m3, upowf_fp8_e4m3, powg(x, param)) +UNARY_OP(__nv_fp8_e4m3, usign_fp8_e4m3, __nv_fp8_e4m3(sign_(F8E4M3_TO_FLOAT(x)))) +UNARY_OP(__nv_fp8_e4m3, usigmoid_fp8_e4m3, __nv_fp8_e4m3(sigmoid_fwd(F8E4M3_TO_FLOAT(x)))) #endif #if __CUDA_ARCH__ >= 530 @@ -153,6 +180,8 @@ UNARY_OP(__half, usigmoid_f16, sigmoid_fwd(x)) UNARY_OP(uint8_t, ucopy_u8, x) UNARY_OP(uint32_t, ucopy_u32, x) +UNARY_OP(int16_t, ucopy_i16, x) +UNARY_OP(int32_t, ucopy_i32, x) UNARY_OP(int64_t, ucopy_i64, x) UNARY_OP(float, ucopy_f32, x) UNARY_OP(double, ucopy_f64, x) diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index e83498e40d..4c558c2cdb 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -58,13 +58,17 @@ kernel void FN_NAME_STRIDED( \ BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \ BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); \ BINARY(FN, uint32_t, uint32_t, NAME##_u32, NAME##_u32_strided); \ -BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); +BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); \ +BINARY(FN, int16_t, int16_t, NAME##_i16, NAME##_i16_strided); \ +BINARY(FN, int32_t, int32_t, NAME##_i32, NAME##_i32_strided); #define BINARY_OP_OUT(NAME, FN) \ BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \ BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); \ BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \ -BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); +BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); \ +BINARY(FN, int16_t, uint8_t, NAME##_i16, NAME##_i16_strided); \ +BINARY(FN, int32_t, uint8_t, NAME##_i32, NAME##_i32_strided); #define INT64_BINARY_OP(NAME, FN) \ BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided); diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index 2af3fdceb0..5a8324bf11 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -76,6 +76,8 @@ kernel void FN_NAME_STRIDED( \ CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float) CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t) CAST(cast_u32_f16, cast_u32_f16_strided, uint32_t, half) +CAST(cast_u32_i32, cast_u32_i32_strided, uint32_t, int32_t) +CAST(cast_u32_i16, cast_u32_i16_strided, uint32_t, int16_t) #if __METAL_VERSION__ >= 220 CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t) #endif @@ -87,6 +89,8 @@ CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat) CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t) CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float) CAST(cast_u8_f16, cast_u8_f16_strided, uint8_t, half) +CAST(cast_u8_i32, cast_u8_i32_strided, uint8_t, int64_t) +CAST(cast_u8_i16, cast_u8_i16_strided, uint8_t, int16_t) #if __METAL_VERSION__ >= 220 CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t) #endif @@ -98,6 +102,8 @@ CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat) CAST(cast_f16_f32, cast_f16_f32_strided, half, float) CAST(cast_f16_u8, cast_f16_u8_strided, half, uint8_t) CAST(cast_f16_u32, cast_f16_u32_strided, half, uint32_t) +CAST(cast_f16_i16, cast_f16_i16_strided, half, int16_t) +CAST(cast_f16_i32, cast_f16_i32_strided, half, int64_t) CAST(cast_f16_i64, cast_f16_i64_strided, half, int64_t) #if defined(__HAVE_BFLOAT__) CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float) @@ -107,15 +113,41 @@ CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float) CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float) CAST(cast_i64_u8, cast_i64_u8_strided, int64_t, uint8_t) CAST(cast_i64_u32, cast_i64_u32_strided, int64_t, uint32_t) +CAST(cast_i64_i32, cast_i64_i32_strided, int64_t, int32_t) +CAST(cast_i64_i16, cast_i64_i16_strided, int64_t, int16_t) CAST(cast_i64_f16, cast_i64_f16_strided, int64_t, half) #if defined(__HAVE_BFLOAT__) CAST_THROUGH(cast_i64_bf16, cast_i64_bf16_strided, int64_t, bfloat, float) #endif +// i32 +CAST(cast_i32_f32, cast_i32_f32_strided, int32_t, float) +CAST(cast_i32_u8, cast_i32_u8_strided, int32_t, uint8_t) +CAST(cast_i32_u32, cast_i32_u32_strided, int32_t, uint32_t) +CAST(cast_i32_i64, cast_i32_i64_strided, int32_t, int64_t) +CAST(cast_i32_i16, cast_i32_i16_strided, int32_t, int16_t) +CAST(cast_i32_f16, cast_i32_f16_strided, int32_t, half) +#if defined(__HAVE_BFLOAT__) +CAST_THROUGH(cast_i32_bf16, cast_i32_bf16_strided, int64_t, bfloat, float) +#endif + +// i16 +CAST(cast_i16_f32, cast_i16_f32_strided, int16_t, float) +CAST(cast_i16_u8, cast_i16_u8_strided, int16_t, uint8_t) +CAST(cast_i16_u32, cast_i16_u32_strided, int16_t, uint32_t) +CAST(cast_i16_i32, cast_i16_i32_strided, int16_t, int32_t) +CAST(cast_i16_i64, cast_i16_i64_strided, int16_t, int64_t) +CAST(cast_i16_f16, cast_i16_f16_strided, int16_t, half) +#if defined(__HAVE_BFLOAT__) +CAST_THROUGH(cast_i16_bf16, cast_i16_bf16_strided, int16_t, bfloat, float) +#endif + // f32 CAST(cast_f32_f16, cast_f32_f16_strided, float, half) CAST(cast_f32_u32, cast_f32_u32_strided, float, uint32_t) CAST(cast_f32_u8, cast_f32_u8_strided, float, uint8_t) +CAST(cast_f32_i16, cast_f32_i16_strided, float, int16_t) +CAST(cast_f32_i32, cast_f32_i32_strided, float, int32_t) CAST(cast_f32_i64, cast_f32_i64_strided, float, int64_t) #if defined(__HAVE_BFLOAT__) CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat) @@ -124,6 +156,8 @@ CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat) // bf16 #if defined(__HAVE_BFLOAT__) CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t) +CAST(cast_bf16_i16, cast_bf16_i16_strided, bfloat, int16_t) +CAST(cast_bf16_i32, cast_bf16_i32_strided, bfloat, int32_t) CAST(cast_bf16_i64, cast_bf16_i64_strided, bfloat, int64_t) CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float) CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float) diff --git a/candle-metal-kernels/src/fill.metal b/candle-metal-kernels/src/fill.metal index 35c3fe7ab2..7e99a8525d 100644 --- a/candle-metal-kernels/src/fill.metal +++ b/candle-metal-kernels/src/fill.metal @@ -33,6 +33,8 @@ FILL_OPS(u32, uint) FILL_OPS(i64, long) FILL_OPS(f16, half) FILL_OPS(f32, float) +FILL_OPS(i32, int) +FILL_OPS(i16, short) #if __METAL_VERSION__ >= 310 FILL_OPS(bf16, bfloat) diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 9eee97ca0a..f01d4795d8 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -193,6 +193,18 @@ INDEX_OP(is_i64_f16, int64_t, half) INDEX_OP(is_i64_bf16, int64_t, bfloat) #endif +INDEX_OP(is_i32_f32, int32_t, float) +INDEX_OP(is_i32_f16, int32_t, half) +#if defined(__HAVE_BFLOAT__) +INDEX_OP(is_i32_bf16, int32_t, bfloat) +#endif + +INDEX_OP(is_i16_f32, int16_t, float) +INDEX_OP(is_i16_f16, int16_t, half) +#if defined(__HAVE_BFLOAT__) +INDEX_OP(is_i16_bf16, int16_t, bfloat) +#endif + INDEX_OP(is_u32_f32, uint32_t, float) INDEX_OP(is_u32_f16, uint32_t, half) #if defined(__HAVE_BFLOAT__) @@ -213,9 +225,13 @@ GATHER_OP(gather_u32_bf16, uint, bfloat) SCATTER_ADD_OP(sa_u32_f32, uint32_t, float) SCATTER_ADD_OP(sa_u8_f32, uint8_t, float) +SCATTER_ADD_OP(sa_i16_f32, int16_t, float) +SCATTER_ADD_OP(sa_i32_f32, int32_t, float) SCATTER_ADD_OP(sa_i64_f32, int64_t, float) SCATTER_ADD_OP(sa_u32_f16, uint32_t, half) SCATTER_ADD_OP(sa_u8_f16, uint8_t, half) +SCATTER_ADD_OP(sa_i16_f16, int16_t, half) +SCATTER_ADD_OP(sa_i32_f16, int32_t, half) SCATTER_ADD_OP(sa_i64_f16, int64_t, half) #if defined(__HAVE_BFLOAT__) SCATTER_ADD_OP(sa_u32_bf16, uint32_t, bfloat) @@ -226,6 +242,8 @@ SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat) // i64 INDEX_ADD_OP(ia_i64_f16, int64_t, half) INDEX_ADD_OP(ia_i64_f32, int64_t, float) +INDEX_ADD_OP(ia_i64_i16, int64_t, int16_t) +INDEX_ADD_OP(ia_i64_i32, int64_t, int32_t) INDEX_ADD_OP(ia_i64_i64, int64_t, int64_t) INDEX_ADD_OP(ia_i64_u32, int64_t, uint32_t) INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t) @@ -233,9 +251,35 @@ INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t) INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat) #endif +// i32 +INDEX_ADD_OP(ia_i32_f16, int32_t, half) +INDEX_ADD_OP(ia_i32_f32, int32_t, float) +INDEX_ADD_OP(ia_i32_i64, int32_t, int64_t) +INDEX_ADD_OP(ia_i32_i32, int32_t, int32_t) +INDEX_ADD_OP(ia_i32_u32, int32_t, uint32_t) +INDEX_ADD_OP(ia_i32_u8, int32_t, uint8_t) +#if defined(__HAVE_BFLOAT__) +INDEX_ADD_OP(ia_i32_bf16, int32_t, bfloat) +#endif + +// i16 +INDEX_ADD_OP(ia_i16_f16, int16_t, half) +INDEX_ADD_OP(ia_i16_f32, int16_t, float) +INDEX_ADD_OP(ia_i16_i16, int16_t, int16_t) +INDEX_ADD_OP(ia_i16_i32, int16_t, int32_t) +INDEX_ADD_OP(ia_i16_i64, int16_t, int64_t) +INDEX_ADD_OP(ia_i16_u32, int16_t, uint32_t) +INDEX_ADD_OP(ia_i16_u8, int16_t, uint8_t) +#if defined(__HAVE_BFLOAT__) +INDEX_ADD_OP(ia_i16_bf16, int16_t, bfloat) +#endif + + // u32 INDEX_ADD_OP(ia_u32_f16, uint32_t, half) INDEX_ADD_OP(ia_u32_f32, uint32_t, float) +INDEX_ADD_OP(ia_u32_i16, uint32_t, int16_t) +INDEX_ADD_OP(ia_u32_i32, uint32_t, int32_t) INDEX_ADD_OP(ia_u32_i64, uint32_t, int64_t) INDEX_ADD_OP(ia_u32_u32, uint32_t, uint32_t) INDEX_ADD_OP(ia_u32_u8, uint32_t, uint8_t) @@ -246,6 +290,8 @@ INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat) // u8 INDEX_ADD_OP(ia_u8_f16, uint8_t, half) INDEX_ADD_OP(ia_u8_f32, uint8_t, float) +INDEX_ADD_OP(ia_u8_i16, uint8_t, int16_t) +INDEX_ADD_OP(ia_u8_i32, uint8_t, int32_t) INDEX_ADD_OP(ia_u8_i64, uint8_t, int64_t) INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t) INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index a270bb2888..cfec91988d 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -50,6 +50,8 @@ pub mod copy2d { pub const HALF: Kernel = Kernel("copy2d_f16"); pub const BFLOAT: Kernel = Kernel("copy2d_bf16"); pub const I64: Kernel = Kernel("copy2d_i64"); + pub const I32: Kernel = Kernel("copy2d_i32"); + pub const I16: Kernel = Kernel("copy2d_i16"); pub const U32: Kernel = Kernel("copy2d_u32"); pub const U8: Kernel = Kernel("copy2d_u8"); } @@ -66,6 +68,8 @@ macro_rules! ops{ pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16")); pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16")); pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64")); + pub const I32: Kernel = Kernel(concat!(stringify!($name), "_i32")); + pub const I16: Kernel = Kernel(concat!(stringify!($name), "_i16")); pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32")); pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8")); } @@ -76,6 +80,8 @@ macro_rules! ops{ pub const HALF: Kernel = Kernel("copy_f16"); pub const BFLOAT: Kernel = Kernel("copy_bf16"); pub const I64: Kernel = Kernel("copy_i64"); + pub const I32: Kernel = Kernel("copy_i32"); + pub const I16: Kernel = Kernel("copy_i16"); pub const U32: Kernel = Kernel("copy_u32"); pub const U8: Kernel = Kernel("copy_u8"); } @@ -90,6 +96,8 @@ macro_rules! ops{ pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_tiled")); pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_tiled")); pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_tiled")); + pub const I32: Kernel = Kernel(concat!(stringify!($name), "_i32_tiled")); + pub const I16: Kernel = Kernel(concat!(stringify!($name), "_i16_tiled")); pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_tiled")); pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_tiled")); } @@ -100,6 +108,8 @@ macro_rules! ops{ pub const HALF: Kernel = Kernel("copy_f16_tiled"); pub const BFLOAT: Kernel = Kernel("copy_bf16_tiled"); pub const I64: Kernel = Kernel("copy_i64_tiled"); + pub const I32: Kernel = Kernel("copy_i32_tiled"); + pub const I16: Kernel = Kernel("copy_i16_tiled"); pub const U32: Kernel = Kernel("copy_u32_tiled"); pub const U8: Kernel = Kernel("copy_u8_tiled"); } @@ -114,6 +124,8 @@ macro_rules! ops{ pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided")); pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided")); pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_strided")); + pub const I32: Kernel = Kernel(concat!(stringify!($name), "_i32_strided")); + pub const I16: Kernel = Kernel(concat!(stringify!($name), "_i16_strided")); pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_strided")); pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_strided")); } @@ -124,6 +136,8 @@ macro_rules! ops{ pub const HALF: Kernel = Kernel("copy_f16_strided"); pub const BFLOAT: Kernel = Kernel("copy_bf16_strided"); pub const I64: Kernel = Kernel("copy_i64_strided"); + pub const I32: Kernel = Kernel("copy_i32_strided"); + pub const I16: Kernel = Kernel("copy_i16_strided"); pub const U32: Kernel = Kernel("copy_u32_strided"); pub const U8: Kernel = Kernel("copy_u8_strided"); } @@ -1471,6 +1485,8 @@ pub fn call_gemm( rhs_offset: usize, rhs_buffer: &Buffer, output: &Buffer, + alpha: f32, + beta: f32, ) -> Result<(), MetalKernelError> { assert!(rhs_stride.len() >= 2); assert!(lhs_stride.len() >= 2); @@ -1505,8 +1521,6 @@ pub fn call_gemm( })?; }; let d_trans = false; - let alpha = 1.0f32; - let beta = 0.0f32; let batched = b > 1; let fused_activation = false; let fused_bias = false; @@ -1843,6 +1857,7 @@ pub enum GgmlDType { Q8K, F16, F32, + BF16, } #[allow(clippy::too_many_arguments)] @@ -1920,7 +1935,7 @@ pub fn call_quantized_matmul_mv_t( let align = 2; (nth0, nth1, align) } - GgmlDType::F16 | GgmlDType::Q8K => { + GgmlDType::F16 | GgmlDType::BF16 | GgmlDType::Q8K => { // Original implem uses rows let nth0 = 32; let nth1 = 1; @@ -1958,6 +1973,7 @@ pub fn call_quantized_matmul_mv_t( GgmlDType::Q6K => "kernel_mul_mv_q6_K_f32", GgmlDType::Q8K => "kernel_mul_mv_q8_K_f32", GgmlDType::F16 => "kernel_mul_mv_f16_f32", + GgmlDType::BF16 => "kernel_mul_mv_bf16_f32", GgmlDType::F32 => "kernel_mul_mv_f32_f32", }; diff --git a/candle-metal-kernels/src/quantized.metal b/candle-metal-kernels/src/quantized.metal index fef6ac54f8..162b7a2d19 100644 --- a/candle-metal-kernels/src/quantized.metal +++ b/candle-metal-kernels/src/quantized.metal @@ -1495,8 +1495,203 @@ kernel void kernel_mul_mv_f16_f32( kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); } +#if defined(__HAVE_BFLOAT__) +void kernel_mul_mv_bf16_f32_1row_impl( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int64_t im = tgpig.z; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + + device const bfloat* x = (device const bfloat*) (src0 + offset0); + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + if (ne00 < 128) { + for (int i = tiisg; i < ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } else { + device const bfloat4* x4 = (device const bfloat4*) x; + device const float4 * y4 = (device const float4 *) y; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k]; + } + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } +} + +[[host_name("kernel_mul_mv_bf16_f32_1row")]] +kernel void kernel_mul_mv_bf16_f32_1row( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + kernel_mul_mv_bf16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); +} +#endif + +#define N_BF16_F32 4 + +#if defined(__HAVE_BFLOAT__) +void kernel_mul_mv_bf16_f32_impl( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int64_t r0 = tgpig.x; + const int64_t rb = tgpig.y*N_BF16_F32; + const int64_t im = tgpig.z; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + + device const bfloat * x = (device const bfloat *) (src0 + offset0); + + if (ne00 < 128) { + for (int row = 0; row < N_BF16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + device const bfloat4 * x4 = (device const bfloat4 *)x; + for (int row = 0; row < N_BF16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + device const float4 * y4 = (device const float4 *) y; + + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} + +[[host_name("kernel_mul_mv_bf16_f32")]] +kernel void kernel_mul_mv_bf16_f32( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + kernel_mul_mv_bf16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); +} +#endif + +#if defined(__HAVE_BFLOAT__) // Assumes row size (ne00) is a multiple of 4 -kernel void kernel_mul_mv_f16_f32_l4( +kernel void kernel_mul_mv_bf16_f32_l4( device const char * src0, device const char * src1, device float * dst, @@ -1528,7 +1723,7 @@ kernel void kernel_mul_mv_f16_f32_l4( const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - device const half4 * x4 = (device const half4 *) (src0 + offset0); + device const bfloat4 * x4 = (device const bfloat4 *) (src0 + offset0); for (int r1 = 0; r1 < nrows; ++r1) { device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); @@ -1544,6 +1739,7 @@ kernel void kernel_mul_mv_f16_f32_l4( } } } +#endif kernel void kernel_alibi_f32( device const float * src0, diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index e009ca1d6a..56ef56f7e0 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -602,6 +602,18 @@ ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX) ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN) #endif +REDUCE(x + y, fast_sum_i32_strided, int32_t, 0) +REDUCE(MIN(x, y), fast_min_i32_strided, int32_t, INT_MAX) +REDUCE(MAX(x, y), fast_max_i32_strided, int32_t, INT_MIN) +ARGMIN(fast_argmin_i32_strided, int32_t, INT_MAX) +ARGMAX(fast_argmax_i32_strided, int32_t, INT_MIN) + +REDUCE(x + y, fast_sum_i16_strided, int16_t, 0) +REDUCE(MIN(x, y), fast_min_i16_strided, int16_t, INT_MAX) +REDUCE(MAX(x, y), fast_max_i16_strided, int16_t, INT_MIN) +ARGMIN(fast_argmin_i16_strided, int16_t, INT_MAX) +ARGMAX(fast_argmax_i16_strided, int16_t, INT_MIN) + #if defined(__HAVE_BFLOAT__) REDUCE(x + y, fast_sum_bf16, bfloat, 0) REDUCE(x + y, fast_sum_bf16_strided, half, 0) diff --git a/candle-metal-kernels/src/sort.metal b/candle-metal-kernels/src/sort.metal index d71ab82234..9f001d8fb6 100644 --- a/candle-metal-kernels/src/sort.metal +++ b/candle-metal-kernels/src/sort.metal @@ -88,6 +88,8 @@ ARGSORT(float, f32) ARGSORT(half, f16) ARGSORT(uint8_t, u8) ARGSORT(uint32_t, u32) +ARGSORT(int32_t, i32) +ARGSORT(int16_t, i16) #if __METAL_VERSION__ >= 220 ARGSORT(int64_t, i64) diff --git a/candle-metal-kernels/src/ternary.metal b/candle-metal-kernels/src/ternary.metal index fe04f2378f..98aacd0036 100644 --- a/candle-metal-kernels/src/ternary.metal +++ b/candle-metal-kernels/src/ternary.metal @@ -75,11 +75,40 @@ WHERE_OP(float, int64_t, where_i64_f32) WHERE_OP(uint8_t, int64_t, where_i64_u8) WHERE_OP(uint32_t, int64_t, where_i64_u32) WHERE_OP(int64_t, int64_t, where_i64_i64) +WHERE_OP(int64_t, int32_t, where_i64_i32) +WHERE_OP(int64_t, int16_t, where_i64_i16) #if defined(__HAVE_BFLOAT__) WHERE_OP(bfloat, int64_t, where_i64_bf16) #endif #endif +WHERE_OP(int64_t, uint8_t, where_u8_i32) +WHERE_OP(int64_t, uint32_t, where_u32_i32) + +WHERE_OP(half, int32_t, where_i32_f16) +WHERE_OP(float, int32_t, where_i32_f32) +WHERE_OP(uint8_t, int32_t, where_i32_u8) +WHERE_OP(uint32_t, int32_t, where_i32_u32) +WHERE_OP(int64_t, int32_t, where_i32_i64) +WHERE_OP(int32_t, int32_t, where_i32_i32) +#if defined(__HAVE_BFLOAT__) +WHERE_OP(bfloat, int32_t, where_i32_bf16) +#endif + +WHERE_OP(int64_t, uint8_t, where_u8_i16) +WHERE_OP(int64_t, uint32_t, where_u32_i16) + +WHERE_OP(half, int16_t, where_i16_f16) +WHERE_OP(float, int16_t, where_i16_f32) +WHERE_OP(uint8_t, int16_t, where_i16_u8) +WHERE_OP(uint32_t, int16_t, where_i16_u32) +WHERE_OP(int64_t, int16_t, where_i16_i64) +WHERE_OP(int32_t, int16_t, where_i16_i32) +WHERE_OP(int16_t, int16_t, where_i16_i16) +#if defined(__HAVE_BFLOAT__) +WHERE_OP(bfloat, int16_t, where_i16_bf16) +#endif + #if defined(__HAVE_BFLOAT__) WHERE_OP(bfloat, uint8_t, where_u8_bf16) WHERE_OP(bfloat, uint32_t, where_u32_bf16) diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index f37ab5bb9c..01b5a9184e 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1088,6 +1088,8 @@ fn run_gemm( rhs_offset, &rhs, &output, + 1., + 0., ) .unwrap(); command_buffer.commit(); diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index e3a18cfe91..ab4342ec3d 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -173,6 +173,12 @@ UNARY(id, int64_t, copy_i64, copy_i64_strided) COPY2D(copy2d_i64, int64_t) #endif +UNARY(id, int32_t, copy_i32, copy_i32_strided) +COPY2D(copy2d_i32, int32_t) + +UNARY(id, int16_t, copy_i16, copy_i16_strided) +COPY2D(copy2d_i16, int16_t) + #if defined(__HAVE_BFLOAT__) BFLOAT_UNARY_OP(cos) BFLOAT_UNARY_OP(sin) diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml index 9f0d56bdea..570edb48be 100644 --- a/candle-nn/Cargo.toml +++ b/candle-nn/Cargo.toml @@ -21,6 +21,7 @@ safetensors = { workspace = true } serde = { workspace = true } metal = { workspace = true, optional = true } candle-metal-kernels = { workspace = true, optional = true } +candle-flash-attn = { workspace = true, optional = true } [dev-dependencies] anyhow = { workspace = true } @@ -34,6 +35,7 @@ accelerate = ["dep:accelerate-src", "candle/accelerate"] cuda = ["candle/cuda"] mkl = ["dep:intel-mkl-src", "candle/mkl"] metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"] +flash-attn = ["cuda", "dep:candle-flash-attn"] [[bench]] name = "bench_main" diff --git a/candle-nn/benches/bench_main.rs b/candle-nn/benches/bench_main.rs index 4db1d35c0a..727479b5c9 100644 --- a/candle-nn/benches/bench_main.rs +++ b/candle-nn/benches/bench_main.rs @@ -1,4 +1,9 @@ mod benchmarks; use criterion::criterion_main; -criterion_main!(benchmarks::layer_norm::benches, benchmarks::conv::benches); +criterion_main!( + benchmarks::layer_norm::benches, + benchmarks::conv::benches, + benchmarks::attention::benches_fast, + benchmarks::attention::benches_naive +); diff --git a/candle-nn/benches/benchmarks/attention.rs b/candle-nn/benches/benchmarks/attention.rs new file mode 100644 index 0000000000..8aa479d319 --- /dev/null +++ b/candle-nn/benches/benchmarks/attention.rs @@ -0,0 +1,111 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use candle::{DType, Device, Tensor}; +use candle_nn::scaled_dot_product_attention; +use criterion::{black_box, criterion_group, Criterion, Throughput}; +use std::time::Instant; + +fn run_attention(q: &Tensor, k: &Tensor, v: &Tensor, m: &Tensor, s: f64) { + let att = (q + .contiguous() + .unwrap() + .matmul(&k.t().unwrap().contiguous().unwrap()) + .unwrap() + / s) + .unwrap(); + + let att = att.broadcast_add(m).unwrap(); + + let att = candle_nn::ops::softmax_last_dim(&att).unwrap(); + // Convert to contiguous as matmul doesn't support strided vs for now. + att.matmul(&v.contiguous().unwrap()).unwrap(); +} + +fn run_bench_naive(c: &mut Criterion, device: &Device) { + let b = 4; + let seq = 1024; + let heads = 32; + let hd = 128; + + let dtype = DType::F32; + let q = Tensor::zeros((b, heads, seq, hd), dtype, device).unwrap(); + let k = Tensor::zeros((b, heads, seq, hd), dtype, device).unwrap(); + let v = Tensor::zeros((b, heads, seq, hd), dtype, device).unwrap(); + let m = Tensor::zeros((b, heads, seq, seq), dtype, device).unwrap(); + + let flops = b * seq * heads * hd; + + let mut group = c.benchmark_group(device.bench_name("attention_naive")); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run_attention( + black_box(&q), + black_box(&k), + black_box(&v), + black_box(&m), + 0.3, + ); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark_naive(c: &mut Criterion) { + let handler = BenchDeviceHandler::new().unwrap(); + for device in handler.devices { + run_bench_naive(c, &device); + } +} + +fn run_bench_fast(c: &mut Criterion, device: &Device) { + let b = 4; + let seq = 1024; + let heads = 32; + let hd = 128; + + let dtype = DType::F32; + let q = Tensor::zeros((b, heads, seq, hd), dtype, device).unwrap(); + let k = Tensor::zeros((b, heads, seq, hd), dtype, device).unwrap(); + let v = Tensor::zeros((b, heads, seq, hd), dtype, device).unwrap(); + let m = Tensor::zeros((b, heads, seq, seq), dtype, device).unwrap(); + + let flops = b * seq * heads * hd; + + let mut group = c.benchmark_group(device.bench_name("attention_fast")); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + let _ = scaled_dot_product_attention( + black_box(&q), + black_box(&k), + black_box(&v), + 0.3, + Some(black_box(&m)), + false, + seq, + ) + .unwrap(); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark_fast(c: &mut Criterion) { + let handler = BenchDeviceHandler::new().unwrap(); + for device in handler.devices { + run_bench_fast(c, &device); + } +} + +criterion_group!(benches_naive, criterion_benchmark_naive); +criterion_group!(benches_fast, criterion_benchmark_fast); diff --git a/candle-nn/benches/benchmarks/mod.rs b/candle-nn/benches/benchmarks/mod.rs index 30a6ab6a2b..8c60df2ee5 100644 --- a/candle-nn/benches/benchmarks/mod.rs +++ b/candle-nn/benches/benchmarks/mod.rs @@ -1,3 +1,4 @@ +pub(crate) mod attention; pub(crate) mod conv; pub(crate) mod layer_norm; @@ -15,7 +16,10 @@ impl BenchDevice for Device { Device::Cpu => Ok(()), Device::Cuda(device) => { #[cfg(feature = "cuda")] - return Ok(device.synchronize()?); + { + use candle::cuda::WrapErr; + return Ok(device.synchronize().w()?); + } #[cfg(not(feature = "cuda"))] panic!("Cuda device without cuda feature enabled: {:?}", device) } diff --git a/candle-nn/src/attention.rs b/candle-nn/src/attention.rs new file mode 100644 index 0000000000..5b8de4388d --- /dev/null +++ b/candle-nn/src/attention.rs @@ -0,0 +1,63 @@ +use candle::{Result, Tensor}; + +#[cfg(feature = "flash-attn")] +pub fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result { + candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) +} + +#[cfg(not(feature = "flash-attn"))] +pub fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result { + unimplemented!("Compile with '--features flash-attn'") +} + +/// Computes (softmax(QK^T*sqrt(d_k)) + M)V. `M` is the attention mask, and is a bias (0 for unmasked, -inf for masked). +/// +/// The attention implementation is automatically accelerated and dispatched as follows: +/// 1) If `use_flash_attn == true`, use a Flash Attention V2 kernel +/// 2) Otherwise, use SDPA with fusion of softmax scale and attention bias application +/// +/// Note that there may be minute differences in output because floating point operations are not associative. +#[allow(unused_variables, clippy::too_many_arguments)] +pub fn scaled_dot_product_attention( + q: &Tensor, + k: &Tensor, + v: &Tensor, + scale: f64, + mask: Option<&Tensor>, + use_flash_attn: bool, + seq_len: usize, +) -> Result { + if use_flash_attn { + // flash-attn expects (b_sz, seq_len, nheads, head_dim) + let q = q.transpose(1, 2)?; + let k = k.transpose(1, 2)?; + let v = v.transpose(1, 2)?; + return flash_attn(&q, &k, &v, scale as f32, seq_len > 1)?.transpose(1, 2); + } + + let att = match mask { + Some(mask) => { + let (b, n, s, _h) = q.dims4()?; + let mut mask_and_output = mask.broadcast_as((b, n, s, s))?.contiguous()?; + q.contiguous()?.matmul_with_alpha_beta( + &k.t()?.contiguous()?, + &mut mask_and_output, + Some(scale), + )?; + mask_and_output + } + None => q + .contiguous()? + .matmul_with_alpha(&k.t()?.contiguous()?, Some(scale))?, + }; + + let att = crate::ops::softmax_last_dim(&att)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + att.matmul(&v.contiguous()?) +} diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs index b7dd61cba1..d38e64e582 100644 --- a/candle-nn/src/layer_norm.rs +++ b/candle-nn/src/layer_norm.rs @@ -28,6 +28,21 @@ //! ``` //! //! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450 + +use std::marker::PhantomData; + +#[cfg(feature = "cuda")] +use candle::cuda_backend::{ + cudarc::driver::{DeviceRepr, LaunchAsync, LaunchConfig}, + kernel_name, kernels, CudaDType, WrapErr, +}; + +#[cfg(feature = "cuda")] +use candle::{ + backend::BackendStorage, from_storage_no_op, CudaDevice, CudaStorage, Device, Storage, + WithDType, +}; + use candle::{DType, Module, Result, Tensor, D}; #[derive(Debug, Clone, Copy, PartialEq)] @@ -63,7 +78,7 @@ impl From for LayerNormConfig { #[derive(Clone, Debug)] pub struct LayerNorm { weight: Tensor, - bias: Option, + bias: Tensor, remove_mean: bool, eps: f64, } @@ -72,7 +87,7 @@ impl LayerNorm { pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self { Self { weight, - bias: Some(bias), + bias, remove_mean: true, eps, } @@ -80,8 +95,8 @@ impl LayerNorm { pub fn new_no_bias(weight: Tensor, eps: f64) -> Self { Self { - weight, - bias: None, + weight: weight.clone(), + bias: Tensor::zeros_like(&weight).unwrap(), remove_mean: true, eps, } @@ -89,8 +104,8 @@ impl LayerNorm { pub fn rms_norm(weight: Tensor, eps: f64) -> Self { Self { - weight, - bias: None, + weight: weight.clone(), + bias: Tensor::zeros_like(&weight).unwrap(), remove_mean: false, eps, } @@ -100,17 +115,15 @@ impl LayerNorm { &self.weight } - pub fn bias(&self) -> Option<&Tensor> { - self.bias.as_ref() + pub fn bias(&self) -> &Tensor { + &self.bias } } impl Module for LayerNorm { fn forward(&self, x: &Tensor) -> Result { if x.is_contiguous() && self.remove_mean { - if let Some(bias) = self.bias.as_ref() { - return crate::ops::layer_norm(x, &self.weight, bias, self.eps as f32); - } + return crate::ops::layer_norm(x, &self.weight, &self.bias, self.eps as f32); } let x_dtype = x.dtype(); let internal_dtype = match x_dtype { @@ -128,10 +141,7 @@ impl Module for LayerNorm { let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; let x = x_normed.to_dtype(x_dtype)?.broadcast_mul(&self.weight)?; - match &self.bias { - None => Ok(x), - Some(bias) => x.broadcast_add(bias), - } + x.broadcast_add(&self.bias) } } @@ -148,47 +158,182 @@ pub fn layer_norm>( None }; Ok(LayerNorm { - weight, - bias, + weight: weight.clone(), + bias: bias.unwrap_or(Tensor::zeros_like(&weight)?), remove_mean: config.remove_mean, eps: config.eps, }) } +// This whole non quantized/quantized RmsNorm is a hack. It seems like quantized works without this impl, but it is slower. +#[derive(Clone, Debug)] +pub struct RmsNormQuantized; +#[derive(Clone, Debug)] +pub struct RmsNormNonQuantized; + /// RmsNorm is a specialized version of the LayerNorm module. #[derive(Clone, Debug)] -pub struct RmsNorm(LayerNorm); +pub struct RmsNorm { + inner: LayerNorm, + _ghost: PhantomData, +} + +impl RmsNorm { + pub fn new(weight: Tensor, eps: f64) -> Self { + Self { + inner: LayerNorm::rms_norm(weight, eps), + _ghost: PhantomData, + } + } +} -impl RmsNorm { +impl RmsNorm { pub fn new(weight: Tensor, eps: f64) -> Self { - Self(LayerNorm::rms_norm(weight, eps)) + Self { + inner: LayerNorm::rms_norm(weight, eps), + _ghost: PhantomData, + } + } + + #[cfg(feature = "cuda")] + fn dtype_execute_rmsnorm( + &self, + dev: &CudaDevice, + eps_converter: F, + x_storage: &CudaStorage, + weight_storage: &CudaStorage, + x: &Tensor, + ) -> Result + where + F: FnOnce(f64) -> T, + { + assert!(x.layout().is_contiguous()); + let hidden_size = *x.dims().last().unwrap(); + let elem_count = x.elem_count(); + let num_tokens = elem_count / hidden_size; + let out = unsafe { dev.alloc::(elem_count) }.w()?; + + let cfg = LaunchConfig { + grid_dim: (num_tokens as u32, 1, 1), + block_dim: (u32::min(hidden_size as u32, 1024), 1, 1), + shared_mem_bytes: 0, + }; + + let func = dev.get_or_load_func(&kernel_name::("rms_norm"), kernels::FUSED_RMS_NORM)?; + + let params = ( + &out, + x_storage.as_cuda_slice::()?, + weight_storage.as_cuda_slice::()?, + eps_converter(self.inner.eps), + num_tokens as i32, + hidden_size as i32, + ); + unsafe { func.launch(cfg, params) }.w()?; + + Ok(from_storage_no_op( + Storage::Cuda(CudaStorage::wrap_cuda_slice(out, dev.clone())), + x.shape(), + false, + )) } + #[cfg(feature = "cuda")] + fn fused_rmsnorm(&self, x: &Tensor, dev: &CudaDevice) -> Result { + match ( + &*x.storage_and_layout().0, + &*self.inner.weight().storage_and_layout().0, + ) { + (Storage::Cuda(x_storage), Storage::Cuda(weight_storage)) => { + match (x_storage.dtype(), weight_storage.dtype()) { + (DType::BF16, DType::BF16) => self.dtype_execute_rmsnorm::( + dev, + |x| half::bf16::from_f64(x), + &x_storage, + &weight_storage, + x, + ), + (DType::F16, DType::F16) => self.dtype_execute_rmsnorm::( + dev, + |x| half::f16::from_f64(x), + &x_storage, + &weight_storage, + x, + ), + (DType::F32, DType::F32) => self.dtype_execute_rmsnorm::( + dev, + |x| x as f32, + &x_storage, + &weight_storage, + x, + ), + _ => candle::bail!("DType mismatch in fused rmsnorm."), + } + } + _ => unreachable!(), + } + } +} + +impl RmsNorm { pub fn into_inner(self) -> LayerNorm { - self.0 + self.inner + } + pub fn inner(&self) -> &LayerNorm { + &self.inner } +} - /// Faster variant of the forward kernel, this can only be used on contiguous tensors though. - pub fn forward_diff(&self, xs: &Tensor) -> Result { - self.0.forward(xs) +impl Module for RmsNorm { + fn forward(&self, xs: &Tensor) -> Result { + self.inner.forward(xs) } } -impl Module for RmsNorm { +impl Module for RmsNorm { fn forward(&self, xs: &Tensor) -> Result { - if xs.is_contiguous() { - crate::ops::rms_norm(xs, &self.0.weight, self.0.eps as f32) - } else { - self.0.forward(xs) + #[cfg(feature = "cuda")] + match (xs.dtype(), xs.device()) { + (DType::BF16, Device::Cuda(dev)) + | (DType::F32, Device::Cuda(dev)) + | (DType::F16, Device::Cuda(dev)) => return self.fused_rmsnorm(xs, &dev), + _ => return self.inner.forward(xs), + } + #[cfg(not(feature = "cuda"))] + { + self.inner.forward(xs) } } } -pub fn rms_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result { +pub fn rms_norm_non_quant( + size: usize, + eps: f64, + vb: crate::VarBuilder, +) -> Result> { + let config = LayerNormConfig { + eps, + remove_mean: false, + affine: false, + }; + Ok(RmsNorm { + inner: layer_norm(size, config, vb)?, + _ghost: PhantomData, + }) +} + +pub fn rms_norm_quant( + size: usize, + eps: f64, + vb: crate::VarBuilder, +) -> Result> { let config = LayerNormConfig { eps, remove_mean: false, affine: false, }; - Ok(RmsNorm(layer_norm(size, config, vb)?)) + Ok(RmsNorm { + inner: layer_norm(size, config, vb)?, + _ghost: PhantomData, + }) } diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index fcac58308c..037304e0a8 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -1,4 +1,5 @@ pub mod activation; +pub mod attention; pub mod batch_norm; pub mod conv; pub mod embedding; @@ -13,12 +14,14 @@ pub mod loss; pub mod ops; pub mod optim; pub mod rnn; +pub mod rope; pub mod rotary_emb; pub mod sequential; pub mod var_builder; pub mod var_map; pub use activation::{prelu, Activation, PReLU}; +pub use attention::scaled_dot_product_attention; pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig}; pub use conv::{ conv1d, conv1d_no_bias, conv2d, conv2d_no_bias, conv_transpose1d, conv_transpose1d_no_bias, @@ -29,11 +32,14 @@ pub use embedding::{embedding, Embedding}; pub use func::{func, func_t, Func, FuncT}; pub use group_norm::{group_norm, GroupNorm}; pub use init::Init; -pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm}; +pub use layer_norm::{ + layer_norm, rms_norm_non_quant, rms_norm_quant, LayerNorm, LayerNormConfig, RmsNorm, +}; pub use linear::{linear, linear_b, linear_no_bias, Linear}; -pub use ops::Dropout; +pub use ops::{kvconcat, Dropout}; pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD}; pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN}; +pub use rope::RotaryEmbedding; pub use sequential::{seq, Sequential}; pub use var_builder::VarBuilder; pub use var_map::VarMap; diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 9a360c472c..beb771aaf9 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -927,6 +927,33 @@ pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result { } } +#[cfg(feature = "cuda")] +pub fn kvconcat(ltensor: &Tensor, rtensor: &Tensor, concat_dim: usize) -> Result { + if !ltensor.device().is_cuda() { + return Tensor::cat(&[ltensor, &rtensor], concat_dim as usize)?.contiguous(); + } + use candle::cuda_backend::KVConcat; + let op = KVConcat { concat_dim }; + //inputs for kvconcat must be contiguous tensors + if ltensor.is_contiguous() && rtensor.is_contiguous() { + ltensor.apply_op2(&rtensor, op) + } else if ltensor.is_contiguous() { + ltensor.apply_op2(&rtensor.contiguous()?, op) + } else if rtensor.is_contiguous() { + let ltensor = ltensor.contiguous()?; + ltensor.apply_op2(&rtensor, op) + } else { + let ltensor = ltensor.contiguous()?; + let rtensor = rtensor.contiguous()?; + ltensor.apply_op2(&rtensor, op) + } +} + +#[cfg(not(feature = "cuda"))] +pub fn kvconcat(ltensor: &Tensor, rtensor: &Tensor, concat_dim: i32) -> Result { + Tensor::cat(&[ltensor, rtensor], concat_dim as usize)?.contiguous() +} + #[derive(Clone, Debug)] pub struct Identity; diff --git a/candle-nn/src/rope.rs b/candle-nn/src/rope.rs new file mode 100644 index 0000000000..2f3af072af --- /dev/null +++ b/candle-nn/src/rope.rs @@ -0,0 +1,328 @@ +use std::iter::zip; + +#[allow(unused_imports)] +use candle::{ + backend::BackendStorage, CudaDevice, CudaStorage, DType, Device, IndexOp, Module, Result, + Storage, Tensor, WithDType, D, +}; + +#[cfg(feature = "cuda")] +use candle::cuda_backend::{ + cudarc::driver::{DeviceRepr, LaunchAsync, LaunchConfig}, + kernel_name, kernels, CudaDType, +}; + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct RotaryEmbedding { + cos: Tensor, + sin: Tensor, + head_size: usize, + cache: Tensor, + is_gpt_neox: bool, +} + +impl RotaryEmbedding { + pub fn new( + base: f32, + head_dim: usize, + max_position_embeddings: usize, + device: &Device, + is_gpt_neox: bool, + dtype: DType, + ) -> Result { + let theta: Vec<_> = (0..head_dim) + .step_by(2) + .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32)) + .collect(); + let theta_len = theta.len(); + let theta = Tensor::from_vec(theta, (1, theta_len), device)?.to_dtype(DType::F32)?; + let idx_theta = Tensor::arange(0, max_position_embeddings as u32, device)? + .to_dtype(DType::F32)? + .reshape((max_position_embeddings, 1))? + .matmul(&theta)?; + let cos = idx_theta.cos()?; + let sin = idx_theta.sin()?; + Ok(Self { + head_size: head_dim, + cos: if is_gpt_neox { + Tensor::cat( + &[cos.clone().to_dtype(dtype)?, cos.clone().to_dtype(dtype)?], + D::Minus1, + )? + } else { + cos.clone().to_dtype(dtype)? + }, + sin: if is_gpt_neox { + Tensor::cat( + &[sin.clone().to_dtype(dtype)?, sin.clone().to_dtype(dtype)?], + D::Minus1, + )? + } else { + sin.clone().to_dtype(dtype)? + }, + cache: Tensor::cat(&[cos.clone(), sin.clone()], D::Minus1)? + .contiguous()? + .to_dtype(dtype)?, + is_gpt_neox, + }) + } + + pub fn new_partial( + base: f32, + head_dim: usize, + rot_dim: usize, + max_position_embeddings: usize, + device: &Device, + is_gpt_neox: bool, + dtype: DType, + ) -> Result { + let theta: Vec<_> = (0..rot_dim) + .step_by(2) + .map(|i| 1f32 / base.powf(i as f32 / rot_dim as f32)) + .collect(); + let theta_len = theta.len(); + let theta = Tensor::from_vec(theta, (1, theta_len), device)?.to_dtype(DType::F32)?; + let idx_theta = Tensor::arange(0, max_position_embeddings as u32, device)? + .to_dtype(DType::F32)? + .reshape((max_position_embeddings, 1))? + .matmul(&theta)?; + let cos = idx_theta.cos()?; + let sin = idx_theta.sin()?; + Ok(Self { + head_size: head_dim, + cos: if is_gpt_neox { + Tensor::cat( + &[cos.clone().to_dtype(dtype)?, cos.clone().to_dtype(dtype)?], + D::Minus1, + )? + } else { + cos.clone().to_dtype(dtype)? + }, + sin: if is_gpt_neox { + Tensor::cat( + &[sin.clone().to_dtype(dtype)?, sin.clone().to_dtype(dtype)?], + D::Minus1, + )? + } else { + sin.clone().to_dtype(dtype)? + }, + cache: Tensor::cat(&[cos.clone(), sin.clone()], D::Minus1)? + .contiguous()? + .to_dtype(dtype)?, + is_gpt_neox, + }) + } + + #[cfg(feature = "cuda")] + fn execute_dtype( + &self, + dev: &CudaDevice, + q_storage: &CudaStorage, + k_storage: &CudaStorage, + q: &Tensor, + k: &Tensor, + cache_storage: &CudaStorage, + pos_storage: &CudaStorage, + ) -> Result<()> { + use candle::cuda_backend::WrapErr; + + let num_tokens = q.dim(0)?; + let rot_dim = self.cache.dim(1)?; + let num_heads = q.dim(1)?; + let num_kv_heads = k.dim(1)?; + let q_stride = q.stride()[0]; + let k_stride = k.stride()[0]; + + let func = dev.get_or_load_func( + &if self.is_gpt_neox { + kernel_name::("rotary_embedding_kernel_neox") + } else { + kernel_name::("rotary_embedding_kernel") + }, + kernels::FUSED_ROPE, + )?; + + let cfg = LaunchConfig { + grid_dim: (num_tokens as u32, 1, 1), + block_dim: (512.min((num_heads * rot_dim / 2) as u32), 1, 1), + shared_mem_bytes: 0, + }; + + let params = ( + pos_storage.as_cuda_slice::()?, + q_storage.as_cuda_slice::()?, + k_storage.as_cuda_slice::()?, + cache_storage.as_cuda_slice::()?, + rot_dim as i32, + q_stride as i64, + k_stride as i64, + num_heads as i32, + num_kv_heads as i32, + self.head_size as i32, + ); + unsafe { func.launch(cfg, params) }.w()?; + + Ok(()) + } + + #[cfg(feature = "cuda")] + fn fused_rope( + &self, + dev: &CudaDevice, + positions: &Tensor, + q: &Tensor, + k: &Tensor, + ) -> Result<()> { + let cache_type = self.cache.dtype(); + match ( + &*q.storage_and_layout().0, + &*k.storage_and_layout().0, + &*self.cache.storage_and_layout().0, + &*positions.storage_and_layout().0, + ) { + ( + Storage::Cuda(q_storage), + Storage::Cuda(k_storage), + Storage::Cuda(cache_storage), + Storage::Cuda(pos_storage), + ) => { + return match (q.dtype(), k.dtype(), cache_type) { + (DType::BF16, DType::BF16, DType::BF16) => self.execute_dtype::( + &dev, + q_storage, + k_storage, + q, + k, + cache_storage, + pos_storage, + ), + (DType::F16, DType::F16, DType::F16) => self.execute_dtype::( + &dev, + q_storage, + k_storage, + q, + k, + cache_storage, + pos_storage, + ), + (DType::F32, DType::F32, DType::F32) => self.execute_dtype::( + &dev, + q_storage, + k_storage, + q, + k, + cache_storage, + pos_storage, + ), + (DType::F64, DType::F64, DType::F64) => self.execute_dtype::( + &dev, + q_storage, + k_storage, + q, + k, + cache_storage, + pos_storage, + ), + _ => candle::bail!( + "DType mismatch in fused RotaryEmbedding q={:?}, k={:?}, cache={:?}", + q.dtype(), + k.dtype(), + cache_type + ), + } + } + _ => unreachable!(), + }; + } + + /// This may modify the tensors in place! + #[allow(unused_variables)] + pub fn forward( + &self, + positions: &[usize], + positions_kernel: &Tensor, + q: &mut Tensor, + k: &mut Tensor, + b_sz: usize, + ) -> Result<()> { + match (q.device(), k.device()) { + #[cfg(feature = "cuda")] + (Device::Cuda(dev), Device::Cuda(_)) => { + self.fused_rope(dev, positions_kernel, &*q, &*k)?; + } + + _ => { + *q = self.apply_rotary_emb(&*q, positions, b_sz)?; + *k = self.apply_rotary_emb(&*k, positions, b_sz)?; + } + }; + Ok(()) + } + + fn apply_rotary_emb( + &self, + x: &Tensor, + seqlen_offsets: &[usize], + b_sz: usize, + ) -> Result { + let (b_sz_seq_len, h, n_embd) = x.dims3()?; + let x = x + .reshape((b_sz, b_sz_seq_len / b_sz, h, n_embd))? + .transpose(1, 2)?; + + fn rotate_half(xs: &Tensor) -> Result { + let last_dim = xs.dim(D::Minus1)?; + let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?; + let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?; + Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1) + } + let (b_sz, n_head, seq_len, _n_embd) = x.dims4()?; + if self.is_gpt_neox { + let mut embeds = Vec::new(); + for (b, seqlen_offset) in zip(0..b_sz, seqlen_offsets) { + let cos = self.cos.narrow(0, *seqlen_offset, seq_len)?; + let sin = self.sin.narrow(0, *seqlen_offset, seq_len)?; + let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) + let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) + let x_b = x.i(b)?.unsqueeze(0)?; + let embed = (x_b.broadcast_mul(&cos)? + rotate_half(&x_b)?.broadcast_mul(&sin)?)?; + embeds.push(embed); + } + Tensor::cat(&embeds, 0) + } else { + let mut ropes = Vec::new(); + let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?; + for (b, seqlen_offset) in zip(0..b_sz, seqlen_offsets) { + let cos = self.cos.narrow(0, *seqlen_offset, seq_len)?.reshape(( + seq_len, + n_embd / 2, + 1, + ))?; + let sin = self.sin.narrow(0, *seqlen_offset, seq_len)?.reshape(( + seq_len, + n_embd / 2, + 1, + ))?; + let cos = cos.broadcast_as((1, 1, seq_len, n_embd / 2, 1))?; + let sin = sin.broadcast_as((1, 1, seq_len, n_embd / 2, 1))?; + // This mimics the llama.cpp behavior. + // https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105 + // The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension. + // The resulting y0 and y1 are also interleaved with: + // y0 = x0*cos - x1*sin + // y1 = x0*sin + x1*cos + let x_b = x.i(b)?.unsqueeze(0)?; + let x_b = x_b.reshape((1, n_head, seq_len, n_embd / 2, 2))?; + let x0 = x_b.narrow(D::Minus1, 0, 1)?; + let x1 = x_b.narrow(D::Minus1, 1, 1)?; + let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?; + let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?; + let rope = Tensor::cat(&[y0, y1], D::Minus1)?; + let rope = rope.flatten_from(D::Minus2)?; + ropes.push(rope); + } + Tensor::cat(&ropes, 0) + } + } +} diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 00669468d6..dfd4977042 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -34,7 +34,8 @@ impl<'a, B: Backend> Clone for VarBuilderArgs<'a, B> { pub type VarBuilder<'a> = VarBuilderArgs<'a, Box>; struct TensorData { - backend: B, + backend: Arc, + pub dtype: DType, pub device: Device, } @@ -95,7 +96,8 @@ impl<'a> Backend for Box { impl<'a, B: Backend> VarBuilderArgs<'a, B> { pub fn new_with_args(backend: B, dtype: DType, dev: &Device) -> Self { let data = TensorData { - backend, + backend: Arc::new(backend), + dtype, device: dev.clone(), }; Self { @@ -213,6 +215,31 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { .backend .get(s.into(), &path, hints, dtype, &self.data.device) } + + /// Set the device of the VarBuilder. + pub fn set_device(self, device: Device) -> Self { + Self { + data: Arc::new(TensorData { + backend: self.data.backend.clone(), + dtype: self.data.dtype, + device, + }), + ..self + } + } + + /// Set the dtype of the VarBuilder. + pub fn set_dtype(self, dtype: DType) -> Self { + Self { + data: Arc::new(TensorData { + backend: self.data.backend.clone(), + dtype, + device: self.data.device.clone(), + }), + dtype, + ..self + } + } } struct Zeros; @@ -474,7 +501,11 @@ impl<'a> VarBuilder<'a> { dtype: DType, device: Device, ) -> Self { - let data = TensorData { backend, device }; + let data = TensorData { + backend: Arc::new(backend), + dtype, + device, + }; Self { data: Arc::new(data), path: vec![], @@ -578,7 +609,11 @@ impl<'a> VarBuilder<'a> { let path = self.path.clone(); let backend = Rename::new(self, renamer); let backend: Box = Box::new(backend); - let data = TensorData { backend, device }; + let data = TensorData { + backend: Arc::new(backend), + dtype, + device, + }; Self { data: Arc::new(data), dtype, diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index de3e1010ac..43d7a628ca 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -362,7 +362,7 @@ fn simple_eval_( // HACK: current implementation of broadcast_pow cannot handle negative base, // so we use powf where we can, which *does* correctly handle negative base. if let Ok(exp) = (|| input1.to_dtype(DType::F64)?.to_scalar::())() { - let output = input0.powf(exp)?; + let output = input0.powf(exp as f64)?; values.insert(node.output[0].clone(), output); } else { let output = input0.broadcast_pow(input1)?; @@ -643,7 +643,7 @@ fn simple_eval_( let mask = indices.lt(&zeros)?; mask.to_dtype(indices.dtype())? .broadcast_mul(&max)? - .add(indices)? + .add(&indices)? }; // In Pytorch or Numpy this can be done by indexing the xs tensor using the indices @@ -717,6 +717,8 @@ fn simple_eval_( let output = match start.dtype() { DType::U8 => arange_step!(u8), DType::U32 => arange_step!(u32), + DType::I16 => arange_step!(i16), + DType::I32 => arange_step!(i32), DType::I64 => arange_step!(i64), DType::BF16 => arange_step!(f32), DType::F16 => arange_step!(f32), @@ -1463,7 +1465,7 @@ fn simple_eval_( let input = get(&node.input[0])?; let dt = input.dtype(); match dt { - DType::U8 | DType::U32 | DType::I64 => { + DType::U8 | DType::U32 | DType::I64 | DType::I16 | DType::I32 => { bail!( "unsupported dtype {}, only float types are allowed for LeakyRelu", dt.as_str() diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml index 8800133429..bfed9eb48b 100644 --- a/candle-pyo3/Cargo.toml +++ b/candle-pyo3/Cargo.toml @@ -19,6 +19,7 @@ candle = { workspace = true } candle-nn = { workspace = true } candle-onnx = { workspace = true, optional = true } half = { workspace = true } +float8 = { workspace = true } intel-mkl-src = { workspace = true, optional = true } pyo3 = { version = "0.21.0", features = ["extension-module", "abi3-py38"] } diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 0da2c70028..ab7f07d985 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1,4 +1,5 @@ #![allow(clippy::redundant_closure_call)] +use float8::F8E4M3; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::pyclass::CompareOp; @@ -151,6 +152,8 @@ macro_rules! pydtype { }; } +pydtype!(i16, |v| v); +pydtype!(i32, |v| v); pydtype!(i64, |v| v); pydtype!(u8, |v| v); pydtype!(u32, |v| v); @@ -158,6 +161,7 @@ pydtype!(f16, f32::from); pydtype!(bf16, f32::from); pydtype!(f32, |v| v); pydtype!(f64, |v| v); +pydtype!(F8E4M3, f32::from); fn actual_index(t: &Tensor, dim: usize, index: i64) -> ::candle::Result { let dim = t.dim(dim)?; @@ -200,11 +204,14 @@ trait MapDType { match t.dtype() { DType::U8 => self.f::(t), DType::U32 => self.f::(t), + DType::I16 => self.f::(t), + DType::I32 => self.f::(t), DType::I64 => self.f::(t), DType::BF16 => self.f::(t), DType::F16 => self.f::(t), DType::F32 => self.f::(t), DType::F64 => self.f::(t), + DType::F8E4M3 => self.f::(t), } } } diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml index 6589b4b146..94d3f51fd9 100644 --- a/candle-transformers/Cargo.toml +++ b/candle-transformers/Cargo.toml @@ -29,6 +29,6 @@ tracing = { workspace = true } default = [] accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"] cuda = ["candle/cuda", "candle-nn/cuda"] -flash-attn = ["cuda", "dep:candle-flash-attn"] +flash-attn = ["cuda", "dep:candle-flash-attn", "candle-nn/flash-attn"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"] metal = ["candle/metal", "candle-nn/metal"] diff --git a/candle-transformers/src/models/based.rs b/candle-transformers/src/models/based.rs index aa28f52333..534ed3c964 100644 --- a/candle-transformers/src/models/based.rs +++ b/candle-transformers/src/models/based.rs @@ -8,8 +8,8 @@ use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{ - conv1d_no_bias, linear, linear_no_bias, ops::softmax_last_dim, rms_norm, Conv1d, Conv1dConfig, - Func, Linear, RmsNorm, VarBuilder, + conv1d_no_bias, layer_norm::RmsNormNonQuantized, linear, linear_no_bias, ops::softmax_last_dim, + rms_norm_non_quant, Conv1d, Conv1dConfig, Func, Linear, RmsNorm, VarBuilder, }; use std::sync::Arc; @@ -460,16 +460,16 @@ impl SequenceMixer { #[derive(Debug, Clone)] struct DecoderLayer { mlp: MLP, - norm1: RmsNorm, - norm2: RmsNorm, + norm1: RmsNorm, + norm2: RmsNorm, mixer: SequenceMixer, } impl DecoderLayer { fn new(layer_idx: usize, cfg: &Config, vb: VarBuilder) -> Result { let mlp = MLP::new(cfg, vb.pp("mlp"))?; - let norm1 = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("norm1"))?; - let norm2 = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("norm2"))?; + let norm1 = rms_norm_non_quant(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("norm1"))?; + let norm2 = rms_norm_non_quant(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("norm2"))?; let l_attn = cfg.alt_mixer_layers.contains(&layer_idx); let sw_attn = cfg.alt_mixer_2_layers.contains(&layer_idx); @@ -510,7 +510,7 @@ impl DecoderLayer { pub struct Model { embed_tokens: super::with_tracing::Embedding, layers: Vec, - norm: RmsNorm, + norm: RmsNorm, lm_head: Linear, sliding_window: usize, device: Device, @@ -529,7 +529,7 @@ impl Model { let layer = DecoderLayer::new(layer_idx, cfg, vb_l.pp(layer_idx))?; layers.push(layer) } - let norm = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb_m.pp("ln_f"))?; + let norm = rms_norm_non_quant(cfg.hidden_size, cfg.layer_norm_epsilon, vb_m.pp("ln_f"))?; Ok(Self { embed_tokens, layers, diff --git a/candle-transformers/src/models/beit.rs b/candle-transformers/src/models/beit.rs index 8f6284a8e6..f122caee22 100644 --- a/candle-transformers/src/models/beit.rs +++ b/candle-transformers/src/models/beit.rs @@ -79,34 +79,34 @@ impl Attention { .contiguous()?; let relative_coords = relative_coords.slice_assign( - &[0..w_area, 0..w_area, 0..1], + &[&(0..w_area), &(0..w_area), &(0..1)], &(relative_coords.i((0..w_area, 0..w_area, 0..1))? + (WINDOW_SIZE - 1) as f64)?, )?; let relative_coords = relative_coords.slice_assign( - &[0..w_area, 0..w_area, 1..2], + &[&(0..w_area), &(0..w_area), &(1..2)], &(relative_coords.i((0..w_area, 0..w_area, 1..2))? + (WINDOW_SIZE - 1) as f64)?, )?; let relative_coords = relative_coords.slice_assign( - &[0..w_area, 0..w_area, 0..1], + &[&(0..w_area), &(0..w_area), &(0..1)], &(relative_coords.i((.., .., 0..1))? * (2. * (WINDOW_SIZE as f64) - 1.))?, )?; Tensor::zeros((w_area + 1, w_area + 1), DType::I64, device)? - .slice_assign(&[1.., 1..], &relative_coords.sum(2)?)? + .slice_assign(&[&(1..), &(1..)], &relative_coords.sum(2)?)? .slice_assign( - &[0..1, 0..(w_area + 1)], + &[&(0..1), &(0..(w_area + 1))], &(Tensor::ones((1, w_area + 1), DType::I64, device)? * ((num_relative_distance - 3) as f64))? .to_dtype(DType::I64)?, )? .slice_assign( - &[0..(w_area + 1), 0..1], + &[&(0..(w_area + 1)), &(0..1)], &(Tensor::ones((w_area + 1, 1), DType::I64, device)? * ((num_relative_distance - 2) as f64))? .to_dtype(DType::I64)?, )? .slice_assign( - &[0..1, 0..1], + &[&(0..1), &(0..1)], &(Tensor::ones((1, 1), DType::I64, device)? * ((num_relative_distance - 1) as f64))? .to_dtype(DType::I64)?, diff --git a/candle-transformers/src/models/chatglm.rs b/candle-transformers/src/models/chatglm.rs index 0686b34ef3..da093a7c17 100644 --- a/candle-transformers/src/models/chatglm.rs +++ b/candle-transformers/src/models/chatglm.rs @@ -374,7 +374,7 @@ struct Block { impl Block { fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result { let input_layernorm = if cfg.rmsnorm { - candle_nn::rms_norm( + candle_nn::rms_norm_non_quant( cfg.hidden_size, cfg.layernorm_epsilon, vb.pp("input_layernorm"), @@ -388,7 +388,7 @@ impl Block { )? }; let post_attention_layernorm = if cfg.rmsnorm { - candle_nn::rms_norm( + candle_nn::rms_norm_non_quant( cfg.hidden_size, cfg.layernorm_epsilon, vb.pp("post_attention_layernorm"), @@ -460,7 +460,7 @@ impl Transformer { } let final_layernorm = if cfg.post_layer_norm { let ln = if cfg.rmsnorm { - candle_nn::rms_norm( + candle_nn::rms_norm_non_quant( cfg.hidden_size, cfg.layernorm_epsilon, vb.pp("final_layernorm"), diff --git a/candle-transformers/src/models/codegeex4_9b.rs b/candle-transformers/src/models/codegeex4_9b.rs index aaa99fd96d..ae4b629601 100644 --- a/candle-transformers/src/models/codegeex4_9b.rs +++ b/candle-transformers/src/models/codegeex4_9b.rs @@ -384,7 +384,7 @@ struct Block { impl Block { fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result { let input_layernorm = if cfg.rmsnorm { - candle_nn::rms_norm( + candle_nn::rms_norm_non_quant( cfg.hidden_size, cfg.layernorm_epsilon, vb.pp("input_layernorm"), @@ -398,7 +398,7 @@ impl Block { )? }; let post_attention_layernorm = if cfg.rmsnorm { - candle_nn::rms_norm( + candle_nn::rms_norm_non_quant( cfg.hidden_size, cfg.layernorm_epsilon, vb.pp("post_attention_layernorm"), @@ -470,7 +470,7 @@ impl Transformer { } let final_layernorm = if cfg.post_layer_norm { let ln = if cfg.rmsnorm { - candle_nn::rms_norm( + candle_nn::rms_norm_non_quant( cfg.hidden_size, cfg.layernorm_epsilon, vb.pp("final_layernorm"), diff --git a/candle-transformers/src/models/fastvit.rs b/candle-transformers/src/models/fastvit.rs index 8eae8bb200..8199874276 100644 --- a/candle-transformers/src/models/fastvit.rs +++ b/candle-transformers/src/models/fastvit.rs @@ -11,13 +11,13 @@ use candle_nn::{ BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder, }; -#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)] +#[derive(Clone, Debug)] pub struct Config { - pub exp_ratio: usize, - pub in_channels: usize, - pub blocks: [usize; 4], - pub attn: bool, - pub lkc_use_act: bool, + exp_ratio: usize, + in_channels: usize, + blocks: [usize; 4], + attn: bool, + lkc_use_act: bool, } impl Config { @@ -495,6 +495,7 @@ fn fastvit_model(cfg: &Config, nclasses: Option, vb: VarBuilder) -> Resul .apply(&stage3)? .apply(&stage4)? .apply(&final_conv)?; + match &cls { None => Ok(xs), Some(cls) => xs.mean(D::Minus2)?.mean(D::Minus1)?.apply(cls), diff --git a/candle-transformers/src/models/flux/model.rs b/candle-transformers/src/models/flux/model.rs index 17b4eb2532..26cbd98e3f 100644 --- a/candle-transformers/src/models/flux/model.rs +++ b/candle-transformers/src/models/flux/model.rs @@ -1,5 +1,5 @@ use candle::{DType, IndexOp, Result, Tensor, D}; -use candle_nn::{LayerNorm, Linear, RmsNorm, VarBuilder}; +use candle_nn::{layer_norm::RmsNormNonQuantized, LayerNorm, Linear, RmsNorm, VarBuilder}; // https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/model.py#L12 #[derive(Debug, Clone)] @@ -195,16 +195,16 @@ impl candle::Module for MlpEmbedder { #[derive(Debug, Clone)] pub struct QkNorm { - query_norm: RmsNorm, - key_norm: RmsNorm, + query_norm: RmsNorm, + key_norm: RmsNorm, } impl QkNorm { fn new(dim: usize, vb: VarBuilder) -> Result { let query_norm = vb.get(dim, "query_norm.scale")?; - let query_norm = RmsNorm::new(query_norm, 1e-6); + let query_norm = RmsNorm::::new(query_norm, 1e-6); let key_norm = vb.get(dim, "key_norm.scale")?; - let key_norm = RmsNorm::new(key_norm, 1e-6); + let key_norm = RmsNorm::::new(key_norm, 1e-6); Ok(Self { query_norm, key_norm, diff --git a/candle-transformers/src/models/flux/quantized_model.rs b/candle-transformers/src/models/flux/quantized_model.rs index 0efeeab573..a06897540b 100644 --- a/candle-transformers/src/models/flux/quantized_model.rs +++ b/candle-transformers/src/models/flux/quantized_model.rs @@ -2,6 +2,7 @@ use super::model::{attention, timestep_embedding, Config, EmbedNd}; use crate::quantized_nn::{linear, linear_b, Linear}; use crate::quantized_var_builder::VarBuilder; use candle::{DType, IndexOp, Result, Tensor, D}; +use candle_nn::layer_norm::RmsNormNonQuantized; use candle_nn::{LayerNorm, RmsNorm}; fn layer_norm(dim: usize, vb: VarBuilder) -> Result { @@ -34,16 +35,16 @@ impl candle::Module for MlpEmbedder { #[derive(Debug, Clone)] pub struct QkNorm { - query_norm: RmsNorm, - key_norm: RmsNorm, + query_norm: RmsNorm, + key_norm: RmsNorm, } impl QkNorm { fn new(dim: usize, vb: VarBuilder) -> Result { let query_norm = vb.get(dim, "query_norm.scale")?.dequantize(vb.device())?; - let query_norm = RmsNorm::new(query_norm, 1e-6); + let query_norm = RmsNorm::::new(query_norm, 1e-6); let key_norm = vb.get(dim, "key_norm.scale")?.dequantize(vb.device())?; - let key_norm = RmsNorm::new(key_norm, 1e-6); + let key_norm = RmsNorm::::new(key_norm, 1e-6); Ok(Self { query_norm, key_norm, diff --git a/candle-transformers/src/models/glm4.rs b/candle-transformers/src/models/glm4.rs index 3b436eaa6d..00ead338d0 100644 --- a/candle-transformers/src/models/glm4.rs +++ b/candle-transformers/src/models/glm4.rs @@ -383,7 +383,7 @@ struct Block { impl Block { fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result { let input_layernorm = if cfg.rmsnorm { - candle_nn::rms_norm( + candle_nn::rms_norm_non_quant( cfg.hidden_size, cfg.layernorm_epsilon, vb.pp("input_layernorm"), @@ -397,7 +397,7 @@ impl Block { )? }; let post_attention_layernorm = if cfg.rmsnorm { - candle_nn::rms_norm( + candle_nn::rms_norm_non_quant( cfg.hidden_size, cfg.layernorm_epsilon, vb.pp("post_attention_layernorm"), @@ -469,7 +469,7 @@ impl Transformer { } let final_layernorm = if cfg.post_layer_norm { let ln = if cfg.rmsnorm { - candle_nn::rms_norm( + candle_nn::rms_norm_non_quant( cfg.hidden_size, cfg.layernorm_epsilon, vb.pp("final_layernorm"), diff --git a/candle-transformers/src/models/llama2_c.rs b/candle-transformers/src/models/llama2_c.rs index 923a270646..91d298e97f 100644 --- a/candle-transformers/src/models/llama2_c.rs +++ b/candle-transformers/src/models/llama2_c.rs @@ -1,6 +1,7 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle_nn::layer_norm::RmsNormNonQuantized; use candle_nn::linear_no_bias as linear; -use candle_nn::{embedding, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder}; +use candle_nn::{embedding, rms_norm_non_quant, Embedding, Linear, Module, RmsNorm, VarBuilder}; use std::collections::HashMap; #[derive(Debug, Clone)] @@ -282,14 +283,19 @@ impl Mlp { #[derive(Debug, Clone)] struct Block { - rms_1: RmsNorm, + rms_1: RmsNorm, attn: CausalSelfAttention, - rms_2: RmsNorm, + rms_2: RmsNorm, mlp: Mlp, } impl Block { - fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { + fn new( + rms_1: RmsNorm, + attn: CausalSelfAttention, + rms_2: RmsNorm, + mlp: Mlp, + ) -> Self { Self { rms_1, attn, @@ -316,9 +322,9 @@ impl Block { fn load(vb: VarBuilder, cfg: &Config) -> Result { let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?; let mlp = Mlp::load(vb.pp("mlp"), cfg)?; - let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; + let input_layernorm = rms_norm_non_quant(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; let post_attention_layernorm = - rms_norm(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?; + rms_norm_non_quant(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?; Ok(Self::new( input_layernorm, attn, @@ -332,7 +338,7 @@ impl Block { pub struct Llama { wte: Embedding, blocks: Vec, - ln_f: RmsNorm, + ln_f: RmsNorm, lm_head: Linear, pub config: Config, } @@ -352,7 +358,7 @@ impl Llama { pub fn load(vb: VarBuilder, cfg: Config) -> Result { let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?; - let ln_f = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; + let ln_f = rms_norm_non_quant(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.n_layers) .map(|i| Block::load(vb.pp(format!("model.layers.{i}")), &cfg).unwrap()) .collect(); diff --git a/candle-transformers/src/models/mamba.rs b/candle-transformers/src/models/mamba.rs index a75ee87a6e..c8d9bf1e2e 100644 --- a/candle-transformers/src/models/mamba.rs +++ b/candle-transformers/src/models/mamba.rs @@ -2,7 +2,7 @@ /// This is based on: https://github.com/LaurentMazare/mamba.rs use crate::models::with_tracing::{linear, linear_no_bias, Linear}; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; -use candle_nn::{RmsNorm, VarBuilder}; +use candle_nn::{layer_norm::RmsNormNonQuantized, RmsNorm, VarBuilder}; const D_CONV: usize = 4; const D_STATE: usize = 16; @@ -155,12 +155,12 @@ impl MambaBlock { #[derive(Clone, Debug)] pub struct ResidualBlock { mixer: MambaBlock, - norm: RmsNorm, + norm: RmsNorm, } impl ResidualBlock { pub fn new(layer_index: usize, cfg: &Config, vb: VarBuilder) -> Result { - let norm = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm"))?; + let norm = candle_nn::rms_norm_non_quant(cfg.d_model, 1e-5, vb.pp("norm"))?; let mixer = MambaBlock::new(layer_index, cfg, vb.pp("mixer"))?; Ok(Self { mixer, norm }) } @@ -175,7 +175,7 @@ impl ResidualBlock { pub struct Model { embedding: candle_nn::Embedding, layers: Vec, - norm_f: RmsNorm, + norm_f: RmsNorm, lm_head: Linear, dtype: DType, } @@ -189,7 +189,7 @@ impl Model { let layer = ResidualBlock::new(layer_idx, cfg, vb_l.pp(layer_idx))?; layers.push(layer) } - let norm_f = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm_f"))?; + let norm_f = candle_nn::rms_norm_non_quant(cfg.d_model, 1e-5, vb.pp("norm_f"))?; let lm_head = Linear::from_weights(embedding.embeddings().clone(), None); Ok(Self { embedding, diff --git a/candle-transformers/src/models/metavoice.rs b/candle-transformers/src/models/metavoice.rs index 43de594f9d..ec382711cd 100644 --- a/candle-transformers/src/models/metavoice.rs +++ b/candle-transformers/src/models/metavoice.rs @@ -1,5 +1,5 @@ use candle::{DType, Device, Error as E, IndexOp, Module, Result, Tensor, D}; -use candle_nn::{embedding, linear_b, rms_norm, Embedding, Linear, RmsNorm, VarBuilder}; +use candle_nn::{embedding, linear_b, rms_norm_non_quant, Embedding, Linear, RmsNorm, VarBuilder}; // Equivalent to torch.repeat_interleave pub(crate) fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result { @@ -328,6 +328,8 @@ pub mod tokenizers { } pub mod gpt { + use candle_nn::layer_norm::RmsNormNonQuantized; + use super::*; #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] @@ -350,7 +352,7 @@ pub mod gpt { } enum Norm { - RMSNorm(candle_nn::RmsNorm), + RMSNorm(candle_nn::RmsNorm), LayerNorm(candle_nn::LayerNorm), } @@ -400,7 +402,7 @@ pub mod gpt { fn new(cfg: &Config, vb: VarBuilder) -> Result { match cfg.norm_type { NormType::RMSNorm => { - let rms_norm = candle_nn::rms_norm(cfg.n_embd, cfg.rmsnorm_eps, vb)?; + let rms_norm = candle_nn::rms_norm_non_quant(cfg.n_embd, cfg.rmsnorm_eps, vb)?; Ok(Self::RMSNorm(rms_norm)) } NormType::LayerNorm => { @@ -666,6 +668,8 @@ pub mod gpt { } pub mod transformer { + use candle_nn::layer_norm::RmsNormNonQuantized; + use super::*; #[derive(Debug, Clone, serde::Deserialize)] @@ -833,8 +837,8 @@ pub mod transformer { struct Block { attention: Attention, feed_forward: FeedForward, - ffn_norm: RmsNorm, - attention_norm: RmsNorm, + ffn_norm: RmsNorm, + attention_norm: RmsNorm, span: tracing::Span, } @@ -842,8 +846,9 @@ pub mod transformer { fn new(cfg: &Config, vb: VarBuilder) -> Result { let attention = Attention::new(cfg, vb.pp("attention"))?; let feed_forward = FeedForward::new(cfg, vb.pp("feed_forward"))?; - let ffn_norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("ffn_norm"))?; - let attention_norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("attention_norm"))?; + let ffn_norm = rms_norm_non_quant(cfg.dim, cfg.norm_eps, vb.pp("ffn_norm"))?; + let attention_norm = + rms_norm_non_quant(cfg.dim, cfg.norm_eps, vb.pp("attention_norm"))?; Ok(Self { attention, feed_forward, @@ -871,7 +876,7 @@ pub mod transformer { pos_embeddings: Embedding, speaker_cond_pos: Linear, layers: Vec, - norm: RmsNorm, + norm: RmsNorm, output: Linear, spk_cond_mask: Tensor, span: tracing::Span, @@ -893,7 +898,7 @@ pub mod transformer { let layer = Block::new(cfg, vb_l.pp(layer_idx))?; layers.push(layer) } - let norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("norm"))?; + let norm = rms_norm_non_quant(cfg.dim, cfg.norm_eps, vb.pp("norm"))?; let output = linear_b(cfg.dim, cfg.vocab_size, false, vb.pp("output"))?; let dtype = vb.dtype(); let spk_cond_mask = Tensor::cat( diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs index e8f7a7c4b8..4a0cddf190 100644 --- a/candle-transformers/src/models/mistral.rs +++ b/candle-transformers/src/models/mistral.rs @@ -1,7 +1,7 @@ use crate::models::with_tracing::{linear_no_bias, Linear, RmsNorm}; /// Mistral LLM, https://github.com/mistralai/mistral-src use candle::{DType, Device, Module, Result, Tensor, D}; -use candle_nn::{Activation, VarBuilder}; +use candle_nn::{scaled_dot_product_attention, Activation, VarBuilder}; use std::sync::Arc; fn default_num_attention_heads() -> usize { @@ -176,22 +176,6 @@ impl Module for MLP { } } -#[cfg(feature = "flash-attn")] -fn flash_attn( - q: &Tensor, - k: &Tensor, - v: &Tensor, - softmax_scale: f32, - causal: bool, -) -> Result { - candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) -} - -#[cfg(not(feature = "flash-attn"))] -fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result { - unimplemented!("compile with '--features flash-attn'") -} - #[derive(Debug, Clone)] struct Attention { q_proj: Linear, @@ -274,24 +258,17 @@ impl Attention { let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?; let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?; - let attn_output = if self.use_flash_attn { - // flash-attn expects (b_sz, seq_len, nheads, head_dim) - let q = query_states.transpose(1, 2)?; - let k = key_states.transpose(1, 2)?; - let v = value_states.transpose(1, 2)?; - let softmax_scale = 1f32 / (self.head_dim as f32).sqrt(); - flash_attn(&q, &k, &v, softmax_scale, q_len > 1)?.transpose(1, 2)? - } else { - let scale = 1f64 / f64::sqrt(self.head_dim as f64); - let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?; - - let attn_weights = match attention_mask { - None => attn_weights, - Some(mask) => attn_weights.broadcast_add(mask)?, - }; - let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; - attn_weights.matmul(&value_states)? - }; + let scale = 1. / (self.head_dim as f64).sqrt(); + let attn_output = scaled_dot_product_attention( + &query_states, + &key_states, + &value_states, + scale, + attention_mask, + self.use_flash_attn, + q_len, + )?; + attn_output .transpose(1, 2)? .reshape((b_sz, q_len, self.num_heads * self.head_dim))? diff --git a/candle-transformers/src/models/mobileclip.rs b/candle-transformers/src/models/mobileclip.rs index 45a5dbad9f..4953d835b5 100644 --- a/candle-transformers/src/models/mobileclip.rs +++ b/candle-transformers/src/models/mobileclip.rs @@ -22,6 +22,7 @@ impl MobileClipConfig { pub fn s1() -> Self { let text_config = text_model::Config::vit_base_patch32(); let vision_config = fastvit::Config::mci1(); + Self { text_config, vision_config, @@ -31,6 +32,7 @@ impl MobileClipConfig { pub fn s2() -> Self { let text_config = text_model::Config::vit_base_patch32(); let vision_config = fastvit::Config::mci2(); + Self { text_config, vision_config, @@ -43,10 +45,12 @@ impl MobileClipModel { pub fn new(vs: VarBuilder, c: &MobileClipConfig) -> Result { let vision_model = fastvit::fastvit(&c.vision_config, 512, vs.pp("visual.trunk"))?; let text_model = text_model::OpenClipTextTransformer::new(vs.pp("text"), &c.text_config)?; + let text_projection = vs.get( (c.text_config.embed_dim, c.text_config.projection_dim), "text.text_projection", )?; + let logit_scale = vs.get(&[], "logit_scale")?; Ok(Self { text_model, diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 80cd4f810c..88d9f0307e 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -63,7 +63,6 @@ pub mod quantized_mixformer; pub mod quantized_moondream; pub mod quantized_mpt; pub mod quantized_phi; -pub mod quantized_phi3; pub mod quantized_qwen2; pub mod quantized_recurrent_gemma; pub mod quantized_rwkv_v5; diff --git a/candle-transformers/src/models/pixtral/vision_model.rs b/candle-transformers/src/models/pixtral/vision_model.rs index 20d8f08231..4e3875b27b 100644 --- a/candle-transformers/src/models/pixtral/vision_model.rs +++ b/candle-transformers/src/models/pixtral/vision_model.rs @@ -1,5 +1,7 @@ use candle::{DType, Module, Result, Tensor, D}; -use candle_nn::{linear_b, rms_norm, Linear, RmsNorm, VarBuilder}; +use candle_nn::{ + layer_norm::RmsNormNonQuantized, linear_b, rms_norm_non_quant, Linear, RmsNorm, VarBuilder, +}; fn default_act() -> candle_nn::Activation { candle_nn::Activation::Gelu @@ -165,18 +167,18 @@ impl Module for Mlp { #[derive(Debug, Clone)] struct AttentionLayer { - attention_norm: RmsNorm, + attention_norm: RmsNorm, feed_forward: Mlp, attention: Attention, - ffn_norm: RmsNorm, + ffn_norm: RmsNorm, } impl AttentionLayer { fn new(cfg: &Config, vb: VarBuilder) -> Result { - let attention_norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("attention_norm"))?; + let attention_norm = rms_norm_non_quant(cfg.hidden_size, 1e-5, vb.pp("attention_norm"))?; let feed_forward = Mlp::new(cfg, vb.pp("feed_forward"))?; let attention = Attention::new(cfg, vb.pp("attention"))?; - let ffn_norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("ffn_norm"))?; + let ffn_norm = rms_norm_non_quant(cfg.hidden_size, 1e-5, vb.pp("ffn_norm"))?; Ok(Self { attention_norm, feed_forward, @@ -283,7 +285,7 @@ impl RotaryEmbedding { #[derive(Debug, Clone)] pub struct Model { patch_conv: candle_nn::Conv2d, - ln_pre: RmsNorm, + ln_pre: RmsNorm, transformer: Transformer, patch_positional_embedding: RotaryEmbedding, } @@ -301,7 +303,7 @@ impl Model { conv2d_cfg, vb.pp("patch_conv"), )?; - let ln_pre = candle_nn::rms_norm(cfg.hidden_size, 1e-5, vb.pp("ln_pre"))?; + let ln_pre = rms_norm_non_quant(cfg.hidden_size, 1e-5, vb.pp("ln_pre"))?; let transformer = Transformer::new(cfg, vb.pp("transformer"))?; let patch_positional_embedding = RotaryEmbedding::new(cfg, vb.pp("patch_positional_embedding"))?; diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 6b326fbe92..544bf8a456 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -4,7 +4,7 @@ use crate::quantized_nn::RmsNorm; use candle::quantized::QTensor; use candle::quantized::{ggml_file, gguf_file}; use candle::{DType, Device, IndexOp, Result, Tensor}; -use candle_nn::{Embedding, Module}; +use candle_nn::{scaled_dot_product_attention, Embedding, Module}; pub const MAX_SEQ_LEN: usize = 4096; @@ -138,19 +138,12 @@ struct LayerWeights { head_dim: usize, cos: Tensor, sin: Tensor, - neg_inf: Tensor, kv_cache: Option<(Tensor, Tensor)>, span_attn: tracing::Span, span_rot: tracing::Span, span_mlp: tracing::Span, } -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result { - let shape = mask.shape(); - let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?; - Ok(m) -} - impl LayerWeights { fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result { let _enter = self.span_rot.enter(); @@ -209,17 +202,10 @@ impl LayerWeights { let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?; let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?; - let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; - let att = match mask { - None => att, - Some(mask) => { - let mask = mask.broadcast_as(att.shape())?; - masked_fill(&att, &mask, &self.neg_inf)? - } - }; - let att = candle_nn::ops::softmax_last_dim(&att)?; - // Convert to contiguous as matmul doesn't support strided vs for now. - let y = att.matmul(&v.contiguous()?)?; + let scale = 1. / (self.head_dim as f64).sqrt(); + + let y = scaled_dot_product_attention(&q, &k, &v, scale, mask, false, seq_len)?; + let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; let y = self.attention_wo.forward(&y)?; Ok(y) @@ -260,7 +246,6 @@ impl ModelWeights { pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result { let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; let (cos, sin) = precomput_freqs_cis(head_dim, 10000., &ct.device)?; - let neg_inf = Tensor::new(f32::NEG_INFINITY, &ct.device)?; let tok_embeddings = ct.remove("tok_embeddings.weight")?; let tok_embeddings = tok_embeddings.dequantize(&ct.device)?; let norm = RmsNorm::from_qtensor(ct.remove("norm.weight")?, 1e-5)?; @@ -300,7 +285,6 @@ impl ModelWeights { head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize, cos: cos.clone(), sin: sin.clone(), - neg_inf: neg_inf.clone(), kv_cache: None, span_attn, span_rot, @@ -349,7 +333,6 @@ impl ModelWeights { .and_then(|m| m.to_f32()) .unwrap_or(10000f32); let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?; - let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?; let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; let tok_embeddings = tok_embeddings.dequantize(device)?; @@ -420,7 +403,6 @@ impl ModelWeights { head_dim: embedding_length / head_count, cos: cos.clone(), sin: sin.clone(), - neg_inf: neg_inf.clone(), kv_cache: None, span_attn, span_rot, @@ -445,9 +427,11 @@ impl ModelWeights { Ok(mask.clone()) } else { let mask: Vec<_> = (0..t) - .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) + .flat_map(|i| (0..t).map(move |j| if i < j { f32::NEG_INFINITY } else { 0.0 })) .collect(); - let mask = Tensor::from_slice(&mask, (t, t), device)?; + let mask = Tensor::from_slice(&mask, (t, t), device)? + .expand((1, 1, t, t))? + .to_dtype(DType::F32)?; self.masks.insert(t, mask.clone()); Ok(mask) } diff --git a/candle-transformers/src/models/quantized_phi3.rs b/candle-transformers/src/models/quantized_phi3.rs deleted file mode 100644 index 257ad98379..0000000000 --- a/candle-transformers/src/models/quantized_phi3.rs +++ /dev/null @@ -1,322 +0,0 @@ -use std::collections::HashMap; - -use candle::quantized::gguf_file; -use candle::quantized::QTensor; -use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; -use candle_nn::{kv_cache::KvCache, Embedding, RmsNorm}; - -#[derive(Debug, Clone)] -struct QLinear { - inner: candle::quantized::QMatMul, - span: tracing::Span, -} - -impl QLinear { - fn new( - ct: &gguf_file::Content, - r: &mut R, - name: &str, - device: &Device, - ) -> Result { - let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); - let w = ct.tensor(r, &format!("{name}.weight"), device)?; - let inner = candle::quantized::QMatMul::from_qtensor(w)?; - Ok(Self { inner, span }) - } -} - -impl Module for QLinear { - fn forward(&self, xs: &Tensor) -> Result { - let _enter = self.span.enter(); - self.inner.forward(xs) - } -} - -#[derive(Debug, Clone)] -struct Mlp { - ffn_up: QLinear, - ffn_down: QLinear, - i_size: usize, -} - -impl Module for Mlp { - fn forward(&self, xs: &Tensor) -> Result { - let up_states = xs.apply(&self.ffn_up)?; - let gate = up_states.narrow(D::Minus1, 0, self.i_size)?; - let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?; - let up_states = (up_states * gate.silu()?)?; - up_states.apply(&self.ffn_down) - } -} - -fn rms_norm(w: QTensor, eps: f64) -> Result { - let w = w.dequantize(&w.device())?; - let rms = RmsNorm::new(w, eps); - Ok(rms) -} - -#[derive(Debug, Clone)] -struct LayerWeights { - attn_qkv: QLinear, - attn_output: QLinear, - attn_norm: RmsNorm, - ffn_norm: RmsNorm, - mlp: Mlp, - n_head: usize, - n_kv_head: usize, - head_dim: usize, - cos: Tensor, - sin: Tensor, - neg_inf: Tensor, - kv_cache: KvCache, - use_flash_attn: bool, - span_attn: tracing::Span, - span_rot: tracing::Span, -} - -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result { - let shape = mask.shape(); - let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?; - Ok(m) -} - -impl LayerWeights { - fn apply_rotary_emb(&self, xs: &Tensor, index_pos: usize) -> Result { - let _enter = self.span_rot.enter(); - let (_b_sz, _h, seq_len, _n_embd) = xs.dims4()?; - let cos = self.cos.narrow(0, index_pos, seq_len)?; - let sin = self.sin.narrow(0, index_pos, seq_len)?; - candle_nn::rotary_emb::rope(&xs.contiguous()?, &cos, &sin) - } - - fn forward_attn( - &mut self, - x: &Tensor, - mask: Option<&Tensor>, - index_pos: usize, - ) -> Result { - let _enter = self.span_attn.enter(); - let (b_sz, seq_len, n_embd) = x.dims3()?; - let qkv = self.attn_qkv.forward(x)?; - - let query_pos = self.n_head * self.head_dim; - let q = qkv.narrow(D::Minus1, 0, query_pos)?; - let k = qkv.narrow(D::Minus1, query_pos, self.n_kv_head * self.head_dim)?; - let v = qkv.narrow( - D::Minus1, - query_pos + self.n_kv_head * self.head_dim, - self.n_kv_head * self.head_dim, - )?; - - let q = q - .reshape((b_sz, seq_len, self.n_head, self.head_dim))? - .transpose(1, 2)?; - let k = k - .reshape((b_sz, seq_len, self.n_head, self.head_dim))? - .transpose(1, 2)?; - let v = v - .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? - .transpose(1, 2)?; - - let q = self.apply_rotary_emb(&q, index_pos)?.contiguous()?; - let k = self.apply_rotary_emb(&k, index_pos)?; - - let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?; - - let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?; - let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?; - - let y = if self.use_flash_attn { - // flash-attn expects (b_sz, seq_len, nheads, head_dim) - let q = q.to_dtype(DType::BF16)?.transpose(1, 2)?; - let k = k.to_dtype(DType::BF16)?.transpose(1, 2)?; - let v = v.to_dtype(DType::BF16)?.transpose(1, 2)?; - let softmax_scale = 1f32 / (self.head_dim as f32).sqrt(); - flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)? - .to_dtype(DType::F32)? - .transpose(1, 2)? - } else { - let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; - let att = match mask { - None => att, - Some(mask) => { - let mask = mask.broadcast_as(att.shape())?; - masked_fill(&att, &mask, &self.neg_inf)? - } - }; - let att = candle_nn::ops::softmax_last_dim(&att)?; - // Convert to contiguous as matmul doesn't support strided vs for now. - att.matmul(&v)? - }; - let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; - let y = self.attn_output.forward(&y)?; - Ok(y) - } -} - -#[cfg(feature = "flash-attn")] -fn flash_attn( - q: &Tensor, - k: &Tensor, - v: &Tensor, - softmax_scale: f32, - causal: bool, -) -> Result { - candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) -} - -#[cfg(not(feature = "flash-attn"))] -fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result { - unimplemented!("compile with '--features flash-attn'") -} - -#[derive(Debug, Clone)] -pub struct ModelWeights { - tok_embeddings: Embedding, - layers: Vec, - output_norm: RmsNorm, - output: QLinear, - masks: HashMap, - span: tracing::Span, - span_output: tracing::Span, -} - -fn precomput_freqs_cis( - head_dim: usize, - max_seq_len: usize, - freq_base: f32, - device: &Device, -) -> Result<(Tensor, Tensor)> { - let theta: Vec<_> = (0..head_dim) - .step_by(2) - .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32)) - .collect(); - let theta = Tensor::new(theta.as_slice(), device)?; - let idx_theta = Tensor::arange(0, max_seq_len as u32, device)? - .to_dtype(DType::F32)? - .reshape((max_seq_len, 1))? - .matmul(&theta.reshape((1, theta.elem_count()))?)?; - let cos = idx_theta.cos()?; - let sin = idx_theta.sin()?; - Ok((cos, sin)) -} - -impl ModelWeights { - pub fn from_gguf( - use_flash_attn: bool, - ct: gguf_file::Content, - reader: &mut R, - device: &Device, - ) -> Result { - let md_get = |s: &str| match ct.metadata.get(s) { - None => candle::bail!("cannot find {s} in metadata"), - Some(v) => Ok(v), - }; - - // Parameter extraction from metadata. - let head_count = md_get("phi3.attention.head_count")?.to_u32()? as usize; - let head_count_kv = md_get("phi3.attention.head_count_kv")?.to_u32()? as usize; - let block_count = md_get("phi3.block_count")?.to_u32()? as usize; - let embedding_length = md_get("phi3.embedding_length")?.to_u32()? as usize; - let max_seq_len = md_get("phi3.context_length")?.to_u32()? as usize; - let head_dim = embedding_length / head_count; - let i_size = md_get("phi3.feed_forward_length")?.to_u32()? as usize; - let rope_dim = md_get("phi3.rope.dimension_count")?.to_u32()? as usize; - let rms_eps = md_get("phi3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64; - let (cos, sin) = precomput_freqs_cis(rope_dim, max_seq_len, 10_000., device)?; - let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?; - - let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; - let tok_embeddings = tok_embeddings.dequantize(device)?; - let output_norm = rms_norm(ct.tensor(reader, "output_norm.weight", device)?, rms_eps)?; - let output = QLinear::new(&ct, reader, "output", device)?; - - let mut layers = Vec::with_capacity(block_count); - for layer_idx in 0..block_count { - let prefix = format!("blk.{layer_idx}"); - let ffn_up = QLinear::new(&ct, reader, &format!("{prefix}.ffn_up"), device)?; - let ffn_down = QLinear::new(&ct, reader, &format!("{prefix}.ffn_down"), device)?; - let mlp = Mlp { - ffn_up, - ffn_down, - i_size, - }; - let attn_norm = rms_norm( - ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?, - rms_eps, - )?; - let ffn_norm = rms_norm( - ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?, - rms_eps, - )?; - let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); - let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); - let kv_cache = KvCache::new(2, max_seq_len); - layers.push(LayerWeights { - attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?, - attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?, - attn_norm, - ffn_norm, - mlp, - n_head: head_count, - n_kv_head: head_count_kv, - head_dim, - cos: cos.clone(), - sin: sin.clone(), - neg_inf: neg_inf.clone(), - kv_cache, - use_flash_attn, - span_attn, - span_rot, - }) - } - let span = tracing::span!(tracing::Level::TRACE, "model"); - let span_output = tracing::span!(tracing::Level::TRACE, "output"); - Ok(Self { - tok_embeddings: Embedding::new(tok_embeddings, embedding_length), - layers, - output_norm, - output, - masks: HashMap::new(), - span, - span_output, - }) - } - - fn mask(&mut self, t: usize, device: &Device) -> Result { - if let Some(mask) = self.masks.get(&t) { - Ok(mask.clone()) - } else { - let mask: Vec<_> = (0..t) - .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) - .collect(); - let mask = Tensor::from_slice(&mask, (t, t), device)?; - self.masks.insert(t, mask.clone()); - Ok(mask) - } - } - - pub fn forward(&mut self, xs: &Tensor, index_pos: usize) -> Result { - let (_b_sz, seq_len) = xs.dims2()?; - let mask = if seq_len == 1 { - None - } else { - Some(self.mask(seq_len, xs.device())?) - }; - let _enter = self.span.enter(); - let mut xs = self.tok_embeddings.forward(xs)?; - for layer in self.layers.iter_mut() { - let residual = &xs; - let ys = xs.apply(&layer.attn_norm)?; - let ys = layer.forward_attn(&ys, mask.as_ref(), index_pos)?; - let ys = (ys + residual)?; - let residual = &ys; - let ys = ys.apply(&layer.ffn_norm)?; - let ys = layer.mlp.forward(&ys)?; - xs = (ys + residual)? - } - let xs = xs.apply(&self.output_norm)?.i((.., seq_len - 1, ..))?; - let _enter = self.span_output.enter(); - self.output.forward(&xs) - } -} diff --git a/candle-transformers/src/models/vgg.rs b/candle-transformers/src/models/vgg.rs index 010643c8d2..7c8dad510e 100644 --- a/candle-transformers/src/models/vgg.rs +++ b/candle-transformers/src/models/vgg.rs @@ -54,17 +54,17 @@ impl ModuleT for Vgg<'_> { fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result> { let layers = convs .iter() - .map(|&(in_c, out_c, name)| { + .map(|(in_c, out_c, name)| { candle_nn::conv2d( - in_c, - out_c, + *in_c, + *out_c, 3, candle_nn::Conv2dConfig { stride: 1, padding: 1, ..Default::default() }, - vb.pp(name), + vb.pp(*name), ) }) .collect::>>()?; diff --git a/candle-transformers/src/models/with_tracing.rs b/candle-transformers/src/models/with_tracing.rs index f4706c7e95..29bbf637e6 100644 --- a/candle-transformers/src/models/with_tracing.rs +++ b/candle-transformers/src/models/with_tracing.rs @@ -1,5 +1,5 @@ use candle::{Module, Result, Tensor}; -use candle_nn::VarBuilder; +use candle_nn::{layer_norm::RmsNormNonQuantized, VarBuilder}; #[derive(Debug, Clone)] pub struct Embedding { @@ -170,20 +170,20 @@ pub fn layer_norm>( #[derive(Debug, Clone)] pub struct RmsNorm { - inner: candle_nn::RmsNorm, + inner: candle_nn::RmsNorm, span: tracing::Span, } impl RmsNorm { pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result { let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); - let inner = candle_nn::rms_norm(size, eps, vb)?; + let inner = candle_nn::rms_norm_non_quant(size, eps, vb)?; Ok(Self { inner, span }) } pub fn forward_diff(&self, x: &Tensor) -> Result { let _enter = self.span.enter(); - self.inner.forward_diff(x) + self.inner.forward(x) } } diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs index 01f5910aea..dae6cb4fdc 100644 --- a/candle-wasm-examples/llama2-c/src/model.rs +++ b/candle-wasm-examples/llama2-c/src/model.rs @@ -1,6 +1,7 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle_nn::layer_norm::RmsNormNonQuantized; use candle_nn::{ - embedding, linear_no_bias as linear, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder, + embedding, linear_no_bias as linear, Embedding, Linear, Module, RmsNorm, VarBuilder, }; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -201,14 +202,19 @@ impl Mlp { } struct Block { - rms_1: RmsNorm, + rms_1: RmsNorm, attn: CausalSelfAttention, - rms_2: RmsNorm, + rms_2: RmsNorm, mlp: Mlp, } impl Block { - fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { + fn new( + rms_1: RmsNorm, + attn: CausalSelfAttention, + rms_2: RmsNorm, + mlp: Mlp, + ) -> Self { Self { rms_1, attn, @@ -229,9 +235,13 @@ impl Block { fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?; let mlp = Mlp::load(vb.pp("mlp"), cfg)?; - let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; - let post_attention_layernorm = - rms_norm(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?; + let input_layernorm = + candle_nn::rms_norm_non_quant(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; + let post_attention_layernorm = candle_nn::rms_norm_non_quant( + cfg.dim, + cfg.norm_eps, + vb.pp("post_attention_layernorm"), + )?; Ok(Self::new( input_layernorm, attn, @@ -244,12 +254,17 @@ impl Block { pub struct Llama { wte: Embedding, blocks: Vec, - ln_f: RmsNorm, + ln_f: RmsNorm, lm_head: Linear, } impl Llama { - fn new(wte: Embedding, blocks: Vec, ln_f: RmsNorm, lm_head: Linear) -> Self { + fn new( + wte: Embedding, + blocks: Vec, + ln_f: RmsNorm, + lm_head: Linear, + ) -> Self { Self { wte, blocks, @@ -273,7 +288,7 @@ impl Llama { pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?; - let norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; + let norm = candle_nn::rms_norm_non_quant(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.n_layers) .map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cache, cfg).unwrap()) .collect();