-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Zero copy initialization of models onto training workers for LLMs (#3469
) Co-authored-by: Geoffrey Angus <[email protected]> Co-authored-by: Travis Addair <[email protected]>
- Loading branch information
1 parent
354627a
commit 8696a72
Showing
11 changed files
with
221 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from collections import OrderedDict | ||
from typing import Dict, List, Tuple | ||
|
||
import numpy as np | ||
import torch | ||
|
||
NUMPY_TO_TORCH_DTYPE = { | ||
bool: torch.bool, | ||
np.bool_: torch.bool, | ||
np.uint8: torch.uint8, | ||
np.int8: torch.int8, | ||
np.int16: torch.int16, | ||
np.int32: torch.int32, | ||
np.int64: torch.int64, | ||
np.float16: torch.float16, | ||
np.float32: torch.float32, | ||
np.float64: torch.float64, | ||
np.complex64: torch.complex64, | ||
np.complex128: torch.complex128, | ||
} | ||
|
||
|
||
def extract_tensors(model: torch.nn.Module) -> Tuple[torch.nn.Module, List[Dict]]: | ||
"""Remove the tensors from a PyTorch model, convert them to NumPy arrays, and return the stripped model and | ||
tensors. | ||
Reference implementation: https://medium.com/ibm-data-ai/how-to-load-pytorch-models-340-times-faster-with- | ||
ray-8be751a6944c # noqa | ||
""" | ||
|
||
tensors = [] | ||
for _, module in model.named_modules(): | ||
# Store the tensors as numpy arrays in Python dictionaries | ||
# Delete the same tensors since we no longer need them and we want to reduce memory pressure. | ||
# This ensures that throughout this process, we keep memory nearly linear w.r.t model parameters. | ||
params = OrderedDict() | ||
buffers = OrderedDict() | ||
for name, param in module.named_parameters(recurse=False): | ||
params[name] = torch.clone(param).detach().numpy() | ||
del param | ||
for name, buf in module.named_buffers(recurse=False): | ||
buffers[name] = torch.clone(buf).detach().numpy() | ||
del buf | ||
tensors.append({"params": params, "buffers": buffers}) | ||
|
||
# Strip all tensors and buffers out of the original model. | ||
for _, module in model.named_modules(): | ||
for name in [name for name, _ in module.named_parameters(recurse=False)] + [ | ||
name for name, _ in module.named_buffers(recurse=False) | ||
]: | ||
setattr(module, name, None) | ||
|
||
return model, tensors | ||
|
||
|
||
def replace_tensors(m: torch.nn.Module, tensors: List[Dict], device: torch.device): | ||
"""Restore the tensors that extract_tensors() stripped out of a PyTorch model. This operation is performed in | ||
place. | ||
Reference implementation: https://medium.com/ibm-data-ai/how-to-load-pytorch-models-340-times-faster-with- | ||
ray-8be751a6944c # noqa | ||
""" | ||
modules = [module for _, module in m.named_modules()] | ||
for module, tensor_dict in zip(modules, tensors): | ||
# There are separate APIs to set parameters and buffers. | ||
for name, array in tensor_dict["params"].items(): | ||
module.register_parameter( | ||
name, | ||
torch.nn.Parameter(torch.as_tensor(array, device=device, dtype=NUMPY_TO_TORCH_DTYPE.get(array.dtype))), | ||
) | ||
|
||
for name, array in tensor_dict["buffers"].items(): | ||
module.register_buffer( | ||
name, | ||
torch.as_tensor(array, device=device, dtype=NUMPY_TO_TORCH_DTYPE.get(array.dtype)), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.