Skip to content

Commit

Permalink
Don't autoset per-device batch size
Browse files Browse the repository at this point in the history
  • Loading branch information
corystephenson-db committed Oct 6, 2024
1 parent 9825134 commit b5aa661
Showing 1 changed file with 0 additions and 10 deletions.
10 changes: 0 additions & 10 deletions diffusion/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,6 @@ def train(config: DictConfig) -> None:

# Load train dataset. Need to ensure that the per-device batch size is added as a streaming kwarg
per_device_train_batch_size = config.dataset.train_batch_size // dist.get_world_size()
if 'streaming_kwargs' in config.dataset.train_dataset:
if 'batch_size' not in config.dataset.train_dataset.streaming_kwargs:
config.dataset.train_dataset.streaming_kwargs['batch_size'] = per_device_train_batch_size
else:
config.dataset.train_dataset.streaming_kwargs = {'batch_size': per_device_train_batch_size}
if tokenizer:
train_dataloader: Union[Iterable, DataSpec, Dict[str, Any]] = hydra.utils.instantiate(
config.dataset.train_dataset,
Expand Down Expand Up @@ -154,11 +149,6 @@ def train(config: DictConfig) -> None:
else:
# Need to ensure that the per-device batch size is added as a streaming kwarg
per_device_eval_batch_size = config.dataset.eval_batch_size // dist.get_world_size()
if 'streaming_kwargs' in config.dataset.eval_dataset:
if 'batch_size' not in config.dataset.eval_dataset.streaming_kwargs:
config.dataset.eval_dataset.streaming_kwargs['batch_size'] = per_device_eval_batch_size
else:
config.dataset.eval_dataset.streaming_kwargs = {'batch_size': per_device_eval_batch_size}
if tokenizer:
eval_set = hydra.utils.instantiate(config.dataset.eval_dataset,
tokenizer=model.tokenizer,
Expand Down

0 comments on commit b5aa661

Please sign in to comment.