From 6b98ffb88dc34d774d5cc85e7e7251aeb053ac59 Mon Sep 17 00:00:00 2001 From: Charles Tang Date: Tue, 15 Aug 2023 17:07:56 -0700 Subject: [PATCH] Add runtime error in train.py if yaml config is improperly formatted with extraneous or missing values (#506) ## Description This PR enables us to sanity check our train yaml configuration files before we run the full training pipeline. This enables us to catch errors in the YAML config before a training run starts. If a yaml config is improperly formatted with extraneous or missing values, a `NameError` will be thrown. ## Unit Test: Added unit tests to make sure we raise a `NameError` and/or warn the user if the yaml is incorrectly formatted. Warnings are used if the parameter has an optional default value. ## Integration Test: Before and after training runs show the same loss curves throughout training. Screenshot 2023-08-15 at 10 57 55 AM **mpt-125m training branch master**: https://wandb.ai/mosaic-ml/chuck-runs/runs/2jq1uy86 **mpt-125m training branch chuck/add_yaml_sanity_check_train**: https://wandb.ai/mosaic-ml/chuck-runs/runs/n8wlqwec All runs at: https://wandb.ai/mosaic-ml/chuck-runs?workspace=user-chuck-tang98 --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/utils/config_utils.py | 27 ++- scripts/train/train.py | 354 ++++++++++++++++++++++--------- tests/test_train_inputs.py | 89 ++++++++ 3 files changed, 371 insertions(+), 99 deletions(-) create mode 100644 tests/test_train_inputs.py diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 6c12775bfc..aa210b3b37 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -4,7 +4,7 @@ import contextlib import math import warnings -from typing import Dict, Optional, Union +from typing import Any, Dict, Optional, Union from composer.utils import dist from omegaconf import DictConfig @@ -13,6 +13,26 @@ from llmfoundry.models.utils import init_empty_weights +def pop_config(cfg: DictConfig, + key: str, + must_exist: bool = True, + default_value: Any = None) -> Any: + """Pop a value from the main config file and return it. + + If the key does not exist, return the default_value or raise a RuntimeError + depending on the must_exist flag. + """ + value = cfg.pop(key, None) + if value is not None: + return value + elif must_exist: + raise NameError( + f'The {key} parameter is missing and must exist for execution. Please check your yaml.' + ) + else: + return default_value + + def calculate_batch_size_info(global_batch_size: int, device_microbatch_size: Union[int, str]): if global_batch_size % dist.get_world_size() != 0: @@ -90,6 +110,11 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]): def log_config(cfg: DictConfig): + """Logs the current config and updates the wandb and mlflow configs. + + This function can be called multiple times to update the wandb and MLflow + config with different variables. + """ print(om.to_yaml(cfg)) if 'wandb' in cfg.get('loggers', {}): try: diff --git a/scripts/train/train.py b/scripts/train/train.py index 7cceed46d3..2953d080fc 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -1,15 +1,16 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import copy import os import sys import warnings -from typing import Dict +from typing import Dict, List, Optional, Union import torch from composer import Trainer from composer.core import Evaluator from composer.utils import dist, get_device, reproducibility -from omegaconf import DictConfig +from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om from transformers import PreTrainedTokenizerBase @@ -21,7 +22,8 @@ build_icl_evaluators, build_logger, build_optimizer, build_scheduler, build_tokenizer) -from llmfoundry.utils.config_utils import (log_config, process_init_device, +from llmfoundry.utils.config_utils import (log_config, pop_config, + process_init_device, update_batch_size_info) @@ -119,7 +121,7 @@ def build_composer_peft_model( print('Building model from HuggingFace checkpoint...') model = MPTForCausalLM.from_pretrained( - cfg.model.pretrained_model_name_or_path, trust_remote_code=True) + model_cfg.pretrained_model_name_or_path, trust_remote_code=True) print('Model built!') print('Adding Lora modules...') @@ -164,177 +166,333 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, tokenizer, device_batch_size, ) - else: raise ValueError(f'Not sure how to build dataloader with config: {cfg}') def main(cfg: DictConfig): - # Check for incompatibilities between the model and data loaders - validate_config(cfg) - - max_split_size_mb = cfg.get('max_split_size_mb', None) - if max_split_size_mb is not None: - os.environ[ - 'PYTORCH_CUDA_ALLOC_CONF'] = f'max_split_size_mb:{max_split_size_mb}' - # Filter deprecation warning from torch internal usage warnings.filterwarnings( action='ignore', category=UserWarning, message= - f'torch.distributed.*_base is a private function and will be deprecated.*' + 'torch.distributed.*_base is a private function and will be deprecated.*' ) - cfg.dist_timeout = cfg.get('dist_timeout', 600.0) + # Check for incompatibilities between the model and data loaders + validate_config(cfg) + + # Resolve all interpolation variables as early as possible + om.resolve(cfg) + + # Create copy of config for logging + logged_cfg: DictConfig = copy.deepcopy(cfg) + + # Get max split size mb + max_split_size_mb: Optional[int] = cfg.pop('max_split_size_mb', None) + if max_split_size_mb is not None: + os.environ[ + 'PYTORCH_CUDA_ALLOC_CONF'] = f'max_split_size_mb:{max_split_size_mb}' - reproducibility.seed_all(cfg.seed) - dist.initialize_dist(get_device(None), timeout=cfg.dist_timeout) + # Set seed first + seed: int = pop_config(cfg, 'seed', must_exist=True) + reproducibility.seed_all(seed) - # Run Name - if cfg.get('run_name') is None: - cfg.run_name = os.environ.get('RUN_NAME', 'llm') + # Initialize pytorch distributed training process groups + dist_timeout: Union[int, float] = pop_config(cfg, + 'dist_timeout', + must_exist=False, + default_value=600.0) + dist.initialize_dist(get_device(None), timeout=dist_timeout) - # Get batch size info + # Get global and device batch size information from distributed/single node setting cfg = update_batch_size_info(cfg) + logged_cfg.update(cfg, merge=True) + + # Mandatory model training configs + model_config: DictConfig = pop_config(cfg, 'model', must_exist=True) + tokenizer_config: DictConfig = pop_config(cfg, 'tokenizer', must_exist=True) + optimizer_config: DictConfig = pop_config(cfg, 'optimizer', must_exist=True) + scheduler_config: DictConfig = pop_config(cfg, 'scheduler', must_exist=True) + train_loader_config: DictConfig = pop_config(cfg, + 'train_loader', + must_exist=True) + + # Optional fsdp data, fine-tuning, and eval configs + fsdp_dict_config: Optional[DictConfig] = pop_config(cfg, + 'fsdp_config', + must_exist=False, + default_value=None) + fsdp_config: Optional[Dict] = om.to_container( + fsdp_dict_config + ) if fsdp_dict_config is not None else None # type: ignore + lora_config: Optional[DictConfig] = pop_config(cfg, + 'lora', + must_exist=False, + default_value=None) + eval_loader_config: Optional[DictConfig] = pop_config(cfg, + 'eval_loader', + must_exist=False, + default_value=None) + icl_tasks_config: Optional[ListConfig] = pop_config(cfg, + 'icl_tasks', + must_exist=False, + default_value=None) + + # Optional logging, evaluation and callback configs + logger_configs: Optional[DictConfig] = pop_config(cfg, + 'loggers', + must_exist=False, + default_value=None) + callback_configs: Optional[DictConfig] = pop_config(cfg, + 'callbacks', + must_exist=False, + default_value=None) + algorithm_configs: Optional[DictConfig] = pop_config(cfg, + 'algorithms', + must_exist=False, + default_value=None) + + # Mandatory hyperparameters for training + device_train_batch_size: int = pop_config(cfg, + 'device_train_batch_size', + must_exist=True) + device_eval_batch_size: int = pop_config(cfg, + 'device_eval_batch_size', + must_exist=True) + max_duration: Union[int, str] = pop_config(cfg, + 'max_duration', + must_exist=True) + eval_interval: Union[int, str] = pop_config(cfg, + 'eval_interval', + must_exist=True) + precision: str = pop_config(cfg, 'precision', must_exist=True) + max_seq_len: int = pop_config(cfg, 'max_seq_len', must_exist=True) + + # Optional parameters will be set to default values if not specified. + default_run_name: str = os.environ.get('RUN_NAME', 'llm') + run_name: str = pop_config(cfg, + 'run_name', + must_exist=False, + default_value=default_run_name) + save_folder: Optional[str] = pop_config(cfg, + 'save_folder', + must_exist=False, + default_value=None) + save_latest_filename: str = pop_config(cfg, + 'save_latest_filename', + must_exist=False, + default_value='latest-rank{rank}.pt') + save_overwrite: bool = pop_config(cfg, + 'save_overwrite', + must_exist=False, + default_value=False) + save_weights_only: bool = pop_config(cfg, + 'save_weights_only', + must_exist=False, + default_value=False) + save_filename: str = pop_config( + cfg, + 'save_filename', + must_exist=False, + default_value='ep{epoch}-ba{batch}-rank{rank}.pt') + save_interval: Union[str, int] = pop_config(cfg, + 'save_interval', + must_exist=False, + default_value='1000ba') + save_num_checkpoints_to_keep: int = pop_config( + cfg, 'save_num_checkpoints_to_keep', must_exist=False, default_value=-1) + progress_bar = pop_config(cfg, + 'progress_bar', + must_exist=False, + default_value=False) + log_to_console: bool = pop_config(cfg, + 'log_to_console', + must_exist=False, + default_value=True) + python_log_level: str = pop_config(cfg, + 'python_log_level', + must_exist=False, + default_value='debug') + console_log_interval: Union[int, str] = pop_config(cfg, + 'console_log_interval', + must_exist=False, + default_value='1ba') + device_train_microbatch_size: Union[str, int] = pop_config( + cfg, + 'device_train_microbatch_size', + must_exist=False, + default_value='auto') + eval_subset_num_batches: int = pop_config(cfg, + 'eval_subset_num_batches', + must_exist=False, + default_value=-1) + eval_first: bool = pop_config(cfg, + 'eval_first', + must_exist=False, + default_value=False) + load_path: str = pop_config(cfg, + 'load_path', + must_exist=False, + default_value=None) + load_weights_only: bool = pop_config(cfg, + 'load_weights_only', + must_exist=False, + default_value=False) + load_ignore_keys: Optional[List[str]] = pop_config(cfg, + 'load_ignore_keys', + must_exist=False, + default_value=None) + # Enable autoresume from model checkpoints if possible + autoresume_default: bool = False + if logged_cfg.get('run_name', None) is not None \ + and save_folder is not None \ + and not save_overwrite \ + and not save_weights_only: + print('As run_name, save_folder, and save_latest_filename are set, \ + changing autoresume default to True...') + autoresume_default = True + autoresume: bool = pop_config(cfg, + 'autoresume', + must_exist=False, + default_value=autoresume_default) + + # Pop known unused parameters that are used as interpolation variables or + # created by update_batch_size_info. + pop_config(cfg, 'data_local', must_exist=False) + pop_config(cfg, 'data_remote', must_exist=False) + pop_config(cfg, 'global_seed', must_exist=False) + pop_config(cfg, 'global_train_batch_size', must_exist=False) + pop_config(cfg, 'n_gpus', must_exist=False) + pop_config(cfg, 'device_train_grad_accum', must_exist=False) + + # Warn users for unused parameters + for key in cfg: + warnings.warn( + f'Unused parameter {key} found in cfg. Please check your yaml to ensure this parameter is necessary.' + ) - # Read FSDP Config as a dict - fsdp_config = cfg.get('fsdp_config', None) - fsdp_config = om.to_container(fsdp_config, - resolve=True) if fsdp_config else None - assert isinstance(fsdp_config, Dict) or fsdp_config is None + # Warn if fsdp is enabled but user only has 1 GPU if dist.get_world_size() == 1 and fsdp_config is not None: warnings.warn( 'FSDP is not applicable for single-GPU training. Reverting to DDP.') - cfg.pop('fsdp_config') fsdp_config = None - init_context = process_init_device(cfg.model, fsdp_config) + # Initialize context + init_context = process_init_device(model_config, fsdp_config) + logged_cfg.update({'fsdp_config': fsdp_config}, merge=True) - # build tokenizer - tokenizer = build_tokenizer(cfg.tokenizer) + # Build tokenizer + tokenizer = build_tokenizer(tokenizer_config) # Build Model print('Initializing model...') with init_context: - if cfg.get('lora', - None) is not None: # frozen model + trainable lora modules + if lora_config is not None: # frozen model + trainable lora modules model: ComposerHFCausalLM = build_composer_peft_model( - cfg.model, cfg.lora, tokenizer) + model_config, lora_config, tokenizer) print_trainable_parameters(model) # should not be 100% else: # standard model - model = build_composer_model(cfg.model, tokenizer) - cfg.n_params = sum(p.numel() for p in model.parameters()) - print(f'{cfg.n_params=:.2e}') + model = build_composer_model(model_config, tokenizer) + + # Log number of parameters + n_params = sum(p.numel() for p in model.parameters()) + logged_cfg.update({'n_params': n_params}) # Dataloaders print('Building train loader...') train_loader = build_dataloader( - cfg.train_loader, + train_loader_config, tokenizer, - cfg.device_train_batch_size, + device_train_batch_size, ) + + ## Evaluation print('Building eval loader...') evaluators = [] - if 'eval_loader' in cfg: + if eval_loader_config is not None: assert model.train_metrics is not None + eval_dataloader = build_dataloader(eval_loader_config, tokenizer, + device_eval_batch_size) + eval_metric_names = list(model.train_metrics.keys()) eval_loader = Evaluator(label='eval', - dataloader=build_dataloader( - cfg.eval_loader, tokenizer, - cfg.device_eval_batch_size), - metric_names=list(model.train_metrics.keys())) + dataloader=eval_dataloader, + metric_names=eval_metric_names) evaluators.append(eval_loader) - if 'icl_tasks' in cfg: - icl_evaluators, _ = build_icl_evaluators(cfg.icl_tasks, tokenizer, - cfg.max_seq_len, - cfg.device_eval_batch_size) + if icl_tasks_config is not None: + icl_evaluators, _ = build_icl_evaluators(icl_tasks_config, tokenizer, + max_seq_len, + device_eval_batch_size) evaluators.extend(icl_evaluators) # Optimizer - optimizer = build_optimizer(cfg.optimizer, model) + optimizer = build_optimizer(optimizer_config, model) # Scheduler - scheduler = build_scheduler(cfg.scheduler) + scheduler = build_scheduler(scheduler_config) # Loggers loggers = [ - build_logger(name, logger_cfg) - for name, logger_cfg in (cfg.get('loggers') or {}).items() - ] + build_logger(str(name), logger_cfg) + for name, logger_cfg in logger_configs.items() + ] if logger_configs else None # Callbacks callbacks = [ - build_callback(name, callback_cfg) - for name, callback_cfg in (cfg.get('callbacks') or {}).items() - ] + build_callback(str(name), callback_cfg) + for name, callback_cfg in callback_configs.items() + ] if callback_configs else None # Algorithms algorithms = [ - build_algorithm(name, algorithm_cfg) - for name, algorithm_cfg in (cfg.get('algorithms') or {}).items() - ] - - # Set autoresume default on if possible - save_folder = cfg.get('save_folder', None) - save_latest_filename = cfg.get('save_latest_filename', - 'latest-rank{rank}.pt') - save_overwrite = cfg.get('save_overwrite', False) - save_weights_only = cfg.get('save_weights_only', False) - autoresume_default = False - if cfg.run_name is not None and save_folder is not None and save_latest_filename is not None and not save_overwrite and not save_weights_only: - print( - 'As run_name, save_folder, and save_latest_filename are set, changing autoresume default to True...' - ) - autoresume_default = True + build_algorithm(str(name), algorithm_cfg) + for name, algorithm_cfg in algorithm_configs.items() + ] if algorithm_configs else None # Build the Trainer print('Building trainer...') trainer = Trainer( - run_name=cfg.run_name, - seed=cfg.seed, + run_name=run_name, + seed=seed, model=model, train_dataloader=train_loader, eval_dataloader=evaluators, optimizers=optimizer, schedulers=scheduler, - max_duration=cfg.max_duration, - eval_interval=cfg.eval_interval, - eval_subset_num_batches=cfg.get('eval_subset_num_batches', -1), - progress_bar=cfg.get('progress_bar', False), - log_to_console=cfg.get('log_to_console', True), - console_log_interval=cfg.get('console_log_interval', '1ba'), + max_duration=max_duration, + eval_interval=eval_interval, + eval_subset_num_batches=eval_subset_num_batches, + progress_bar=progress_bar, + log_to_console=log_to_console, + console_log_interval=console_log_interval, loggers=loggers, callbacks=callbacks, - precision=cfg.precision, + precision=precision, algorithms=algorithms, - device_train_microbatch_size=cfg.get('device_train_microbatch_size', - 'auto'), + device_train_microbatch_size=device_train_microbatch_size, fsdp_config=fsdp_config, # type: ignore save_folder=save_folder, - save_filename=cfg.get('save_filename', - 'ep{epoch}-ba{batch}-rank{rank}.pt'), + save_filename=save_filename, save_latest_filename=save_latest_filename, - save_interval=cfg.get('save_interval', '1000ba'), - save_num_checkpoints_to_keep=cfg.get('save_num_checkpoints_to_keep', - -1), + save_interval=save_interval, + save_num_checkpoints_to_keep=save_num_checkpoints_to_keep, save_overwrite=save_overwrite, save_weights_only=save_weights_only, - load_path=cfg.get('load_path', None), - load_weights_only=cfg.get('load_weights_only', False), - load_ignore_keys=cfg.get('load_ignore_keys', None), - autoresume=cfg.get('autoresume', autoresume_default), - python_log_level=cfg.get('python_log_level', 'debug'), - dist_timeout=cfg.dist_timeout, + load_path=load_path, + load_weights_only=load_weights_only, + load_ignore_keys=load_ignore_keys, + autoresume=autoresume, + python_log_level=python_log_level, + dist_timeout=dist_timeout, ) + print('Logging config') + log_config(logged_cfg) torch.cuda.empty_cache() - print('Logging config...') - log_config(cfg) - - if cfg.get('eval_first', - False) and trainer.state.timestamp.batch.value == 0: + # Eval first if requested + if eval_first and trainer.state.timestamp.batch.value == 0: trainer.eval() print('Starting training...') diff --git a/tests/test_train_inputs.py b/tests/test_train_inputs.py new file mode 100644 index 0000000000..857a2b8c57 --- /dev/null +++ b/tests/test_train_inputs.py @@ -0,0 +1,89 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 +import copy +import os +import sys +import warnings + +import omegaconf +import pytest +from omegaconf import DictConfig +from omegaconf import OmegaConf as om + +# Add repo root to path so we can import scripts and test it +repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +sys.path.append(repo_dir) + +from scripts.train.train import main # noqa: E402 + + +class TestTrainingYAMLInputs: + """Validate and tests error handling for the input YAML file.""" + + @pytest.fixture + def cfg(self) -> DictConfig: + """Create YAML cfg fixture for testing purposes.""" + conf_path: str = os.path.join( + repo_dir, 'scripts/train/yamls/pretrain/testing.yaml') + with open(conf_path, 'r', encoding='utf-8') as config: + test_cfg = om.load(config) + assert isinstance(test_cfg, DictConfig) + return test_cfg + + def test_mispelled_mandatory_params_fail(self, cfg: DictConfig) -> None: + """Check that mandatory mispelled inputs fail to train.""" + cfg.trai_loader = cfg.pop('train_loader') + with pytest.raises(omegaconf.errors.ConfigAttributeError): + main(cfg) + + def test_missing_mandatory_parameters_fail(self, cfg: DictConfig) -> None: + """Check that missing mandatory parameters fail to train.""" + mandatory_params = [ + 'train_loader', + 'model', + 'tokenizer', + 'optimizer', + 'scheduler', + 'max_duration', + 'eval_interval', + 'precision', + 'max_seq_len', + ] + for param in mandatory_params: + orig_param = cfg.pop(param) + with pytest.raises( + (omegaconf.errors.ConfigAttributeError, NameError)): + main(cfg) + cfg[param] = orig_param + + def test_optional_mispelled_params_raise_warning(self, + cfg: DictConfig) -> None: + """Check that warnings are raised for optional mispelled parameters.""" + optional_params = [ + 'save_weights_only', + 'save_filename', + 'run_name', + 'progress_bar', + 'python_log_level', + 'eval_first', + 'autoresume', + 'save_folder', + 'fsdp_config', + 'lora_config', + 'eval_loader_config', + 'icl_tasks_config', + ] + old_cfg = copy.deepcopy(cfg) + for param in optional_params: + orig_value = cfg.pop(param, None) + updated_param = param + '-mispelling' + cfg[updated_param] = orig_value + with warnings.catch_warnings(record=True) as warning_list: + try: + main(cfg) + except: + pass + assert any(f'Unused parameter {updated_param} found in cfg.' in + str(warning.message) for warning in warning_list) + # restore configs. + cfg = copy.deepcopy(old_cfg)