From 831342fb4df44d62e299a0273fd840ede4710b5e Mon Sep 17 00:00:00 2001 From: Colin Gaffney Date: Thu, 13 Jul 2023 15:23:00 -0700 Subject: [PATCH] Orbax no longer supports lazy restoration of checkpoints. If our checkpoint is `{'a': 2, 'b': 4}`, and we wish to 'lazily' load the tree such that only 'a' is materialized, equivalent functionality can be achieved using the following: ``` ckptr = ocp.PyTreeCheckpointer() tree = {'a': , 'b': } ckptr.restore(path, item=tree, transforms={'b': ocp.Transform(use_fallback=True)}) {'a': 2, 'b': } ``` OR ``` ckptr = ocp.PyTreeCheckpointer() tree = {'a': } # Returns a tree just containing 'a', materialized. ckptr.restore(path, item=tree, transforms={}) {'a': 2} ``` Previously this could have also been done with the lazy feature: ``` ckptr = ocp.PyTreeCheckpointer() ckptr.restore(path, restore_args={'a': RestoreArgs(), 'b': RestoreArgs(lazy=True)}) {'a': 2, 'b': LazyValue()} ``` PiperOrigin-RevId: 547941035 --- docs/usage/gin.md | 3 ++- t5x/checkpoint_importer.py | 3 +-- t5x/checkpoints.py | 55 +++++++++++++++++++------------------- t5x/test_utils.py | 5 ++-- t5x/utils_test.py | 8 +++--- 5 files changed, 36 insertions(+), 38 deletions(-) diff --git a/docs/usage/gin.md b/docs/usage/gin.md index d5ed767b2..0c6bce3f6 100644 --- a/docs/usage/gin.md +++ b/docs/usage/gin.md @@ -374,13 +374,14 @@ inference evaluation you may add `--gin.train.infer_eval_dataset_cfg=None`. At the beginning of the primer, we saw a fully-specified run config. We can do something similar with the previous example to create a self-contained run configuration. -[t5_1_1/examples/base_wmt_finetune.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin) +[t5_1_1/examples/small_wmt_finetune.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin) is just such an example that allows you to exactly duplicate the previous launch command simply by calling: ```sh python -m t5x.train \ --gin_file=t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin \ + --gin.MODEL_DIR=\"/tmp/t5_1_1_small_finetune_gin\" \ --logtostderr ``` diff --git a/t5x/checkpoint_importer.py b/t5x/checkpoint_importer.py index b6d000e8f..022e5c150 100644 --- a/t5x/checkpoint_importer.py +++ b/t5x/checkpoint_importer.py @@ -24,7 +24,6 @@ import jax from jax import numpy as jnp import numpy as np -from orbax.checkpoint import lazy_utils import tensorflow as tf import tensorstore as ts @@ -32,7 +31,7 @@ ScalarOrArrayType = Union[int, float, ArrayType] -class LazyArray(lazy_utils.LazyValue, metaclass=abc.ABCMeta): +class LazyArray(metaclass=abc.ABCMeta): """Lazily and asynchronously loads an array. LazyArray behaves in the same way as a `numpy` or `jax.numpy` array diff --git a/t5x/checkpoints.py b/t5x/checkpoints.py index 1cf23f318..8424f73d1 100644 --- a/t5x/checkpoints.py +++ b/t5x/checkpoints.py @@ -49,7 +49,7 @@ from jax.experimental.array_serialization import serialization as array_serialization import jax.numpy as jnp import numpy as np -import orbax.checkpoint +import orbax.checkpoint as ocp from t5x import checkpoint_importer from t5x import checkpoint_utils from t5x import optimizers @@ -60,9 +60,7 @@ from tensorflow.io import gfile import tensorstore as ts import typing_extensions -from tensorboard.backend.event_processing import directory_watcher -from tensorboard.backend.event_processing import event_file_loader -from tensorboard.backend.event_processing import io_wrapper + PartitionSpec = partitioning.PartitionSpec PyTree = Any @@ -836,7 +834,7 @@ def save(self, end_time = time.time() monitoring.record_event_duration_secs(_WRITE_CHECKPOINT_EVENT, end_time - start_time) - orbax.checkpoint.utils.record_saved_duration(start_time) + ocp.utils.record_saved_duration(start_time) def _write_state_to_tensorstore( self, @@ -1886,7 +1884,7 @@ class _OrbaxParamInfo: mesh_axes: partitioning.PartitionSpec -class DatasetCheckpointHandler(orbax.checkpoint.CheckpointHandler): +class DatasetCheckpointHandler(ocp.CheckpointHandler): """A CheckpointHandler implementation that handles tf.data.Iterator.""" def __init__(self, checkpoint_filename: str): @@ -1942,34 +1940,30 @@ def _step_from_train_state(train_state: train_state_lib.TrainState) -> int: def _construct_save_args( param_info: _OrbaxParamInfo, dtype: jnp.dtype -) -> orbax.checkpoint.SaveArgs: +) -> ocp.SaveArgs: """Create SaveArgs for Orbax saving.""" if param_info.name.split('.')[0] != 'target': dtype = None - return orbax.checkpoint.SaveArgs( - aggregate=param_info.mesh_axes is None, dtype=dtype - ) + return ocp.SaveArgs(aggregate=param_info.mesh_axes is None, dtype=dtype) def _construct_restore_args( param_info: _OrbaxParamInfo, dtype: jnp.dtype, mesh: jax.sharding.Mesh, - lazy_parameters: bool, -) -> orbax.checkpoint.RestoreArgs: +) -> ocp.RestoreArgs: """Create RestoreArgs for Orbax restoration.""" if not isinstance(param_info, _OrbaxParamInfo): # from fallback - return orbax.checkpoint.RestoreArgs(dtype=dtype, lazy=lazy_parameters) + return ocp.RestoreArgs(dtype=dtype) if param_info.name.split('.')[0] != 'target': dtype = None if param_info.mesh_axes is None: - return orbax.checkpoint.RestoreArgs(dtype=dtype, lazy=lazy_parameters) - return orbax.checkpoint.ArrayRestoreArgs( + return ocp.RestoreArgs(dtype=dtype) + return ocp.ArrayRestoreArgs( restore_type=jax.Array, mesh=mesh, mesh_axes=param_info.mesh_axes, dtype=dtype, - lazy=lazy_parameters, ) @@ -1991,7 +1985,7 @@ def _construct_orbax_param_infos( def _construct_orbax_restoration_transforms( - manager: orbax.checkpoint.CheckpointManager, + manager: ocp.CheckpointManager, step: int, directory: epath.Path, state_dict: PyTree, @@ -2035,11 +2029,16 @@ def _transform_fn( # This structure is unneeded, because we already restored and transformed # it. del structure + del param_infos # Construct param_infos from item because item is the transformed # structure. # pylint: disable=protected-access - param_infos = orbax.checkpoint.pytree_checkpoint_handler._get_param_infos_from_structure( - manager._get_save_directory(step, directory, key_name=_STATE_KEY), item + param_infos, _ = ocp.pytree_checkpoint_handler._get_restore_parameters( + manager._get_save_directory(step, directory, key_name=_STATE_KEY), + None, + item, + None, + None, ) return item, param_infos @@ -2079,9 +2078,9 @@ def _partition_parameter(maybe_arr: Any, param_info: _OrbaxParamInfo): class OrbaxCheckpointManagerInterface: - """Wrapper for orbax.checkpoint.CheckpointManager.""" + """Wrapper for ocp.CheckpointManager.""" - class _CheckpointManagerImpl(orbax.checkpoint.CheckpointManager): + class _CheckpointManagerImpl(ocp.CheckpointManager): """CheckpointManager implementation to deal with metrics update.""" def _remove_old_checkpoints(self): @@ -2136,20 +2135,20 @@ def __init__( ) checkpointers = { - _STATE_KEY: orbax.checkpoint.Checkpointer( + _STATE_KEY: ocp.Checkpointer( # TODO(b/273803615) Enable OCDBT. - orbax.checkpoint.PyTreeCheckpointHandler(use_ocdbt=False) + ocp.PyTreeCheckpointHandler(use_ocdbt=False) ), } if self._should_write_dataset_ckpt: - checkpointers[_DATASET_KEY] = orbax.checkpoint.Checkpointer( + checkpointers[_DATASET_KEY] = ocp.Checkpointer( DatasetCheckpointHandler(checkpoint_filename=dataset_ckpt_name) ) def best_fn(metrics): return metrics[metric_name_to_monitor] - options = orbax.checkpoint.CheckpointManagerOptions( + options = ocp.CheckpointManagerOptions( max_to_keep=keep, save_interval_steps=period, keep_period=force_keep_period, @@ -2239,7 +2238,7 @@ def save( monitoring.record_event_duration_secs( _WRITE_CHECKPOINT_EVENT, end_time - start_time ) - orbax.checkpoint.utils.record_saved_duration(start_time) + ocp.utils.record_saved_duration(start_time) return saved @@ -2271,6 +2270,8 @@ def restore( Returns: The restored train state. """ + if lazy_parameters: + logging.warning('Orbax does not support lazy restoration.') start_time = time.time() if step is not None and path is not None: raise ValueError('Can only provide `step` or `path` but not both.') @@ -2289,7 +2290,6 @@ def restore( _construct_restore_args, dtype=self._restore_dtype, mesh=self._partitioner.mesh, - lazy_parameters=lazy_parameters, ), param_infos, ) @@ -2361,7 +2361,6 @@ def restore_from_tf_checkpoint( full_state_dict = checkpoint_importer.restore_from_t5_checkpoint( self._train_state.state_dict(), path_or_dir, - lazy_parameters=False, strict=strict, translator=translator, ) diff --git a/t5x/test_utils.py b/t5x/test_utils.py index 8aed733dc..6368ff776 100644 --- a/t5x/test_utils.py +++ b/t5x/test_utils.py @@ -26,6 +26,7 @@ from jax.sharding import Mesh import numpy as np import seqio +import t5.data from t5x import adafactor from t5x import models from t5x import partitioning @@ -322,8 +323,8 @@ def partition( ): pjitted = pjit( fn, - in_axis_resources=in_axis_resources, - out_axis_resources=out_axis_resources, + in_shardings=in_axis_resources, + out_shardings=out_axis_resources, static_argnums=static_argnums, donate_argnums=donate_argnums, ) diff --git a/t5x/utils_test.py b/t5x/utils_test.py index 0ee711368..193f75c81 100644 --- a/t5x/utils_test.py +++ b/t5x/utils_test.py @@ -1033,8 +1033,8 @@ def get_data_layout(batch_size): def partition(fn, in_axis_resources, out_axis_resources): fn = pjit( fn, - in_axis_resources=in_axis_resources, - out_axis_resources=out_axis_resources, + in_shardings=in_axis_resources, + out_shardings=out_axis_resources, ) return partitioning.PjittedFnWithContext(fn, global_mesh) @@ -1064,9 +1064,7 @@ def partition(fn, in_axis_resources, out_axis_resources): def as_sharded_array(arr, axes, mesh=None): with mesh: - return pjit( - lambda x: x, in_axis_resources=None, out_axis_resources=axes - )(arr) + return pjit(lambda x: x, in_shardings=None, out_shardings=axes)(arr) train_state.params = jax.tree_util.tree_map( functools.partial(as_sharded_array, mesh=global_mesh),