Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some fast Metal MLX SDPA kernels #2584

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

EricLBuehler
Copy link
Member

@EricLBuehler EricLBuehler commented Oct 29, 2024

This PR adds some MLX SDPA kernels on Metal.

I can observe about a 26% performance improvement with Llama 3.1 8b @ q4k and @ q8_0 when testing through mistral.rs on my Candle fork. I updated the quantized_llama.rs file here to use the new function.

This PR adds a function candle_nn::ops::sdpa. The MLX attention kernels don't support masking yet, so the performance gains are only for decoding on Metal. Once/if they do, I'll update them - otherwise we can explore using Flash Attention kernels for Metal from llama.cpp.

EricLBuehler and others added 2 commits October 29, 2024 06:34
* Sketch the sdpa kernel

* Add full sdpa kernel,

* Add test

* Add vectorized kernel for decoding

* Update tests

* Add some docs

* Fix sdpa_vector names

* Add softcapping for vectorized sdpa

* Add softcapping for full sdpa

* Add support for head dim 32, 96, 256

* Add support for head dim 32, 96, 256

* Update docs

* Add update notice

* Clippy and format
@@ -7,5 +7,6 @@
"candle-pyo3"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
"python.testing.pytestEnabled": true,
"rust-analyzer.cargo.features": ["metal"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't that break the editor for non metal users?

// q = (bs, qhead, seq, hidden)
// k/v = (bs, kv_head, seq, hidden)

let _hidden = q_shape[q_shape.len() - 1];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If variables such as _hidden are not planned to be used, maybe remove them?

Comment on lines +1812 to +1831
encoder.set_buffer(0, Some(&q_buffer), q_offset as NSUInteger);
encoder.set_buffer(1, Some(&k_buffer), k_offset as NSUInteger);
encoder.set_buffer(2, Some(&v_buffer), v_offset as NSUInteger);
encoder.set_buffer(3, Some(&output), 0);

encoder.set_bytes(
4,
std::mem::size_of::<MLXFastAttentionParams>() as u64,
&params as *const MLXFastAttentionParams as *const c_void,
);
encoder.set_bytes(
6,
(std::mem::size_of::<i32>() * batch_shape.len()) as u64,
batch_shape.as_ptr() as *const i32 as *const c_void,
);
encoder.set_bytes(
7,
(std::mem::size_of::<usize>() * batch_strides.len()) as u64,
batch_strides.as_ptr() as *const c_void,
);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't the EncoderParam trait or better the set_params! macro be used to simplify this?

Comment on lines +1920 to +1949
encoder.set_buffer(0, Some(&q_buffer), q_offset as NSUInteger);
encoder.set_buffer(1, Some(&k_buffer), k_offset as NSUInteger);
encoder.set_buffer(2, Some(&v_buffer), v_offset as NSUInteger);
encoder.set_buffer(3, Some(&output), 0);

encoder.set_bytes(
4,
std::mem::size_of::<i32>() as u64,
&gqa_factor as *const i32 as *const c_void,
);
encoder.set_bytes(
5,
std::mem::size_of::<i32>() as u64,
&n as *const i32 as *const c_void,
);
encoder.set_bytes(
6,
std::mem::size_of::<usize>() as u64,
&stride as *const usize as *const c_void,
);
encoder.set_bytes(
7,
std::mem::size_of::<f32>() as u64,
&alpha as *const f32 as *const c_void,
);
encoder.set_bytes(
8,
std::mem::size_of::<f32>() as u64,
&softcapping as *const f32 as *const c_void,
);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above, set_params! should be an easy win here.

candle::bail!("query `n_heads` must be a multiple of `n_kv_heads`");
}

let k_head = k_l.dims()[k_l.dims().len() - 1];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

k_l.dim(D::Minus1)? would be simpler and make for a better error message than a panic (the same applies to a bunch of places in this function)


impl candle::CustomOp3 for Sdpa {
fn name(&self) -> &'static str {
"sdpa"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's call this metal-sdpa instead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants