diff --git a/src/metatrain/experimental/pet/model.py b/src/metatrain/experimental/pet/model.py index 4bd63ece..0a451823 100644 --- a/src/metatrain/experimental/pet/model.py +++ b/src/metatrain/experimental/pet/model.py @@ -139,16 +139,24 @@ def forward( def load_checkpoint(cls, path: Union[str, Path]) -> "PET": checkpoint = torch.load(path, weights_only=False, map_location="cpu") - hypers = checkpoint["hypers"] + if "checkpoint" in checkpoint: + # This is the case when the checkpoint was saved with the Trainer + state_dict = checkpoint["checkpoint"]["model_state_dict"] + model_hypers = checkpoint["hypers"]["ARCHITECTURAL_HYPERS"] + self_contributions = checkpoint["self_contributions"] + elif "model_state_dict" in checkpoint: + # This is the case when the checkpoint was saved for the + # HuggingFace API + state_dict = checkpoint["model_state_dict"] + model_hypers = checkpoint["model_hypers"] + self_contributions = state_dict.pop("pet.self_contributions").numpy() + else: + raise ValueError("Invalid checkpoint format") dataset_info = checkpoint["dataset_info"] - model = cls( - model_hypers=hypers["ARCHITECTURAL_HYPERS"], dataset_info=dataset_info - ) - checkpoint = torch.load(path, weights_only=False) - state_dict = checkpoint["checkpoint"]["model_state_dict"] + model = cls(model_hypers=model_hypers, dataset_info=dataset_info) - ARCHITECTURAL_HYPERS = Hypers(model.hypers) + ARCHITECTURAL_HYPERS = Hypers(model_hypers) raw_pet = RawPET(ARCHITECTURAL_HYPERS, 0.0, len(model.atomic_types)) if ARCHITECTURAL_HYPERS.USE_LORA_PEFT: lora_rank = ARCHITECTURAL_HYPERS.LORA_RANK @@ -160,7 +168,6 @@ def load_checkpoint(cls, path: Union[str, Path]) -> "PET": dtype = next(iter(new_state_dict.values())).dtype raw_pet.to(dtype).load_state_dict(new_state_dict) - self_contributions = checkpoint["self_contributions"] wrapper = SelfContributionsWrapper(raw_pet, self_contributions) model.to(dtype).set_trained_model(wrapper) diff --git a/src/metatrain/experimental/pet/utils/update_state_dict.py b/src/metatrain/experimental/pet/utils/update_state_dict.py index 3f3039e9..677207cc 100644 --- a/src/metatrain/experimental/pet/utils/update_state_dict.py +++ b/src/metatrain/experimental/pet/utils/update_state_dict.py @@ -7,6 +7,9 @@ def update_state_dict(state_dict: Dict) -> Dict: """ new_state_dict = {} for name, value in state_dict.items(): - name = name.split("pet_model.")[1] + if "pet_model" in name: + name = name.split("pet_model.")[1] + else: + name = name.replace("pet.model.", "") new_state_dict[name] = value return new_state_dict