Skip to content

Commit

Permalink
Map location when using torch.load for checkpoints (#381)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster authored Nov 11, 2024
1 parent a69d9ab commit f6d4b64
Show file tree
Hide file tree
Showing 6 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/metatrain/experimental/alchemical_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def forward(
def load_checkpoint(cls, path: Union[str, Path]) -> "AlchemicalModel":

# Load the checkpoint
checkpoint = torch.load(path, weights_only=False)
checkpoint = torch.load(path, weights_only=False, map_location="cpu")
model_hypers = checkpoint["model_hypers"]
model_state_dict = checkpoint["model_state_dict"]

Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/experimental/alchemical_model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def save_checkpoint(self, model, path: Union[str, Path]):
def load_checkpoint(cls, path: Union[str, Path], train_hypers) -> "Trainer":

# Load the checkpoint
checkpoint = torch.load(path, weights_only=False)
checkpoint = torch.load(path, weights_only=False, map_location="cpu")
model_hypers = checkpoint["model_hypers"]
model_state_dict = checkpoint["model_state_dict"]
epoch = checkpoint["epoch"]
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/experimental/pet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def forward(
@classmethod
def load_checkpoint(cls, path: Union[str, Path]) -> "PET":

checkpoint = torch.load(path, weights_only=False)
checkpoint = torch.load(path, weights_only=False, map_location="cpu")
hypers = checkpoint["hypers"]
dataset_info = checkpoint["dataset_info"]
model = cls(
Expand Down
4 changes: 2 additions & 2 deletions src/metatrain/experimental/pet/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ def save_checkpoint(self, model, path: Union[str, Path]):
# together with the hypers inside a file that will act as a metatrain
# checkpoint
checkpoint_path = self.pet_dir / "checkpoint" # type: ignore
checkpoint = torch.load(checkpoint_path, weights_only=False)
checkpoint = torch.load(checkpoint_path, weights_only=False, map_location="cpu")
torch.save(
{
"checkpoint": checkpoint,
Expand All @@ -749,7 +749,7 @@ def load_checkpoint(cls, path: Union[str, Path], train_hypers) -> "Trainer":
# This function loads a metatrain PET checkpoint and returns a Trainer
# instance with the hypers, while also saving the checkpoint in the
# class
checkpoint = torch.load(path, weights_only=False)
checkpoint = torch.load(path, weights_only=False, map_location="cpu")
trainer = cls(train_hypers)
trainer.pet_checkpoint = checkpoint["checkpoint"]
return trainer
2 changes: 1 addition & 1 deletion src/metatrain/experimental/soap_bpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def forward(
def load_checkpoint(cls, path: Union[str, Path]) -> "SoapBpnn":

# Load the checkpoint
checkpoint = torch.load(path, weights_only=False)
checkpoint = torch.load(path, weights_only=False, map_location="cpu")
model_hypers = checkpoint["model_hypers"]
model_state_dict = checkpoint["model_state_dict"]

Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/experimental/soap_bpnn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def save_checkpoint(self, model, path: Union[str, Path]):
def load_checkpoint(cls, path: Union[str, Path], train_hypers) -> "Trainer":

# Load the checkpoint
checkpoint = torch.load(path, weights_only=False)
checkpoint = torch.load(path, weights_only=False, map_location="cpu")
model_hypers = checkpoint["model_hypers"]
model_state_dict = checkpoint["model_state_dict"]
epoch = checkpoint["epoch"]
Expand Down

0 comments on commit f6d4b64

Please sign in to comment.