From 829465cbcf96914de7f27062f20fd55d004c8684 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Mon, 7 Oct 2024 16:38:27 +0000 Subject: [PATCH] Pass batch size to dataset class --- diffusion/datasets/image_caption_latents.py | 1 + diffusion/train.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/diffusion/datasets/image_caption_latents.py b/diffusion/datasets/image_caption_latents.py index 01ae1e8d..8ca46c14 100644 --- a/diffusion/datasets/image_caption_latents.py +++ b/diffusion/datasets/image_caption_latents.py @@ -273,6 +273,7 @@ def build_streaming_image_caption_latents_dataloader( text_latent_shapes=text_latent_shapes, attention_mask_keys=attention_mask_keys, latent_dtype=dtype, + batch_size=batch_size, **streaming_kwargs, ) diff --git a/diffusion/train.py b/diffusion/train.py index 10f73c1e..becff0f1 100644 --- a/diffusion/train.py +++ b/diffusion/train.py @@ -103,18 +103,19 @@ def train(config: DictConfig) -> None: else: optimizer = hydra.utils.instantiate(config.optimizer, params=model.parameters()) - # Load train dataset. Need to ensure that the per-device batch size is added as a streaming kwarg - per_device_train_batch_size = config.dataset.train_batch_size // dist.get_world_size() + # Load train dataset. Currently this expects to load according to the datasetHparam method. + # This means adding external datasets is currently not super easy. Will refactor or check for + # upstream composer changes that could make this easier. if tokenizer: train_dataloader: Union[Iterable, DataSpec, Dict[str, Any]] = hydra.utils.instantiate( config.dataset.train_dataset, tokenizer=tokenizer, - batch_size=per_device_train_batch_size, + batch_size=config.dataset.train_batch_size // dist.get_world_size(), ) else: train_dataloader: Union[Iterable, DataSpec, Dict[str, Any]] = hydra.utils.instantiate( config.dataset.train_dataset, - batch_size=per_device_train_batch_size, + batch_size=config.dataset.train_batch_size // dist.get_world_size(), ) # Need to sleep for a bit to avoid dataloader crash time.sleep(10) @@ -147,14 +148,13 @@ def train(config: DictConfig) -> None: eval_set = evaluators else: - # Need to ensure that the per-device batch size is added as a streaming kwarg - per_device_eval_batch_size = config.dataset.eval_batch_size // dist.get_world_size() if tokenizer: eval_set = hydra.utils.instantiate(config.dataset.eval_dataset, tokenizer=model.tokenizer, - batch_size=per_device_eval_batch_size) + batch_size=config.dataset.eval_batch_size // dist.get_world_size()) else: - eval_set = hydra.utils.instantiate(config.dataset.eval_dataset, batch_size=per_device_eval_batch_size) + eval_set = hydra.utils.instantiate(config.dataset.eval_dataset, + batch_size=config.dataset.eval_batch_size // dist.get_world_size()) # Need to sleep for a bit to avoid dataloader crash time.sleep(10)