diff --git a/ludwig/utils/model_utils.py b/ludwig/utils/model_utils.py index 0a9e9cfa438..e7661cf812f 100644 --- a/ludwig/utils/model_utils.py +++ b/ludwig/utils/model_utils.py @@ -4,23 +4,31 @@ import torch -def extract_tensors(m: torch.nn.Module) -> Tuple[torch.nn.Module, List[Dict]]: +def extract_tensors(model: torch.nn.Module) -> Tuple[torch.nn.Module, List[Dict]]: """Remove the tensors from a PyTorch model, convert them to NumPy arrays, and return the stripped model and tensors. Reference implementation: https://medium.com/ibm-data-ai/how-to-load-pytorch-models-340-times-faster-with- ray-8be751a6944c # noqa """ + tensors = [] - for _, module in m.named_modules(): - # Store the tensors in Python dictionaries - params = {name: torch.clone(param).detach().numpy() for name, param in module.named_parameters(recurse=False)} - buffers = {name: torch.clone(buf).detach().numpy() for name, buf in module.named_buffers(recurse=False)} + for _, module in model.named_modules(): + # Store the tensors as numpy arrays in Python dictionaries + # Move the same tensors to a meta device since we no longer need them and we want to reduce memory pressure. + # This ensures that throughout this process, we keep memory nearly linear w.r.t model parameters. + params = {} + buffers = {} + for name, param in module.named_parameters(recurse=False): + params[name] = torch.clone(param).detach().numpy() + del param + for name, buf in module.named_buffers(recurse=False): + buffers[name] = torch.clone(buf).detach().numpy() + del buf tensors.append({"params": params, "buffers": buffers}) - # Make a copy of the original model and strip all tensors and - # buffers out of the copy. - m_copy = copy.deepcopy(m) + # Make a copy of the original model and strip all tensors and buffers out of the copy. + m_copy = copy.deepcopy(model) for _, module in m_copy.named_modules(): for name in [name for name, _ in module.named_parameters(recurse=False)] + [ name for name, _ in module.named_buffers(recurse=False)