Skip to content

Commit

Permalink
Bump to jax 0.4.21. (#237)
Browse files Browse the repository at this point in the history
  • Loading branch information
markblee authored Dec 10, 2023
1 parent 67eaa40 commit 68f1200
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 37 deletions.
33 changes: 1 addition & 32 deletions axlearn/common/utils_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,10 @@
"""SPMD related utils."""

import logging
import socket
from typing import Optional

import jax
import jax.numpy as jnp
import portpicker
from jax.experimental import multihost_utils

_jax_distributed_initialized = False

Expand Down Expand Up @@ -56,11 +53,7 @@ def setup(
"distributed_coordinator, num_processes, and process_id "
"should all be None for tpu backend."
)
jax.distributed.initialize(
coordinator_address=_infer_tpu_coordinator_address(),
num_processes=jax.process_count(),
process_id=jax.process_index(),
)
jax.distributed.initialize()
else:
if distributed_coordinator is None and num_processes is None and process_id is None:
logging.info(
Expand Down Expand Up @@ -89,27 +82,3 @@ def setup(
process_id=process_id,
)
_jax_distributed_initialized = True


def _infer_tpu_coordinator_address() -> str:
"""Infers a viable JAX coordination address on TPU (including over multiple TPU slices).
TODO(markblee,tom_gunter): Delete this when multi-slice init is fully supported by JAX.
Returns:
A coordinator address string as "ip:port".
"""
slice_local_coordinator_ip = socket.gethostbyname(socket.gethostname())
# E.g. "172.31.4.83".
slice_local_coordinator_ip_as_nums = [int(num) for num in slice_local_coordinator_ip.split(".")]
# E.g. [172, 31, 4, 83].
global_coordinator_ip_as_nums = multihost_utils.broadcast_one_to_all(
jnp.asarray(slice_local_coordinator_ip_as_nums)
)
global_coordinator_ip = ".".join([str(num) for num in global_coordinator_ip_as_nums])
# E.g. "172.31.4.83" on all hosts on all slices.
global_coordinator_port = multihost_utils.broadcast_one_to_all(
jnp.asarray(portpicker.pick_unused_port())
)
global_coordinator_address = f"{global_coordinator_ip}:{global_coordinator_port}"
return global_coordinator_address
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ dependencies = [
"chex<0.1.81", # chex 0.1.81 depends on numpy>=1.25.0.
"flax==0.7.4", # only for checkpoints.
"importlab==0.7", # breaks pytype on 0.8
"jax>=0.4.18,<=0.4.20", # jax 0.4.20 runs into issues on GPU.
"jaxlib>=0.4.18,<=0.4.20",
"jax==0.4.21",
"jaxlib==0.4.21",
"nltk==3.7", # for text preprocessing
"numpy<1.24", # needed to pin to < 1.24; tf ragged_tensor depends on deprecated np.object.
"optax==0.1.7", # optimizers (0.1.0 has known bugs).
Expand All @@ -43,8 +43,8 @@ apple-silicon = [
"absl-py",
"chex>=0.1.7",
"flax==0.7.4", # only for checkpoints.
"jax>=0.4.18,<=0.4.20",
"jaxlib>=0.4.18,<=0.4.20",
"jax==0.4.21",
"jaxlib==0.4.21",
"nltk==3.7", # for text preprocessing
"optax>=0.1.7", # optimizers (0.1.0 has known bugs).
"portpicker",
Expand Down Expand Up @@ -94,7 +94,7 @@ gcp = [
# Note: Specify -f https://storage.googleapis.com/jax-releases/libtpu_releases.html during install.
tpu = [
"axlearn[gcp]",
"jax[tpu]==0.4.20", # must be >=0.4.19 for compat with v5p.
"jax[tpu]==0.4.21", # must be >=0.4.19 for compat with v5p.
]
# Vertex AI tensorboard.
vertexai_tensorboard = [
Expand Down

0 comments on commit 68f1200

Please sign in to comment.