Skip to content

Commit

Permalink
Fix partial checkpoint loading caused by changes in jax's flatten_up_…
Browse files Browse the repository at this point in the history
…to behaviour.

PiperOrigin-RevId: 683097942
  • Loading branch information
T5X Team authored and t5-copybara committed Oct 7, 2024
1 parent 705247b commit 79d58f1
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions t5x/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2504,15 +2504,19 @@ def restore(
# After restoration, some values may still be non-sharded arrays from
# fallback state.
def _maybe_make_sharded_array_helper(arr, info):
return _maybe_make_sharded_array(
arr,
self._partitioner.mesh,
axes=info.mesh_axes,
restore_dtype=self._restore_dtype,
)
if arr is not None:
return _maybe_make_sharded_array(
arr,
self._partitioner.mesh,
axes=info.mesh_axes,
restore_dtype=self._restore_dtype,
)

state_dict = jax.tree_util.tree_map(
_maybe_make_sharded_array_helper, state_dict, param_infos
_maybe_make_sharded_array_helper,
state_dict,
param_infos,
is_leaf=lambda x: x is None,
)

train_state = self._train_state.restore_state(state_dict)
Expand Down

0 comments on commit 79d58f1

Please sign in to comment.