Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Closes #35, updates #52) Changes for Jax-Metal #57

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

artus-LYTiQ
Copy link

Based on #52

Added conditionals based on jax.extended.backend.get_backend(). Added jax-metal to toml. Verified that it is running on both Metal and CPU.

nix and others added 6 commits October 8, 2024 16:13
The new implementation of top_k is enabled with a flag. It should
check the current device and use jax.lax.top_k if not running on
jax-metal.
This is not a general-purpose implementation of top_k but hopefully
sufficient for sampling.
…ers code back for gpu/cpu; added jax-metal library to poetry"
@artus-LYTiQ
Copy link
Author

artus-LYTiQ commented Oct 9, 2024

BUG: didn't see that #52 removes the apply_scaling that is normally present in precompute_freqs_cis when use_scaled is True. Will need to update. @nix can you please investigate? I am stuck in writing the initial MLX version (that's how I found this issue).

@nix
Copy link

nix commented Oct 9, 2024

#52 does not remove the scaling, you did that here.
I didn't put the conditionals in #52 because the original code seems to be aiming for concision, but will wait to hear what @xjdr-alt thinks.

@artus-LYTiQ
Copy link
Author

Added back the scaled_rope that was accidentally deleted by me before. Tested with jax-metal and with jax-cpu.

@nix sure, we have to wait for them. My hunch is that he will want to run the original code on "his" devices (read: TPU and jax-gpu) and we should try to come as close to that as possible but not force him to run watered down code. But now @Arrabonae has the choice.

I will try to integrate the Frog branch next.

@artus-LYTiQ
Copy link
Author

Integrated the Frog branch and ran main.py on jax-metal as well as jax-cpu. Our lass is happy. Can someone validate on his rig, preferably on a jax-gpu, too?

Note that I did not test the eval_main.py yet due to the README update by xjdr.

@artus-LYTiQ
Copy link
Author

Frog entropix/main.py successfully tested under

  • Jax-metal on M2
  • Jax cpu on M2
  • Jax gpu on Ubuntu 22.04 4090 cu124

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants