Skip to content

Commit

Permalink
De-duplicate get_ts_context usages and move to ts_utils.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686540442
  • Loading branch information
cpgaffney1 authored and t5-copybara committed Oct 16, 2024
1 parent 0ff8254 commit 133590b
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion t5x/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,9 @@ def get_restore_parameters(
restore_args = jax.tree.map(lambda x: ocp.RestoreArgs(), structure)
flat_param_infos = {}
is_ocdbt_checkpoint = ocp.type_handlers.is_ocdbt_checkpoint(directory)
ts_context = ocp.type_handlers.get_ts_context()
ts_context = ocp.serialization.ts_utils.get_ts_context(
use_ocdbt=is_ocdbt_checkpoint
)

def _get_param_info(
name: str,
Expand Down

0 comments on commit 133590b

Please sign in to comment.