Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add softcapping support to flash attention #2437

Closed
wants to merge 72 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
83a9e88
Mistral.rs Squash Changes (#4)
EricLBuehler May 15, 2024
4e82fab
Merge remote-tracking branch 'upstream/main'
EricLBuehler May 15, 2024
37cafcc
Merge remote-tracking branch 'upstream/main'
EricLBuehler May 16, 2024
5892fac
fix issue with cuda header file for A10G (#5)
joshpopelka20 May 16, 2024
9b151f5
Merge remote-tracking branch 'upstream/main'
EricLBuehler May 18, 2024
ea49ea2
Remove candle-layer-norm (#6)
EricLBuehler May 19, 2024
38f8d9e
Merge
EricLBuehler May 19, 2024
c10fc33
Merge
EricLBuehler May 27, 2024
527ebcc
Merge remote-tracking branch 'upstream/main'
EricLBuehler May 28, 2024
bfc197b
Merge remote-tracking branch 'upstream/main'
EricLBuehler May 29, 2024
0c2ac76
Merge remote-tracking branch 'upstream/main'
EricLBuehler May 30, 2024
cb3dbc2
Merge remote-tracking branch 'upstream/main'
EricLBuehler Jun 1, 2024
faa9435
Add a set_dtype method
EricLBuehler Jun 3, 2024
462d948
Merge remote-tracking branch 'upstream/main'
EricLBuehler Jun 3, 2024
5c06acd
Merge remote-tracking branch 'upstream/main'
EricLBuehler Jun 4, 2024
696acaa
Add more capability to slice_assign (#7)
EricLBuehler Jun 9, 2024
0936406
Implement unfold (#8)
EricLBuehler Jun 9, 2024
636de1d
Merge remote-tracking branch 'upstream/main'
EricLBuehler Jun 11, 2024
f52e234
Bump cudarc to 0.11.5 (#10)
EricLBuehler Jun 11, 2024
bb8f6f0
Add QTensor::quantize_onto (#12)
EricLBuehler Jun 29, 2024
5b04d96
implement Slice op (#2260)
shua Jun 12, 2024
f7095bb
Fix the fast bf16 gemm cublas kernels. (#2274)
LaurentMazare Jun 18, 2024
b55b360
Fix a bug in the metal implemtation of col2im1d. (#2284)
LaurentMazare Jun 22, 2024
08e93a6
Depth Anything v2 (#2279)
jeroenvlek Jun 24, 2024
5df1ae2
Adding Gemm and ArgMax operators to candle-onnx (#2231)
socathie Jun 28, 2024
0bb678c
Add DINOv2Reg4 + PlantCLEF2024 (#2293)
v-espitalier Jun 29, 2024
b438cba
make up for the missing last token output of phi2 example (#2299)
Czxck001 Jun 29, 2024
b7a3e34
Patch metal function
EricLBuehler Jun 30, 2024
c967be9
Complete merge
EricLBuehler Jul 15, 2024
9e09d7f
Expose cublas handle
EricLBuehler Jul 26, 2024
8b357f6
Merge remote-tracking branch 'upstream/main'
EricLBuehler Jul 26, 2024
2064fb0
Merge remote-tracking branch 'upstream/main'
EricLBuehler Jul 31, 2024
1a48767
Add sdpa function with cublaslt
EricLBuehler Aug 4, 2024
7bbcf00
Update docs
EricLBuehler Aug 4, 2024
1bf7101
Add matmul_bias_and_scale
EricLBuehler Aug 4, 2024
d6d3d18
Rename
EricLBuehler Aug 4, 2024
e20d85a
Add a simple test and fix for cpu
EricLBuehler Aug 4, 2024
8d2f32a
Update sdpa function
EricLBuehler Aug 4, 2024
9f144d6
Add matmul_alpha
EricLBuehler Aug 4, 2024
c830f26
Use matmul_with_alpha in sdpa
EricLBuehler Aug 4, 2024
86d0876
Add it to mistral
EricLBuehler Aug 5, 2024
8d8889c
Add it to q llama
EricLBuehler Aug 5, 2024
d18eb13
Add attention benches
EricLBuehler Aug 5, 2024
d71b7d7
Fixes
EricLBuehler Aug 5, 2024
412e9f4
Merge commit 'd71b7d78396a944817876c56f1677bd17633234d'
EricLBuehler Aug 5, 2024
27ca77e
Simplify things a bit
EricLBuehler Aug 7, 2024
7ad6494
Mistral.rs GPTQ dev PR (#14)
EricLBuehler Aug 9, 2024
6f0e190
Fix on metal
EricLBuehler Aug 14, 2024
ec55f58
Add the flux model for image generation. (#2390)
LaurentMazare Aug 4, 2024
0a146d7
Simplify handling of flux modulations. (#2394)
LaurentMazare Aug 4, 2024
0f55c37
optimize gradient for silu a bit (#2393)
MilkFather Aug 4, 2024
aef4eba
Support the flux-dev model too. (#2395)
LaurentMazare Aug 4, 2024
c301efa
Support for mistral-nemo. (#2396)
LaurentMazare Aug 4, 2024
fd0e933
add models support and example for THUDM/glm-4 (#2362)
donjuanplatinum Aug 5, 2024
f8e2b36
Add the MMDiT model of Stable Diffusion 3 (#2397)
Czxck001 Aug 5, 2024
0e78d29
Add the import script for the T5 tokenizer. (#2399)
LaurentMazare Aug 5, 2024
1b796b9
fix: usage of `actions/checkout@v2` (#2403)
hamirmahal Aug 6, 2024
c9cdd54
Fix issues in the encodec example README.md (#2407)
jnises Aug 10, 2024
283a5cf
Soft Non-Maximum Suppression (#2400)
onichmath Aug 10, 2024
de719a2
Add documentation examples for `Tensor::i` and `Tensor::narrow` metho…
csicar Aug 10, 2024
2e72a3d
Add Based LLM from Hazy Research. (#2411)
janimo Aug 12, 2024
d7a9bd0
Fix the device for the bert attention mask. (#2414)
LaurentMazare Aug 14, 2024
3d40ffc
Clippy fixes. (#2415)
LaurentMazare Aug 14, 2024
c5c5d49
Update flash_fwd_launch_template.h with fix for kernels (#16)
joshpopelka20 Aug 14, 2024
2386e4e
Build fixes
EricLBuehler Aug 14, 2024
a38053f
Merge branch 'sdpa'
EricLBuehler Aug 14, 2024
1b1974e
Add GGUF BF16 support (#17)
EricLBuehler Aug 21, 2024
d632eb5
Expose the softcap methods
EricLBuehler Aug 22, 2024
da095a6
Add some tests
EricLBuehler Aug 22, 2024
36bd9f9
Merge remote-tracking branch 'upstream/main'
EricLBuehler Aug 22, 2024
6fbddd6
Complete merge
EricLBuehler Aug 22, 2024
a3431d1
Merge branch 'main' into flash_attn_softcap
EricLBuehler Aug 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,8 @@
"candle-pyo3"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
"python.testing.pytestEnabled": true,
"rust-analyzer.cargo.features": [
"cuda", "flash-attn",
],
}
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
[![Documentation](https://docs.rs/candle-core/badge.svg)](https://docs.rs/candle-core)
![License](https://img.shields.io/crates/l/candle-core.svg)

**This is an optimized implmentation by Eric Buehler.**

Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support)
and ease of use. Try our online demos:
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
Expand Down
4 changes: 2 additions & 2 deletions candle-core/benches/benchmarks/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub(crate) mod random;
pub(crate) mod unary;
pub(crate) mod where_cond;

use candle_core::{Device, Result};
use candle_core::{cuda::WrapErr, Device, Result};

pub(crate) trait BenchDevice {
fn sync(&self) -> Result<()>;
Expand All @@ -20,7 +20,7 @@ impl BenchDevice for Device {
Device::Cpu => Ok(()),
Device::Cuda(device) => {
#[cfg(feature = "cuda")]
return Ok(device.synchronize()?);
return Ok(device.synchronize().w()?);
#[cfg(not(feature = "cuda"))]
panic!("Cuda device without cuda feature enabled: {:?}", device)
}
Expand Down
16 changes: 15 additions & 1 deletion candle-core/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,23 @@ pub trait BackendStorage: Sized {
_: usize,
) -> Result<Self>;

fn matmul(
#[allow(clippy::too_many_arguments)]
fn matmul_with_alpha_beta(
&self,
_: &Self,
_: &mut Self,
_: Option<f64>,
_: (usize, usize, usize, usize),
_: &Layout,
_: &Layout,
_: &Layout,
) -> Result<()>;

#[allow(clippy::too_many_arguments)]
fn matmul_with_alpha(
&self,
_: &Self,
_: Option<f64>,
_: (usize, usize, usize, usize),
_: &Layout,
_: &Layout,
Expand Down
5 changes: 5 additions & 0 deletions candle-core/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ impl Tensor {
f.write_u32::<LittleEndian>(v)?
}
}
DType::I32 => {
for v in vs.to_vec1::<i32>()? {
f.write_i32::<LittleEndian>(v)?
}
}
DType::I64 => {
for v in vs.to_vec1::<i64>()? {
f.write_i64::<LittleEndian>(v)?
Expand Down
83 changes: 81 additions & 2 deletions candle-core/src/cpu/avx.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use super::{Cpu, CpuF16};
use super::{Cpu, CpuBF16, CpuF16};
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;

use half::f16;
use half::{bf16, f16};

pub struct CurrentCpu {}

Expand Down Expand Up @@ -146,3 +146,82 @@ impl CpuF16<ARR> for CurrentCpuF16 {
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
}
}

pub struct CurrentCpuBF16 {}
impl CpuBF16<ARR> for CurrentCpuBF16 {
type Unit = __m256;
type Array = [__m256; ARR];

const STEP: usize = STEP;
const EPR: usize = EPR;

fn n() -> usize {
ARR
}

unsafe fn zero() -> Self::Unit {
_mm256_setzero_ps()
}

unsafe fn zero_array() -> Self::Array {
[Self::zero(); ARR]
}

unsafe fn from_f32(v: f32) -> Self::Unit {
_mm256_set1_ps(v)
}

#[cfg(target_feature = "f16c")]
unsafe fn load(mem_addr: *const bf16) -> Self::Unit {
_mm256_cvtph_ps(_mm_loadu_si128(mem_addr as *const __m128i))
}

#[cfg(not(target_feature = "f16c"))]
unsafe fn load(mem_addr: *const bf16) -> Self::Unit {
let mut tmp = [0.0f32; 8];
for i in 0..8 {
tmp[i] = (*mem_addr.add(i)).to_f32();
}
_mm256_loadu_ps(tmp.as_ptr())
}

unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
_mm256_add_ps(a, b)
}

unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
_mm256_add_ps(_mm256_mul_ps(b, c), a)
}

#[cfg(target_feature = "f16c")]
unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) {
_mm_storeu_si128(mem_addr as *mut __m128i, _mm256_cvtps_ph(a, 0))
}

#[cfg(not(target_feature = "f16c"))]
unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) {
let mut tmp = [0.0f32; 8];
_mm256_storeu_ps(tmp.as_mut_ptr(), a);
for i in 0..8 {
*mem_addr.add(i) = bf16::from_f32(tmp[i]);
}
}

unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
let mut offset = ARR >> 1;
for i in 0..offset {
x[i] = _mm256_add_ps(x[i], x[offset + i]);
}
offset >>= 1;
for i in 0..offset {
x[i] = _mm256_add_ps(x[i], x[offset + i]);
}
offset >>= 1;
for i in 0..offset {
x[i] = _mm256_add_ps(x[i], x[offset + i]);
}
let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1));
let t1 = _mm_hadd_ps(t0, t0);
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
}
}
18 changes: 18 additions & 0 deletions candle-core/src/cpu/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@ impl VecOps for half::bf16 {
fn max(self, other: Self) -> Self {
Self::max(self, other)
}

#[inline(always)]
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
let mut res_f32 = 0f32;
super::vec_dot_bf16(lhs, rhs, &mut res_f32, len);
*res = half::bf16::from_f32(res_f32);
}
}
impl VecOps for u8 {
#[inline(always)]
Expand All @@ -144,6 +151,17 @@ impl VecOps for u32 {
<Self as Ord>::max(self, other)
}
}
impl VecOps for i32 {
#[inline(always)]
fn min(self, other: Self) -> Self {
<Self as Ord>::min(self, other)
}

#[inline(always)]
fn max(self, other: Self) -> Self {
<Self as Ord>::max(self, other)
}
}
impl VecOps for i64 {
#[inline(always)]
fn min(self, other: Self) -> Self {
Expand Down
62 changes: 60 additions & 2 deletions candle-core/src/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,33 @@ trait CpuF16<const ARR: usize> {
unsafe fn from_f32(v: f32) -> Self::Unit;
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit);
}
use half::f16;

#[allow(unused)]
trait CpuBF16<const ARR: usize> {
type Unit;
type Array;
const STEP: usize;
const EPR: usize;

fn n() -> usize;
unsafe fn zero() -> Self::Unit;
unsafe fn zero_array() -> Self::Array;
unsafe fn load(mem_addr: *const bf16) -> Self::Unit;
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit;
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit;
unsafe fn vec_reduce(x: Self::Array, y: *mut f32);
unsafe fn from_f32(v: f32) -> Self::Unit;
unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit);
}

use half::{bf16, f16};

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(target_feature = "avx")]
pub mod avx;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(target_feature = "avx")]
pub use avx::{CurrentCpu, CurrentCpuF16};
pub use avx::{CurrentCpu, CurrentCpuBF16, CurrentCpuF16};

#[cfg(target_arch = "wasm32")]
#[cfg(target_feature = "simd128")]
Expand Down Expand Up @@ -170,6 +189,34 @@ pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f
*c = sumf;
}

#[cfg(target_feature = "avx")]
#[inline(always)]
pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) {
let mut sumf = 0.0f32;
let np = k & !(CurrentCpuBF16::STEP - 1);

let mut sum = CurrentCpuBF16::zero_array();
let mut ax = CurrentCpuBF16::zero_array();
let mut ay = CurrentCpuBF16::zero_array();

for i in (0..np).step_by(CurrentCpuBF16::STEP) {
for j in 0..CurrentCpuBF16::n() {
ax[j] = CurrentCpuBF16::load(a_row.add(i + j * CurrentCpuBF16::EPR));
ay[j] = CurrentCpuBF16::load(b_row.add(i + j * CurrentCpuBF16::EPR));

sum[j] = CurrentCpuBF16::vec_fma(sum[j], ax[j], ay[j]);
}
}

CurrentCpuBF16::vec_reduce(sum, &mut sumf);

// leftovers
for i in np..k {
sumf += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
}
*c = sumf;
}

#[cfg(not(target_feature = "avx"))]
#[inline(always)]
pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) {
Expand All @@ -180,3 +227,14 @@ pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f
}
*c = sum;
}

#[cfg(not(target_feature = "avx"))]
#[inline(always)]
pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) {
// leftovers
let mut sum = 0.0;
for i in 0..k {
sum += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
}
*c = sum;
}
Loading
Loading