-
Notifications
You must be signed in to change notification settings - Fork 524
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add runtime error in eval.py if yaml config is improperly formatted w…
…ith extraneous or missing values (#521) ## Description This PR enables us to sanity check our eval yaml configuration files before we run the full eval pipeline. This enables us to catch errors in the YAML config prior to an eval run starting. If a yaml config is improperly formatted with extraneous or missing values, a runtime error will be thrown. ## Unit Test: Added unit tests to make sure we raise `omegaconf.errors` and/or warn users if the yaml is incorrectly formatted. Warnings are used if the parameter has an optional default value. ## Integration Test: **Test 1**: Model reaches 0.56 COPA accuracy before and after refactor (1 GPU) ``` python composer eval/eval.py \ eval/yamls/hf_eval.yaml \ icl_tasks=eval/yamls/copa.yaml \ model_name_or_path=mpt-125m-hf ``` **Test 2**: Model gauntlet runs: `mcli run -f mcli/mcli-hf-eval.yaml --follow` Main branch: `mcli logs -f all-eval-main-nphii5` ``` Ran mosaicml/mpt-7b-instruct eval in: 5852.91842341423 seconds Printing gauntlet results for all models | model_name | average | world_knowledge | commonsense_reasoning | language_understanding | symbolic_problem_solving | reading_comprehension | |:-------------------------|----------:|------------------:|------------------------:|-------------------------:|---------------------------:|------------------------:| | mosaicml/mpt-7b-instruct | 0.354255 | 0.398764 | 0.415097 | 0.371509 | 0.171216 | 0.414691 | Printing complete results for all models | Category | Benchmark ``` chuck/add_yaml_check_eval branch: `mcli logs -f all-eval-chuck-branch-UqSCPR` ``` Ran mosaicml/mpt-7b-instruct eval in: 5858.392446279526 seconds Printing gauntlet results for all models | model_name | average | world_knowledge | commonsense_reasoning | language_understanding | symbolic_problem_solving | reading_comprehension | |:-------------------------|----------:|------------------:|------------------------:|-------------------------:|---------------------------:|------------------------:| | mosaicml/mpt-7b-instruct | 0.354255 | 0.398764 | 0.415097 | 0.371509 | 0.171216 | 0.414691 | Printing complete results for all models | Category | Benchmark ``` ## Issues Addressed This PR chips away at this debt [issue](https://mosaicml.atlassian.net/browse/RESEARCH-717).
- Loading branch information
Showing
2 changed files
with
172 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Copyright 2022 MosaicML LLM Foundry authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import copy | ||
import os | ||
import sys | ||
import warnings | ||
|
||
import omegaconf | ||
import pytest | ||
from omegaconf import DictConfig | ||
from omegaconf import OmegaConf as om | ||
|
||
# Add repo root to path so we can import scripts and test it | ||
repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) | ||
sys.path.append(repo_dir) | ||
|
||
from scripts.eval.eval import main # noqa: E402 | ||
|
||
|
||
class TestEvalYAMLInputs: | ||
"""Validate and tests error handling for the input YAML file.""" | ||
|
||
@pytest.fixture | ||
def cfg(self) -> DictConfig: | ||
"""Create YAML cfg fixture for testing purposes.""" | ||
conf_path: str = os.path.join(repo_dir, | ||
'scripts/eval/yamls/hf_eval.yaml') | ||
with open(conf_path, 'r', encoding='utf-8') as config: | ||
test_cfg = om.load(config) | ||
assert isinstance(test_cfg, DictConfig) | ||
return test_cfg | ||
|
||
def test_mispelled_mandatory_params_fail(self, cfg: DictConfig) -> None: | ||
"""Check that mandatory mispelled inputs fail to train.""" | ||
mandatory_params = [ | ||
'max_seq_len', | ||
'device_eval_batch_size', | ||
'precision', | ||
'model_configs', | ||
] | ||
mandatory_configs = ['models', 'icl_tasks'] | ||
for p in mandatory_params + mandatory_configs: | ||
with pytest.raises((omegaconf.errors.ConfigKeyError, | ||
omegaconf.errors.InterpolationKeyError)): | ||
cfg[p + '-mispelled'] = cfg.pop(p) | ||
main(cfg) | ||
cfg[p] = cfg.pop(p + '-mispelled') | ||
|
||
def test_optional_mispelled_params_raise_warning(self, | ||
cfg: DictConfig) -> None: | ||
"""Check that warnings are raised for optional mispelled parameters.""" | ||
optional_params = [ | ||
'seed', | ||
'dist_timeout', | ||
'run_name', | ||
'num_retries', | ||
'loggers', | ||
'model_gauntlet', | ||
'fsdp_config', | ||
] | ||
old_cfg = copy.deepcopy(cfg) | ||
for param in optional_params: | ||
orig_value = cfg.pop(param, None) | ||
updated_param = param + '-mispelling' | ||
cfg[updated_param] = orig_value | ||
with warnings.catch_warnings(record=True) as warning_list: | ||
try: | ||
main(cfg) | ||
except: | ||
pass | ||
assert any(f'Unused parameter {updated_param} found in cfg.' in | ||
str(warning.message) for warning in warning_list) | ||
# restore configs. | ||
cfg = copy.deepcopy(old_cfg) |