From 78fa9a99a5101b47ad1629c753d53c2ec8bd8cc0 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 15 Oct 2024 23:26:05 -0700 Subject: [PATCH] pass through --- llmfoundry/command_utils/train.py | 1 + llmfoundry/utils/config_utils.py | 1 + 2 files changed, 2 insertions(+) diff --git a/llmfoundry/command_utils/train.py b/llmfoundry/command_utils/train.py index 280ff0e0d5..5bf984b1ce 100644 --- a/llmfoundry/command_utils/train.py +++ b/llmfoundry/command_utils/train.py @@ -592,6 +592,7 @@ def train(cfg: DictConfig) -> Trainer: profiler=profiler, compile_config=compile_config, spin_dataloaders=train_cfg.spin_dataloaders, + accumulate_train_batch_on_tokens=train_cfg.accumulate_train_batch_on_tokens, ) _sort_callbacks(trainer) diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 18112c18aa..64514c528d 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -121,6 +121,7 @@ class TrainConfig: dist_timeout: Union[int, float] = 600.0 fsdp_config: Optional[dict[str, Any]] = None tp_config: Optional[dict[str, Any]] = None + accumulate_train_batch_on_tokens: bool = False # Evaluation parameters eval_interval: Union[int, str] = 1