diff --git a/t5x/checkpoints.py b/t5x/checkpoints.py index b39b874fb..315604c87 100644 --- a/t5x/checkpoints.py +++ b/t5x/checkpoints.py @@ -2504,15 +2504,19 @@ def restore( # After restoration, some values may still be non-sharded arrays from # fallback state. def _maybe_make_sharded_array_helper(arr, info): - return _maybe_make_sharded_array( - arr, - self._partitioner.mesh, - axes=info.mesh_axes, - restore_dtype=self._restore_dtype, - ) + if arr is not None: + return _maybe_make_sharded_array( + arr, + self._partitioner.mesh, + axes=info.mesh_axes, + restore_dtype=self._restore_dtype, + ) state_dict = jax.tree_util.tree_map( - _maybe_make_sharded_array_helper, state_dict, param_infos + _maybe_make_sharded_array_helper, + state_dict, + param_infos, + is_leaf=lambda x: x is None, ) train_state = self._train_state.restore_state(state_dict)