Skip to content

Commit

Permalink
Move the model creation to the last step before trainer creation (#547)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Aug 25, 2023
1 parent 795ab4a commit 3e3c7d3
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 29 deletions.
67 changes: 38 additions & 29 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,29 +390,6 @@ def main(cfg: DictConfig):
# Build tokenizer
tokenizer = build_tokenizer(tokenizer_config)

# Build Model
print('Initializing model...')
with init_context:
if lora_config is not None: # frozen model + trainable lora modules
model: ComposerHFCausalLM = build_composer_peft_model(
model_config.pretrained_model_name_or_path, lora_config['args'],
tokenizer)
print_trainable_parameters(model) # should not be 100%
else: # standard model
model = build_composer_model(model_config, tokenizer)
if model_config.get('master_weights_dtype') in ('bf16', 'bfloat16'):
model = model.to(dtype=torch.bfloat16)
elif model_config.get('master_weights_dtype') in ('f16', 'float16'):
model = model.to(dtype=torch.float16)

# Log number of parameters
n_params = sum(p.numel() for p in model.parameters())
logged_cfg.update({'n_params': n_params})

# Optimizer
optimizer_name: str = optimizer_config.pop('name')
optimizer = build_optimizer(model, optimizer_name, optimizer_config)

# Scheduler
scheduler_name: str = scheduler_config.pop('name')
scheduler = build_scheduler(scheduler_name, scheduler_config)
Expand Down Expand Up @@ -446,22 +423,54 @@ def main(cfg: DictConfig):
## Evaluation
print('Building eval loader...')
evaluators = []
eval_loader = None
if eval_loader_config is not None:
assert model.train_metrics is not None
eval_dataloader = build_dataloader(eval_loader_config, tokenizer,
device_eval_batch_size)
eval_metric_names = list(model.train_metrics.keys())
eval_loader = Evaluator(label='eval',
dataloader=eval_dataloader,
metric_names=eval_metric_names)
evaluators.append(eval_loader)
eval_loader = Evaluator(
label='eval',
dataloader=eval_dataloader,
metric_names=[], # we will add these after model is created
)

if icl_tasks_config is not None:
icl_evaluators, _ = build_icl_evaluators(icl_tasks_config, tokenizer,
max_seq_len,
device_eval_batch_size)
evaluators.extend(icl_evaluators)

# Build Model
print('Initializing model...')
with init_context:
if lora_config is not None: # frozen model + trainable lora modules
model: ComposerHFCausalLM = build_composer_peft_model(
model_config.pretrained_model_name_or_path, lora_config['args'],
tokenizer)
print_trainable_parameters(model) # should not be 100%
else: # standard model
model = build_composer_model(model_config, tokenizer)

if model_config.get('master_weights_dtype') in ('bf16', 'bfloat16'):
model = model.to(dtype=torch.bfloat16)
elif model_config.get('master_weights_dtype') in ('f16', 'float16'):
model = model.to(dtype=torch.float16)

# Log number of parameters
n_params = sum(p.numel() for p in model.parameters())
logged_cfg.update({'n_params': n_params})

# Optimizer
optimizer_name: str = optimizer_config.pop('name')
optimizer = build_optimizer(model, optimizer_name, optimizer_config)

# Now add the eval metrics
if eval_loader_config is not None:
assert eval_loader is not None
assert model.train_metrics is not None
eval_metric_names = list(model.train_metrics.keys())
eval_loader.metric_names = eval_metric_names
evaluators.insert(0, eval_loader) # Put the base eval_loader first

# Build the Trainer
print('Building trainer...')
trainer = Trainer(
Expand Down
43 changes: 43 additions & 0 deletions tests/test_train_inputs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import copy
import json
import os
import sys
import warnings
Expand All @@ -17,6 +18,38 @@
from scripts.train.train import main # noqa: E402


def make_fake_index_file(path: str) -> None:
"""Create a fake index file in the path."""
fake_index = {
'shards': [{
'column_encodings': ['bytes'],
'column_names': ['tokens'],
'column_sizes': [None],
'compression': 'zstd',
'format': 'mds',
'hashes': [],
'raw_data': {
'basename': 'shard.00000.mds',
'bytes': 5376759,
'hashes': {},
},
'samples': 328,
'size_limit': 67108864,
'version': 2,
'zip_data': {
'basename': 'shard.00000.mds.zstd',
'bytes': 564224,
'hashes': {},
}
}],
'version': 2
}
if not os.path.exists(path):
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, 'w') as f:
json.dump(fake_index, f)


class TestTrainingYAMLInputs:
"""Validate and tests error handling for the input YAML file."""

Expand Down Expand Up @@ -90,13 +123,23 @@ def test_optional_mispelled_params_raise_warning(self,

def test_extra_params_in_optimizer_cfg_errors(self,
cfg: DictConfig) -> None:
data_local = './my-copy-c4-opt1'
make_fake_index_file(f'{data_local}/train/index.json')
make_fake_index_file(f'{data_local}/val/index.json')
cfg.train_loader.dataset.local = data_local
cfg.eval_loader.dataset.local = data_local
cfg.optimizer.beta2 = 'extra-parameter'
with pytest.raises(TypeError):
main(cfg)

def test_invalid_name_in_optimizer_cfg_errors(self,
cfg: DictConfig) -> None:
data_local = './my-copy-c4-opt2'
make_fake_index_file(f'{data_local}/train/index.json')
make_fake_index_file(f'{data_local}/val/index.json')
cfg.optimizer.name = 'invalid-optimizer'
cfg.train_loader.dataset.local = data_local
cfg.eval_loader.dataset.local = data_local
with pytest.raises(ValueError) as exception_info:
main(cfg)
assert str(exception_info.value
Expand Down

0 comments on commit 3e3c7d3

Please sign in to comment.