Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XLA2 does not work with jax 0.4.34 (but did work on jax 0.4.33) #8240

Open
Chaosruler972 opened this issue Oct 9, 2024 · 2 comments
Open
Assignees

Comments

@Chaosruler972
Copy link

🐛 Bug

A toy example of MNIST using XLA2 does not work on the latest version of jax (0.4.34) on Trillium machine of 64 cores (V6e-64) but downgrading to 0.4.33 fixes the issue

To Reproduce

  1. Download the toy training example from here

  2. Allocate a V6e-64 trillium TPU at GCP

  3. copy that file using gcp scp to all the VM machines

  4. prepare an environment containing torch_xla2 (refer to the readme here)

  5. install 0.4.43 jax/lib from pip

install jax==0.4.33 jaxlib==0.4.33 libtpu-nightly==0.1.dev20241008+nightly -f https://storage.googleapis.com/libtpu-releases/index.html
  1. run your training, verify it is working well

  2. upgrade to jax 0.4.44

install jax==0.4.33 jaxlib==0.4.33 libtpu-nightly==0.1.dev20241008+nightly -f https://storage.googleapis.com/libtpu-releases/index.html
  1. run your training again, note how the training loop exits without warning/messages after the loss was extracted

Expected behavior

small varying results between the scripts when running on different version of jax

Environment

  • Reproducible on XLA backend TPU
  • Using Trillum 64 machine
  • torch_xla2 version: 0.0.1
@qihqi
Copy link
Collaborator

qihqi commented Oct 9, 2024

pip install jax==0.4.34 jaxlib==0.4.34 libtpu-nightly==0.1.dev20241008+nightly -f https://storage.googleapis.com/libtpu-releases/index.html

@qihqi qihqi self-assigned this Oct 9, 2024
@Chaosruler972
Copy link
Author

I made a mistake on the issue page, I installed using

pip install jax==0.4.34 jaxlib==0.4.34 libtpu-nightly==0.1.dev20241008+nightly -f https://storage.googleapis.com/libtpu-
releases/index.html

on the second experiement, which lead to the issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants