diff --git a/src/metatrain/experimental/alchemical_model/model.py b/src/metatrain/experimental/alchemical_model/model.py index 4e985dc7..0542fcc0 100644 --- a/src/metatrain/experimental/alchemical_model/model.py +++ b/src/metatrain/experimental/alchemical_model/model.py @@ -149,7 +149,7 @@ def forward( def load_checkpoint(cls, path: Union[str, Path]) -> "AlchemicalModel": # Load the checkpoint - checkpoint = torch.load(path, weights_only=False) + checkpoint = torch.load(path, weights_only=False, map_location="cpu") model_hypers = checkpoint["model_hypers"] model_state_dict = checkpoint["model_state_dict"] diff --git a/src/metatrain/experimental/alchemical_model/trainer.py b/src/metatrain/experimental/alchemical_model/trainer.py index 3ed190c0..4ea4e02d 100644 --- a/src/metatrain/experimental/alchemical_model/trainer.py +++ b/src/metatrain/experimental/alchemical_model/trainer.py @@ -395,7 +395,7 @@ def save_checkpoint(self, model, path: Union[str, Path]): def load_checkpoint(cls, path: Union[str, Path], train_hypers) -> "Trainer": # Load the checkpoint - checkpoint = torch.load(path, weights_only=False) + checkpoint = torch.load(path, weights_only=False, map_location="cpu") model_hypers = checkpoint["model_hypers"] model_state_dict = checkpoint["model_state_dict"] epoch = checkpoint["epoch"] diff --git a/src/metatrain/experimental/pet/model.py b/src/metatrain/experimental/pet/model.py index d0d21d73..deda99c0 100644 --- a/src/metatrain/experimental/pet/model.py +++ b/src/metatrain/experimental/pet/model.py @@ -137,7 +137,7 @@ def forward( @classmethod def load_checkpoint(cls, path: Union[str, Path]) -> "PET": - checkpoint = torch.load(path, weights_only=False) + checkpoint = torch.load(path, weights_only=False, map_location="cpu") hypers = checkpoint["hypers"] dataset_info = checkpoint["dataset_info"] model = cls( diff --git a/src/metatrain/experimental/pet/trainer.py b/src/metatrain/experimental/pet/trainer.py index 63c06afd..f089544b 100644 --- a/src/metatrain/experimental/pet/trainer.py +++ b/src/metatrain/experimental/pet/trainer.py @@ -662,7 +662,7 @@ def save_checkpoint(self, model, path: Union[str, Path]): # together with the hypers inside a file that will act as a metatrain # checkpoint checkpoint_path = self.pet_dir / "checkpoint" # type: ignore - checkpoint = torch.load(checkpoint_path, weights_only=False) + checkpoint = torch.load(checkpoint_path, weights_only=False, map_location="cpu") torch.save( { "checkpoint": checkpoint, @@ -680,7 +680,7 @@ def load_checkpoint(cls, path: Union[str, Path], train_hypers) -> "Trainer": # This function loads a metatrain PET checkpoint and returns a Trainer # instance with the hypers, while also saving the checkpoint in the # class - checkpoint = torch.load(path, weights_only=False) + checkpoint = torch.load(path, weights_only=False, map_location="cpu") trainer = cls(train_hypers) trainer.pet_checkpoint = checkpoint["checkpoint"] return trainer diff --git a/src/metatrain/experimental/soap_bpnn/model.py b/src/metatrain/experimental/soap_bpnn/model.py index 556f3ef5..b726f738 100644 --- a/src/metatrain/experimental/soap_bpnn/model.py +++ b/src/metatrain/experimental/soap_bpnn/model.py @@ -299,7 +299,7 @@ def forward( def load_checkpoint(cls, path: Union[str, Path]) -> "SoapBpnn": # Load the checkpoint - checkpoint = torch.load(path, weights_only=False) + checkpoint = torch.load(path, weights_only=False, map_location="cpu") model_hypers = checkpoint["model_hypers"] model_state_dict = checkpoint["model_state_dict"] diff --git a/src/metatrain/experimental/soap_bpnn/trainer.py b/src/metatrain/experimental/soap_bpnn/trainer.py index aed858bc..f4d74b81 100644 --- a/src/metatrain/experimental/soap_bpnn/trainer.py +++ b/src/metatrain/experimental/soap_bpnn/trainer.py @@ -432,7 +432,7 @@ def save_checkpoint(self, model, path: Union[str, Path]): def load_checkpoint(cls, path: Union[str, Path], train_hypers) -> "Trainer": # Load the checkpoint - checkpoint = torch.load(path, weights_only=False) + checkpoint = torch.load(path, weights_only=False, map_location="cpu") model_hypers = checkpoint["model_hypers"] model_state_dict = checkpoint["model_state_dict"] epoch = checkpoint["epoch"]