diff --git a/t5x/checkpoints.py b/t5x/checkpoints.py index 75e4ac331..d429c305d 100644 --- a/t5x/checkpoints.py +++ b/t5x/checkpoints.py @@ -2056,7 +2056,7 @@ def _construct_save_args( """Create SaveArgs for Orbax saving.""" if param_info.name.split('.')[0] != 'target': dtype = None - return ocp.SaveArgs(aggregate=False, dtype=dtype) + return ocp.SaveArgs(dtype=dtype) def _construct_restore_args(