Skip to content

Commit

Permalink
Bugfix: precise device in torch.load
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Rieutord committed Dec 5, 2023
1 parent 56cf264 commit 1e49f4f
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions mmt/inference/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def load_old_pytorch_model(xp_name, lc_in="esawc", lc_out="esgp"):

return model

def load_pytorch_model(xp_name, lc_in="esawc", lc_out="esgp", train_mode = False):
def load_pytorch_model(xp_name, lc_in="esawc", lc_out="esgp", train_mode = False, device = "cpu"):
"""Return the pre-trained Pytorch model from the experiment `xp_name`"""

if os.path.isabs(xp_name):
Expand All @@ -223,7 +223,7 @@ def load_pytorch_model(xp_name, lc_in="esawc", lc_out="esgp", train_mode = False

assert os.path.isfile(checkpoint_path), f"No checkpoint found at {checkpoint_path}"#

checkpoint = torch.load(checkpoint_path)
checkpoint = torch.load(checkpoint_path, map_location=device)
config = utilconf.get_config(config_path)

if config.model.type == "transformer_embedding":
Expand Down
2 changes: 1 addition & 1 deletion mmt/inference/translators.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def __init__(
self.esawc.n_labels + 1, device=self.device
)
self.encoder_decoder = io.load_pytorch_model(
checkpoint_path, lc_in="esawc", lc_out="esgp"
checkpoint_path, lc_in="esawc", lc_out="esgp", device = device
)
self.encoder_decoder.to(self.device)
self.landcover = self.esawc
Expand Down

0 comments on commit 1e49f4f

Please sign in to comment.