Skip to content

Commit

Permalink
Orbax no longer supports lazy restoration of checkpoints.
Browse files Browse the repository at this point in the history
If our checkpoint is `{'a': 2, 'b': 4}`, and we wish to 'lazily' load the tree such that only 'a' is materialized, equivalent functionality can be achieved using the following:

```
ckptr = ocp.PyTreeCheckpointer()
tree = {'a': <dummy_value>, 'b': <dummy_value>}
ckptr.restore(path, item=tree, transforms={'b': ocp.Transform(use_fallback=True)})
{'a': 2, 'b': <dummy_value>}
```
OR
```
ckptr = ocp.PyTreeCheckpointer()
tree = {'a': <dummy_value>}
# Returns a tree just containing 'a', materialized.
ckptr.restore(path, item=tree, transforms={})
{'a': 2}
```

Previously this could have also been done with the lazy feature:
```
ckptr = ocp.PyTreeCheckpointer()
ckptr.restore(path, restore_args={'a': RestoreArgs(), 'b': RestoreArgs(lazy=True)})
{'a': 2, 'b': LazyValue()}
```
PiperOrigin-RevId: 547941035
  • Loading branch information
cpgaffney1 authored and t5-copybara committed Jul 18, 2023
1 parent eb08ffb commit 831342f
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 38 deletions.
3 changes: 2 additions & 1 deletion docs/usage/gin.md
Original file line number Diff line number Diff line change
Expand Up @@ -374,13 +374,14 @@ inference evaluation you may add `--gin.train.infer_eval_dataset_cfg=None`.
At the beginning of the primer, we saw a fully-specified run config. We can do
something similar with the previous example to create a self-contained run
configuration.
[t5_1_1/examples/base_wmt_finetune.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin)
[t5_1_1/examples/small_wmt_finetune.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin)
is just such an example that allows you to exactly duplicate the previous launch
command simply by calling:

```sh
python -m t5x.train \
--gin_file=t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin \
--gin.MODEL_DIR=\"/tmp/t5_1_1_small_finetune_gin\" \
--logtostderr
```

Expand Down
3 changes: 1 addition & 2 deletions t5x/checkpoint_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,14 @@
import jax
from jax import numpy as jnp
import numpy as np
from orbax.checkpoint import lazy_utils
import tensorflow as tf
import tensorstore as ts

ArrayType = Union[np.ndarray, jnp.ndarray, jax.Array]
ScalarOrArrayType = Union[int, float, ArrayType]


