Skip to content

Commit

Permalink
Allow direct load of the checkpoint from the Hugging Face
Browse files Browse the repository at this point in the history
  • Loading branch information
abmazitov committed Nov 13, 2024
1 parent 7de66e9 commit b69db89
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
23 changes: 15 additions & 8 deletions src/metatrain/experimental/pet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,24 @@ def forward(
def load_checkpoint(cls, path: Union[str, Path]) -> "PET":

checkpoint = torch.load(path, weights_only=False, map_location="cpu")
hypers = checkpoint["hypers"]
if "checkpoint" in checkpoint:
# This is the case when the checkpoint was saved with the Trainer
state_dict = checkpoint["checkpoint"]["model_state_dict"]
model_hypers = checkpoint["hypers"]["ARCHITECTURAL_HYPERS"]
self_contributions = checkpoint["self_contributions"]
elif "model_state_dict" in checkpoint:
# This is the case when the checkpoint was saved for the
# HuggingFace API
state_dict = checkpoint["model_state_dict"]
model_hypers = checkpoint["model_hypers"]
self_contributions = state_dict.pop("pet.self_contributions").numpy()
else:
raise ValueError("Invalid checkpoint format")
dataset_info = checkpoint["dataset_info"]
model = cls(
model_hypers=hypers["ARCHITECTURAL_HYPERS"], dataset_info=dataset_info
)

checkpoint = torch.load(path, weights_only=False)
state_dict = checkpoint["checkpoint"]["model_state_dict"]
model = cls(model_hypers=model_hypers, dataset_info=dataset_info)

ARCHITECTURAL_HYPERS = Hypers(model.hypers)
ARCHITECTURAL_HYPERS = Hypers(model_hypers)
raw_pet = RawPET(ARCHITECTURAL_HYPERS, 0.0, len(model.atomic_types))
if ARCHITECTURAL_HYPERS.USE_LORA_PEFT:
lora_rank = ARCHITECTURAL_HYPERS.LORA_RANK
Expand All @@ -160,7 +168,6 @@ def load_checkpoint(cls, path: Union[str, Path]) -> "PET":
dtype = next(iter(new_state_dict.values())).dtype
raw_pet.to(dtype).load_state_dict(new_state_dict)

self_contributions = checkpoint["self_contributions"]
wrapper = SelfContributionsWrapper(raw_pet, self_contributions)

model.to(dtype).set_trained_model(wrapper)
Expand Down
5 changes: 4 additions & 1 deletion src/metatrain/experimental/pet/utils/update_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ def update_state_dict(state_dict: Dict) -> Dict:
"""
new_state_dict = {}
for name, value in state_dict.items():
name = name.split("pet_model.")[1]
if "pet_model" in name:
name = name.split("pet_model.")[1]
else:
name = name.replace("pet.model.", "")
new_state_dict[name] = value
return new_state_dict

0 comments on commit b69db89

Please sign in to comment.