diff --git a/llmfoundry/optim/__init__.py b/llmfoundry/optim/__init__.py index ea996b3305..1d0e5caced 100644 --- a/llmfoundry/optim/__init__.py +++ b/llmfoundry/optim/__init__.py @@ -3,5 +3,9 @@ from llmfoundry.optim.adaptive_lion import DecoupledAdaLRLion, DecoupledClipLion from llmfoundry.optim.lion import DecoupledLionW +from llmfoundry.optim.lion8b import DecoupledLionW_8bit -__all__ = ['DecoupledLionW', 'DecoupledClipLion', 'DecoupledAdaLRLion'] +__all__ = [ + 'DecoupledLionW', 'DecoupledLionW_8bit', 'DecoupledClipLion', + 'DecoupledAdaLRLion' +] diff --git a/llmfoundry/optim/lion8b.py b/llmfoundry/optim/lion8b.py new file mode 100644 index 0000000000..806dbdbd14 --- /dev/null +++ b/llmfoundry/optim/lion8b.py @@ -0,0 +1,429 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Dict, Iterable, Optional, Tuple + +import torch + + +class DecoupledLionW_8bit(torch.optim.Optimizer): + """LION optimizer with ~8 bits of state per parameter. + + This optimizer is a drop-in replacement for our regular LION optimizer + with decoupled weight decay, but uses less memory, writes smaller + checkpoints, and offers almost-numerically-identical convergence. + + Its state saved per parameter is just an int8, though there are auxiliary + scaling factors that bring the total memory per parameter to ~8.5 bits. + The exact quantization scheme is considered an implementation detail + and may change. + + When training on CPUs, however, no quantization will actually take place. + + See the LION paper (https://arxiv.org/abs/2302.06675) for details about + the algorithm itself. + + Args: + params: iterable of parameters to optimize or dicts defining + parameter groups + lr: learning rate + betas: two coefficients between 0 and 1 used to combine the current + gradients and the momentum. The first coefficient is the weight + of the gradient when computing the update. The second is the + weight of the gradient when computing the new momentum. + weight decay: Weights are multiplied by 1 - `weight_decay` after + each optimizer step. Note that we use decoupled weight decay, + meaning that this decay does not contribute to the momentum. + compress_state_dict: if True, this optimizer's `state_dict` will + include quantized optimizer states. Otherwise, the optimizer + states are converted to bfloat16 Tensors matching the shapes of + their corresponding parameters. The former uses ~8.5 bits per + parameter while the latter uses 16 bits per parameter. However, + the former is less thoroughly tested and will not work with + FSDP or other weight sharding approaches. + quantize: If False, optimizer states will not actually be quantized. + This option is available so that one can easily debug whether + the quantization is causing any convergence issues. Because + quantization is only supported for CUDA parameters, attempting to + update a non-CUDA tensor will raise an error. + error_correction: If True, float16 and bfloat16 parameters will be + given an extra state variable, "errors." This tensor will be + of the same shape as the parameter but of dtype uint8. This + auxiliary variable is used to better approximate float32 updates + by retaining information across optimizer steps. + + Raises: + NotImplemenetedError - If any of `quantize`, `compress_state_dict`, + or `error_correction` are `True` and either a) there is no CUDA + device, or b) step() is executed on a non-CUDA parameter. + """ + + def __init__(self, + params: Iterable[torch.Tensor], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.99), + weight_decay: float = 0, + quantize: bool = True, + compress_state_dict: bool = False, + error_correction: bool = False, + _fused: bool = True): # XXX this flag is mostly for testing... + if lr < 0.0: + raise ValueError('Invalid learning rate: {}'.format(lr)) + if not 0.0 <= betas[0] <= 1.0: + raise ValueError('Invalid beta parameter at index 0: {}'.format( + betas[0])) + if not 0.0 <= betas[1] <= 1.0: + raise ValueError('Invalid beta parameter at index 1: {}'.format( + betas[1])) + if not 0.0 <= weight_decay: + raise ValueError( + 'Invalid weight_decay value: {}'.format(weight_decay)) + + if not torch.cuda.is_available(): + needs_cuda = ' requires a CUDA device.' + if quantize: + raise NotImplementedError('Quantization' + needs_cuda) + if error_correction: + raise NotImplementedError('Error correction' + needs_cuda) + if compress_state_dict: + raise NotImplementedError('Quantized state dict' + needs_cuda) + + _fused = _fused and quantize + self._quantize = quantize + self._error_correction = error_correction + self._compress_state_dict = compress_state_dict + + defaults = { + 'lr': lr, + 'initial_lr': lr, + 'betas': betas, + 'weight_decay': weight_decay, + 'fused': _fused + } + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure: Optional[Callable] = None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + self.step_param(p, group) + + return loss + + def step_param(self, p: torch.Tensor, hparams: Dict[str, Any]) -> None: + if not p.requires_grad or p.grad is None: + return + if self._quantize and not p.is_cuda: + raise NotImplementedError( + f"Can't use quantization with param on {p.device} " + + f'({p.shape}, {p.dtype}). If you need ' + + 'to use DecoupledLionW_8bit without a CUDA device, try ' + + 'creating this optimizer with quantize=False.') + state = self.state[p] # type:ignore using tensor as key + if 'exp_avg' not in state: + mom = torch.zeros_like(p) + state['exp_avg'] = _MaybeQuantizedTensor( + mom, try_quantize=self._quantize) + need_errs = (p.dtype != torch.float32) and self._error_correction + if state.get('errors') is None and need_errs: + state['errors'] = torch.zeros(p.shape, + dtype=torch.uint8, + device=p.device) + decay_factor = hparams['weight_decay'] + decay_factor *= hparams['lr'] / hparams['initial_lr'] + _lion8b_step(momentums=state['exp_avg'], + weights=p, + grads=p.grad, + beta1=hparams['betas'][0], + beta2=hparams['betas'][1], + lr=hparams['lr'], + weight_decay=decay_factor, + fused=hparams['fused'], + errors=state.get('errors')) + + def __setstate__(self, state: Dict[str, Dict[str, Any]]) -> None: + # we override this function to quantize optimizer states when + # loading a state dict + opt_state, _ = state.values() # other val is param_groups + for param_id in opt_state: + param_state = opt_state[param_id] + new_state = {} + if any(k.startswith('exp_avg') for k in param_state): + # the keys can either be just "exp_avg" or + # "exp_avg::quantized" and "exp_avg::scales", depending on + # whether we saved it as quantized or not. The former case + # gives us interop with regular LION. + qtensor = _MaybeQuantizedTensor(None, + try_quantize=self._quantize) + qtensor.load_state_dict(param_state, name='exp_avg') + new_state['exp_avg'] = qtensor + if 'errors' in param_state: + # we need to cast back to the correct dtype since optimizer + # load_state_dict casts to param dtype for fp params; see + # https://github.com/pytorch/pytorch/blob/a25eee1d77d93079614fab3ea4ac66e64fb2343b/torch/optim/optimizer.py#L626C7-L626C7 # noqa + errs = param_state['errors'].to(dtype=torch.uint8) + new_state['errors'] = errs + opt_state[param_id] = new_state + super().__setstate__(state) + + def state_dict(self): + # If the user hasn't opted into storing compressed state dicts + # we have to make sure our states are regular torch.Tensors. This + # is mostly needed to make FSDP happy in the case that we want to + # resume training with a number of devices where + # (param numel / device count) % quantization group size != 0 + # for any param. + d = super().state_dict() + opt_state, _ = d.values() # other val is param_groups + for param_id in opt_state: + # make a copy so that we don't mutate our self.state; opt_state + # isn't the same as self.state, but its consituent dicts are + # the same as those in self.state + param_state = {k: v for k, v in opt_state[param_id].items()} + if 'exp_avg' in param_state: # true if we've taken any steps + qtensor = param_state.pop('exp_avg') + assert isinstance(qtensor, _MaybeQuantizedTensor) # pyright + param_state.update( + qtensor.state_dict( + name='exp_avg', + allow_quantized=self._compress_state_dict)) + opt_state[param_id] = param_state + return d + + +class _MaybeQuantizedTensor: + """Helper class so 8b LION doesn't have to know quantization details. + + Important points about this class: + * It handles CPU tensors not being quantized + * It knows how to save + load state dicts, handling both the quantized + and not quantized cases + * It implements some parts of the torch.Tensor interface that we need, + but is not intended to be a full torch.Tensor replacement + """ + + def __init__(self, data: Optional[torch.Tensor], try_quantize: bool = True): + super().__init__() + self.data: Optional[torch.Tensor] = None + self.quantized: Optional[torch.Tensor] = None + self.scales: Optional[torch.Tensor] = None + self._try_quantize = try_quantize and torch.cuda.is_available() + + # conditionally import CUDA kernels + self._f_encode = None + self._f_decode = None + if self._try_quantize: + from turbo import dequantize8b, quantize8b + self._f_encode = quantize8b + self._f_decode = dequantize8b + + if data is not None: + self.set_data(data) + + def state_dict(self, + name: str, + allow_quantized: bool = False) -> Dict[str, torch.Tensor]: + if self.is_quantized() and allow_quantized: + assert self.quantized is not None # pyright + assert self.scales is not None # pyright + return { + f'{name}::quantized': self.quantized, + f'{name}::scales': self.scales + } + return {name: self.materialize().to(dtype=torch.bfloat16)} + + def load_state_dict(self, d: Dict[str, torch.Tensor], name: str) -> None: + # we allow other keys in the state dict for convenience, so you can + # just pass this the whole opt state for a parameters + d = {k: v for k, v in d.items() if k.startswith(name)} + if name in d: + if len(d) != 1: + raise ValueError( + f'If state dict specifies {name}, it must not ' + + f'specify other keys. Got {list(d.keys())}') + self.set_data(d[name]) + return + + self.quantized = d[f'{name}::quantized'].to(dtype=torch.int8) + self.scales = d[f'{name}::scales'].to(dtype=torch.float16) + + def set_data(self, data: torch.Tensor) -> None: + if self._try_quantize: + if not data.is_cuda: + raise NotImplementedError( + f'Attempting to quantize a non-CUDA {data.dtype} tensor ' + + f'on device {data.device} with shape {data.shape}.') + self.data = None + assert self._f_encode is not None # pyright + self.quantized, self.scales = self._f_encode(data) + else: + self.data = data.to(dtype=torch.float32) + self.quantized = None + self.scales = None + + def is_quantized(self) -> bool: + return self.data is None + + def materialize(self) -> torch.Tensor: + if not self.is_quantized(): + assert self.data is not None # pyright + return self.data + assert self._f_decode is not None # pyright + assert self.quantized is not None # pyright + assert self.scales is not None # pyright + return self._f_decode(self.quantized, self.scales) + + @property # property to mirror Tensor interface + def is_cuda(self) -> bool: + if self.is_quantized(): + assert self.quantized is not None # pyright + return self.quantized.is_cuda + assert self.data is not None # pyright + return self.data.is_cuda + + @property # property to mirror Tensor interface + def shape(self) -> Tuple[int]: + if self.is_quantized(): + assert self.quantized is not None # pyright + return self.quantized.shape + assert self.data is not None # pyright + return self.data.shape + + def numel(self) -> int: + if self.is_quantized(): + assert self.quantized is not None # pyright + return self.quantized.numel() + assert self.data is not None # pyright + return self.data.numel() + + def __repr__(self): + return (f'{self.__class__.__name__} quantized={self.is_quantized()} ' + + f'shape={self.shape}') + + +def lion_step_unfused(grads: torch.Tensor, + weights: torch.Tensor, + momentums: torch.Tensor, + lr: float, + beta1: float, + beta2: float, + weight_decay: float = 0) -> torch.Tensor: + # f32 cast to match fused impl + for compatibility with f32 grads or weights + momentums = momentums.to(dtype=torch.float32) + grads = grads.to(dtype=torch.float32) + + update = momentums.lerp(grads, 1 - beta1).sign_() + if weight_decay > 0: + weights.mul_(1. - weight_decay) + + weights.add_(update, alpha=-lr) + momentums.lerp_(grads, 1. - beta2) + return momentums # f32 upcast means not necessarily modified in place + + +def lion8b_step_fused(grads: torch.Tensor, + weights: torch.Tensor, + momentums: torch.Tensor, + scales: torch.Tensor, + lr: float, + beta1: float, + beta2: float, + weight_decay: float, + errors: Optional[torch.Tensor] = None) -> None: + # just to save space in lists of allowed dtypes + f16, bf16, f32 = torch.float16, torch.bfloat16, torch.float32 + + use_errors = (errors is not None) and (weights.dtype in (f16, bf16)) + orig_shape = weights.shape + + # ------------------------------------------------ wall of error checking + quantize_group_size = 32 + num_groups = (weights.numel() + quantize_group_size - + 1) // quantize_group_size + if (num_groups != scales.numel()): + raise ValueError(f'Expected {num_groups} quantization scales but ' + + f' received {scales.numel()}') + + for name, tensor, allowed_dtypes in [('grad', grads, (f16, bf16, f32)), + ('param', weights, (f16, bf16, f32)), + ('momentum', momentums, [torch.int8]), + ('scales', scales, [f16]), + ('errors', errors, [torch.uint8])]: + if name == 'errors' and not use_errors: + continue + if not tensor.is_cuda: + raise ValueError( + f'{name} must be on a CUDA device, not {tensor.device}') + if not tensor.is_contiguous(): + raise ValueError(f'{name} is not contiguous!') + strides_unequal = tensor.stride() != weights.stride() + if name not in ('scales', 'errors') and strides_unequal: + raise ValueError(f'{name} stride {tensor.stride()} != ' + + f'param stride {weights.stride()}') + if tensor.dtype not in allowed_dtypes: + raise ValueError(f'{name} must have dtype {allowed_dtypes}, not ' + + f'{tensor.dtype}') + if (name != 'scales') and (orig_shape != tensor.shape): + raise ValueError(f'Param shape {orig_shape} != ' + + f'{name} shape {tensor.shape}') + + if grads.dtype in (torch.float16, torch.bfloat16): + allowed_dtypes = (grads.dtype, torch.float32) + if weights.dtype not in allowed_dtypes: + raise ValueError( + f'Weights must be f32 or match grad dtype {grads.dtype}') + + # ------------------------------------------------ actual function call + from turbo import lion8b_step_cuda + return lion8b_step_cuda(grads=grads, + weights=weights, + momentums=momentums, + scales=scales, + lr=lr, + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + errors=errors) + + +def _lion8b_step(grads: torch.Tensor, + weights: torch.Tensor, + momentums: _MaybeQuantizedTensor, + lr: float, + beta1: float, + beta2: float, + weight_decay: float = 0, + errors: Optional[torch.Tensor] = None, + fused: bool = True) -> None: + + if fused and not momentums.is_quantized(): + raise NotImplementedError( + 'Fused LION step only implemented with quantization.') + + if momentums.is_quantized() and fused: + assert momentums.quantized is not None # pyright + assert momentums.scales is not None # pyright + return lion8b_step_fused(grads=grads, + weights=weights, + momentums=momentums.quantized, + scales=momentums.scales, + lr=lr, + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + errors=errors) + + momentums_float = momentums.materialize() + new_momentums = lion_step_unfused(grads=grads, + weights=weights, + momentums=momentums_float, + lr=lr, + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay) + momentums.set_data(new_momentums) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 7230ae6656..8bc6316edf 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -26,7 +26,7 @@ LayerFreezing, MonolithicCheckpointSaver, ScheduledGarbageCollector) from llmfoundry.optim import (DecoupledAdaLRLion, DecoupledClipLion, - DecoupledLionW) + DecoupledLionW, DecoupledLionW_8bit) def build_callback(name: str, kwargs: Dict[str, Any]): @@ -98,6 +98,8 @@ def build_optimizer(model: torch.nn.Module, name: str, return DecoupledClipLion(model.parameters(), **optimizer_config) elif name == 'adalr_lion': return DecoupledAdaLRLion(model.parameters(), **optimizer_config) + elif name == 'decoupled_lionw_8b': + return DecoupledLionW_8bit(model.parameters(), **optimizer_config) else: raise ValueError(f'Not sure how to build optimizer: {name}') diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 09dc5e7e6e..7e8156e34c 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -4,7 +4,7 @@ import contextlib import math import warnings -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Mapping, Optional, Union from composer.utils import dist from omegaconf import DictConfig, ListConfig @@ -116,6 +116,25 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]): # Set defaults for mixed initialization fsdp_config.setdefault('use_orig_params', False) fsdp_config.setdefault('load_monolith_rank0_only', True) + + # no mixed precision needed for weights when they're already 16 bits + master_dtype = model_cfg.get('master_weights_dtype') + small_dtypes = ('bf16', 'f16', 'float16', 'bfloat16', 'amp_fp16', + 'amp_bf16') + if fsdp_config and master_dtype in small_dtypes: + reduce_dtype = None + buffer_dtype = None + mixed_precision = fsdp_config.get('mixed_precision') + if isinstance(mixed_precision, Mapping): + reduce_dtype = mixed_precision.get('reduce_dtype') + buffer_dtype = mixed_precision.get('buffer_dtype') + fsdp_config['mixed_precision'] = { + 'param_dtype': None, + 'reduce_dtype': reduce_dtype, + 'buffer_dtype': buffer_dtype, + 'keep_low_precision_grads': True, + } + return init_context diff --git a/scripts/train/train.py b/scripts/train/train.py index 138ed3b8bc..0d9e4e9d10 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -400,6 +400,10 @@ def main(cfg: DictConfig): print_trainable_parameters(model) # should not be 100% else: # standard model model = build_composer_model(model_config, tokenizer) + if model_config.get('master_weights_dtype') in ('bf16', 'bfloat16'): + model = model.to(dtype=torch.bfloat16) + elif model_config.get('master_weights_dtype') in ('f16', 'float16'): + model = model.to(dtype=torch.float16) # Log number of parameters n_params = sum(p.numel() for p in model.parameters()) @@ -515,5 +519,6 @@ def main(cfg: DictConfig): yaml_cfg = om.load(f) cli_cfg = om.from_cli(args_list) cfg = om.merge(yaml_cfg, cli_cfg) + om.resolve(cfg) assert isinstance(cfg, DictConfig) main(cfg) diff --git a/setup.py b/setup.py index f32876a943..631c910051 100644 --- a/setup.py +++ b/setup.py @@ -84,6 +84,7 @@ extra_deps['gpu'] = [ 'flash-attn==v1.0.3.post0', + 'mosaicml-turbo>=0.0.2,<0.1', # PyPI does not support direct dependencies, so we remove this line before uploading from PyPI 'xentropy-cuda-lib@git+https://github.com/HazyResearch/flash-attention.git@v1.0.3#subdirectory=csrc/xentropy', ] diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py new file mode 100644 index 0000000000..2852d99b8b --- /dev/null +++ b/tests/test_lion8b.py @@ -0,0 +1,548 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import os +import time +import warnings + +import numpy as np +import packaging.version as version +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed import fsdp +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +if version.parse(torch.__version__) >= version.parse('2.0.1'): + from torch.distributed.fsdp.api import ( # type:ignore .api not in public API + FullOptimStateDictConfig, LocalOptimStateDictConfig, + ShardedOptimStateDictConfig) +else: + from unittest.mock import MagicMock # for pyright so vars aren't None + FullOptimStateDictConfig = MagicMock() + LocalOptimStateDictConfig = MagicMock() + ShardedOptimStateDictConfig = MagicMock() + +from llmfoundry.optim import DecoupledLionW_8bit as Lion8bit + +warnings.filterwarnings('ignore') + +_MANY_PARAM_SHAPES = [(1, 1), (1, 2), (17, 23), (64, 32)] +_FLOAT_DTYPES = [torch.bfloat16, torch.float16, torch.float32] + +np.set_printoptions(linewidth=160, formatter={'float': lambda f: f'{f:5.3f}'}) + + +@pytest.mark.gpu +@pytest.mark.parametrize('N,D', _MANY_PARAM_SHAPES) +@pytest.mark.parametrize('dtype', _FLOAT_DTYPES) +@pytest.mark.parametrize('fused,use_errors', [(False, False), (True, False), + (True, True)]) +def test_modifies_weights_and_momentums(N: int, D: int, dtype: torch.dtype, + fused: bool, use_errors: bool) -> None: + device = 'cuda' + torch.manual_seed(123) + X = torch.randn((N, D), device=device, requires_grad=False, dtype=dtype) + W = torch.randn((D, D), device=device, requires_grad=True, dtype=dtype) + W_orig = W.detach().clone() + + opt = Lion8bit([W], + lr=1.0, + _fused=fused, + betas=(.75, .75), + weight_decay=.2, + error_correction=use_errors) + + Y = X @ W + loss = Y.sum() + loss.backward() + torch.testing.assert_close(W_orig, W) # no weight modification yet + opt.step() + opt.zero_grad() + + with pytest.raises(AssertionError): # opt step modified the weights + torch.testing.assert_close(W_orig, W) + + # Every momentum should be nonzero with infinite precision, but + # might be zero after quantization. We turn the _MaybeQuantizedTensor + # instance into a regular torch Tensor to simplify this check. + param_state = opt.state[W] # type:ignore using tensor as key + momentum = param_state['exp_avg'].materialize() + assert momentum.shape == (D, D) + momentum = momentum.ravel() + if momentum.numel() == 1: + assert momentum.item() != 0 + else: + assert torch.std(momentum).item() > 0 + + +@pytest.mark.gpu +@pytest.mark.parametrize('N,D', _MANY_PARAM_SHAPES) +@pytest.mark.parametrize('device,dtype', [('cpu', torch.float32), + ('cuda', torch.bfloat16), + ('cuda', torch.float16), + ('cuda', torch.float32)]) +@pytest.mark.parametrize('weight_decay', [0, .1]) +@pytest.mark.parametrize('fused,use_errors', [(False, False), (True, False), + (True, True)]) +def test_changes_with_zero_grads(N: int, D: int, device: str, + dtype: torch.dtype, weight_decay: float, + fused: bool, use_errors: bool) -> None: + if (device == 'cpu') and (fused or use_errors): + return + + torch.manual_seed(123) + W = torch.rand((D, D), device=device, requires_grad=True) + W_orig = W.detach().clone() + + opt = Lion8bit([W], + _fused=fused, + betas=(.5, .5), + quantize=(device != 'cpu'), + weight_decay=weight_decay, + error_correction=use_errors) + + zeros_grad = torch.zeros_like(W) + for _ in range(5): + W.grad = zeros_grad + opt.step() + opt.zero_grad() + + mom = opt.state[W]['exp_avg'] # type:ignore using tensor as key + assert torch.all(mom.materialize() == 0) + if mom.is_quantized(): + assert torch.all(mom.quantized == 0) + + if weight_decay: + assert torch.all(W_orig.abs() > W.abs()) + else: + torch.testing.assert_close(W_orig, W) # no weight modification + + +@pytest.mark.gpu +@pytest.mark.parametrize('N,D', [(1, 8), (17, 23), (32, 32)]) +@pytest.mark.parametrize('device,dtype', [('cpu', torch.float32), + ('cuda', torch.bfloat16), + ('cuda', torch.float16), + ('cuda', torch.float32)]) +@pytest.mark.parametrize('fused,use_errors', [(False, False), (True, False), + (True, True)]) +def test_descends(N: int, D: int, device: str, dtype: torch.dtype, fused: bool, + use_errors: bool) -> None: + if (device == 'cpu') and (fused or use_errors): + return + torch.manual_seed(123) + X = torch.randn((N, D), device=device, requires_grad=False, dtype=dtype) + W = torch.randn((D, D), device=device, requires_grad=True, dtype=dtype) + + # we use tiny beta1 so we move almost entirely in the gradient direction + opt = Lion8bit([W], + lr=1e-2, + betas=(.5, .5), + quantize=(device != 'cpu'), + _fused=fused, + error_correction=use_errors) + + prev_loss = np.inf + prev_momentum = None + num_iters = 10 if device == 'cuda' else 2 # keep test fast + for _ in range(num_iters): + Y = X @ W + loss = (Y * Y).mean() + loss.backward() + opt.step() + opt.zero_grad() + + loss_val = loss.item() + assert loss_val < prev_loss + prev_loss = loss_val + + # since we're getting the same batch every time and have a small + # learning rate, our gradients should point in the same direction + # at each step. Consequently, our momentum should grow each step. + state_for_param = opt.state[W] # type:ignore using tensor as key + momentum = state_for_param['exp_avg'].materialize() + assert momentum is not None and momentum.shape == W.shape + if prev_momentum is not None: + momentum_abs_changes = (momentum - prev_momentum).abs() + assert torch.all(momentum_abs_changes >= 0) + assert momentum_abs_changes.max() > 0 + prev_momentum = momentum.clone() # {gpu, f32 on cpu} write in place + + +def _nmse(vals_true: torch.Tensor, + vals_hat: torch.Tensor, + norm_how: str = 'l2_sq'): + diffs = vals_true - vals_hat + mse = (diffs * diffs).mean() + if norm_how == 'var': + return mse / vals_true.var() + return mse / (vals_true * vals_true).mean() + + +@pytest.mark.gpu +@pytest.mark.parametrize('w_init', ['cyclic', 'rand']) +@pytest.mark.parametrize('grad_strategy', ['zero', 'ones', 'const', 'rand']) +@pytest.mark.parametrize('D', [4, 12]) # vectorized and unvectorized impls +@pytest.mark.parametrize('dtype', _FLOAT_DTYPES) +def test_lion8b_fused_unfused_unquantized_same(w_init: str, grad_strategy: str, + D: int, + dtype: torch.dtype) -> None: + torch.manual_seed(123) + device = 'cuda' + + # each optimizer gets a different copy of the weight matrix to optimize + if w_init == 'cyclic': + W0 = torch.arange(D * D, + device=device, + requires_grad=False, + dtype=dtype).reshape(D, D) + W0 = ((W0 // 2 % 3) - 1).to(dtype=dtype) + elif w_init == 'rand': + W0 = torch.rand( + size=(D, D), device=device, requires_grad=False, + dtype=dtype) * 2 - 1 + W0 += .01 * torch.sign(W0) # bound away from 0 to cap rel errors + W0 = W0.to(dtype=dtype) + else: # here for pyright + raise ValueError('Unrecognized w_init: ', w_init) + W0.add_(W0.sign()) # bound away from zero so decay won't flip sign + W_true = torch.empty_like(W0, requires_grad=True, + dtype=torch.float32) # ground truth + W_uq = torch.empty_like(W0, requires_grad=True) # unquantized + W_uf = torch.empty_like(W0, requires_grad=True) # unfused + W_fq = torch.empty_like(W0, requires_grad=True) # fused and quantized + W_fqe = torch.empty_like(W0, requires_grad=True) # fused, quantized, ecc + W_sgd = torch.empty_like(W0, requires_grad=True) + with torch.no_grad(): + W_true.copy_(W0.to(W_true.dtype)) + W_uq.copy_(W0) + W_uf.copy_(W0) + W_fq.copy_(W0) + W_fqe.copy_(W0) + W_sgd.copy_(W0) + + # we use a high LR, low betas, and regularization so that there will + # hopefully be differences if *any* of the logic is wrong + lr = .1 + weight_decay = .01 + kwargs = {'lr': lr, 'weight_decay': weight_decay, 'betas': (.5, .75)} + opt_true = Lion8bit([W_true], quantize=False, **kwargs) + opt_uq = Lion8bit([W_uq], quantize=False, **kwargs) + opt_uf = Lion8bit([W_uf], _fused=False, **kwargs) + opt_fq = Lion8bit([W_fq], _fused=True, **kwargs) + opt_fqe = Lion8bit([W_fqe], _fused=True, error_correction=True, **kwargs) + opt_sgd = torch.optim.SGD([W_sgd], lr=lr) + + W_list = [W_true, W_uq, W_uf, W_fq, W_fqe, W_sgd] + opt_list = [opt_true, opt_uq, opt_uf, opt_fq, opt_fqe, opt_sgd] + + if grad_strategy == 'zero': + grads = torch.zeros_like(W0) + elif grad_strategy == 'ones': + grads = ((torch.arange(W0.numel()) % 2) * 2 - 1).reshape(W0.shape) + elif grad_strategy == 'const': + # arange makes blocks have different distros, so we can't + # get away with bugs like always using the first scale_scale + grads = torch.arange(W0.numel(), + device=device, + requires_grad=False, + dtype=W0.dtype).view(W0.shape) + # next two conditions are just here for pyright + elif grad_strategy == 'rand': + grads = torch.tensor([-1]) + else: + raise ValueError('bad grad_strategy: ', grad_strategy) + + for _ in range(4): + if grad_strategy == 'rand': # type:ignore (reportUnnecessaryComparison) + grads = torch.rand(W0.shape, + device=device, + requires_grad=False, + dtype=W0.dtype) + for W, opt in zip(W_list, opt_list): + W.grad = grads.clone().to(dtype=W.dtype, device=W.device) + opt.step() + opt.zero_grad() + + W0_f = W0.float() + diffs_true = (W_true.detach().float() - W0_f).ravel() + diffs_uq = (W_uq.detach().float() - W0_f).ravel() + diffs_uf = (W_uf.detach().float() - W0_f).ravel() + diffs_fq = (W_fq.detach().float() - W0_f).ravel() + diffs_fqe = (W_fqe.detach().float() - W0_f).ravel() + diffs_sgd = (W_sgd.detach().float() - W0_f).ravel() + + # a bunch of made-up numbers; should be tight enough to detect + # regressions, but aren't enough to 100% guarantee correct numerics + if dtype != torch.bfloat16: + min_cossim = .99 + max_nmse = .01 + else: + min_cossim = .98 + max_nmse = .04 + + cossim = torch.cosine_similarity # avoid ugly linewraps + + assert cossim(diffs_true, diffs_uq, dim=-1) > min_cossim + assert _nmse(diffs_true, diffs_uq) < max_nmse + + assert cossim(diffs_true, diffs_uf, dim=-1) > min_cossim + assert _nmse(diffs_true, diffs_uf) < max_nmse + + # fused and unfused should be almost identical; the only differences + # are intermediate upcasting in the fused impl + assert cossim(diffs_uf, diffs_fq, dim=-1) > min_cossim + assert _nmse(diffs_uf, diffs_fq) < max_nmse + + # fused impl should be close to unfused version with no quantization + # at all; latter is "ground truth" + assert cossim(diffs_true, diffs_fq, dim=-1) > min_cossim + assert _nmse(diffs_true, diffs_fq) < max_nmse + + # fused impl with errors should also be close to "true" updates; + assert cossim(diffs_true, diffs_fqe, dim=-1) > min_cossim + assert _nmse(diffs_true, diffs_fqe) < max_nmse + + # error correction should reduce error, or at least do no worse + assert _nmse(diffs_true, diffs_fqe) <= _nmse(diffs_true, diffs_fq) + + # if sgd weights aren't different than LION weights, we haven't + # changed them enough to meaningfully test the LION logic + if grad_strategy not in ('zero', 'ones'): + assert torch.cosine_similarity( + diffs_true, # type:ignore (reportUnboundVariable) + diffs_sgd, # type:ignore (reportUnboundVariable) + dim=-1) < .99 + assert _nmse( + diffs_true, # type:ignore (reportUnboundVariable) + diffs_sgd # type:ignore (reportUnboundVariable) + ) > .01 + + +@pytest.mark.gpu +@pytest.mark.parametrize('device', ['cpu', 'cuda']) +@pytest.mark.parametrize('quantized_state', [False, True]) +@pytest.mark.parametrize('dtype', _FLOAT_DTYPES) +@pytest.mark.parametrize('use_errors', [False, True]) +def test_state_dict_save_load(device: str, quantized_state: bool, + dtype: torch.dtype, use_errors: bool): + torch.manual_seed(123) + params = [] + for shape in _MANY_PARAM_SHAPES: + p = torch.rand(shape, device=device, dtype=dtype, requires_grad=True) + p.grad = torch.rand_like(p) + params.append(p) + + # create optimizer and have it step so that state gets populated + opt = Lion8bit(params, + compress_state_dict=quantized_state, + error_correction=use_errors) + if device == 'cpu': + with pytest.raises(NotImplementedError): + opt.step() + return + else: + opt.step() + opt.zero_grad() + + # copy state dict into a new instance + state_dict = opt.state_dict() + opt_new = Lion8bit(params, + compress_state_dict=quantized_state, + error_correction=use_errors) + opt_new.load_state_dict(state_dict) + + for p in params: + d_orig = opt.state[p] + d_new = opt_new.state[p] + assert list(d_orig.keys()) == list(d_new.keys()) + mom_orig = d_orig['exp_avg'] + mom_new = d_new['exp_avg'] + if quantized_state: + # Optimizer load_state_dict insists on converting scales to + # dtype of param, which is lossy for bf16 params. + # Ideally we'd require == for everything but it's less complexity + # to just relax the bf16 test + assert torch.all(mom_orig.quantized == mom_new.quantized) + if dtype == torch.bfloat16: + torch.testing.assert_close(mom_orig.scales, + mom_new.scales, + atol=1e-3, + rtol=1e-2) + else: + assert torch.all(mom_orig.scales == mom_new.scales) + + torch.testing.assert_close(mom_orig.materialize(), + mom_new.materialize(), + atol=1. / (2 * 127), + rtol=np.inf) + if use_errors and (dtype != torch.float32): + torch.testing.assert_close(d_orig['errors'], d_new['errors']) + + +class _DummyModule(nn.Module): + + def __init__(self, device: str, dtype: torch.dtype): + super().__init__() + self.linear0 = nn.Linear(4, 3, device=device, dtype=dtype) + self.linear1 = nn.Linear(3, 4, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # type:ignore + return self.linear1(self.linear0(x)) + + +_FULL_STATE = fsdp.StateDictType.FULL_STATE_DICT +_SHARDED_STATE = fsdp.StateDictType.SHARDED_STATE_DICT +_LOCAL_STATE = fsdp.StateDictType.LOCAL_STATE_DICT + + +# run just this test with: +# python3 -m composer.cli.launcher -n 2 --master_port 26000 -m pytest -m gpu tests/test_lion8b.py::test_fsdp_save_load # noqa +@pytest.mark.gpu +@pytest.mark.world_size(2) +@pytest.mark.parametrize('dtype', _FLOAT_DTYPES) +@pytest.mark.parametrize('use_errors', [False, True]) +@pytest.mark.parametrize('state_sharding', + [_FULL_STATE, _SHARDED_STATE, _LOCAL_STATE]) +def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool, + state_sharding: fsdp.StateDictType): + device = 'cuda' + if torch.cuda.device_count() < 2: + pytest.skip(f'This test requires 2+ GPUs.') + if version.parse(torch.__version__) < version.parse('2.0.1'): + pytest.skip(f'This test requires torch 2.0.1 or greater.') + + torch.cuda.set_device(f'cuda:{os.environ["RANK"]}') # needed for fsdp + if not dist.is_initialized(): + dist.init_process_group() + assert dist.get_world_size() >= 2, 'Misconfigured test run!' + + mod = FSDP(_DummyModule(device=device, dtype=dtype)) + + # actual forward pass instead of setting p.grad to avoid FSDP issues + X = torch.rand(size=(5, 4), device=device, dtype=dtype) + Y = mod(X) + Y.sum().backward() + for p in mod.parameters(): + p.grad = torch.rand_like(p) + + # create optimizer and have it step so that state gets populated + opt = Lion8bit(mod.parameters(), error_correction=use_errors) + opt.step() + opt.zero_grad() + + def _set_state_dict_type(model: nn.Module): + # for mapping between state dict types and optim state dict types, see: + # https://github.com/pytorch/pytorch/blob/a815e719e85899d4229616617e7827d4de191c2d/torch/distributed/fsdp/fully_sharded_data_parallel.py#L664 # noqa + state_dict_cfg = { + _FULL_STATE: fsdp.FullStateDictConfig(rank0_only=False), + _SHARDED_STATE: fsdp.ShardedStateDictConfig(), + _LOCAL_STATE: fsdp.LocalStateDictConfig(), + }[state_sharding] + optim_cfg = { + _FULL_STATE: FullOptimStateDictConfig(rank0_only=False), + _SHARDED_STATE: ShardedOptimStateDictConfig(), + _LOCAL_STATE: LocalOptimStateDictConfig(), + }[state_sharding] + FSDP.set_state_dict_type(model, state_sharding, state_dict_cfg, + optim_cfg) + + # load FSDP state dict + _set_state_dict_type(mod) + opt_state_dict = FSDP.optim_state_dict(mod, opt) + + # make a new model and optimizer + mod_new = FSDP(_DummyModule(device=device, dtype=dtype)) + opt_new = Lion8bit(mod_new.parameters(), error_correction=use_errors) + _set_state_dict_type(mod_new) + + # load state dict into the new optimizer + opt_state_dict_slice = FSDP.optim_state_dict_to_load( + opt_state_dict, mod_new, opt_new) + opt_new.load_state_dict(opt_state_dict_slice) + + new_opt_state_dict = FSDP.optim_state_dict(mod_new, opt_new) + + orig_state = opt_state_dict['state'] + orig_param_groups = opt_state_dict['param_groups'] + new_state = new_opt_state_dict['state'] + new_param_groups = new_opt_state_dict['param_groups'] + + all_keys = set(orig_state.keys()) | set(new_state.keys()) + assert orig_param_groups == new_param_groups # works since strs, not ptrs + for k in all_keys: # keys are param paths in module as strings + d_orig = orig_state[k] + d_new = new_state[k] + assert list(d_orig.keys()) == list(d_new.keys()) + mom_orig = d_orig['exp_avg'] + mom_new = d_new['exp_avg'] + + assert mom_orig.shape == mom_new.shape + assert mom_orig.dtype == mom_new.dtype + if use_errors: + errs_orig = d_orig['errors'] + errs_new = d_new['errors'] + assert errs_orig.shape == errs_new.shape + assert errs_orig.dtype == errs_new.dtype + + if state_sharding != _FULL_STATE: + continue # more detailed checks lean on FSDP impl details + + # momentums may not be bit-for-bit identical because Optimizer upcasts + # to f32 and we convert back to bf16, possibly with different rounding + torch.testing.assert_close(mom_orig, mom_new) + # errors not bit-for-bit identical because scales get upcast too + if use_errors and (dtype != torch.float32): + torch.testing.assert_close(d_orig['errors'], d_new['errors']) + + +@pytest.mark.gpu +@pytest.mark.parametrize('N,D', [(32, 32), (256, 256), (1024, 1024), + (4096, 4096), [16384, 16384]]) +def test_fused_as_fast_as_unfused(N: int, + D: int, + min_elems_traversed: int = 1000000): + W = torch.randn((N, D), device='cuda', requires_grad=True) + W.grad = torch.randn((N, D), device='cuda', requires_grad=False) + + num_iters = int(np.ceil(min_elems_traversed / W.grad.numel())) + num_iters = min(100, num_iters) # don't take all day when overhead-bound + + times = {} + kwargs = {'weight_decay': .01} + combos = [(True, False), (True, True), (False, False), ('NA', False)] + for fused, use_errors in combos: + if fused == 'NA': + opt = Lion8bit([W], quantize=False, + **kwargs) # type:ignore (reportGeneralTypeIssues) + else: + opt = Lion8bit([W], + _fused=fused, + error_correction=use_errors, + **kwargs) # type:ignore (reportGeneralTypeIssues) + for _ in range(3): + opt.step() # warmup iters + torch.cuda.synchronize() + t_start = time.time() + for _ in range(num_iters): + opt.step() + torch.cuda.synchronize() + t_end = time.time() + dur = (t_end - t_start) / num_iters + if use_errors: + times['ecc'] = dur + else: + times[fused] = dur + + atol = 20e-6 # should always be faster, but avoids rare flakiness + assert times[True] < times[False] + atol + assert times[True] < times['NA'] + atol + assert times['ecc'] < times['NA'] + atol + + print('') + print('time fused (ms): ', times[True] * 1e3) + print('time fused+ecc (ms): ', times['ecc'] * 1e3) + print('time unfused (ms): ', times[False] * 1e3) + print('time unquantized (ms): ', times['NA'] * 1e3)