class LazyArray(lazy_utils.LazyValue, metaclass=abc.ABCMeta):
class LazyArray(metaclass=abc.ABCMeta):
"""Lazily and asynchronously loads an array.
LazyArray behaves in the same way as a `numpy` or `jax.numpy` array
Expand Down
55 changes: 27 additions & 28 deletions t5x/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from jax.experimental.array_serialization import serialization as array_serialization
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint
import orbax.checkpoint as ocp
from t5x import checkpoint_importer
from t5x import checkpoint_utils
from t5x import optimizers
Expand All @@ -60,9 +60,7 @@
from tensorflow.io import gfile
import tensorstore as ts
import typing_extensions
from tensorboard.backend.event_processing import directory_watcher
from tensorboard.backend.event_processing import event_file_loader
from tensorboard.backend.event_processing import io_wrapper


PartitionSpec = partitioning.PartitionSpec
PyTree = Any
Expand Down Expand Up @@ -836,7 +834,7 @@ def save(self,
end_time = time.time()
monitoring.record_event_duration_secs(_WRITE_CHECKPOINT_EVENT,
end_time - start_time)
orbax.checkpoint.utils.record_saved_duration(start_time)
ocp.utils.record_saved_duration(start_time)

def _write_state_to_tensorstore(
self,
Expand Down Expand Up @@ -1886,7 +1884,7 @@ class _OrbaxParamInfo:
mesh_axes: partitioning.PartitionSpec


class DatasetCheckpointHandler(orbax.checkpoint.CheckpointHandler):
class DatasetCheckpointHandler(ocp.CheckpointHandler):
"""A CheckpointHandler implementation that handles tf.data.Iterator."""

def __init__(self, checkpoint_filename: str):
Expand Down Expand Up @@ -1942,34 +1940,30 @@ def _step_from_train_state(train_state: train_state_lib.TrainState) -> int:

def _construct_save_args(
param_info: _OrbaxParamInfo, dtype: jnp.dtype
) -> orbax.checkpoint.SaveArgs:
) -> ocp.SaveArgs:
"""Create SaveArgs for Orbax saving."""
if param_info.name.split('.')[0] != 'target':
dtype = None
return orbax.checkpoint.SaveArgs(
aggregate=param_info.mesh_axes is None, dtype=dtype
)
return ocp.SaveArgs(aggregate=param_info.mesh_axes is None, dtype=dtype)


def _construct_restore_args(
param_info: _OrbaxParamInfo,
dtype: jnp.dtype,
mesh: jax.sharding.Mesh,
lazy_parameters: bool,
) -> orbax.checkpoint.RestoreArgs:
) -> ocp.RestoreArgs:
"""Create RestoreArgs for Orbax restoration."""
if not isinstance(param_info, _OrbaxParamInfo): # from fallback
return orbax.checkpoint.RestoreArgs(dtype=dtype, lazy=lazy_parameters)
return ocp.RestoreArgs(dtype=dtype)
if param_info.name.split('.')[0] != 'target':
dtype = None
if param_info.mesh_axes is None:
return orbax.checkpoint.RestoreArgs(dtype=dtype, lazy=lazy_parameters)
return orbax.checkpoint.ArrayRestoreArgs(
return ocp.RestoreArgs(dtype=dtype)
return ocp.ArrayRestoreArgs(
restore_type=jax.Array,
mesh=mesh,
mesh_axes=param_info.mesh_axes,
dtype=dtype,
lazy=lazy_parameters,
)


Expand All @@ -1991,7 +1985,7 @@ def _construct_orbax_param_infos(


def _construct_orbax_restoration_transforms(
manager: orbax.checkpoint.CheckpointManager,
manager: ocp.CheckpointManager,
step: int,
directory: epath.Path,
state_dict: PyTree,
Expand Down Expand Up @@ -2035,11 +2029,16 @@ def _transform_fn(
# This structure is unneeded, because we already restored and transformed
# it.
del structure
del param_infos
# Construct param_infos from item because item is the transformed
# structure.
# pylint: disable=protected-access
param_infos = orbax.checkpoint.pytree_checkpoint_handler._get_param_infos_from_structure(
manager._get_save_directory(step, directory, key_name=_STATE_KEY), item
param_infos, _ = ocp.pytree_checkpoint_handler._get_restore_parameters(
manager._get_save_directory(step, directory, key_name=_STATE_KEY),
None,
item,
None,
None,
)
return item, param_infos

Expand Down Expand Up @@ -2079,9 +2078,9 @@ def _partition_parameter(maybe_arr: Any, param_info: _OrbaxParamInfo):


class OrbaxCheckpointManagerInterface:
"""Wrapper for orbax.checkpoint.CheckpointManager."""
"""Wrapper for ocp.CheckpointManager."""

class _CheckpointManagerImpl(orbax.checkpoint.CheckpointManager):
class _CheckpointManagerImpl(ocp.CheckpointManager):
"""CheckpointManager implementation to deal with metrics update."""

def _remove_old_checkpoints(self):
Expand Down Expand Up @@ -2136,20 +2135,20 @@ def __init__(
)

checkpointers = {
_STATE_KEY: orbax.checkpoint.Checkpointer(
_STATE_KEY: ocp.Checkpointer(
# TODO(b/273803615) Enable OCDBT.
orbax.checkpoint.PyTreeCheckpointHandler(use_ocdbt=False)
ocp.PyTreeCheckpointHandler(use_ocdbt=False)
),
}
if self._should_write_dataset_ckpt:
checkpointers[_DATASET_KEY] = orbax.checkpoint.Checkpointer(
checkpointers[_DATASET_KEY] = ocp.Checkpointer(
DatasetCheckpointHandler(checkpoint_filename=dataset_ckpt_name)
)

def best_fn(metrics):
return metrics[metric_name_to_monitor]

options = orbax.checkpoint.CheckpointManagerOptions(
options = ocp.CheckpointManagerOptions(
max_to_keep=keep,
save_interval_steps=period,
keep_period=force_keep_period,
Expand Down Expand Up @@ -2239,7 +2238,7 @@ def save(
monitoring.record_event_duration_secs(
_WRITE_CHECKPOINT_EVENT, end_time - start_time
)
orbax.checkpoint.utils.record_saved_duration(start_time)
ocp.utils.record_saved_duration(start_time)

return saved

Expand Down Expand Up @@ -2271,6 +2270,8 @@ def restore(
Returns:
The restored train state.
"""
if lazy_parameters:
logging.warning('Orbax does not support lazy restoration.')
start_time = time.time()
if step is not None and path is not None:
raise ValueError('Can only provide `step` or `path` but not both.')
Expand All @@ -2289,7 +2290,6 @@ def restore(
_construct_restore_args,
dtype=self._restore_dtype,
mesh=self._partitioner.mesh,
lazy_parameters=lazy_parameters,
),
param_infos,
)
Expand Down Expand Up @@ -2361,7 +2361,6 @@ def restore_from_tf_checkpoint(
full_state_dict = checkpoint_importer.restore_from_t5_checkpoint(
self._train_state.state_dict(),
path_or_dir,
lazy_parameters=False,
strict=strict,
translator=translator,
)
Expand Down
5 changes: 3 additions & 2 deletions t5x/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from jax.sharding import Mesh
import numpy as np
import seqio
import t5.data
from t5x import adafactor
from t5x import models
from t5x import partitioning
Expand Down Expand Up @@ -322,8 +323,8 @@ def partition(
):
pjitted = pjit(
fn,
in_axis_resources=in_axis_resources,
out_axis_resources=out_axis_resources,
in_shardings=in_axis_resources,
out_shardings=out_axis_resources,
static_argnums=static_argnums,
donate_argnums=donate_argnums,
)
Expand Down
8 changes: 3 additions & 5 deletions t5x/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,8 +1033,8 @@ def get_data_layout(batch_size):
def partition(fn, in_axis_resources, out_axis_resources):
fn = pjit(
fn,
in_axis_resources=in_axis_resources,
out_axis_resources=out_axis_resources,
in_shardings=in_axis_resources,
out_shardings=out_axis_resources,
)
return partitioning.PjittedFnWithContext(fn, global_mesh)

Expand Down Expand Up @@ -1064,9 +1064,7 @@ def partition(fn, in_axis_resources, out_axis_resources):

def as_sharded_array(arr, axes, mesh=None):
with mesh:
return pjit(
lambda x: x, in_axis_resources=None, out_axis_resources=axes
)(arr)
return pjit(lambda x: x, in_shardings=None, out_shardings=axes)(arr)

train_state.params = jax.tree_util.tree_map(
functools.partial(as_sharded_array, mesh=global_mesh),
Expand Down

0 comments on commit 831342f

Please sign in to comment.