From 707995a3a8238e0c3557d3cc1318a883215c54c9 Mon Sep 17 00:00:00 2001 From: Yaning Liang Date: Thu, 14 Mar 2024 14:37:19 -0700 Subject: [PATCH] Allow clu.data.dataset_iterator.DatasetIterator in addition to tf.data.Iterator PiperOrigin-RevId: 615907993 --- t5x/checkpoints.py | 41 ++++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/t5x/checkpoints.py b/t5x/checkpoints.py index dac5d9e8f..e21afa8b7 100644 --- a/t5x/checkpoints.py +++ b/t5x/checkpoints.py @@ -2015,7 +2015,11 @@ class DatasetCheckpointHandler(ocp.CheckpointHandler): def __init__(self, checkpoint_filename: str): self._checkpoint_filename = checkpoint_filename - def save(self, directory: epath.Path, item: tf.data.Iterator): + def save( + self, + directory: epath.Path, + item: Union[tf.data.Iterator, clu.data.dataset_iterator.DatasetIterator], + ): """Saves the given item. Args: @@ -2025,13 +2029,20 @@ def save(self, directory: epath.Path, item: tf.data.Iterator): if jax.process_count() > 1: directory /= f'process_{jax.process_index()}-of-{jax.process_count()}' directory.mkdir(parents=False, exist_ok=False) - ckpt = tf.train.Checkpoint(ds=item) - ckpt.write(os.fspath(directory / self._checkpoint_filename)) + if isinstance(item, tf.data.Iterator): + ckpt = tf.train.Checkpoint(ds=item) + ckpt.write(os.fspath(directory / self._checkpoint_filename)) + elif isinstance(item, clu.data.dataset_iterator.DatasetIterator): + item.save(os.fspath(directory / self._checkpoint_filename)) multihost_utils.sync_global_devices('DatasetCheckpointHandler:save') def restore( - self, directory: epath.Path, item: Optional[tf.data.Iterator] = None - ) -> tf.data.Iterator: + self, + directory: epath.Path, + item: Optional[ + Union[tf.data.Iterator, clu.data.dataset_iterator.DatasetIterator] + ] = None, + ) -> Union[tf.data.Iterator, clu.data.dataset_iterator.DatasetIterator]: """Restores the given item. Args: @@ -2045,10 +2056,13 @@ def restore( raise ValueError('Must provide item to restore') if jax.process_count() > 1: directory /= f'process_{jax.process_index()}-of-{jax.process_count()}' - ckpt = tf.train.Checkpoint(ds=item) - ckpt.read( - os.fspath(directory / self._checkpoint_filename) - ).assert_consumed() + if isinstance(item, tf.data.Iterator): + ckpt = tf.train.Checkpoint(ds=item) + ckpt.read( + os.fspath(directory / self._checkpoint_filename) + ).assert_consumed() + elif isinstance(item, clu.data.dataset_iterator.DatasetIterator): + item.load(os.fspath(directory / self._checkpoint_filename)) return item def structure(self, directory: epath.Path) -> Any: @@ -2259,7 +2273,9 @@ def __init__( directory: str, train_state: train_state_lib.TrainState, partitioner: partitioning.BasePartitioner, - dataset_iterator: Optional[tf.data.Iterator] = None, + dataset_iterator: Optional[ + Union[tf.data.Iterator, clu.data.dataset_iterator.DatasetIterator] + ] = None, save_dtype: Optional[jnp.dtype] = None, restore_dtype: Optional[jnp.dtype] = None, keep: Optional[int] = None, @@ -2276,11 +2292,14 @@ def __init__( del keep_dataset_checkpoints self._train_state = train_state self._partitioner = partitioner + if isinstance( + dataset_iterator, clu.data.dataset_iterator.TfDatasetIterator + ): + assert dataset_iterator._checkpoint self._dataset_iterator = dataset_iterator self._save_dtype = save_dtype self._restore_dtype = restore_dtype self._tmp_directory: Optional[epath.PathLike] = None - data_layout = partitioner.get_data_layout() dataset_ckpt_name = ( f'{_TRAIN_DS_PREFIX}-'