Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch Implementation of the Sampler on the frog branch + torch eval script + torch implementation bugfix + tests #72

Open
wants to merge 11 commits into
base: frog
Choose a base branch
from

Conversation

nreHieW
Copy link
Contributor

@nreHieW nreHieW commented Oct 12, 2024

This PR does the following:

  • Implement the sampler on the frog Branch
  • Implement a working evaluation script to be used with the Eleuther Eval Harness. Verified to work with eq_bench
  • Update the torch implementation to match the Jax implementation. This should fix Torch doesn't work on mac. #50 as well.
    - (For documentation purposes): Specifically, the existing implementation has issues with how the q k v dtypes are handled after RoPE. In Jax, the first time the kvcache is populated (when cur_pos = 0), the keys and values are in float32. For cur_pos != 0, the cache is in bf16 and jax automatically converts to fp32 to perform xq @ k in fp32. For torch, even though post-RoPE keys are in fp32, the cache buffers are in bf16 and the update method will return bf16. So we will need to explicitly cast to fp32 to match the jax implementation.
  • Add tests to check that the torch implementation matches jax
    - Because of bf16, jax, jit and torch things, tests are done in fp32 with jit except for attention which compares the torch version with the non jit jax version.
    - Note: The test_each_layer test might fail around 3% of the time due to 'unluckily' initialised inputs. Even so, this fails with < 0.5% mismatched elements (the number of mismatched elements is < 5).
  • Other QOL changes to match main.py

Let me know if a PR to main is preferred and I'll update!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant