Skip to content

Commit

Permalink
Refactor build optimizer and peft models to use kwargs syntax (#525)
Browse files Browse the repository at this point in the history
Refactor build optimizer and peft models in `train/train.py` to use  `**kwargs` syntax so that config's are dictionaries, classes/functions have arguments explicitly stated in their signature, and `build_***` uses **kwargs for initializing classes. 

Changes:
- `build_optimizer`
- `build_scheduler`
- `build_composer_peft_model`

Additionally, moved `build_dataloader` right before `Trainer` is initialized to make it faster to catch non-data related issues in the yaml. 

## Unit Tests 
- Added unit tests for optimizer states and build_scheduler. 

## Integration Tests
- Trained PEFT model, got same fine tuning scores: 
<img width="1635" alt="Screenshot 2023-08-16 at 5 39 17 PM" src="https://github.com/mosaicml/llm-foundry/assets/13524881/57226d52-a35e-4229-8e01-4ff3e7d51d29">

- [Before change LORA run](https://wandb.ai/mosaic-ml/chuck-runs/runs/ckod4o5r/logs?workspace=user-chuck-tang98)
- [After change LORA run](https://wandb.ai/mosaic-ml/chuck-runs/runs/69aiy9uk?workspace=user-chuck-tang98)

## Issues Closed
Chipping away at tech [debt](https://id.atlassian.com/login/authorize?continue=https%3A%2F%2Fmosaicml.atlassian.net%2Fbrowse%2FRESEARCH-717&token=eyJraWQiOiJtaWNyb3Mvc2lnbi1pbi1zZXJ2aWNlL2FlN21rM2I2cGxzOWhtaDkiLCJhbGciOiJSUzI1NiJ9.eyJtYXJrZWRWZXJpZmllZCI6ImZhbHNlIiwibG9naW5UeXBlIjoic2Vzc2lvblJlZnJlc2giLCJpc3MiOiJtaWNyb3Mvc2lnbi1pbi1zZXJ2aWNlIiwidXNlcklkIjoiNzEyMDIwOjVkY2IyNmIzLWYzNjctNDc2Yy1hYTJlLTNjNzg3NDhiZjgyNiIsImlzU2xhY2tBcHBTb3VyY2UiOiJmYWxzZSIsImF1ZCI6Imxpbmstc2lnbmF0dXJlLXZhbGlkYXRvciIsIm5iZiI6MTY5MjEzNTk3MCwic2NvcGUiOiJMb2dpbiIsImV4cCI6MTY5MjEzNjA5MCwiaWF0IjoxNjkyMTM1OTcwLCJqdGkiOiJmNTA2MjY1My05YTZmLTRjNTQtOWViMC0yYjhiOWQ4MjQ2ODMiLCJoYXNoZWRDc3JmVG9rZW4iOiIxMDYzYTZkYjY2NTVmZDMyZGYwMjQ5N2FmZjllZWFmYmFhOTU5NzA4ZDA4ZjdmMzg4MDk4YjBjMGJlNWU3Yjg2In0.abK8iEzpDsAT9ca3zLtWLCgkEiYyz0Byng6_Zl68jhrMkR0_w-EXp8MEkWVrWXQWEdpNVRjyb9rxq32VN-STEdR3Y1wXYRI6-7z-V7Uwpk0ZxRNproba8kthUFtu22atf-II_P_zjvE4abkVaeENXCq3i9v5jX1MDgYlF_9t72BnQ_-rFpOYDGhme8NXu5w2Z2G_M3AhTzGSSxi52_0JcCAHwMN39RWQKQR3gsHPFDPC6MGdjnN_dnNB18iYNLzp2BGGBBobCPgruuvMsu9r2wDOC0W34pdFnBkXvTnoCtVp-3jIKPvVDphkuSYl28Y6-rvck86bq_Skwz76zg6DQg&state=eyJoYXNoZWRDc3JmVG9rZW4iOiJhNDAxMTdhYWZmMjkzOWNjNWI5YTE0ZDRjOTMyN2ZlYjAxNGUyNjg4ZWNhMTkzYzhhZDcyNzA5YzJlMWIzNTExIn0%3D)
  • Loading branch information
j316chuck authored Aug 18, 2023
1 parent 9946fc6 commit 36cee64
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 85 deletions.
57 changes: 19 additions & 38 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,48 +88,29 @@ def build_algorithm(name: str, kwargs: Dict[str, Any]):
raise ValueError(f'Not sure how to build algorithm: {name}')


def build_optimizer(cfg: DictConfig, model: torch.nn.Module):
if cfg.name == 'decoupled_adamw':
return DecoupledAdamW(model.parameters(),
lr=cfg.lr,
betas=cfg.betas,
eps=cfg.eps,
weight_decay=cfg.weight_decay)
elif cfg.name == 'decoupled_lionw':
return DecoupledLionW(model.parameters(),
lr=cfg.lr,
betas=cfg.betas,
weight_decay=cfg.weight_decay)
elif cfg.name == 'clip_lion':
return DecoupledClipLion(model.parameters(),
lr=cfg.lr,
betas=cfg.betas,
weight_decay=cfg.weight_decay,
outlier_threshold=cfg.outlier_threshold)
elif cfg.name == 'adalr_lion':
return DecoupledAdaLRLion(model.parameters(),
lr=cfg.lr,
betas=cfg.betas,
weight_decay=cfg.weight_decay,
outlier_threshold=cfg.outlier_threshold,
timeout=cfg.timeout,
lr_penalty=cfg.lr_penalty,
min_scale=cfg.min_scale)
def build_optimizer(model: torch.nn.Module, name: str,
optimizer_config: Dict[str, Any]):
if name == 'decoupled_adamw':
return DecoupledAdamW(model.parameters(), **optimizer_config)
elif name == 'decoupled_lionw':
return DecoupledLionW(model.parameters(), **optimizer_config)
elif name == 'clip_lion':
return DecoupledClipLion(model.parameters(), **optimizer_config)
elif name == 'adalr_lion':
return DecoupledAdaLRLion(model.parameters(), **optimizer_config)
else:
raise ValueError(f'Not sure how to build optimizer: {cfg.name}')
raise ValueError(f'Not sure how to build optimizer: {name}')


def build_scheduler(cfg: DictConfig):
if cfg.name == 'constant_with_warmup':
return ConstantWithWarmupScheduler(t_warmup=cfg.t_warmup)
elif cfg.name == 'cosine_with_warmup':
return CosineAnnealingWithWarmupScheduler(t_warmup=cfg.t_warmup,
alpha_f=cfg.alpha_f)
elif cfg.name == 'linear_decay_with_warmup':
return LinearWithWarmupScheduler(t_warmup=cfg.t_warmup,
alpha_f=cfg.alpha_f)
def build_scheduler(name: str, scheduler_config: Dict[str, Any]):
if name == 'constant_with_warmup':
return ConstantWithWarmupScheduler(**scheduler_config)
elif name == 'cosine_with_warmup':
return CosineAnnealingWithWarmupScheduler(**scheduler_config)
elif name == 'linear_decay_with_warmup':
return LinearWithWarmupScheduler(**scheduler_config)
else:
raise ValueError(f'Not sure how to build scheduler: {cfg.name}')
raise ValueError(f'Not sure how to build scheduler: {name}')


def build_tokenizer(om_tokenizer_config: DictConfig) -> PreTrainedTokenizerBase:
Expand Down
18 changes: 14 additions & 4 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Dict, Optional, Union

from composer.utils import dist
from omegaconf import DictConfig
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om

from llmfoundry.models.utils import init_empty_weights
Expand All @@ -16,14 +16,24 @@
def pop_config(cfg: DictConfig,
key: str,
must_exist: bool = True,
default_value: Any = None) -> Any:
default_value: Any = None,
convert: bool = False) -> Any:
"""Pop a value from the main config file and return it.
If the key does not exist, return the default_value or raise a RuntimeError
depending on the must_exist flag.
depending on the must_exist flag. If the convert flag is set to True, then
we will convert the value to a python object using OmegaConf.to_container.
"""
value = cfg.pop(key, None)
if value is not None:
if value is not None and convert:
if not isinstance(value, DictConfig) and not isinstance(
value, ListConfig):
raise ValueError(
f'The key: {key} has a value: {value} that cannot be \
converted to a dict or list. Please check your yaml.'
)
return om.to_container(value)
elif value is not None:
return value
elif must_exist:
raise NameError(
Expand Down
94 changes: 51 additions & 43 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import sys
import warnings
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import torch
from composer import Trainer
Expand Down Expand Up @@ -104,7 +104,7 @@ def build_composer_model(model_cfg: DictConfig,


def build_composer_peft_model(
model_cfg: DictConfig, lora_cfg: DictConfig,
pretrained_model_name_or_path: str, lora_args: Dict[str, Any],
tokenizer: PreTrainedTokenizerBase) -> ComposerHFCausalLM:
try:
from peft import LoraConfig, get_peft_model
Expand All @@ -117,11 +117,11 @@ def build_composer_peft_model(

# 1) loads a hf model, 2) adds peft modules, 3) wraps it in a ComposerHFCausalLM.
print('Building Lora config...')
lora_cfg = LoraConfig(**lora_cfg.args)
lora_cfg = LoraConfig(**lora_args)

print('Building model from HuggingFace checkpoint...')
model = MPTForCausalLM.from_pretrained(
model_cfg.pretrained_model_name_or_path, trust_remote_code=True)
model = MPTForCausalLM.from_pretrained(pretrained_model_name_or_path,
trust_remote_code=True)
print('Model built!')

print('Adding Lora modules...')
Expand Down Expand Up @@ -212,24 +212,29 @@ def main(cfg: DictConfig):
# Mandatory model training configs
model_config: DictConfig = pop_config(cfg, 'model', must_exist=True)
tokenizer_config: DictConfig = pop_config(cfg, 'tokenizer', must_exist=True)
optimizer_config: DictConfig = pop_config(cfg, 'optimizer', must_exist=True)
scheduler_config: DictConfig = pop_config(cfg, 'scheduler', must_exist=True)
optimizer_config: Dict[str, Any] = pop_config(cfg,
'optimizer',
must_exist=True,
convert=True)
scheduler_config: Dict[str, Any] = pop_config(cfg,
'scheduler',
must_exist=True,
convert=True)
train_loader_config: DictConfig = pop_config(cfg,
'train_loader',
must_exist=True)

# Optional fsdp data, fine-tuning, and eval configs
fsdp_dict_config: Optional[DictConfig] = pop_config(cfg,
'fsdp_config',
must_exist=False,
default_value=None)
fsdp_config: Optional[Dict] = om.to_container(
fsdp_dict_config
) if fsdp_dict_config is not None else None # type: ignore
lora_config: Optional[DictConfig] = pop_config(cfg,
'lora',
must_exist=False,
default_value=None)
fsdp_config: Optional[Dict[str, Any]] = pop_config(cfg,
'fsdp_config',
must_exist=False,
default_value=None,
convert=True)
lora_config: Optional[Dict[str, Any]] = pop_config(cfg,
'lora',
must_exist=False,
default_value=None,
convert=True)
eval_loader_config: Optional[DictConfig] = pop_config(cfg,
'eval_loader',
must_exist=False,
Expand Down Expand Up @@ -390,7 +395,8 @@ def main(cfg: DictConfig):
with init_context:
if lora_config is not None: # frozen model + trainable lora modules
model: ComposerHFCausalLM = build_composer_peft_model(
model_config, lora_config, tokenizer)
model_config.pretrained_model_name_or_path, lora_config['args'],
tokenizer)
print_trainable_parameters(model) # should not be 100%
else: # standard model
model = build_composer_model(model_config, tokenizer)
Expand All @@ -399,6 +405,32 @@ def main(cfg: DictConfig):
n_params = sum(p.numel() for p in model.parameters())
logged_cfg.update({'n_params': n_params})

# Optimizer
optimizer_name: str = optimizer_config.pop('name')
optimizer = build_optimizer(model, optimizer_name, optimizer_config)

# Scheduler
scheduler_name: str = scheduler_config.pop('name')
scheduler = build_scheduler(scheduler_name, scheduler_config)

# Loggers
loggers = [
build_logger(str(name), logger_cfg)
for name, logger_cfg in logger_configs.items()
] if logger_configs else None

# Callbacks
callbacks = [
build_callback(str(name), callback_cfg)
for name, callback_cfg in callback_configs.items()
] if callback_configs else None

# Algorithms
algorithms = [
build_algorithm(str(name), algorithm_cfg)
for name, algorithm_cfg in algorithm_configs.items()
] if algorithm_configs else None

# Dataloaders
print('Building train loader...')
train_loader = build_dataloader(
Expand Down Expand Up @@ -426,30 +458,6 @@ def main(cfg: DictConfig):
device_eval_batch_size)
evaluators.extend(icl_evaluators)

# Optimizer
optimizer = build_optimizer(optimizer_config, model)

# Scheduler
scheduler = build_scheduler(scheduler_config)

# Loggers
loggers = [
build_logger(str(name), logger_cfg)
for name, logger_cfg in logger_configs.items()
] if logger_configs else None

# Callbacks
callbacks = [
build_callback(str(name), callback_cfg)
for name, callback_cfg in callback_configs.items()
] if callback_configs else None

# Algorithms
algorithms = [
build_algorithm(str(name), algorithm_cfg)
for name, algorithm_cfg in algorithm_configs.items()
] if algorithm_configs else None

# Build the Trainer
print('Building trainer...')
trainer = Trainer(
Expand Down
28 changes: 28 additions & 0 deletions tests/test_train_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,31 @@ def test_optional_mispelled_params_raise_warning(self,
str(warning.message) for warning in warning_list)
# restore configs.
cfg = copy.deepcopy(old_cfg)

def test_extra_params_in_optimizer_cfg_errors(self,
cfg: DictConfig) -> None:
cfg.optimizer.beta2 = 'extra-parameter'
with pytest.raises(TypeError):
main(cfg)

def test_invalid_name_in_optimizer_cfg_errors(self,
cfg: DictConfig) -> None:
cfg.optimizer.name = 'invalid-optimizer'
with pytest.raises(ValueError) as exception_info:
main(cfg)
assert str(exception_info.value
) == 'Not sure how to build optimizer: invalid-optimizer'

def test_extra_params_in_scheduler_cfg_errors(self,
cfg: DictConfig) -> None:
cfg.scheduler.t_warmup_extra = 'extra-parameter'
with pytest.raises(TypeError):
main(cfg)

def test_invalid_name_in_scheduler_cfg_errors(self,
cfg: DictConfig) -> None:
cfg.scheduler.name = 'invalid-scheduler'
with pytest.raises(ValueError) as exception_info:
main(cfg)
assert str(exception_info.value
) == 'Not sure how to build scheduler: invalid-scheduler'

0 comments on commit 36cee64

Please sign in to comment.