Skip to content

Commit

Permalink
Save types information to allow for fully self-describing checkpoints…
Browse files Browse the repository at this point in the history
…. Metadata is stored in JSON format to maximize cross-platform compatibility and ease-of-use. Allow users to access metadata via API (previously only possible with Tensorstore). Note that the feature is not yet turned on by default.

PiperOrigin-RevId: 537384891
  • Loading branch information
cpgaffney1 authored and t5-copybara committed Jul 31, 2023
1 parent 104b658 commit 13ded7a
Showing 1 changed file with 51 additions and 14 deletions.
65 changes: 51 additions & 14 deletions t5x/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import re
import subprocess
import time
import typing
from typing import Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union

from absl import logging
Expand Down Expand Up @@ -1927,7 +1928,7 @@ def restore(self,

def structure(self, directory: epath.Path) -> Any:
"""Unimplemented. See parent class."""
return NotImplementedError
pass


def _step_from_train_state(train_state: train_state_lib.TrainState) -> int:
Expand Down Expand Up @@ -1960,7 +1961,6 @@ def _construct_restore_args(
if param_info.mesh_axes is None:
return ocp.RestoreArgs(dtype=dtype)
return ocp.ArrayRestoreArgs(
restore_type=jax.Array,
mesh=mesh,
mesh_axes=param_info.mesh_axes,
dtype=dtype,
Expand Down Expand Up @@ -2000,7 +2000,15 @@ def _construct_orbax_restoration_transforms(
)
assert state_subdir.is_dir()
use_orbax_format = state_subdir.stem == _STATE_KEY # Standard Orbax format
structure = manager._checkpointers[_STATE_KEY].structure(state_subdir) # pylint: disable=protected-access
checkpointer = typing.cast(
ocp.Checkpointer, manager._checkpointers[_STATE_KEY] # pylint: disable=protected-access
)
handler = typing.cast(
ocp.PyTreeCheckpointHandler, checkpointer._handler # pylint: disable=protected-access
)
structure = handler._read_aggregate_file( # pylint: disable=protected-access
state_subdir
)
# Note: Ideally we would use Orbax's `transform_fn` to do this logic, but
# the problem is we need to modify `restore_args`, and there isn't a great
# way to do that within Orbax.
Expand All @@ -2021,26 +2029,55 @@ def _construct_orbax_restoration_transforms(
)

def _transform_fn(
item: PyTree, structure: PyTree, param_infos: PyTree
item_: PyTree, structure_: PyTree, param_infos_: PyTree
) -> Tuple[PyTree, PyTree]:
# When this function is called from within PyTreeCheckpointHandler,
# transforms will already have been performed (see above), but use this
# function to hack param_infos to return the needed values.
# This structure is unneeded, because we already restored and transformed
# it.
del structure
del param_infos
# Construct param_infos from item because item is the transformed
# structure.
# pylint: disable=protected-access
param_infos, _ = ocp.pytree_checkpoint_handler._get_restore_parameters(
manager._get_save_directory(step, directory, key_name=_STATE_KEY),
del structure_, param_infos_

def _make_orbax_internal_metadata(value: Any, args: ocp.RestoreArgs):
if ocp.utils.leaf_is_placeholder(value):
if isinstance(args, ocp.ArrayRestoreArgs):
restore_type = 'jax.Array'
else:
restore_type = 'np.ndarray'
return ocp.pytree_checkpoint_handler._InternalValueMetadata( # pylint: disable=protected-access
restore_type=restore_type
)
else:
return ocp.pytree_checkpoint_handler._InternalValueMetadata( # pylint: disable=protected-access
restore_type=None,
skip_deserialize=True,
aggregate_value=value,
)

directory_ = manager._get_save_directory( # pylint: disable=protected-access
step, directory, key_name=_STATE_KEY
)

def _modify_orbax_param_info(info, value):
if ocp.utils.leaf_is_placeholder(value):
name = ocp.utils.name_from_leaf_placeholder(value)
return dataclasses.replace(info, path=directory_ / name)
return info

item_ = jax.tree_util.tree_map(
_make_orbax_internal_metadata, item_, restore_args
)
param_infos_, _ = ocp.pytree_checkpoint_handler._get_restore_parameters( # pylint: disable=protected-access
directory_,
None,
item,
item_,
None,
None,
)
return item, param_infos
param_infos_ = jax.tree_util.tree_map(
_modify_orbax_param_info, param_infos_, state_dict_to_restore
)
return item_, param_infos_

return state_dict_to_restore, restore_args, _transform_fn

Expand Down Expand Up @@ -2233,7 +2270,7 @@ def save(
step, items, save_kwargs=save_kwargs, force=force
)

# Record JAX montioring events.
# Record JAX monitoring events.
end_time = time.time()
monitoring.record_event_duration_secs(
_WRITE_CHECKPOINT_EVENT, end_time - start_time
Expand Down

0 comments on commit 13ded7a

Please sign in to comment.