Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some fast Metal MLX SDPA kernels #2584

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Commits on Oct 29, 2024

  1. Add some fast Metal MLX SDPA kernels (#32)

    * Sketch the sdpa kernel
    
    * Add full sdpa kernel,
    
    * Add test
    
    * Add vectorized kernel for decoding
    
    * Update tests
    
    * Add some docs
    
    * Fix sdpa_vector names
    
    * Add softcapping for vectorized sdpa
    
    * Add softcapping for full sdpa
    
    * Add support for head dim 32, 96, 256
    
    * Add support for head dim 32, 96, 256
    
    * Update docs
    
    * Add update notice
    
    * Clippy and format
    EricLBuehler committed Oct 29, 2024
    Configuration menu
    Copy the full SHA
    eff930e View commit details
    Browse the repository at this point in the history
  2. Configuration menu
    Copy the full SHA
    49c7255 View commit details
    Browse the repository at this point in the history