Skip to content

Commit

Permalink
Simplified PLKSR (#255)
Browse files Browse the repository at this point in the history
* Simplified PLKSR detection code

* Fixed tags
  • Loading branch information
RunDevelopment authored May 11, 2024
1 parent 7477a51 commit 7f34701
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 55 deletions.
73 changes: 25 additions & 48 deletions libs/spandrel/spandrel/architectures/PLKSR/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,7 @@

from spandrel.util import KeyCondition, get_seq_len

from ...__helpers.model_descriptor import (
Architecture,
ImageModelDescriptor,
SizeRequirements,
StateDict,
)
from ...__helpers.model_descriptor import Architecture, ImageModelDescriptor, StateDict
from .arch.PLKSR import PLKSR
from .arch.RealPLKSR import RealPLKSR

Expand All @@ -24,11 +19,9 @@ def __init__(self) -> None:
super().__init__(
id="PLKSR",
detect=KeyCondition.has_all(
KeyCondition.has_all(
"feats.0.weight",
"feats.1.lk.conv.weight",
"feats.1.refine.weight",
),
"feats.0.weight",
"feats.1.lk.conv.weight",
"feats.1.refine.weight",
KeyCondition.has_any(
"feats.1.channe_mixer.0.weight",
"feats.1.channel_mixer.0.weight",
Expand All @@ -49,49 +42,33 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_PLKSR]:
norm_groups = 4 # un-detectable
dropout = 0 # un-detectable

tags: list[str] = []

size_requirements = SizeRequirements(minimum=1)

if "feats.0.weight" in state_dict:
dim = state_dict["feats.0.weight"].shape[0]
dim = state_dict["feats.0.weight"].shape[0]

total_feat_layers = get_seq_len(state_dict, "feats")
upscale_key = f"feats.{total_feat_layers - 1}.weight"
if upscale_key in state_dict:
scale_shape = state_dict[upscale_key].shape[0]
scale = math.isqrt(scale_shape // 3)
scale = math.isqrt(
state_dict[f"feats.{total_feat_layers - 1}.weight"].shape[0] // 3
)

if "feats.1.lk.conv.weight" in state_dict:
kernel_size = state_dict["feats.1.lk.conv.weight"].shape[2]
after_split = state_dict["feats.1.lk.conv.weight"].shape[0]
split_ratio = after_split / dim
kernel_size = state_dict["feats.1.lk.conv.weight"].shape[2]
split_ratio = state_dict["feats.1.lk.conv.weight"].shape[0] / dim

use_ea = "feats.1.attn.f.0.weight" in state_dict

tags.append(f"{dim}dim")
tags.append(f"{n_blocks}nb")
tags.append(f"{kernel_size}ks")

# Yes, the normal version has this typo.
if "feats.1.channe_mixer.0.weight" in state_dict:
n_blocks = total_feat_layers - 2
ccm_type = "CCM"
if "feats.1.channe_mixer.2.weight" in state_dict:
mixer_0_shape = state_dict["feats.1.channe_mixer.0.weight"].shape[2]
mixer_2_shape = state_dict["feats.1.channe_mixer.2.weight"].shape[2]

if mixer_0_shape == 3 and mixer_2_shape == 1:
ccm_type = "CCM"
tags.append("CCM")
elif mixer_0_shape == 3 and mixer_2_shape == 3:
ccm_type = "DCCM"
tags.append("DCCM")
elif mixer_0_shape == 1 and mixer_2_shape == 3:
ccm_type = "ICCM"
tags.append("ICCM")
else:
raise ValueError("Unknown CCM type")

mixer_0_shape = state_dict["feats.1.channe_mixer.0.weight"].shape[2]
mixer_2_shape = state_dict["feats.1.channe_mixer.2.weight"].shape[2]
if mixer_0_shape == 3 and mixer_2_shape == 1:
ccm_type = "CCM"
elif mixer_0_shape == 3 and mixer_2_shape == 3:
ccm_type = "DCCM"
elif mixer_0_shape == 1 and mixer_2_shape == 3:
ccm_type = "ICCM"
else:
raise ValueError("Unknown CCM type")
more_tags = [ccm_type]

model = PLKSR(
dim=dim,
Expand All @@ -104,7 +81,8 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_PLKSR]:
)
# and RealPLKSR doesn't. This makes it really convenient to detect.
elif "feats.1.channel_mixer.0.weight" in state_dict:
tags.append("Real")
more_tags = ["Real"]

n_blocks = total_feat_layers - 3
model = RealPLKSR(
dim=dim,
Expand All @@ -124,11 +102,10 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_PLKSR]:
state_dict,
architecture=self,
purpose="Restoration" if scale == 1 else "SR",
tags=tags,
tags=[f"{dim}dim", f"{n_blocks}nb", f"{kernel_size}ks", *more_tags],
supports_half=False,
supports_bfloat16=True,
scale=scale,
input_channels=3,
output_channels=3,
size_requirements=size_requirements,
)
14 changes: 7 additions & 7 deletions tests/__snapshots__/test_PLKSR.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
output_channels=3,
purpose='SR',
scale=4,
size_requirements=SizeRequirements(minimum=1, multiple_of=1, square=False),
size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=False,
tags=list([
'64dim',
'28nb',
'12nb',
'13ks',
'DCCM',
]),
Expand All @@ -31,7 +31,7 @@
output_channels=3,
purpose='SR',
scale=2,
size_requirements=SizeRequirements(minimum=1, multiple_of=1, square=False),
size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=False,
tags=list([
Expand All @@ -53,7 +53,7 @@
output_channels=3,
purpose='SR',
scale=3,
size_requirements=SizeRequirements(minimum=1, multiple_of=1, square=False),
size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=False,
tags=list([
Expand All @@ -75,7 +75,7 @@
output_channels=3,
purpose='SR',
scale=4,
size_requirements=SizeRequirements(minimum=1, multiple_of=1, square=False),
size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=False,
tags=list([
Expand All @@ -97,7 +97,7 @@
output_channels=3,
purpose='SR',
scale=2,
size_requirements=SizeRequirements(minimum=1, multiple_of=1, square=False),
size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=False,
tags=list([
Expand All @@ -119,7 +119,7 @@
output_channels=3,
purpose='SR',
scale=4,
size_requirements=SizeRequirements(minimum=1, multiple_of=1, square=False),
size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=False,
tags=list([
Expand Down

0 comments on commit 7f34701

Please sign in to comment.