diff --git a/libs/spandrel/spandrel/architectures/PLKSR/__init__.py b/libs/spandrel/spandrel/architectures/PLKSR/__init__.py index 1cea7d4b..5b7ea009 100644 --- a/libs/spandrel/spandrel/architectures/PLKSR/__init__.py +++ b/libs/spandrel/spandrel/architectures/PLKSR/__init__.py @@ -7,12 +7,7 @@ from spandrel.util import KeyCondition, get_seq_len -from ...__helpers.model_descriptor import ( - Architecture, - ImageModelDescriptor, - SizeRequirements, - StateDict, -) +from ...__helpers.model_descriptor import Architecture, ImageModelDescriptor, StateDict from .arch.PLKSR import PLKSR from .arch.RealPLKSR import RealPLKSR @@ -24,11 +19,9 @@ def __init__(self) -> None: super().__init__( id="PLKSR", detect=KeyCondition.has_all( - KeyCondition.has_all( - "feats.0.weight", - "feats.1.lk.conv.weight", - "feats.1.refine.weight", - ), + "feats.0.weight", + "feats.1.lk.conv.weight", + "feats.1.refine.weight", KeyCondition.has_any( "feats.1.channe_mixer.0.weight", "feats.1.channel_mixer.0.weight", @@ -49,49 +42,33 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_PLKSR]: norm_groups = 4 # un-detectable dropout = 0 # un-detectable - tags: list[str] = [] - - size_requirements = SizeRequirements(minimum=1) - - if "feats.0.weight" in state_dict: - dim = state_dict["feats.0.weight"].shape[0] + dim = state_dict["feats.0.weight"].shape[0] total_feat_layers = get_seq_len(state_dict, "feats") - upscale_key = f"feats.{total_feat_layers - 1}.weight" - if upscale_key in state_dict: - scale_shape = state_dict[upscale_key].shape[0] - scale = math.isqrt(scale_shape // 3) + scale = math.isqrt( + state_dict[f"feats.{total_feat_layers - 1}.weight"].shape[0] // 3 + ) - if "feats.1.lk.conv.weight" in state_dict: - kernel_size = state_dict["feats.1.lk.conv.weight"].shape[2] - after_split = state_dict["feats.1.lk.conv.weight"].shape[0] - split_ratio = after_split / dim + kernel_size = state_dict["feats.1.lk.conv.weight"].shape[2] + split_ratio = state_dict["feats.1.lk.conv.weight"].shape[0] / dim use_ea = "feats.1.attn.f.0.weight" in state_dict - tags.append(f"{dim}dim") - tags.append(f"{n_blocks}nb") - tags.append(f"{kernel_size}ks") - # Yes, the normal version has this typo. if "feats.1.channe_mixer.0.weight" in state_dict: n_blocks = total_feat_layers - 2 - ccm_type = "CCM" - if "feats.1.channe_mixer.2.weight" in state_dict: - mixer_0_shape = state_dict["feats.1.channe_mixer.0.weight"].shape[2] - mixer_2_shape = state_dict["feats.1.channe_mixer.2.weight"].shape[2] - - if mixer_0_shape == 3 and mixer_2_shape == 1: - ccm_type = "CCM" - tags.append("CCM") - elif mixer_0_shape == 3 and mixer_2_shape == 3: - ccm_type = "DCCM" - tags.append("DCCM") - elif mixer_0_shape == 1 and mixer_2_shape == 3: - ccm_type = "ICCM" - tags.append("ICCM") - else: - raise ValueError("Unknown CCM type") + + mixer_0_shape = state_dict["feats.1.channe_mixer.0.weight"].shape[2] + mixer_2_shape = state_dict["feats.1.channe_mixer.2.weight"].shape[2] + if mixer_0_shape == 3 and mixer_2_shape == 1: + ccm_type = "CCM" + elif mixer_0_shape == 3 and mixer_2_shape == 3: + ccm_type = "DCCM" + elif mixer_0_shape == 1 and mixer_2_shape == 3: + ccm_type = "ICCM" + else: + raise ValueError("Unknown CCM type") + more_tags = [ccm_type] model = PLKSR( dim=dim, @@ -104,7 +81,8 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_PLKSR]: ) # and RealPLKSR doesn't. This makes it really convenient to detect. elif "feats.1.channel_mixer.0.weight" in state_dict: - tags.append("Real") + more_tags = ["Real"] + n_blocks = total_feat_layers - 3 model = RealPLKSR( dim=dim, @@ -124,11 +102,10 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_PLKSR]: state_dict, architecture=self, purpose="Restoration" if scale == 1 else "SR", - tags=tags, + tags=[f"{dim}dim", f"{n_blocks}nb", f"{kernel_size}ks", *more_tags], supports_half=False, supports_bfloat16=True, scale=scale, input_channels=3, output_channels=3, - size_requirements=size_requirements, ) diff --git a/tests/__snapshots__/test_PLKSR.ambr b/tests/__snapshots__/test_PLKSR.ambr index bc2e315f..89494faa 100644 --- a/tests/__snapshots__/test_PLKSR.ambr +++ b/tests/__snapshots__/test_PLKSR.ambr @@ -9,12 +9,12 @@ output_channels=3, purpose='SR', scale=4, - size_requirements=SizeRequirements(minimum=1, multiple_of=1, square=False), + size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False), supports_bfloat16=True, supports_half=False, tags=list([ '64dim', - '28nb', + '12nb', '13ks', 'DCCM', ]), @@ -31,7 +31,7 @@ output_channels=3, purpose='SR', scale=2, - size_requirements=SizeRequirements(minimum=1, multiple_of=1, square=False), + size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False), supports_bfloat16=True, supports_half=False, tags=list([ @@ -53,7 +53,7 @@ output_channels=3, purpose='SR', scale=3, - size_requirements=SizeRequirements(minimum=1, multiple_of=1, square=False), + size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False), supports_bfloat16=True, supports_half=False, tags=list([ @@ -75,7 +75,7 @@ output_channels=3, purpose='SR', scale=4, - size_requirements=SizeRequirements(minimum=1, multiple_of=1, square=False), + size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False), supports_bfloat16=True, supports_half=False, tags=list([ @@ -97,7 +97,7 @@ output_channels=3, purpose='SR', scale=2, - size_requirements=SizeRequirements(minimum=1, multiple_of=1, square=False), + size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False), supports_bfloat16=True, supports_half=False, tags=list([ @@ -119,7 +119,7 @@ output_channels=3, purpose='SR', scale=4, - size_requirements=SizeRequirements(minimum=1, multiple_of=1, square=False), + size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False), supports_bfloat16=True, supports_half=False, tags=list([