Skip to content

Commit

Permalink
Add runtime error in eval.py if yaml config is improperly formatted w…
Browse files Browse the repository at this point in the history
…ith extraneous or missing values (#521)

## Description
This PR enables us to sanity check our eval yaml configuration files before we run the full eval pipeline. This enables us to catch errors in the YAML config prior to an eval run starting. If a yaml config is improperly formatted with extraneous or missing values, a runtime error will be thrown. 

## Unit Test:
Added unit tests to make sure we raise `omegaconf.errors` and/or warn users if the yaml is incorrectly formatted. Warnings are used if the parameter has an optional default value. 

## Integration Test:
**Test 1**: Model reaches 0.56 COPA accuracy before and after refactor (1 GPU) 
``` python 
  composer eval/eval.py \
  eval/yamls/hf_eval.yaml \
  icl_tasks=eval/yamls/copa.yaml \
  model_name_or_path=mpt-125m-hf
```

**Test 2**: Model gauntlet runs: `mcli run -f mcli/mcli-hf-eval.yaml --follow`

Main branch: `mcli logs -f all-eval-main-nphii5`
```
Ran mosaicml/mpt-7b-instruct eval in: 5852.91842341423 seconds
Printing gauntlet results for all models
| model_name               |   average |   world_knowledge |   commonsense_reasoning |   language_understanding |   symbolic_problem_solving |   reading_comprehension |
|:-------------------------|----------:|------------------:|------------------------:|-------------------------:|---------------------------:|------------------------:|
| mosaicml/mpt-7b-instruct |  0.354255 |          0.398764 |                0.415097 |                 0.371509 |                   0.171216 |                0.414691 |
Printing complete results for all models
| Category                 | Benchmark
```

chuck/add_yaml_check_eval branch: `mcli logs -f all-eval-chuck-branch-UqSCPR`
```
Ran mosaicml/mpt-7b-instruct eval in: 5858.392446279526 seconds
Printing gauntlet results for all models
| model_name               |   average |   world_knowledge |   commonsense_reasoning |   language_understanding |   symbolic_problem_solving |   reading_comprehension |
|:-------------------------|----------:|------------------:|------------------------:|-------------------------:|---------------------------:|------------------------:|
| mosaicml/mpt-7b-instruct |  0.354255 |          0.398764 |                0.415097 |                 0.371509 |                   0.171216 |                0.414691 |
Printing complete results for all models
| Category                 | Benchmark
```

## Issues Addressed
This PR chips away at this debt [issue](https://mosaicml.atlassian.net/browse/RESEARCH-717).
  • Loading branch information
j316chuck authored Aug 16, 2023
1 parent 6b98ffb commit 148c079
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 39 deletions.
137 changes: 98 additions & 39 deletions scripts/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
import re
import sys
import time
from typing import Dict, List, Optional
import warnings
from typing import Any, Dict, List, Optional, Union

import pandas as pd
import torch
from composer.loggers import InMemoryLogger, LoggerDestination
from composer.models.base import ComposerModel
from composer.trainer import Trainer
from composer.utils import dist, get_device, reproducibility
from omegaconf import DictConfig
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om
from transformers import (AutoModelForCausalLM, PreTrainedTokenizerBase,
T5ForConditionalGeneration)
Expand All @@ -23,7 +24,7 @@
from llmfoundry.models.mpt import MPTForCausalLM
from llmfoundry.utils.builders import (build_icl_evaluators, build_logger,
build_tokenizer)
from llmfoundry.utils.config_utils import process_init_device
from llmfoundry.utils.config_utils import pop_config, process_init_device


def load_peft_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
Expand Down Expand Up @@ -91,42 +92,41 @@ def load_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
)


def evaluate_model(model_cfg: DictConfig, cfg: DictConfig, run_name: str,
def evaluate_model(model_cfg: DictConfig, dist_timeout: Union[float, int],
run_name: str, icl_tasks: Union[str, ListConfig],
max_seq_len: int, device_eval_batch_size: int,
model_gauntlet_config: Optional[Union[str, DictConfig]],
fsdp_config: Optional[Dict], num_retries: int,
loggers_cfg: Dict[str, Any], precision: str,
model_gauntlet_df: Optional[pd.DataFrame]):
print(f'Evaluating model: {model_cfg.model_name}', flush=True)
# Build tokenizer and model
tokenizer = build_tokenizer(model_cfg.tokenizer)

evaluators, logger_keys = build_icl_evaluators(cfg.icl_tasks, tokenizer,
cfg.max_seq_len,
cfg.device_eval_batch_size)
if hasattr(cfg, 'model_gauntlet'):
if isinstance(cfg.model_gauntlet, str):
with open(cfg.model_gauntlet, 'r') as icl_f:
model_gauntlet_cfg = om.load(icl_f)
model_gauntlet = model_gauntlet_cfg.model_gauntlet
evaluators, logger_keys = build_icl_evaluators(icl_tasks, tokenizer,
max_seq_len,
device_eval_batch_size)
model_gauntlet: Optional[DictConfig] = None
model_gauntlet_callback: Optional[ModelGauntlet] = None
if model_gauntlet_config is not None:
if isinstance(model_gauntlet_config, str):
with open(model_gauntlet_config, 'r', encoding='utf-8') as icl_f:
loaded_model_gauntlet_config = om.load(icl_f)
model_gauntlet = loaded_model_gauntlet_config.model_gauntlet
else:
model_gauntlet = cfg.model_gauntlet
model_gauntlet.logger_keys = logger_keys
model_gauntlet.benchmark_sizes = {
model_gauntlet = model_gauntlet_config
model_gauntlet.logger_keys = logger_keys # type: ignore
model_gauntlet.benchmark_sizes = { # type: ignore
e.label: e.dataloader.num_samples for e in evaluators
}
model_gauntlet_callback = ModelGauntlet(**model_gauntlet)
else:
model_gauntlet = None
model_gauntlet_callback = None

fsdp_config = cfg.get('fsdp_config', None)
fsdp_config = om.to_container(
fsdp_config, resolve=True) if fsdp_config is not None else None
assert isinstance(fsdp_config, Dict) or fsdp_config is None
model_gauntlet_callback = ModelGauntlet(**
model_gauntlet) # type: ignore

if hasattr(model_cfg.model, 'pretrained_lora_id_or_path'):
composer_model = load_peft_model(model_cfg.model, tokenizer,
cfg.get('num_retries', 3))
num_retries)
else:
composer_model = load_model(model_cfg.model, tokenizer, fsdp_config,
cfg.get('num_retries', 3))
num_retries)

if model_gauntlet_df is None and model_gauntlet is not None:
model_gauntlet_df = pd.DataFrame(
Expand All @@ -136,7 +136,7 @@ def evaluate_model(model_cfg: DictConfig, cfg: DictConfig, run_name: str,
in_memory_logger = InMemoryLogger() # track metrics in the in_memory_logger
loggers: List[LoggerDestination] = [
build_logger(name, logger_cfg)
for name, logger_cfg in (cfg.get('loggers') or {}).items()
for name, logger_cfg in loggers_cfg.items()
]
loggers.append(in_memory_logger)

Expand All @@ -147,13 +147,13 @@ def evaluate_model(model_cfg: DictConfig, cfg: DictConfig, run_name: str,
run_name=run_name,
model=composer_model,
loggers=loggers,
precision=cfg.precision,
precision=precision,
fsdp_config=fsdp_config, # type: ignore
load_path=load_path,
load_weights_only=True,
progress_bar=False,
log_to_console=True,
dist_timeout=cfg.dist_timeout,
dist_timeout=dist_timeout,
)

if torch.cuda.is_available():
Expand All @@ -169,20 +169,79 @@ def evaluate_model(model_cfg: DictConfig, cfg: DictConfig, run_name: str,


def main(cfg: DictConfig):
cfg.dist_timeout = cfg.get('dist_timeout', 600.0)
if cfg.get('run_name') is None:
cfg.run_name = os.environ.get('RUN_NAME', 'llm')
om.resolve(cfg)
model_configs: ListConfig = pop_config(cfg, 'models', must_exist=True)
model_gauntlet_config: Optional[Union[str, DictConfig]] = pop_config(
cfg, 'model_gauntlet', must_exist=False, default_value=None)
fsdp_dict_cfg: Optional[DictConfig] = pop_config(cfg,
'fsdp_config',
must_exist=False,
default_value=None)
fsdp_config: Optional[Dict] = om.to_container(
fsdp_dict_cfg,
resolve=True) if fsdp_dict_cfg is not None else None # type: ignore
assert isinstance(fsdp_config, Dict) or fsdp_config is None

# Mandatory Evaluation Parameters
icl_tasks: Union[str, ListConfig] = pop_config(cfg,
'icl_tasks',
must_exist=True)
max_seq_len: int = pop_config(cfg, 'max_seq_len', must_exist=True)
device_eval_batch_size: int = pop_config(cfg,
'device_eval_batch_size',
must_exist=True)
precision: str = pop_config(cfg, 'precision', must_exist=True)

# Optional Evaluation Parameters with default values
seed: int = pop_config(cfg, 'seed', must_exist=False, default_value=17)
dist_timeout: Union[float, int] = pop_config(cfg,
'dist_timeout',
must_exist=False,
default_value=600.0)
default_run_name: str = os.environ.get('RUN_NAME', 'llm')
run_name: str = pop_config(cfg,
'run_name',
must_exist=False,
default_value=default_run_name)
num_retries: int = pop_config(cfg,
'num_retries',
must_exist=False,
default_value=3)
loggers_cfg: Dict[str, Any] = pop_config(cfg,
'loggers',
must_exist=False,
default_value={})

# Pop out interpolation variables.
pop_config(cfg, 'model_name_or_path', must_exist=False, default_value=None)

# Warn for unused parameters
for key in cfg:
warnings.warn(
f'Unused parameter {key} found in cfg. Please check your yaml to ensure this parameter is necessary.'
)

reproducibility.seed_all(cfg.seed)
dist.initialize_dist(get_device(None), timeout=cfg.dist_timeout)
reproducibility.seed_all(seed)
dist.initialize_dist(get_device(None), timeout=dist_timeout)

model_gauntlet_df = None
models_df = None
composite_scores = None
for model_cfg in cfg.models:
for model_cfg in model_configs:
(in_memory_logger, logger_keys, model_gauntlet_callback, model_gauntlet,
model_gauntlet_df) = evaluate_model(model_cfg, cfg, cfg.run_name,
model_gauntlet_df)
model_gauntlet_df) = evaluate_model(
model_cfg=model_cfg,
dist_timeout=dist_timeout,
run_name=run_name,
icl_tasks=icl_tasks,
max_seq_len=max_seq_len,
device_eval_batch_size=device_eval_batch_size,
model_gauntlet_config=model_gauntlet_config,
fsdp_config=fsdp_config,
num_retries=num_retries,
loggers_cfg=loggers_cfg,
precision=precision,
model_gauntlet_df=model_gauntlet_df)

if model_gauntlet_callback is not None:
# TODO(bmosaicml) This needs to be refactored to fix the typing issue
Expand All @@ -205,7 +264,7 @@ def main(cfg: DictConfig):
else:
models_df = pd.concat([models_df, model_results], ignore_index=True)

if model_gauntlet_df is not None and model_gauntlet is not None and model_gauntlet_df is not None:
if model_gauntlet_df is not None and model_gauntlet is not None:
assert composite_scores is not None
row = {'model_name': model_cfg['model_name']}
row.update({
Expand Down
74 changes: 74 additions & 0 deletions tests/test_eval_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import copy
import os
import sys
import warnings

import omegaconf
import pytest
from omegaconf import DictConfig
from omegaconf import OmegaConf as om

# Add repo root to path so we can import scripts and test it
repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(repo_dir)

from scripts.eval.eval import main # noqa: E402


class TestEvalYAMLInputs:
"""Validate and tests error handling for the input YAML file."""

@pytest.fixture
def cfg(self) -> DictConfig:
"""Create YAML cfg fixture for testing purposes."""
conf_path: str = os.path.join(repo_dir,
'scripts/eval/yamls/hf_eval.yaml')
with open(conf_path, 'r', encoding='utf-8') as config:
test_cfg = om.load(config)
assert isinstance(test_cfg, DictConfig)
return test_cfg

def test_mispelled_mandatory_params_fail(self, cfg: DictConfig) -> None:
"""Check that mandatory mispelled inputs fail to train."""
mandatory_params = [
'max_seq_len',
'device_eval_batch_size',
'precision',
'model_configs',
]
mandatory_configs = ['models', 'icl_tasks']
for p in mandatory_params + mandatory_configs:
with pytest.raises((omegaconf.errors.ConfigKeyError,
omegaconf.errors.InterpolationKeyError)):
cfg[p + '-mispelled'] = cfg.pop(p)
main(cfg)
cfg[p] = cfg.pop(p + '-mispelled')

def test_optional_mispelled_params_raise_warning(self,
cfg: DictConfig) -> None:
"""Check that warnings are raised for optional mispelled parameters."""
optional_params = [
'seed',
'dist_timeout',
'run_name',
'num_retries',
'loggers',
'model_gauntlet',
'fsdp_config',
]
old_cfg = copy.deepcopy(cfg)
for param in optional_params:
orig_value = cfg.pop(param, None)
updated_param = param + '-mispelling'
cfg[updated_param] = orig_value
with warnings.catch_warnings(record=True) as warning_list:
try:
main(cfg)
except:
pass
assert any(f'Unused parameter {updated_param} found in cfg.' in
str(warning.message) for warning in warning_list)
# restore configs.
cfg = copy.deepcopy(old_cfg)

0 comments on commit 148c079

Please sign in to comment.