From b5aa661d8eb76711a7409fc7119d650384ebb996 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Sun, 6 Oct 2024 04:56:55 +0000 Subject: [PATCH] Don't autoset per-device batch size --- diffusion/train.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/diffusion/train.py b/diffusion/train.py index e0d08da0..10f73c1e 100644 --- a/diffusion/train.py +++ b/diffusion/train.py @@ -105,11 +105,6 @@ def train(config: DictConfig) -> None: # Load train dataset. Need to ensure that the per-device batch size is added as a streaming kwarg per_device_train_batch_size = config.dataset.train_batch_size // dist.get_world_size() - if 'streaming_kwargs' in config.dataset.train_dataset: - if 'batch_size' not in config.dataset.train_dataset.streaming_kwargs: - config.dataset.train_dataset.streaming_kwargs['batch_size'] = per_device_train_batch_size - else: - config.dataset.train_dataset.streaming_kwargs = {'batch_size': per_device_train_batch_size} if tokenizer: train_dataloader: Union[Iterable, DataSpec, Dict[str, Any]] = hydra.utils.instantiate( config.dataset.train_dataset, @@ -154,11 +149,6 @@ def train(config: DictConfig) -> None: else: # Need to ensure that the per-device batch size is added as a streaming kwarg per_device_eval_batch_size = config.dataset.eval_batch_size // dist.get_world_size() - if 'streaming_kwargs' in config.dataset.eval_dataset: - if 'batch_size' not in config.dataset.eval_dataset.streaming_kwargs: - config.dataset.eval_dataset.streaming_kwargs['batch_size'] = per_device_eval_batch_size - else: - config.dataset.eval_dataset.streaming_kwargs = {'batch_size': per_device_eval_batch_size} if tokenizer: eval_set = hydra.utils.instantiate(config.dataset.eval_dataset, tokenizer=model.tokenizer,