diff --git a/scripts/train/train.py b/scripts/train/train.py index 0d9e4e9d10..ebfa563424 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -390,29 +390,6 @@ def main(cfg: DictConfig): # Build tokenizer tokenizer = build_tokenizer(tokenizer_config) - # Build Model - print('Initializing model...') - with init_context: - if lora_config is not None: # frozen model + trainable lora modules - model: ComposerHFCausalLM = build_composer_peft_model( - model_config.pretrained_model_name_or_path, lora_config['args'], - tokenizer) - print_trainable_parameters(model) # should not be 100% - else: # standard model - model = build_composer_model(model_config, tokenizer) - if model_config.get('master_weights_dtype') in ('bf16', 'bfloat16'): - model = model.to(dtype=torch.bfloat16) - elif model_config.get('master_weights_dtype') in ('f16', 'float16'): - model = model.to(dtype=torch.float16) - - # Log number of parameters - n_params = sum(p.numel() for p in model.parameters()) - logged_cfg.update({'n_params': n_params}) - - # Optimizer - optimizer_name: str = optimizer_config.pop('name') - optimizer = build_optimizer(model, optimizer_name, optimizer_config) - # Scheduler scheduler_name: str = scheduler_config.pop('name') scheduler = build_scheduler(scheduler_name, scheduler_config) @@ -446,15 +423,15 @@ def main(cfg: DictConfig): ## Evaluation print('Building eval loader...') evaluators = [] + eval_loader = None if eval_loader_config is not None: - assert model.train_metrics is not None eval_dataloader = build_dataloader(eval_loader_config, tokenizer, device_eval_batch_size) - eval_metric_names = list(model.train_metrics.keys()) - eval_loader = Evaluator(label='eval', - dataloader=eval_dataloader, - metric_names=eval_metric_names) - evaluators.append(eval_loader) + eval_loader = Evaluator( + label='eval', + dataloader=eval_dataloader, + metric_names=[], # we will add these after model is created + ) if icl_tasks_config is not None: icl_evaluators, _ = build_icl_evaluators(icl_tasks_config, tokenizer, @@ -462,6 +439,38 @@ def main(cfg: DictConfig): device_eval_batch_size) evaluators.extend(icl_evaluators) + # Build Model + print('Initializing model...') + with init_context: + if lora_config is not None: # frozen model + trainable lora modules + model: ComposerHFCausalLM = build_composer_peft_model( + model_config.pretrained_model_name_or_path, lora_config['args'], + tokenizer) + print_trainable_parameters(model) # should not be 100% + else: # standard model + model = build_composer_model(model_config, tokenizer) + + if model_config.get('master_weights_dtype') in ('bf16', 'bfloat16'): + model = model.to(dtype=torch.bfloat16) + elif model_config.get('master_weights_dtype') in ('f16', 'float16'): + model = model.to(dtype=torch.float16) + + # Log number of parameters + n_params = sum(p.numel() for p in model.parameters()) + logged_cfg.update({'n_params': n_params}) + + # Optimizer + optimizer_name: str = optimizer_config.pop('name') + optimizer = build_optimizer(model, optimizer_name, optimizer_config) + + # Now add the eval metrics + if eval_loader_config is not None: + assert eval_loader is not None + assert model.train_metrics is not None + eval_metric_names = list(model.train_metrics.keys()) + eval_loader.metric_names = eval_metric_names + evaluators.insert(0, eval_loader) # Put the base eval_loader first + # Build the Trainer print('Building trainer...') trainer = Trainer( diff --git a/tests/test_train_inputs.py b/tests/test_train_inputs.py index e208a0a1ee..2f29f6e7b5 100644 --- a/tests/test_train_inputs.py +++ b/tests/test_train_inputs.py @@ -1,6 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 import copy +import json import os import sys import warnings @@ -17,6 +18,38 @@ from scripts.train.train import main # noqa: E402 +def make_fake_index_file(path: str) -> None: + """Create a fake index file in the path.""" + fake_index = { + 'shards': [{ + 'column_encodings': ['bytes'], + 'column_names': ['tokens'], + 'column_sizes': [None], + 'compression': 'zstd', + 'format': 'mds', + 'hashes': [], + 'raw_data': { + 'basename': 'shard.00000.mds', + 'bytes': 5376759, + 'hashes': {}, + }, + 'samples': 328, + 'size_limit': 67108864, + 'version': 2, + 'zip_data': { + 'basename': 'shard.00000.mds.zstd', + 'bytes': 564224, + 'hashes': {}, + } + }], + 'version': 2 + } + if not os.path.exists(path): + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, 'w') as f: + json.dump(fake_index, f) + + class TestTrainingYAMLInputs: """Validate and tests error handling for the input YAML file.""" @@ -90,13 +123,23 @@ def test_optional_mispelled_params_raise_warning(self, def test_extra_params_in_optimizer_cfg_errors(self, cfg: DictConfig) -> None: + data_local = './my-copy-c4-opt1' + make_fake_index_file(f'{data_local}/train/index.json') + make_fake_index_file(f'{data_local}/val/index.json') + cfg.train_loader.dataset.local = data_local + cfg.eval_loader.dataset.local = data_local cfg.optimizer.beta2 = 'extra-parameter' with pytest.raises(TypeError): main(cfg) def test_invalid_name_in_optimizer_cfg_errors(self, cfg: DictConfig) -> None: + data_local = './my-copy-c4-opt2' + make_fake_index_file(f'{data_local}/train/index.json') + make_fake_index_file(f'{data_local}/val/index.json') cfg.optimizer.name = 'invalid-optimizer' + cfg.train_loader.dataset.local = data_local + cfg.eval_loader.dataset.local = data_local with pytest.raises(ValueError) as exception_info: main(cfg) assert str(exception_info.value