Skip to content

Commit

Permalink
8-bit LION, take 2 (#514)
Browse files Browse the repository at this point in the history
Adds an 8-bit version of the LION optimizer. Also features 1 byte of (optional) auxiliary error correction state for each parameter to make pure bf16 training work.

Code changes:
* Adds `lion8b.py` to `llm-foundry/optim`
* Adds `DecoupledLionW_8bit` to `llm-foundry/optim/__init__.py`
* Adds `lion8b` as an option in `llm-foundry/optim/builders.py`
* Adds `test_lion8b.py` to the tests.
* Adds `mosaicml-turbo` to the GPU dependencies in `setup.py`. This is the repo that currently holds all the CUDA kernels. These are in a separate repo for now to avoid complicating LLM foundry {install, deps, source code}.
* Adds an optional `master_weight_dtype` field in `train.py`. If set to bf16 or fp16, the script does `model.to(dtype=<that dtype>)` before training. This works when we have error correction turned on.
* Tweaks `config_utils.py` to set FSDP's param_dtype to None if the master weights are already fp16/bf16.
  • Loading branch information
dblalock authored Aug 24, 2023
1 parent db2a8d9 commit 795ab4a
Show file tree
Hide file tree
Showing 7 changed files with 1,011 additions and 3 deletions.
6 changes: 5 additions & 1 deletion llmfoundry/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,9 @@

from llmfoundry.optim.adaptive_lion import DecoupledAdaLRLion, DecoupledClipLion
from llmfoundry.optim.lion import DecoupledLionW
from llmfoundry.optim.lion8b import DecoupledLionW_8bit

__all__ = ['DecoupledLionW', 'DecoupledClipLion', 'DecoupledAdaLRLion']
__all__ = [
'DecoupledLionW', 'DecoupledLionW_8bit', 'DecoupledClipLion',
'DecoupledAdaLRLion'
]
429 changes: 429 additions & 0 deletions llmfoundry/optim/lion8b.py

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
LayerFreezing, MonolithicCheckpointSaver,
ScheduledGarbageCollector)
from llmfoundry.optim import (DecoupledAdaLRLion, DecoupledClipLion,
DecoupledLionW)
DecoupledLionW, DecoupledLionW_8bit)


def build_callback(name: str, kwargs: Dict[str, Any]):
Expand Down Expand Up @@ -98,6 +98,8 @@ def build_optimizer(model: torch.nn.Module, name: str,
return DecoupledClipLion(model.parameters(), **optimizer_config)
elif name == 'adalr_lion':
return DecoupledAdaLRLion(model.parameters(), **optimizer_config)
elif name == 'decoupled_lionw_8b':
return DecoupledLionW_8bit(model.parameters(), **optimizer_config)
else:
raise ValueError(f'Not sure how to build optimizer: {name}')

Expand Down
21 changes: 20 additions & 1 deletion llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import contextlib
import math
import warnings
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Mapping, Optional, Union

from composer.utils import dist
from omegaconf import DictConfig, ListConfig
Expand Down Expand Up @@ -116,6 +116,25 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]):
# Set defaults for mixed initialization
fsdp_config.setdefault('use_orig_params', False)
fsdp_config.setdefault('load_monolith_rank0_only', True)

# no mixed precision needed for weights when they're already 16 bits
master_dtype = model_cfg.get('master_weights_dtype')
small_dtypes = ('bf16', 'f16', 'float16', 'bfloat16', 'amp_fp16',
'amp_bf16')
if fsdp_config and master_dtype in small_dtypes:
reduce_dtype = None
buffer_dtype = None
mixed_precision = fsdp_config.get('mixed_precision')
if isinstance(mixed_precision, Mapping):
reduce_dtype = mixed_precision.get('reduce_dtype')
buffer_dtype = mixed_precision.get('buffer_dtype')
fsdp_config['mixed_precision'] = {
'param_dtype': None,
'reduce_dtype': reduce_dtype,
'buffer_dtype': buffer_dtype,
'keep_low_precision_grads': True,
}

return init_context


Expand Down
5 changes: 5 additions & 0 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,10 @@ def main(cfg: DictConfig):
print_trainable_parameters(model) # should not be 100%
else: # standard model
model = build_composer_model(model_config, tokenizer)
if model_config.get('master_weights_dtype') in ('bf16', 'bfloat16'):
model = model.to(dtype=torch.bfloat16)
elif model_config.get('master_weights_dtype') in ('f16', 'float16'):
model = model.to(dtype=torch.float16)

# Log number of parameters
n_params = sum(p.numel() for p in model.parameters())
Expand Down Expand Up @@ -515,5 +519,6 @@ def main(cfg: DictConfig):
yaml_cfg = om.load(f)
cli_cfg = om.from_cli(args_list)
cfg = om.merge(yaml_cfg, cli_cfg)
om.resolve(cfg)
assert isinstance(cfg, DictConfig)
main(cfg)
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@

extra_deps['gpu'] = [
'flash-attn==v1.0.3.post0',
'mosaicml-turbo>=0.0.2,<0.1',
# PyPI does not support direct dependencies, so we remove this line before uploading from PyPI
'xentropy-cuda-lib@git+https://github.com/HazyResearch/[email protected]#subdirectory=csrc/xentropy',
]
Expand Down
Loading

0 comments on commit 795ab4a

Please sign in to comment.