Skip to content

Commit

Permalink
Update manifest and pin newer numpy in containers (#444)
Browse files Browse the repository at this point in the history
As of #405 the presubmit CI is based on package versions listed in the
`manifest.yaml` that is committed to the repository. This has not been
updated for ~1 month, so the presubmit CI is testing ~1 month old
versions of the ecosystem. This PR updates it using the commit from
https://github.com/NVIDIA/JAX-Toolbox/tree/znightly-2024-01-03-7395605285,
generated by the nightly CI run.

Because this bumps the JAX version by ~1 month, we have to include fixes
for deprecations. In particular replacing `jax.random.KeyArray` with
plain `jax.Array`
(nvjax-svc-0/t5x@4d5ec2f).

The deprecated name is used in older versions of the `chex` package,
which are being selected by pip's dependency resolver despite newer
versions being available. We avoid this by giving pip a helping hand and
nudging it to use a newer `numpy` version, which allows it to select a
newer `chex`. But it's easy to imagine similar issues in future with
other packages.

Closes #448.
  • Loading branch information
olupton authored Jan 4, 2024
1 parent 19fb059 commit 2c2d7f9
Show file tree
Hide file tree
Showing 16 changed files with 161 additions and 104 deletions.
4 changes: 4 additions & 0 deletions .github/container/Dockerfile.jax
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ ADD build-jax.sh local_cuda_arch test-jax.sh /usr/local/bin/

RUN mkdir -p /opt/pip-tools.d
RUN <<"EOF" bash -ex
# Encourage a newer numpy so that pip's dependency resolver will allow newer
# versions of other packages that rely on newer numpy, but also include fixes
# for compatibility with newer JAX versions. e.g. chex.
echo "numpy >= 1.24.1" >> /opt/pip-tools.d/requirements-jax.in
echo "-e file://${SRC_PATH_JAX}" >> /opt/pip-tools.d/requirements-jax.in
echo "jaxlib @ file://$(ls ${SRC_PATH_JAX}/dist/*.whl)" >> /opt/pip-tools.d/requirements-jax.in
EOF
Expand Down
29 changes: 15 additions & 14 deletions .github/container/manifest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,49 +2,50 @@
jax:
url: https://github.com/google/jax.git
tracking_ref: main
latest_verified_commit: 595117b70c11055e569480b80907d8c8a9901805
latest_verified_commit: afa2f1e420de3d2cfd684cff080a3808ee66daf5
mode: git-clone
xla:
url: https://github.com/openxla/xla.git
tracking_ref: main
latest_verified_commit: 78a5297d8e4301cb3ba2514061f56f89104e3d88
latest_verified_commit: 64a7946ffd048daf65ef330fc4ca5e4c3c1482a0
mode: git-clone
flax:
url: https://github.com/google/flax.git
mirror_url: https://github.com/nvjax-svc-0/flax.git
tracking_ref: main
latest_verified_commit: 230b0d77e98da22b6e574c3cbff743ca1504bfca
latest_verified_commit: 85dfad242e56098849dbf05e7e4657b3a40820f9
mode: git-clone
patches:
pull/3340/head: file://patches/flax/PR-3340.patch # Add Sharding Annotations to Flax Modules
transformer-engine:
url: https://github.com/NVIDIA/TransformerEngine.git
tracking_ref: main
latest_verified_commit: 92c1e500dd14608e54f75df8276baa1104c61d48
latest_verified_commit: d155eaac8e08d42e67d7efd812ee2a69954de816
mode: git-clone
t5x:
url: https://github.com/google-research/t5x.git
mirror_url: https://github.com/nvjax-svc-0/t5x.git
tracking_ref: main
latest_verified_commit: 1bfd2f15e5e77b09d60301367f67fdc9bb756b46
latest_verified_commit: dbc4b6f426862d5a742a2104a17524f53dd442f0
mode: git-clone
patches:
mirror/patch/partial-checkpoint-restore: file://patches/t5x/mirror-patch-partial-checkpoint-restore.patch # pull/1392/head # https://github.com/google-research/t5x/pull/1392: Add support for partial checkpoint restore
mirror/patch/dali-support: file://patches/t5x/mirror-patch-dali-support.patch # pull/1393/head # https://github.com/google-research/t5x/pull/1393: Adds DALI support to t5x
mirror/patch/t5x_te_in_contrib_noindent: file://patches/t5x/mirror-patch-t5x_te_in_contrib_noindent.patch # pull/1391/head # https://github.com/google-research/t5x/pull/1391: Adds transformer engine support and GPU optimizations to T5x (enables H100)
# mirror/patch/t5x_te_in_contrib_noindent: file://patches/t5x/mirror-patch-t5x_te_in_contrib_noindent.patch # pull/1391/head # https://github.com/google-research/t5x/pull/1391: Adds transformer engine support and GPU optimizations to T5x (enables H100)
mirror/ashors/fix_rng_dtype: file://patches/t5x/mirror-ashors-fix_rng_dtype.patch # fix on top of (and incorporating) https://github.com/google-research/t5x/pull/1391
paxml:
url: https://github.com/google/paxml.git
mirror_url: https://github.com/nvjax-svc-0/paxml.git
tracking_ref: main
latest_verified_commit: 7ae682d4d99630008e190b96c5296990297175c2
latest_verified_commit: 60e9e29bd3c6cc53bb4462f8c03bd5408daacd7b
mode: git-clone
patches:
pull/46/head: file://patches/paxml/PR-46.patch # adds Transformer Engine support
praxis:
url: https://github.com/google/praxis.git
mirror_url: https://github.com/nvjax-svc-0/praxis.git
tracking_ref: main
latest_verified_commit: b6f32fa0fc6721db1cec75972b0f569c82095956
latest_verified_commit: 5b70196ffba154e78a5f78ce9175854b18cf936d
mode: git-clone
patches:
pull/27/head: file://patches/praxis/PR-27.patch # This PR allows XLA:GPU to detect the MHA pattern more easily to call fused kernels from cublas.
Expand All @@ -53,7 +54,7 @@ lingvo:
# Used only in ARM pax builds
url: https://github.com/tensorflow/lingvo.git
tracking_ref: master
latest_verified_commit: 0274fa20b4ff194c1c118b94b5f778caa5d9a84a
latest_verified_commit: ab71210c31706b190ebdd3bd3573ed833e693587
mode: git-clone
tensorflow-text:
# Used only in ARM pax builds
Expand All @@ -68,18 +69,18 @@ pydantic:
fiddle:
url: https://github.com/google/fiddle.git
tracking_ref: main
latest_verified_commit: d409cf95164599a88e49d2b6a23a0972a7170b0b
latest_verified_commit: a9e98709d4b109bf04ed61cee5ff366c50c82463
mode: pip-vcs
# Used by t5x
airio:
url: https://github.com/google/airio.git
tracking_ref: main
latest_verified_commit: 69b3ec4ded478ad9cacdc97652a9d086a6a644c4
latest_verified_commit: 0e31f368b12d298e133b3a774e27d9bb0e85d087
mode: pip-vcs
clu:
url: https://github.com/google/CommonLoopUtils.git
tracking_ref: main
latest_verified_commit: 7ba2a9d83a3bc1a97b59482c2f02dc4b3614bc31
latest_verified_commit: f30bc441a14f0ccf8eaff79800f486a846613a8c
mode: pip-vcs
dllogger:
url: https://github.com/NVIDIA/dllogger.git
Expand All @@ -94,10 +95,10 @@ jestimator:
optax:
url: https://github.com/deepmind/optax.git
tracking_ref: master
latest_verified_commit: bf987e15eacf6efeb1a1a51b8868c094c3a15f9b
latest_verified_commit: bc22961422eb2397a4639ec945da0bea73d624d6
mode: pip-vcs
seqio:
url: https://github.com/google/seqio.git
tracking_ref: main
latest_verified_commit: 515d917bf58da4103a2bbf39c3716213c36aff03
latest_verified_commit: b582c96cb83f1472925c2b50b90059ad1da8c138
mode: pip-vcs
4 changes: 2 additions & 2 deletions .github/container/patches/flax/PR-3340.patch
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ index 076fd680..6eff2dd1 100644


--
2.25.1
2.43.0


From d1f3ec337b85b5c5377aab72d814adfc89dd4af5 Mon Sep 17 00:00:00 2001
Expand Down Expand Up @@ -436,5 +436,5 @@ index 999acf2c..8e031c77 100644
else:
bias = None
--
2.25.1
2.43.0

18 changes: 9 additions & 9 deletions .github/container/patches/paxml/PR-46.patch
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ index 587181d..e7fe54a 100644

train_state_partition_specs = (
--
2.25.1
2.43.0


From 9d6b6db6039d7e6658dd179e5838379c7dc967e3 Mon Sep 17 00:00:00 2001
Expand Down Expand Up @@ -717,7 +717,7 @@ index d44ca67..2b9dba4 100644
assert self.packed_input == False
assert len(self.moe_layers) == 0
--
2.25.1
2.43.0


From 1612dc7a1f77f0a515eb4801087a8b4f0756e5b9 Mon Sep 17 00:00:00 2001
Expand All @@ -744,7 +744,7 @@ index 2b9dba4..ef20305 100644
return x_out

--
2.25.1
2.43.0


From 71507dc4b1396252e6fa746d1299854c204f0c51 Mon Sep 17 00:00:00 2001
Expand All @@ -771,7 +771,7 @@ index e7fe54a..4093c3b 100644
vars_with_opt = tasks_lib.filter_vars_for_grad_or_opt(
mdl_vars, excluded_for_learner
--
2.25.1
2.43.0


From 2a8233302c7e42b7dc7628c41abb637518d15c29 Mon Sep 17 00:00:00 2001
Expand Down Expand Up @@ -808,7 +808,7 @@ index ef20305..fed1601 100644
finally:
pass
--
2.25.1
2.43.0


From 2a6e5a960f438653b4c9cbeb0c016225af853279 Mon Sep 17 00:00:00 2001
Expand Down Expand Up @@ -975,7 +975,7 @@ index fed1601..5914e54 100644
def update_fp8_metas_if_needed(mdl_vars, grads):
return TransformerEngineHelper.get_helper().update_fp8_metas_if_needed(mdl_vars, grads)
--
2.25.1
2.43.0


From b57188225e7890dfc54d70db7d89fcb32e61e762 Mon Sep 17 00:00:00 2001
Expand Down Expand Up @@ -1136,7 +1136,7 @@ index 4093c3b..2e8fc35 100644
grads, states.opt_states[0], vars_with_opt, wps_with_opt
)
--
2.25.1
2.43.0


From c43766ee2e8cda686176a3895e87150b10d5de5e Mon Sep 17 00:00:00 2001
Expand Down Expand Up @@ -1528,7 +1528,7 @@ index fd482df..b271258 100644
@contextmanager
def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"):
--
2.25.1
2.43.0


From abc0fabc3e2ffb42d1f62254ad42448a39cbd128 Mon Sep 17 00:00:00 2001
Expand Down Expand Up @@ -1563,5 +1563,5 @@ index b271258..cbac7cf 100644

class TransformerEngineHelperBase:
--
2.25.1
2.43.0

2 changes: 1 addition & 1 deletion .github/container/patches/praxis/PR-27.patch
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ index a35ce8b..52886bc 100644
self.add_summary('attention_mask', atten_mask)
if self.attention_extra_logit is None:
--
2.25.1
2.43.0

4 changes: 2 additions & 2 deletions .github/container/patches/praxis/PR-36.patch
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ index ab6cff3..c79dac9 100644
# Annotate the inputs before the pipeline to prevent unexpected
# propagation from earlier layers.
--
2.25.1
2.43.0


From ff1745796009cf1ec59f463f8e776c66f1286938 Mon Sep 17 00:00:00 2001
Expand Down Expand Up @@ -358,5 +358,5 @@ index e3b2f7c..b31526e 100644
trans_in_fn=_get_to_f32_converter(bf16_vars_to_convert),
trans_out_fn=_get_to_bf16_converter(bf16_vars_to_convert),
--
2.25.1
2.43.0

Loading

0 comments on commit 2c2d7f9

Please sign in to comment.