diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 6c12775bfc..aa210b3b37 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -4,7 +4,7 @@ import contextlib import math import warnings -from typing import Dict, Optional, Union +from typing import Any, Dict, Optional, Union from composer.utils import dist from omegaconf import DictConfig @@ -13,6 +13,26 @@ from llmfoundry.models.utils import init_empty_weights +def pop_config(cfg: DictConfig, + key: str, + must_exist: bool = True, + default_value: Any = None) -> Any: + """Pop a value from the main config file and return it. + + If the key does not exist, return the default_value or raise a RuntimeError + depending on the must_exist flag. + """ + value = cfg.pop(key, None) + if value is not None: + return value + elif must_exist: + raise NameError( + f'The {key} parameter is missing and must exist for execution. Please check your yaml.' + ) + else: + return default_value + + def calculate_batch_size_info(global_batch_size: int, device_microbatch_size: Union[int, str]): if global_batch_size % dist.get_world_size() != 0: @@ -90,6 +110,11 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]): def log_config(cfg: DictConfig): + """Logs the current config and updates the wandb and mlflow configs. + + This function can be called multiple times to update the wandb and MLflow + config with different variables. + """ print(om.to_yaml(cfg)) if 'wandb' in cfg.get('loggers', {}): try: diff --git a/scripts/train/train.py b/scripts/train/train.py index 7cceed46d3..2953d080fc 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -1,15 +1,16 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import copy import os import sys import warnings -from typing import Dict +from typing import Dict, List, Optional, Union import torch from composer import Trainer from composer.core import Evaluator from composer.utils import dist, get_device, reproducibility -from omegaconf import DictConfig +from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om from transformers import PreTrainedTokenizerBase @@ -21,7 +22,8 @@ build_icl_evaluators, build_logger, build_optimizer, build_scheduler, build_tokenizer) -from llmfoundry.utils.config_utils import (log_config, process_init_device, +from llmfoundry.utils.config_utils import (log_config, pop_config, + process_init_device, update_batch_size_info) @@ -119,7 +121,7 @@ def build_composer_peft_model( print('Building model from HuggingFace checkpoint...') model = MPTForCausalLM.from_pretrained( - cfg.model.pretrained_model_name_or_path, trust_remote_code=True) + model_cfg.pretrained_model_name_or_path, trust_remote_code=True) print('Model built!') print('Adding Lora modules...') @@ -164,177 +166,333 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, tokenizer, device_batch_size, ) - else: raise ValueError(f'Not sure how to build dataloader with config: {cfg}') def main(cfg: DictConfig): - # Check for incompatibilities between the model and data loaders - validate_config(cfg) - - max_split_size_mb = cfg.get('max_split_size_mb', None) - if max_split_size_mb is not None: - os.environ[ - 'PYTORCH_CUDA_ALLOC_CONF'] = f'max_split_size_mb:{max_split_size_mb}' - # Filter deprecation warning from torch internal usage warnings.filterwarnings( action='ignore', category=UserWarning, message= - f'torch.distributed.*_base is a private function and will be deprecated.*' + 'torch.distributed.*_base is a private function and will be deprecated.*' ) - cfg.dist_timeout = cfg.get('dist_timeout', 600.0) + # Check for incompatibilities between the model and data loaders + validate_config(cfg) + + # Resolve all interpolation variables as early as possible + om.resolve(cfg) + + # Create copy of config for logging + logged_cfg: DictConfig = copy.deepcopy(cfg) + + # Get max split size mb + max_split_size_mb: Optional[int] = cfg.pop('max_split_size_mb', None) + if max_split_size_mb is not None: + os.environ[ + 'PYTORCH_CUDA_ALLOC_CONF'] = f'max_split_size_mb:{max_split_size_mb}' - reproducibility.seed_all(cfg.seed) - dist.initialize_dist(get_device(None), timeout=cfg.dist_timeout) + # Set seed first + seed: int = pop_config(cfg, 'seed', must_exist=True) + reproducibility.seed_all(seed) - # Run Name - if cfg.get('run_name') is None: - cfg.run_name = os.environ.get('RUN_NAME', 'llm') + # Initialize pytorch distributed training process groups + dist_timeout: Union[int, float] = pop_config(cfg, + 'dist_timeout', + must_exist=False, + default_value=600.0) + dist.initialize_dist(get_device(None), timeout=dist_timeout) - # Get batch size info + # Get global and device batch size information from distributed/single node setting cfg = update_batch_size_info(cfg) + logged_cfg.update(cfg, merge=True) + + # Mandatory model training configs + model_config: DictConfig = pop_config(cfg, 'model', must_exist=True) + tokenizer_config: DictConfig = pop_config(cfg, 'tokenizer', must_exist=True) + optimizer_config: DictConfig = pop_config(cfg, 'optimizer', must_exist=True) + scheduler_config: DictConfig = pop_config(cfg, 'scheduler', must_exist=True) + train_loader_config: DictConfig = pop_config(cfg, + 'train_loader', + must_exist=True) + + # Optional fsdp data, fine-tuning, and eval configs + fsdp_dict_config: Optional[DictConfig] = pop_config(cfg, + 'fsdp_config', + must_exist=False, + default_value=None) + fsdp_config: Optional[Dict] = om.to_container( + fsdp_dict_config + ) if fsdp_dict_config is not None else None # type: ignore + lora_config: Optional[DictConfig] = pop_config(cfg, + 'lora', + must_exist=False, + default_value=None) + eval_loader_config: Optional[DictConfig] = pop_config(cfg, + 'eval_loader', + must_exist=False, + default_value=None) + icl_tasks_config: Optional[ListConfig] = pop_config(cfg, + 'icl_tasks', + must_exist=False, + default_value=None) + + # Optional logging, evaluation and callback configs + logger_configs: Optional[DictConfig] = pop_config(cfg, + 'loggers', + must_exist=False, + default_value=None) + callback_configs: Optional[DictConfig] = pop_config(cfg, + 'callbacks', + must_exist=False, + default_value=None) + algorithm_configs: Optional[DictConfig] = pop_config(cfg, + 'algorithms', + must_exist=False, + default_value=None) + + # Mandatory hyperparameters for training + device_train_batch_size: int = pop_config(cfg, + 'device_train_batch_size', + must_exist=True) + device_eval_batch_size: int = pop_config(cfg, + 'device_eval_batch_size', + must_exist=True) + max_duration: Union[int, str] = pop_config(cfg, + 'max_duration', + must_exist=True) + eval_interval: Union[int, str] = pop_config(cfg, + 'eval_interval', + must_exist=True) + precision: str = pop_config(cfg, 'precision', must_exist=True) + max_seq_len: int = pop_config(cfg, 'max_seq_len', must_exist=True) + + # Optional parameters will be set to default values if not specified. + default_run_name: str = os.environ.get('RUN_NAME', 'llm') + run_name: str = pop_config(cfg, + 'run_name', + must_exist=False, + default_value=default_run_name) + save_folder: Optional[str] = pop_config(cfg, + 'save_folder', + must_exist=False, + default_value=None) + save_latest_filename: str = pop_config(cfg, + 'save_latest_filename', + must_exist=False, + default_value='latest-rank{rank}.pt') + save_overwrite: bool = pop_config(cfg, + 'save_overwrite', + must_exist=False, + default_value=False) + save_weights_only: bool = pop_config(cfg, + 'save_weights_only', + must_exist=False, + default_value=False) + save_filename: str = pop_config( + cfg, + 'save_filename', + must_exist=False, + default_value='ep{epoch}-ba{batch}-rank{rank}.pt') + save_interval: Union[str, int] = pop_config(cfg, + 'save_interval', + must_exist=False, + default_value='1000ba') + save_num_checkpoints_to_keep: int = pop_config( + cfg, 'save_num_checkpoints_to_keep', must_exist=False, default_value=-1) + progress_bar = pop_config(cfg, + 'progress_bar', + must_exist=False, + default_value=False) + log_to_console: bool = pop_config(cfg, + 'log_to_console', + must_exist=False, + default_value=True) + python_log_level: str = pop_config(cfg, + 'python_log_level', + must_exist=False, + default_value='debug') + console_log_interval: Union[int, str] = pop_config(cfg, + 'console_log_interval', + must_exist=False, + default_value='1ba') + device_train_microbatch_size: Union[str, int] = pop_config( + cfg, + 'device_train_microbatch_size', + must_exist=False, + default_value='auto') + eval_subset_num_batches: int = pop_config(cfg, + 'eval_subset_num_batches', + must_exist=False, + default_value=-1) + eval_first: bool = pop_config(cfg, + 'eval_first', + must_exist=False, + default_value=False) + load_path: str = pop_config(cfg, + 'load_path', + must_exist=False, + default_value=None) + load_weights_only: bool = pop_config(cfg, + 'load_weights_only', + must_exist=False, + default_value=False) + load_ignore_keys: Optional[List[str]] = pop_config(cfg, + 'load_ignore_keys', + must_exist=False, + default_value=None) + # Enable autoresume from model checkpoints if possible + autoresume_default: bool = False + if logged_cfg.get('run_name', None) is not None \ + and save_folder is not None \ + and not save_overwrite \ + and not save_weights_only: + print('As run_name, save_folder, and save_latest_filename are set, \ + changing autoresume default to True...') + autoresume_default = True + autoresume: bool = pop_config(cfg, + 'autoresume', + must_exist=False, + default_value=autoresume_default) + + # Pop known unused parameters that are used as interpolation variables or + # created by update_batch_size_info. + pop_config(cfg, 'data_local', must_exist=False) + pop_config(cfg, 'data_remote', must_exist=False) + pop_config(cfg, 'global_seed', must_exist=False) + pop_config(cfg, 'global_train_batch_size', must_exist=False) + pop_config(cfg, 'n_gpus', must_exist=False) + pop_config(cfg, 'device_train_grad_accum', must_exist=False) + + # Warn users for unused parameters + for key in cfg: + warnings.warn( + f'Unused parameter {key} found in cfg. Please check your yaml to ensure this parameter is necessary.' + ) - # Read FSDP Config as a dict - fsdp_config = cfg.get('fsdp_config', None) - fsdp_config = om.to_container(fsdp_config, - resolve=True) if fsdp_config else None - assert isinstance(fsdp_config, Dict) or fsdp_config is None + # Warn if fsdp is enabled but user only has 1 GPU if dist.get_world_size() == 1 and fsdp_config is not None: warnings.warn( 'FSDP is not applicable for single-GPU training. Reverting to DDP.') - cfg.pop('fsdp_config') fsdp_config = None - init_context = process_init_device(cfg.model, fsdp_config) + # Initialize context + init_context = process_init_device(model_config, fsdp_config) + logged_cfg.update({'fsdp_config': fsdp_config}, merge=True) - # build tokenizer - tokenizer = build_tokenizer(cfg.tokenizer) + # Build tokenizer + tokenizer = build_tokenizer(tokenizer_config) # Build Model print('Initializing model...') with init_context: - if cfg.get('lora', - None) is not None: # frozen model + trainable lora modules + if lora_config is not None: # frozen model + trainable lora modules model: ComposerHFCausalLM = build_composer_peft_model( - cfg.model, cfg.lora, tokenizer) + model_config, lora_config, tokenizer) print_trainable_parameters(model) # should not be 100% else: # standard model - model = build_composer_model(cfg.model, tokenizer) - cfg.n_params = sum(p.numel() for p in model.parameters()) - print(f'{cfg.n_params=:.2e}') + model = build_composer_model(model_config, tokenizer) + + # Log number of parameters + n_params = sum(p.numel() for p in model.parameters()) + logged_cfg.update({'n_params': n_params}) # Dataloaders print('Building train loader...') train_loader = build_dataloader( - cfg.train_loader, + train_loader_config, tokenizer, - cfg.device_train_batch_size, + device_train_batch_size, ) + + ## Evaluation print('Building eval loader...') evaluators = [] - if 'eval_loader' in cfg: + if eval_loader_config is not None: assert model.train_metrics is not None + eval_dataloader = build_dataloader(eval_loader_config, tokenizer, + device_eval_batch_size) + eval_metric_names = list(model.train_metrics.keys()) eval_loader = Evaluator(label='eval', - dataloader=build_dataloader( - cfg.eval_loader, tokenizer, - cfg.device_eval_batch_size), - metric_names=list(model.train_metrics.keys())) + dataloader=eval_dataloader, + metric_names=eval_metric_names) evaluators.append(eval_loader) - if 'icl_tasks' in cfg: - icl_evaluators, _ = build_icl_evaluators(cfg.icl_tasks, tokenizer, - cfg.max_seq_len, - cfg.device_eval_batch_size) + if icl_tasks_config is not None: + icl_evaluators, _ = build_icl_evaluators(icl_tasks_config, tokenizer, + max_seq_len, + device_eval_batch_size) evaluators.extend(icl_evaluators) # Optimizer - optimizer = build_optimizer(cfg.optimizer, model) + optimizer = build_optimizer(optimizer_config, model) # Scheduler - scheduler = build_scheduler(cfg.scheduler) + scheduler = build_scheduler(scheduler_config) # Loggers loggers = [ - build_logger(name, logger_cfg) - for name, logger_cfg in (cfg.get('loggers') or {}).items() - ] + build_logger(str(name), logger_cfg) + for name, logger_cfg in logger_configs.items() + ] if logger_configs else None # Callbacks callbacks = [ - build_callback(name, callback_cfg) - for name, callback_cfg in (cfg.get('callbacks') or {}).items() - ] + build_callback(str(name), callback_cfg) + for name, callback_cfg in callback_configs.items() + ] if callback_configs else None # Algorithms algorithms = [ - build_algorithm(name, algorithm_cfg) - for name, algorithm_cfg in (cfg.get('algorithms') or {}).items() - ] - - # Set autoresume default on if possible - save_folder = cfg.get('save_folder', None) - save_latest_filename = cfg.get('save_latest_filename', - 'latest-rank{rank}.pt') - save_overwrite = cfg.get('save_overwrite', False) - save_weights_only = cfg.get('save_weights_only', False) - autoresume_default = False - if cfg.run_name is not None and save_folder is not None and save_latest_filename is not None and not save_overwrite and not save_weights_only: - print( - 'As run_name, save_folder, and save_latest_filename are set, changing autoresume default to True...' - ) - autoresume_default = True + build_algorithm(str(name), algorithm_cfg) + for name, algorithm_cfg in algorithm_configs.items() + ] if algorithm_configs else None # Build the Trainer print('Building trainer...') trainer = Trainer( - run_name=cfg.run_name, - seed=cfg.seed, + run_name=run_name, + seed=seed, model=model, train_dataloader=train_loader, eval_dataloader=evaluators, optimizers=optimizer, schedulers=scheduler, - max_duration=cfg.max_duration, - eval_interval=cfg.eval_interval, - eval_subset_num_batches=cfg.get('eval_subset_num_batches', -1), - progress_bar=cfg.get('progress_bar', False), - log_to_console=cfg.get('log_to_console', True), - console_log_interval=cfg.get('console_log_interval', '1ba'), + max_duration=max_duration, + eval_interval=eval_interval, + eval_subset_num_batches=eval_subset_num_batches, + progress_bar=progress_bar, + log_to_console=log_to_console, + console_log_interval=console_log_interval, loggers=loggers, callbacks=callbacks, - precision=cfg.precision, + precision=precision, algorithms=algorithms, - device_train_microbatch_size=cfg.get('device_train_microbatch_size', - 'auto'), + device_train_microbatch_size=device_train_microbatch_size, fsdp_config=fsdp_config, # type: ignore save_folder=save_folder, - save_filename=cfg.get('save_filename', - 'ep{epoch}-ba{batch}-rank{rank}.pt'), + save_filename=save_filename, save_latest_filename=save_latest_filename, - save_interval=cfg.get('save_interval', '1000ba'), - save_num_checkpoints_to_keep=cfg.get('save_num_checkpoints_to_keep', - -1), + save_interval=save_interval, + save_num_checkpoints_to_keep=save_num_checkpoints_to_keep, save_overwrite=save_overwrite, save_weights_only=save_weights_only, - load_path=cfg.get('load_path', None), - load_weights_only=cfg.get('load_weights_only', False), - load_ignore_keys=cfg.get('load_ignore_keys', None), - autoresume=cfg.get('autoresume', autoresume_default), - python_log_level=cfg.get('python_log_level', 'debug'), - dist_timeout=cfg.dist_timeout, + load_path=load_path, + load_weights_only=load_weights_only, + load_ignore_keys=load_ignore_keys, + autoresume=autoresume, + python_log_level=python_log_level, + dist_timeout=dist_timeout, ) + print('Logging config') + log_config(logged_cfg) torch.cuda.empty_cache() - print('Logging config...') - log_config(cfg) - - if cfg.get('eval_first', - False) and trainer.state.timestamp.batch.value == 0: + # Eval first if requested + if eval_first and trainer.state.timestamp.batch.value == 0: trainer.eval() print('Starting training...') diff --git a/tests/test_train_inputs.py b/tests/test_train_inputs.py new file mode 100644 index 0000000000..857a2b8c57 --- /dev/null +++ b/tests/test_train_inputs.py @@ -0,0 +1,89 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 +import copy +import os +import sys +import warnings + +import omegaconf +import pytest +from omegaconf import DictConfig +from omegaconf import OmegaConf as om + +# Add repo root to path so we can import scripts and test it +repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +sys.path.append(repo_dir) + +from scripts.train.train import main # noqa: E402 + + +class TestTrainingYAMLInputs: + """Validate and tests error handling for the input YAML file.""" + + @pytest.fixture + def cfg(self) -> DictConfig: + """Create YAML cfg fixture for testing purposes.""" + conf_path: str = os.path.join( + repo_dir, 'scripts/train/yamls/pretrain/testing.yaml') + with open(conf_path, 'r', encoding='utf-8') as config: + test_cfg = om.load(config) + assert isinstance(test_cfg, DictConfig) + return test_cfg + + def test_mispelled_mandatory_params_fail(self, cfg: DictConfig) -> None: + """Check that mandatory mispelled inputs fail to train.""" + cfg.trai_loader = cfg.pop('train_loader') + with pytest.raises(omegaconf.errors.ConfigAttributeError): + main(cfg) + + def test_missing_mandatory_parameters_fail(self, cfg: DictConfig) -> None: + """Check that missing mandatory parameters fail to train.""" + mandatory_params = [ + 'train_loader', + 'model', + 'tokenizer', + 'optimizer', + 'scheduler', + 'max_duration', + 'eval_interval', + 'precision', + 'max_seq_len', + ] + for param in mandatory_params: + orig_param = cfg.pop(param) + with pytest.raises( + (omegaconf.errors.ConfigAttributeError, NameError)): + main(cfg) + cfg[param] = orig_param + + def test_optional_mispelled_params_raise_warning(self, + cfg: DictConfig) -> None: + """Check that warnings are raised for optional mispelled parameters.""" + optional_params = [ + 'save_weights_only', + 'save_filename', + 'run_name', + 'progress_bar', + 'python_log_level', + 'eval_first', + 'autoresume', + 'save_folder', + 'fsdp_config', + 'lora_config', + 'eval_loader_config', + 'icl_tasks_config', + ] + old_cfg = copy.deepcopy(cfg) + for param in optional_params: + orig_value = cfg.pop(param, None) + updated_param = param + '-mispelling' + cfg[updated_param] = orig_value + with warnings.catch_warnings(record=True) as warning_list: + try: + main(cfg) + except: + pass + assert any(f'Unused parameter {updated_param} found in cfg.' in + str(warning.message) for warning in warning_list) + # restore configs. + cfg = copy.deepcopy(old_cfg)