Skip to content

Commit

Permalink
Add runtime error in train.py if yaml config is improperly formatted …
Browse files Browse the repository at this point in the history
…with extraneous or missing values (#506)

## Description
This PR enables us to sanity check our train yaml configuration files before we run the full training pipeline. This enables us to catch errors in the YAML config before a training run starts. If a yaml config is improperly formatted with extraneous or missing values, a `NameError` will be thrown. 

## Unit Test: 
Added unit tests to make sure we raise a `NameError` and/or warn the user if the yaml is incorrectly formatted. Warnings are used if the parameter has an optional default value. 

## Integration Test:
Before and after training runs show the same loss curves throughout training. 
<img width="333" alt="Screenshot 2023-08-15 at 10 57 55 AM" src="https://github.com/mosaicml/llm-foundry/assets/13524881/a0329a77-c54b-4d0b-961b-058e35021664">


**mpt-125m training branch master**: https://wandb.ai/mosaic-ml/chuck-runs/runs/2jq1uy86
**mpt-125m training branch chuck/add_yaml_sanity_check_train**: https://wandb.ai/mosaic-ml/chuck-runs/runs/n8wlqwec

All runs at: https://wandb.ai/mosaic-ml/chuck-runs?workspace=user-chuck-tang98 

---------

Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
j316chuck and dakinggg authored Aug 16, 2023
1 parent aff3eaa commit 6b98ffb
Show file tree
Hide file tree
Showing 3 changed files with 371 additions and 99 deletions.
27 changes: 26 additions & 1 deletion llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import contextlib
import math
import warnings
from typing import Dict, Optional, Union
from typing import Any, Dict, Optional, Union

from composer.utils import dist
from omegaconf import DictConfig
Expand All @@ -13,6 +13,26 @@
from llmfoundry.models.utils import init_empty_weights


def pop_config(cfg: DictConfig,
key: str,
must_exist: bool = True,
default_value: Any = None) -> Any:
"""Pop a value from the main config file and return it.
If the key does not exist, return the default_value or raise a RuntimeError
depending on the must_exist flag.
"""
value = cfg.pop(key, None)
if value is not None:
return value
elif must_exist:
raise NameError(
f'The {key} parameter is missing and must exist for execution. Please check your yaml.'
)
else:
return default_value


def calculate_batch_size_info(global_batch_size: int,
device_microbatch_size: Union[int, str]):
if global_batch_size % dist.get_world_size() != 0:
Expand Down Expand Up @@ -90,6 +110,11 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]):


def log_config(cfg: DictConfig):
"""Logs the current config and updates the wandb and mlflow configs.
This function can be called multiple times to update the wandb and MLflow
config with different variables.
"""
print(om.to_yaml(cfg))
if 'wandb' in cfg.get('loggers', {}):
try:
Expand Down
Loading

0 comments on commit 6b98ffb

Please sign in to comment.