Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(closes #35) Changes for jax-metal #52

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

(closes #35) Changes for jax-metal #52

wants to merge 4 commits into from

Conversation

nix
Copy link

@nix nix commented Oct 8, 2024

Three changes to get jax working with metal:

  • Prefer METAL device if available.
  • Compute RoPE with float instead of complex.
  • Reimplement top_k to avoid broken jax.lax.top_k on jax-metal.

The new implementation of top_k is enabled with a flag. It should
check the current device and use jax.lax.top_k if not running on
jax-metal.
This is not a general-purpose implementation of top_k but hopefully
sufficient for sampling.
@artus-LYTiQ
Copy link

Added conditionals based on jax.extended.backend.get_backend(). Added jax-metal to toml. Verified that it is running on both Metal and CPU. In PR #57

@Arrabonae
Copy link
Contributor

i think the 'frog' branch needs to be merged into main to see if this implementing works with the new structure or not.

@artus-LYTiQ
Copy link

I still don't know how to push updates to this branch so I am continuing with the "frog" integration in my version (pr-57).

@nix nix changed the title (Closes #35) Changes for jax-metal (closes #35) Changes for jax-metal Oct 10, 2024
@nix
Copy link
Author

nix commented Oct 10, 2024

With all the jit annotations commented out, this jax-metal version is about 10x slower than the native MLX version: https://github.com/samefarrar/entropix_mlx
On an M2 Air I get about 1.5 tps with this jax-metal implementation (no jit), 4.5 tps with pytorch metal, and 15tps with entropix_mlx.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants