Skip to content

Commit

Permalink
Fix scale detection
Browse files Browse the repository at this point in the history
  • Loading branch information
joeyballentine committed May 11, 2024
1 parent 590d9d7 commit efa4585
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions libs/spandrel/spandrel/architectures/PLKSR/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ def __init__(self) -> None:
"feats.0.weight",
"feats.1.lk.conv.weight",
"feats.1.refine.weight",
)
and KeyCondition.has_any(
).has_any(
"feats.1.channe_mixer.0.weight",
"feats.1.channel_mixer.0.weight",
),
Expand All @@ -55,8 +54,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_PLKSR]:
dim = state_dict["feats.0.weight"].shape[0]

total_feat_layers = get_seq_len(state_dict, "feats")

upscale_key = list(state_dict.keys())[-2]
upscale_key = f"feats.{total_feat_layers - 1}.weight"
if upscale_key in state_dict:
scale_shape = state_dict[upscale_key].shape[0]
scale = math.isqrt(scale_shape // 3)
Expand Down

0 comments on commit efa4585

Please sign in to comment.