Skip to content

Commit

Permalink
Remove external references to pytree_checkpoint_handler.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 669337504
  • Loading branch information
cpgaffney1 authored and t5-copybara committed Aug 30, 2024
1 parent 59a473c commit 8414ac6
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 13 deletions.
61 changes: 60 additions & 1 deletion t5x/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@

import enum
import os
from typing import Any, BinaryIO, Optional
from typing import Any, BinaryIO, Optional, Tuple, Union

from absl import logging
from etils import epath
import jax
import msgpack
import orbax.checkpoint as ocp
from tensorflow.io import gfile


# PINNED file in the checkpoint directory indicates that the checkpoint should
# not be removed during the automatic pruning of old checkpoints.
_PINNED_CHECKPOINT_FILENAME = 'PINNED'
Expand Down Expand Up @@ -246,3 +249,59 @@ def detect_checkpoint_type(
'written with T5X.',
)
return checkpoint_type


def get_restore_parameters(
directory: epath.Path,
structure: PyTree,
) -> Tuple[PyTree, PyTree]:
"""Construct parameters needed for restoration.
ParamInfos are
constructed from the structure of the original checkpoint, and restore_args
are serialized to a tree structure compatible with param_infos and structure.
Args:
directory: Checkpoint directory.
structure: The structure of the original checkpoint.
Returns:
Tuple of param_infos, and restore_args.
"""
flat_structure = ocp.tree.to_flat_dict(structure, keep_empty_nodes=True)
param_names = ocp.tree.get_param_names(structure)
flat_param_names = ocp.tree.to_flat_dict(param_names, keep_empty_nodes=True)
restore_args = jax.tree.map(lambda x: ocp.RestoreArgs(), structure)
flat_param_infos = {}
is_ocdbt_checkpoint = ocp.type_handlers.is_ocdbt_checkpoint(directory)
ts_context = ocp.type_handlers.get_ts_context()

def _get_param_info(
name: str,
meta_or_value: Union[Any, ocp.metadata.tree.ValueMetadataEntry],
) -> Union[ocp.type_handlers.ParamInfo, Any]:
if ocp.type_handlers.is_supported_empty_aggregation_type(meta_or_value):
# Empty node, ParamInfo should not be returned.
return meta_or_value
elif not isinstance(meta_or_value, ocp.metadata.tree.ValueMetadataEntry):
# Aggregated value.
skip_deserialize = True
else:
skip_deserialize = meta_or_value.skip_deserialize
return ocp.type_handlers.ParamInfo(
name=name,
path=directory / name,
parent_dir=directory,
skip_deserialize=skip_deserialize,
is_ocdbt_checkpoint=is_ocdbt_checkpoint,
ts_context=ts_context,
)

for key, meta in flat_structure.items():
flat_param_infos[key] = _get_param_info(flat_param_names[key], meta)
restore_args = ocp.tree.serialize_tree(restore_args, keep_empty_nodes=True)

return (
ocp.tree.from_flat_dict(flat_param_infos, target=structure),
restore_args,
)
15 changes: 3 additions & 12 deletions t5x/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2165,18 +2165,9 @@ def _modify_orbax_param_info(info, value):
)
return info

item_ = jax.tree_util.tree_map(
_make_orbax_internal_metadata, item_, restore_args
)
param_infos_, _ = ocp.pytree_checkpoint_handler._get_restore_parameters( # pylint: disable=protected-access
directory_,
None,
item_,
None,
None,
None,
)
param_infos_ = jax.tree_util.tree_map(
item_ = jax.tree.map(_make_orbax_internal_metadata, item_, restore_args)
param_infos_, _ = checkpoint_utils.get_restore_parameters(directory_, item_)
param_infos_ = jax.tree.map(
_modify_orbax_param_info, param_infos_, state_dict_to_restore
)
return item_, param_infos_
Expand Down

0 comments on commit 8414ac6

Please sign in to comment.