Skip to content

Commit

Permalink
Add method to load model from already loaded state dict
Browse files Browse the repository at this point in the history
  • Loading branch information
joeyballentine committed Nov 18, 2023
1 parent f3e1abd commit fde9a38
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/spandrel/__helpers/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def load_from_file(self, path: str | Path) -> ModelDescriptor:
"""

state_dict = self.load_state_dict_from_file(path)
return self.registry.load(state_dict).to(self.device)
return self.load_from_state_dict(state_dict)

def load_state_dict_from_file(self, path: str | Path) -> StateDict:
"""
Expand All @@ -65,6 +65,15 @@ def load_state_dict_from_file(self, path: str | Path) -> StateDict:
f"Unsupported model file extension {extension}. Please try a supported model type."
)

def load_from_state_dict(self, state_dict: StateDict) -> ModelDescriptor:
"""
Load a model from the given state dict.
Throws an `UnsupportedModelError` if the model architecture is not supported.
"""

return self.registry.load(state_dict).to(self.device)

def _load_pth(self, path: str | Path) -> StateDict:
return torch.load(
path,
Expand Down

0 comments on commit fde9a38

Please sign in to comment.