Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

T5X rosetta nightlies are broken #448

Closed
olupton opened this issue Jan 4, 2024 · 0 comments · Fixed by #444
Closed

T5X rosetta nightlies are broken #448

olupton opened this issue Jan 4, 2024 · 0 comments · Fixed by #444

Comments

@olupton
Copy link
Collaborator

olupton commented Jan 4, 2024

The patch in google-research/t5x#1391 (

mirror/patch/t5x_te_in_contrib_noindent: file://patches/t5x/mirror-patch-t5x_te_in_contrib_noindent.patch # pull/1391/head # https://github.com/google-research/t5x/pull/1391: Adds transformer engine support and GPU optimizations to T5x (enables H100)
) uses jax.random.KeyArray (https://github.com/google-research/t5x/blob/74d742f053dabbe594637ad1f481237a23065512/t5x/models.py#L643), which has been removed from JAX.

The error is visible in the nightly runs, e.g. https://github.com/NVIDIA/JAX-Toolbox/actions/runs/7395920816.

olupton added a commit that referenced this issue Jan 4, 2024
As of #405 the presubmit CI is based on package versions listed in the
`manifest.yaml` that is committed to the repository. This has not been
updated for ~1 month, so the presubmit CI is testing ~1 month old
versions of the ecosystem. This PR updates it using the commit from
https://github.com/NVIDIA/JAX-Toolbox/tree/znightly-2024-01-03-7395605285,
generated by the nightly CI run.

Because this bumps the JAX version by ~1 month, we have to include fixes
for deprecations. In particular replacing `jax.random.KeyArray` with
plain `jax.Array`
(nvjax-svc-0/t5x@4d5ec2f).

The deprecated name is used in older versions of the `chex` package,
which are being selected by pip's dependency resolver despite newer
versions being available. We avoid this by giving pip a helping hand and
nudging it to use a newer `numpy` version, which allows it to select a
newer `chex`. But it's easy to imagine similar issues in future with
other packages.

Closes #448.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant