diff --git a/llmfoundry/command_utils/train.py b/llmfoundry/command_utils/train.py index 280ff0e0d5..cb287b029c 100644 --- a/llmfoundry/command_utils/train.py +++ b/llmfoundry/command_utils/train.py @@ -592,6 +592,8 @@ def train(cfg: DictConfig) -> Trainer: profiler=profiler, compile_config=compile_config, spin_dataloaders=train_cfg.spin_dataloaders, + accumulate_train_batch_on_tokens=train_cfg. + accumulate_train_batch_on_tokens, ) _sort_callbacks(trainer) diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 18112c18aa..64514c528d 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -121,6 +121,7 @@ class TrainConfig: dist_timeout: Union[int, float] = 600.0 fsdp_config: Optional[dict[str, Any]] = None tp_config: Optional[dict[str, Any]] = None + accumulate_train_batch_on_tokens: bool = False # Evaluation parameters eval_interval: Union[int, str] = 1