Skip to content

Commit

Permalink
Changes in preparation for jax 0.4.21. (#235)
Browse files Browse the repository at this point in the history
  • Loading branch information
markblee authored Dec 8, 2023
1 parent c259655 commit 67eaa40
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 88 deletions.
11 changes: 11 additions & 0 deletions axlearn/audio/asr_encoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""Tests speech encoder layers."""

import jax.random
import pytest
from absl.testing import parameterized
from jax import numpy as jnp

Expand All @@ -28,6 +29,7 @@ class SpeechFeatureLayerTest(TestCase):
"""Tests SpeechFeatureLayer."""

@parameterized.parameters([True, False])
@pytest.mark.fp64
def test_speech_feature_layer(self, is_training: bool):
num_filters, sample_rate, frame_size_ms, hop_size_ms = 80, 16000, 25, 10
hidden_dim, output_dim = 32, 16
Expand Down Expand Up @@ -68,6 +70,10 @@ def test_speech_feature_layer(self, is_training: bool):
prng_key=input_key, batch_size=batch_size, seq_len=seq_len
)

# Slightly higher diff without fp64 from conv subsampler on jax 0.4.21.
inputs = inputs.astype(jnp.float64)
layer_params = jax.tree_map(lambda x: x.astype(jnp.float64), layer_params)

output_batch, output_collections = F(
layer,
inputs=dict(inputs=inputs, paddings=paddings),
Expand Down Expand Up @@ -172,6 +178,7 @@ class ASREncoderTest(TestCase):
"""Tests ASREncoder."""

@parameterized.parameters([True, False])
@pytest.mark.fp64
def test_asr_encoder(self, is_training: bool):
conv_dim, output_dim = 12, 36
num_filters, sample_rate, frame_size_ms, hop_size_ms = 80, 16000, 25, 10
Expand Down Expand Up @@ -212,6 +219,10 @@ def test_asr_encoder(self, is_training: bool):
prng_key=input_key, batch_size=batch_size, seq_len=seq_len
)

# Slightly higher diff without fp64 from conv subsampler on jax 0.4.21.
inputs = inputs.astype(jnp.float64)
layer_params = jax.tree_map(lambda x: x.astype(jnp.float64), layer_params)

output_batch, _ = F(
layer,
inputs=dict(inputs=inputs, paddings=paddings),
Expand Down
6 changes: 1 addition & 5 deletions axlearn/common/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from jax.experimental import mesh_utils
from jax.experimental.pjit import pjit

from axlearn.common import layers, test_utils, utils, utils_spmd
from axlearn.common import layers, test_utils, utils
from axlearn.common.base_model import BaseModel
from axlearn.common.checkpointer import CheckpointValidationType, TensorStoreStateStorage
from axlearn.common.config import Configurable, config_class, config_for_function
Expand Down Expand Up @@ -180,10 +180,6 @@ def is_supported(
class InferenceTest(test_utils.TestCase):
"""Inference tests."""

def setUp(self):
super().setUp()
utils_spmd.setup()

@parameterized.parameters(
(tf.constant("query"), "query"),
(tf.constant(["query"]), ["query"]),
Expand Down
46 changes: 12 additions & 34 deletions axlearn/common/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
import os
import sys

tpu_type = os.environ.get("TPU_TYPE", "none")
instance_type = os.environ.get("TPU_TYPE", "none")

# Set LIBTPU_INIT_ARGS before importing jax!
libtpu_init_args = [
"--xla_tpu_spmd_rng_bit_generator_unsafe=1", # SPMD partition-aware RngBitGenerator.
"--xla_tpu_enable_latency_hiding_scheduler=true", # Try to schedule ops efficiently.
"--xla_tpu_perform_spmd_cse_prevention=false", # b/229655601: prevent OOM on gpt2-small-repeat.
]
if tpu_type.startswith("v4-"):
if instance_type.startswith("v4-"):
libtpu_init_args += [
# Per maggioni@google.com, the following flags are not supported by V3.
"--xla_enable_async_all_gather=true", # Allow async all-gather.
Expand Down Expand Up @@ -44,11 +44,9 @@
# tpu_library_init_fns.inc:98] TpuEmbeddingEngine_ExecutePartitioner not available in this library.
import jax # jax must be imported before tensorflow!

# NOTE: calling JAX distributed APIs (e.g. jax.default_backend(), jax.process_index() or
# jax.process_count()) on GPU causes JAX to only view one process' GPUs.
print(f"jax version={jax.__version__}", file=sys.stderr)
if tpu_type != "none":
print(f"instance_type={tpu_type} num_slices={num_tpu_slices}", file=sys.stderr)
if instance_type != "none":
print(f"instance_type={instance_type} num_slices={num_tpu_slices}", file=sys.stderr)

import logging as pylogging

Expand Down Expand Up @@ -76,48 +74,29 @@
"If 'FAKE', uses fake inputs.",
)
flags.DEFINE_integer("jax_profiler_port", None, "If not None, the profiler port.")
flags.DEFINE_string(
"jax_backend", None, "If not None, ensures that trainer runs on the specified XLA backend."
)
flags.DEFINE_string("jax_backend", None, "Specifies the XLA backend to use.", required=True)
flags.DEFINE_string(
"distributed_coordinator",
None,
"Set this None for tpu backend but it is required for multi-gpu environment",
"Distributed coordinator IP address. Must be None on tpu, otherwise required.",
)
flags.DEFINE_integer(
"num_processes", None, "Total number of hosts (nodes). Must be None on tpu, otherwise required."
)
flags.DEFINE_integer(
"num_processes", None, "Total number of hosts (nodes). Set this None for tpu backend."
"process_id", None, "Rank of the current process. Must be None on tpu, otherwise required."
)
flags.DEFINE_integer("process_id", None, "Host process id. Set this None for tpu backend.")
flags.DEFINE_string(
"mesh_selector",
None,
"The mesh selector string. See `SpmdTrainer.Config.mesh_rules` for details.",
)
# TODO(markblee): Remove this flag.
flags.DEFINE_boolean(
"filter_info_logs",
None,
"If None (default), info log only on process 0 on TPUs, and on all processes on GPUs. "
"If True, info log only on process 0. "
"If False, info log on all processes.",
)


FLAGS = flags.FLAGS


def setup():
# Decide whether to filter logs.
if FLAGS.filter_info_logs is not None:
filter_info_logs = FLAGS.filter_info_logs
else:
# Infer from platform. For multi-node multi-gpu environment, filtering makes it so that only
# one process' devices are visible, so we disable it by default.
filter_info_logs = FLAGS.jax_backend is None or FLAGS.jax_backend != "gpu"

if filter_info_logs:
logging.get_absl_handler().addFilter(InfoLogOnlyOnMaster())

setup_spmd(
distributed_coordinator=FLAGS.distributed_coordinator,
num_processes=FLAGS.num_processes,
Expand All @@ -133,9 +112,8 @@ def setup():
logging.info("Devices: %s", devices)
local_devices = jax.local_devices()
logging.info("Local Devices: %s", local_devices)
if FLAGS.jax_backend is not None:
if not devices or not all(device.platform == FLAGS.jax_backend for device in devices):
raise RuntimeError(f"Expected backend {FLAGS.jax_backend}. Got {devices}.")
if not devices or not all(device.platform == FLAGS.jax_backend for device in devices):
raise RuntimeError(f"Expected backend {FLAGS.jax_backend}. Got {devices}.")
if FLAGS.data_dir:
# TODO(ruoming): Get rid of --data_dir and use only env var DATA_DIR.
os.environ["DATA_DIR"] = FLAGS.data_dir
Expand Down
9 changes: 3 additions & 6 deletions axlearn/common/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@
)
from axlearn.experiments.trainer_config_utils import TrainerConfigFn

# See utils_spmd.py for where we set "jax_default_prng_impl".
_default_prng_impl = "rbg"
_PYTEST_OPT_REGISTERED = {}


Expand Down Expand Up @@ -131,8 +129,9 @@ def data_dir(self):
return "FAKE"

def setUp(self):
utils_spmd.setup()
push_data_dir(self.data_dir)
# Setup without distributed initialization.
utils_spmd.setup(jax_backend="cpu")

def tearDown(self) -> None:
self.assertEqual(pop_data_dir(), self.data_dir)
Expand Down Expand Up @@ -310,11 +309,9 @@ def run(self, result=None):

@contextlib.contextmanager
def prng_impl(new_prng_impl: str):
old_prng_impl = _default_prng_impl
old_prng_impl = jax.config.jax_default_prng_impl

def switch(value):
global _default_prng_impl # pylint: disable=global-statement
_default_prng_impl = value
jax.config.update("jax_default_prng_impl", value)

switch(new_prng_impl)
Expand Down
62 changes: 38 additions & 24 deletions axlearn/common/utils_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

"""SPMD related utils."""

import logging
import socket
from typing import Optional

Expand All @@ -15,29 +16,30 @@

def setup(
*,
jax_backend: str,
distributed_coordinator: Optional[str] = None,
num_processes: Optional[int] = None,
process_id: Optional[int] = None,
jax_backend: Optional[str] = None,
):
"""Sets up the Jax environment for SPMD/pjit.
"""Sets up the JAX environment for SPMD.
Args:
distributed_coordinator: The distributed coordinator address (in the form of <host>:<port>).
Needed only if not running on TPU *and* jax.process_count() > 1. Otherwise the
coordinator will be configured automatically.
num_processes: The number of processes (GPU backend: total number of gpus). Needed only
if not running on TPU *and* jax.process_count() > 1. Otherwise the coordinator will
be configured automatically.
process_id: The process id (GPU backend: the GPU rank). Needed only if not running
on TPU *and* jax.process_count() > 1. Otherwise the coordinator will be
configured automatically.
jax_backend: The distributed backend, which can be "cpu", "gpu", or "tpu".
By default, it would be configured automatically.
distributed_coordinator: The distributed coordinator address (in the form of <host>:<port>).
Needed only for `jax_backend != "tpu"` and `num_processes > 1`. Otherwise, the
coordinator will be configured automatically when `num_processes` and `process_id` are
provided.
num_processes: The number of processes. Needed only if distributed initialization is desired
for `jax_backend != "tpu"`.
process_id: The process ID (the process rank). Needed only if distributed initialization is
desired for `jax_backend != "tpu"`.
Raises:
ValueError: If distributed_coordinator, num_processes, or process_id are not None when
jax_backend is "tpu", or if distributed_coordinator is unsupported.
ValueError: If any of the following conditions are met:
* distributed_coordinator, num_processes, or process_id are not None when
jax_backend is "tpu";
* one of num_processes or process_id is None when jax_backend is not "tpu";
* distributed_coordinator is None when jax_backend is not "tpu" and num_processes > 1.
"""
# Use a GSPMD-friendly PRNG implementation.
jax.config.update("jax_default_prng_impl", "rbg")
Expand All @@ -46,29 +48,41 @@ def setup(

global _jax_distributed_initialized # pylint: disable=global-statement
if not _jax_distributed_initialized:
# NOTE: calling JAX distributed APIs (e.g. jax.default_backend(), jax.process_index() or
# jax.process_count()) on GPU causes JAX to only view one process' GPUs.
jax_backend = jax_backend or jax.default_backend()
if jax_backend == "tpu":
assert (
if not (
distributed_coordinator is None and num_processes is None and process_id is None
), ValueError(
"distributed_coordinator, num_processes, process_id "
"should all be None for tpu backend"
)
):
raise ValueError(
"distributed_coordinator, num_processes, and process_id "
"should all be None for tpu backend."
)
jax.distributed.initialize(
coordinator_address=_infer_tpu_coordinator_address(),
num_processes=jax.process_count(),
process_id=jax.process_index(),
)
else:
num_processes = num_processes if num_processes is not None else jax.process_count()
process_id = process_id if process_id is not None else jax.process_index()
if distributed_coordinator is None and num_processes is None and process_id is None:
logging.info(
"Skipping distributed initialization for %s backend, "
"since distributed_coordinator, num_processes, and process_id are all None.",
jax_backend,
)
return

if num_processes is None or process_id is None:
raise ValueError(
"num_processes and process_id should be provided together "
f"if distributed initialization is desired for backend {jax_backend}. "
f"Instead, got num_processes={num_processes}, process_id={process_id}."
)

if not distributed_coordinator:
if num_processes == 1:
distributed_coordinator = f"localhost:{portpicker.pick_unused_port()}"
else:
raise ValueError(f"Unknown distributed_coordinator: {distributed_coordinator}")

jax.distributed.initialize(
distributed_coordinator,
num_processes=num_processes,
Expand Down
7 changes: 4 additions & 3 deletions axlearn/experiments/text/gpt/c4_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
mkdir -p /tmp/gpt_c4_test;
python3 -m axlearn.common.launch_trainer_main \
--module=text.gpt.c4_trainer --config=fuji-test \
--trainer_dir=/tmp/gpt_c4_test --data_dir=FAKE
--trainer_dir=/tmp/gpt_c4_test --data_dir=FAKE --jax_backend=cpu
GS_ROOT=gs://my-bucket; \
CONFIG=fuji-7B; \
Expand All @@ -18,7 +18,7 @@
--module=text.gpt.c4_trainer --config=$CONFIG \
--trainer_dir=$OUTPUT_DIR \
--data_dir=$GS_ROOT/tensorflow_datasets \
--mesh_selector=$INSTANCE_TYPE
--mesh_selector=$INSTANCE_TYPE --jax_backend=tpu
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh; \
bash Miniconda3-latest-Linux-x86_64.sh; \
Expand All @@ -32,7 +32,8 @@
mkdir -p /tmp/test_trainer; \
python3 -m axlearn.common.launch_trainer_main \
--module=text.gpt.c4_trainer --config=fuji-7B-single \
--trainer_dir=/tmp/test_trainer --data_dir=gs://axlearn-public/tensorflow_datasets
--trainer_dir=/tmp/test_trainer --data_dir=gs://axlearn-public/tensorflow_datasets \
--jax_backend=gpu
"""

from typing import Dict
Expand Down
6 changes: 3 additions & 3 deletions axlearn/experiments/vision/resnet/imagenet_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
mkdir -p /tmp/resnet_test;
python3 -m axlearn.common.launch_trainer_main \
--module=vision.resnet.imagenet_trainer --config=ResNet-Test \
--trainer_dir=/tmp/resnet_test --data_dir=FAKE
--trainer_dir=/tmp/resnet_test --data_dir=FAKE --jax_backend=cpu
# Launch training on a v4-8 TPU, reading and writing from GCS.
#
Expand All @@ -26,7 +26,7 @@
axlearn gcp launch --instance_type=tpu-v4-8 --output_dir=$OUTPUT_DIR -- \
python3 -m axlearn.common.launch_trainer_main \
--module=vision.resnet.imagenet_trainer --config=ResNet-50 \
--trainer_dir=$OUTPUT_DIR --data_dir=${GS_ROOT}/tensorflow_datasets
--trainer_dir=$OUTPUT_DIR --data_dir=${GS_ROOT}/tensorflow_datasets --jax_backend=tpu
# Sample docker launch.
GS_ROOT=gs://my-bucket; \
Expand All @@ -40,7 +40,7 @@
--bundler_spec=target=tpu -- \
python3 -m axlearn.common.launch_trainer_main \
--module=vision.resnet.imagenet_trainer --config=ResNet-50 \
--trainer_dir=$OUTPUT_DIR --data_dir=${GS_ROOT}/tensorflow_datasets
--trainer_dir=$OUTPUT_DIR --data_dir=${GS_ROOT}/tensorflow_datasets --jax_backend=tpu
```
"""
Expand Down
4 changes: 2 additions & 2 deletions docs/01-start.md
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ DATA_DIR=gs://path/to/tensorflow_datasets
axlearn gcp tpu start --tpu_type=v4-8 --output_dir=$OUTPUT_DIR -- \
python3 -m axlearn.common.launch_trainer_main \
--module=vision.resnet.imagenet_trainer --config=ResNet-50 \
--trainer_dir=$OUTPUT_DIR --data_dir=$DATA_DIR
--trainer_dir=$OUTPUT_DIR --data_dir=$DATA_DIR --jax_backend=tpu
```

If you have been following along with the code, assuming you have a file `axlearn/experiments/tutorial.py`, you can also launch your own experiment with:
Expand All @@ -752,7 +752,7 @@ axlearn gcp tpu start --tpu_type=v4-8 --output_dir=$OUTPUT_DIR -- \
python3 -m axlearn.common.launch_trainer_main \
- --module=vision.resnet.imagenet_trainer --config=ResNet-50 \
+ --module=tutorial --config=ResNet-50 \
--trainer_dir=$OUTPUT_DIR --data_dir=$DATA_DIR
--trainer_dir=$OUTPUT_DIR --data_dir=$DATA_DIR --jax_backend=tpu
```

Both commands are similar to the one from the previous section except we run the trainer defined by `--module` and `--config` instead of simply printing `jax.devices()`.
Expand Down
11 changes: 0 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,6 @@ tpu = [
"axlearn[gcp]",
"jax[tpu]==0.4.20", # must be >=0.4.19 for compat with v5p.
]
# Jax-triton. Can only be installed on a GPU machine.
# Note: Specify -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html during install.
jax_triton = [
"cmake",
# TODO(markblee): Find compatible version(s) for jax 0.4.20.
"jax-triton@git+https://github.com/jax-ml/jax-triton.git@06d35175fbb65771243f607b1f098ec4a75556c9",
"jaxlib==0.4.18+cuda12.cudnn89",
"nvidia-cudnn-cu12==8.9.2.26",
# To avoid conflicting with jax.
"tensorflow-cpu==2.8.0",
]
# Vertex AI tensorboard.
vertexai_tensorboard = [
# Required to fix a `distro-info` bug we run into when using `tb_gcp_tensorboard` from
Expand Down

0 comments on commit 67eaa40

Please sign in to comment.