-
Notifications
You must be signed in to change notification settings - Fork 943
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add some fast Metal MLX SDPA kernels #2584
base: main
Are you sure you want to change the base?
Conversation
* Sketch the sdpa kernel * Add full sdpa kernel, * Add test * Add vectorized kernel for decoding * Update tests * Add some docs * Fix sdpa_vector names * Add softcapping for vectorized sdpa * Add softcapping for full sdpa * Add support for head dim 32, 96, 256 * Add support for head dim 32, 96, 256 * Update docs * Add update notice * Clippy and format
@@ -7,5 +7,6 @@ | |||
"candle-pyo3" | |||
], | |||
"python.testing.unittestEnabled": false, | |||
"python.testing.pytestEnabled": true | |||
"python.testing.pytestEnabled": true, | |||
"rust-analyzer.cargo.features": ["metal"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't that break the editor for non metal users?
// q = (bs, qhead, seq, hidden) | ||
// k/v = (bs, kv_head, seq, hidden) | ||
|
||
let _hidden = q_shape[q_shape.len() - 1]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If variables such as _hidden
are not planned to be used, maybe remove them?
encoder.set_buffer(0, Some(&q_buffer), q_offset as NSUInteger); | ||
encoder.set_buffer(1, Some(&k_buffer), k_offset as NSUInteger); | ||
encoder.set_buffer(2, Some(&v_buffer), v_offset as NSUInteger); | ||
encoder.set_buffer(3, Some(&output), 0); | ||
|
||
encoder.set_bytes( | ||
4, | ||
std::mem::size_of::<MLXFastAttentionParams>() as u64, | ||
¶ms as *const MLXFastAttentionParams as *const c_void, | ||
); | ||
encoder.set_bytes( | ||
6, | ||
(std::mem::size_of::<i32>() * batch_shape.len()) as u64, | ||
batch_shape.as_ptr() as *const i32 as *const c_void, | ||
); | ||
encoder.set_bytes( | ||
7, | ||
(std::mem::size_of::<usize>() * batch_strides.len()) as u64, | ||
batch_strides.as_ptr() as *const c_void, | ||
); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couldn't the EncoderParam
trait or better the set_params!
macro be used to simplify this?
encoder.set_buffer(0, Some(&q_buffer), q_offset as NSUInteger); | ||
encoder.set_buffer(1, Some(&k_buffer), k_offset as NSUInteger); | ||
encoder.set_buffer(2, Some(&v_buffer), v_offset as NSUInteger); | ||
encoder.set_buffer(3, Some(&output), 0); | ||
|
||
encoder.set_bytes( | ||
4, | ||
std::mem::size_of::<i32>() as u64, | ||
&gqa_factor as *const i32 as *const c_void, | ||
); | ||
encoder.set_bytes( | ||
5, | ||
std::mem::size_of::<i32>() as u64, | ||
&n as *const i32 as *const c_void, | ||
); | ||
encoder.set_bytes( | ||
6, | ||
std::mem::size_of::<usize>() as u64, | ||
&stride as *const usize as *const c_void, | ||
); | ||
encoder.set_bytes( | ||
7, | ||
std::mem::size_of::<f32>() as u64, | ||
&alpha as *const f32 as *const c_void, | ||
); | ||
encoder.set_bytes( | ||
8, | ||
std::mem::size_of::<f32>() as u64, | ||
&softcapping as *const f32 as *const c_void, | ||
); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above, set_params!
should be an easy win here.
candle::bail!("query `n_heads` must be a multiple of `n_kv_heads`"); | ||
} | ||
|
||
let k_head = k_l.dims()[k_l.dims().len() - 1]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
k_l.dim(D::Minus1)?
would be simpler and make for a better error message than a panic (the same applies to a bunch of places in this function)
|
||
impl candle::CustomOp3 for Sdpa { | ||
fn name(&self) -> &'static str { | ||
"sdpa" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's call this metal-sdpa instead.
This PR adds some MLX SDPA kernels on Metal.
I can observe about a 26% performance improvement with Llama 3.1 8b @ q4k and @ q8_0 when testing through mistral.rs on my Candle fork. I updated the quantized_llama.rs file here to use the new function.
This PR adds a function
candle_nn::ops::sdpa
. The MLX attention kernels don't support masking yet, so the performance gains are only for decoding on Metal. Once/if they do, I'll update them - otherwise we can explore using Flash Attention kernels for Metal from llama.cpp.