diff --git a/ludwig/backend/ray.py b/ludwig/backend/ray.py index 0364b335642..910e4efa02c 100644 --- a/ludwig/backend/ray.py +++ b/ludwig/backend/ray.py @@ -502,7 +502,7 @@ def train( # re-register the weights of the model object in the main process self.model = dist_strategy.replace_model_from_serialization(ray.get(model_ref)) - + # ensure module is initialized exactly as it is in the trainer process # so that the state dict can be loaded back into the model correctly. self.model.prepare_for_training() diff --git a/ludwig/distributed/base.py b/ludwig/distributed/base.py index 7f1f8b520f7..cbe1de0fd5b 100644 --- a/ludwig/distributed/base.py +++ b/ludwig/distributed/base.py @@ -181,11 +181,11 @@ def create_checkpoint_handle( from ludwig.utils.checkpoint_utils import MultiNodeCheckpoint return MultiNodeCheckpoint(self, model, optimizer, scheduler) - + @classmethod def extract_model_for_serialization(cls, model: nn.Module) -> Union[nn.Module, Tuple[nn.Module, List[Dict]]]: return model - + @classmethod def replace_model_from_serialization(cls, state: Union[nn.Module, Tuple[nn.Module, List[Dict]]]) -> nn.Module: assert isinstance(state, nn.Module) diff --git a/ludwig/distributed/deepspeed.py b/ludwig/distributed/deepspeed.py index 30691a97365..f92577f1753 100644 --- a/ludwig/distributed/deepspeed.py +++ b/ludwig/distributed/deepspeed.py @@ -220,11 +220,11 @@ def get_state_for_inference(self, save_path: str, device: Optional[torch.device] save_path, load_optimizer_states=False, load_lr_scheduler_states=False, load_module_only=True ) return self.model.module.cpu().state_dict() - + @classmethod def extract_model_for_serialization(cls, model: nn.Module) -> Union[nn.Module, Tuple[nn.Module, List[Dict]]]: return extract_tensors(model) - + @classmethod def replace_model_from_serialization(cls, state: Union[nn.Module, Tuple[nn.Module, List[Dict]]]) -> nn.Module: assert isinstance(state, tuple)