Skip to content

Commit

Permalink
[low-bit optim] Update docs on supported platforms and caveats (#971)
Browse files Browse the repository at this point in the history
update
  • Loading branch information
gau-nernst authored Sep 29, 2024
1 parent 96e8fee commit c0a81f9
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions torchao/prototype/low_bit_optim/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ This folder implements:
- 4-bit optimizers as outlined in https://arxiv.org/abs/2309.01507
- FP8 optimizers using the native `torch.float8_e4m3fn` dtype (experimental)

The implementation is fully done in Python (with tensor subclass) and relies on `torch.compile()` to generate efficient fused kernel.
The implementation is fully done in Python (with tensor subclass) and relies on `torch.compile()` to generate efficient fused kernel. Thus, your platform must support `torch.compile()` to use these optimizers. We only test on CPU and CUDA, so there might be bugs or errors on other platforms.

## Usage

Expand Down Expand Up @@ -58,7 +58,7 @@ NOTE: lpmm's 4-bit AdamW does not support BF16 weights.

## Optimizer CPU offload

This folder also implements optimizer CPU offload (i.e. ZeRO-Offload) for single GPU training. For multi-GPU training, you can use FSDP's built-in CPU offload.
This folder also implements optimizer CPU offload (i.e. ZeRO-Offload) for single GPU training. Only CUDA is supported. For multi-GPU training, you can use FSDP's built-in CPU offload.

```python
import torch
Expand Down Expand Up @@ -87,6 +87,17 @@ optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, fused=True)
optim.load_state_dict(ckpt["optim"])
```

`CPUOffloadOptimizer` is not compatible with PyTorch's built-in LR scheduler because it only acts as a wrapper around the actual optimizers (and extra logic for moving data around). To adjust the LR, you have to manually update it like follows (in fact you can use the below code for all PyTorch optimizers too):

```python
lr = ... # compute your desired LR value
for param_group in optim.param_groups:
if isinstance(param_group["lr"], torch.Tensor):
param_group["lr"].fill_(lr)
else:
param_group["lr"] = lr
```

NOTE:
- Since the optimizer step is done on CPU, it is highly recommended to use a fast CPU optimizer, such as `torch.optim.AdamW(fused=True)` (requires PyTorch 2.4). For other optimizers, you can try `torch.compile()` their optimizer step.
- To minimize the amount of CPU<->GPU data transfer, we keep a copy of parameters and pre-allocate gradients memory on CPU. Therefore, expect your RAM usage to increase by 2x model size + optimizer state (which is 2x model size for Adam).
Expand Down

0 comments on commit c0a81f9

Please sign in to comment.