Skip to content

Commit

Permalink
Reduce CPU memory usage during extract
Browse files Browse the repository at this point in the history
  • Loading branch information
arnavgarg1 committed Jul 14, 2023
1 parent 5fe168a commit e4284a0
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions ludwig/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,31 @@
import torch


def extract_tensors(m: torch.nn.Module) -> Tuple[torch.nn.Module, List[Dict]]:
def extract_tensors(model: torch.nn.Module) -> Tuple[torch.nn.Module, List[Dict]]:
"""Remove the tensors from a PyTorch model, convert them to NumPy arrays, and return the stripped model and
tensors.
Reference implementation: https://medium.com/ibm-data-ai/how-to-load-pytorch-models-340-times-faster-with-
ray-8be751a6944c # noqa
"""

tensors = []
for _, module in m.named_modules():
# Store the tensors in Python dictionaries
params = {name: torch.clone(param).detach().numpy() for name, param in module.named_parameters(recurse=False)}
buffers = {name: torch.clone(buf).detach().numpy() for name, buf in module.named_buffers(recurse=False)}
for _, module in model.named_modules():
# Store the tensors as numpy arrays in Python dictionaries
# Move the same tensors to a meta device since we no longer need them and we want to reduce memory pressure.
# This ensures that throughout this process, we keep memory nearly linear w.r.t model parameters.
params = {}
buffers = {}
for name, param in module.named_parameters(recurse=False):
params[name] = torch.clone(param).detach().numpy()
del param
for name, buf in module.named_buffers(recurse=False):
buffers[name] = torch.clone(buf).detach().numpy()
del buf
tensors.append({"params": params, "buffers": buffers})

# Make a copy of the original model and strip all tensors and
# buffers out of the copy.
m_copy = copy.deepcopy(m)
# Make a copy of the original model and strip all tensors and buffers out of the copy.
m_copy = copy.deepcopy(model)
for _, module in m_copy.named_modules():
for name in [name for name, _ in module.named_parameters(recurse=False)] + [
name for name, _ in module.named_buffers(recurse=False)
Expand Down

0 comments on commit e4284a0

Please sign in to comment.