Skip to content

Commit

Permalink
Merge branch 'zero_copy_load' of https://github.com/ludwig-ai/ludwig
Browse files Browse the repository at this point in the history
…into zero_copy_load
  • Loading branch information
tgaddair committed Jul 21, 2023
2 parents e6987e9 + 3ddc403 commit 6fff831
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion ludwig/backend/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def train(

# re-register the weights of the model object in the main process
self.model = dist_strategy.replace_model_from_serialization(ray.get(model_ref))

# ensure module is initialized exactly as it is in the trainer process
# so that the state dict can be loaded back into the model correctly.
self.model.prepare_for_training()
Expand Down
4 changes: 2 additions & 2 deletions ludwig/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,11 @@ def create_checkpoint_handle(
from ludwig.utils.checkpoint_utils import MultiNodeCheckpoint

return MultiNodeCheckpoint(self, model, optimizer, scheduler)

@classmethod
def extract_model_for_serialization(cls, model: nn.Module) -> Union[nn.Module, Tuple[nn.Module, List[Dict]]]:
return model

@classmethod
def replace_model_from_serialization(cls, state: Union[nn.Module, Tuple[nn.Module, List[Dict]]]) -> nn.Module:
assert isinstance(state, nn.Module)
Expand Down
4 changes: 2 additions & 2 deletions ludwig/distributed/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,11 @@ def get_state_for_inference(self, save_path: str, device: Optional[torch.device]
save_path, load_optimizer_states=False, load_lr_scheduler_states=False, load_module_only=True
)
return self.model.module.cpu().state_dict()

@classmethod
def extract_model_for_serialization(cls, model: nn.Module) -> Union[nn.Module, Tuple[nn.Module, List[Dict]]]:
return extract_tensors(model)

@classmethod
def replace_model_from_serialization(cls, state: Union[nn.Module, Tuple[nn.Module, List[Dict]]]) -> nn.Module:
assert isinstance(state, tuple)
Expand Down

0 comments on commit 6fff831

Please sign in to comment.