diff --git a/t5x/checkpoints.py b/t5x/checkpoints.py index 8424f73d1..3eec38798 100644 --- a/t5x/checkpoints.py +++ b/t5x/checkpoints.py @@ -34,6 +34,7 @@ import re import subprocess import time +import typing from typing import Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union from absl import logging @@ -1927,7 +1928,7 @@ def restore(self, def structure(self, directory: epath.Path) -> Any: """Unimplemented. See parent class.""" - return NotImplementedError + pass def _step_from_train_state(train_state: train_state_lib.TrainState) -> int: @@ -1960,7 +1961,6 @@ def _construct_restore_args( if param_info.mesh_axes is None: return ocp.RestoreArgs(dtype=dtype) return ocp.ArrayRestoreArgs( - restore_type=jax.Array, mesh=mesh, mesh_axes=param_info.mesh_axes, dtype=dtype, @@ -2000,7 +2000,15 @@ def _construct_orbax_restoration_transforms( ) assert state_subdir.is_dir() use_orbax_format = state_subdir.stem == _STATE_KEY # Standard Orbax format - structure = manager._checkpointers[_STATE_KEY].structure(state_subdir) # pylint: disable=protected-access + checkpointer = typing.cast( + ocp.Checkpointer, manager._checkpointers[_STATE_KEY] # pylint: disable=protected-access + ) + handler = typing.cast( + ocp.PyTreeCheckpointHandler, checkpointer._handler # pylint: disable=protected-access + ) + structure = handler._read_aggregate_file( # pylint: disable=protected-access + state_subdir + ) # Note: Ideally we would use Orbax's `transform_fn` to do this logic, but # the problem is we need to modify `restore_args`, and there isn't a great # way to do that within Orbax. @@ -2021,26 +2029,55 @@ def _construct_orbax_restoration_transforms( ) def _transform_fn( - item: PyTree, structure: PyTree, param_infos: PyTree + item_: PyTree, structure_: PyTree, param_infos_: PyTree ) -> Tuple[PyTree, PyTree]: # When this function is called from within PyTreeCheckpointHandler, # transforms will already have been performed (see above), but use this # function to hack param_infos to return the needed values. # This structure is unneeded, because we already restored and transformed # it. - del structure - del param_infos - # Construct param_infos from item because item is the transformed - # structure. - # pylint: disable=protected-access - param_infos, _ = ocp.pytree_checkpoint_handler._get_restore_parameters( - manager._get_save_directory(step, directory, key_name=_STATE_KEY), + del structure_, param_infos_ + + def _make_orbax_internal_metadata(value: Any, args: ocp.RestoreArgs): + if ocp.utils.leaf_is_placeholder(value): + if isinstance(args, ocp.ArrayRestoreArgs): + restore_type = 'jax.Array' + else: + restore_type = 'np.ndarray' + return ocp.pytree_checkpoint_handler._InternalValueMetadata( # pylint: disable=protected-access + restore_type=restore_type + ) + else: + return ocp.pytree_checkpoint_handler._InternalValueMetadata( # pylint: disable=protected-access + restore_type=None, + skip_deserialize=True, + aggregate_value=value, + ) + + directory_ = manager._get_save_directory( # pylint: disable=protected-access + step, directory, key_name=_STATE_KEY + ) + + def _modify_orbax_param_info(info, value): + if ocp.utils.leaf_is_placeholder(value): + name = ocp.utils.name_from_leaf_placeholder(value) + return dataclasses.replace(info, path=directory_ / name) + return info + + item_ = jax.tree_util.tree_map( + _make_orbax_internal_metadata, item_, restore_args + ) + param_infos_, _ = ocp.pytree_checkpoint_handler._get_restore_parameters( # pylint: disable=protected-access + directory_, None, - item, + item_, None, None, ) - return item, param_infos + param_infos_ = jax.tree_util.tree_map( + _modify_orbax_param_info, param_infos_, state_dict_to_restore + ) + return item_, param_infos_ return state_dict_to_restore, restore_args, _transform_fn @@ -2233,7 +2270,7 @@ def save( step, items, save_kwargs=save_kwargs, force=force ) - # Record JAX montioring events. + # Record JAX monitoring events. end_time = time.time() monitoring.record_event_duration_secs( _WRITE_CHECKPOINT_EVENT, end_time - start_time