You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
A toy example of MNIST using XLA2 does not work on the latest version of jax (0.4.34) on Trillium machine of 64 cores (V6e-64) but downgrading to 0.4.33 fixes the issue
🐛 Bug
A toy example of MNIST using XLA2 does not work on the latest version of jax (0.4.34) on Trillium machine of 64 cores (V6e-64) but downgrading to 0.4.33 fixes the issue
To Reproduce
Download the toy training example from here
Allocate a V6e-64 trillium TPU at GCP
copy that file using gcp scp to all the VM machines
prepare an environment containing torch_xla2 (refer to the readme here)
install 0.4.43 jax/lib from pip
run your training, verify it is working well
upgrade to jax 0.4.44
Expected behavior
small varying results between the scripts when running on different version of jax
Environment
The text was updated successfully, but these errors were encountered: