diff --git a/src/spandrel/architectures/Compact/__init__.py b/src/spandrel/architectures/Compact/__init__.py index 9fe634d3..a1caf114 100644 --- a/src/spandrel/architectures/Compact/__init__.py +++ b/src/spandrel/architectures/Compact/__init__.py @@ -1,12 +1,12 @@ from ...__helpers.model_descriptor import SRModelDescriptor, StateDict -from ..__arch_helpers.state import get_max_seq_index, get_scale_and_output_channels +from ..__arch_helpers.state import get_scale_and_output_channels, get_seq_len from .arch.SRVGG import SRVGGNetCompact def load(state_dict: StateDict) -> SRModelDescriptor[SRVGGNetCompact]: state = state_dict - highest_num = get_max_seq_index(state, "body.{}.weight") + highest_num = get_seq_len(state, "body") - 1 in_nc = state["body.0.weight"].shape[1] num_feat = state["body.0.weight"].shape[0] diff --git a/src/spandrel/architectures/FBCNN/__init__.py b/src/spandrel/architectures/FBCNN/__init__.py index 1ed53276..7ade5724 100644 --- a/src/spandrel/architectures/FBCNN/__init__.py +++ b/src/spandrel/architectures/FBCNN/__init__.py @@ -2,6 +2,7 @@ RestorationModelDescriptor, StateDict, ) +from ..__arch_helpers.state import get_seq_len from .arch.FBCNN import FBCNN @@ -17,10 +18,7 @@ def load(state_dict: StateDict) -> RestorationModelDescriptor[FBCNN]: in_nc = state_dict["m_head.weight"].shape[1] out_nc = state_dict["m_tail.weight"].shape[0] - for i in range(0, 20): - if f"m_down1.{i}.weight" in state_dict: - nb = i - break + nb = get_seq_len(state_dict, "m_body_encoder") nc[0] = state_dict["m_head.weight"].shape[0] nc[1] = state_dict[f"m_down1.{nb}.weight"].shape[0] diff --git a/src/spandrel/architectures/KBNet/__init__.py b/src/spandrel/architectures/KBNet/__init__.py index 61e0cc01..5f01082f 100644 --- a/src/spandrel/architectures/KBNet/__init__.py +++ b/src/spandrel/architectures/KBNet/__init__.py @@ -5,7 +5,7 @@ SizeRequirements, StateDict, ) -from ..__arch_helpers.state import get_max_seq_index +from ..__arch_helpers.state import get_seq_len from .arch.kbnet_l import KBNet_l from .arch.kbnet_s import KBNet_s @@ -27,14 +27,12 @@ def load_l(state_dict: StateDict) -> RestorationModelDescriptor[KBNet_l]: dim = state_dict["patch_embed.proj.weight"].shape[0] - num_blocks[0] = get_max_seq_index(state_dict, "encoder_level1.{}.norm1.weight") + 1 - num_blocks[1] = get_max_seq_index(state_dict, "encoder_level2.{}.norm1.weight") + 1 - num_blocks[2] = get_max_seq_index(state_dict, "encoder_level3.{}.norm1.weight") + 1 - num_blocks[3] = get_max_seq_index(state_dict, "latent.{}.norm1.weight") + 1 + num_blocks[0] = get_seq_len(state_dict, "encoder_level1") + num_blocks[1] = get_seq_len(state_dict, "encoder_level2") + num_blocks[2] = get_seq_len(state_dict, "encoder_level3") + num_blocks[3] = get_seq_len(state_dict, "latent") - num_refinement_blocks = ( - get_max_seq_index(state_dict, "refinement.{}.norm1.weight") + 1 - ) + num_refinement_blocks = get_seq_len(state_dict, "refinement") heads[0] = state_dict["encoder_level1.0.ffn.temperature"].shape[0] heads[1] = state_dict["encoder_level2.0.ffn.temperature"].shape[0] @@ -87,21 +85,17 @@ def load_s(state_dict: StateDict) -> RestorationModelDescriptor[KBNet_s]: img_channel = state_dict["intro.weight"].shape[1] width = state_dict["intro.weight"].shape[0] - middle_blk_num = get_max_seq_index(state_dict, "middle_blks.{}.w") + 1 + middle_blk_num = get_seq_len(state_dict, "middle_blks") - enc_count = get_max_seq_index(state_dict, "encoders.{}.0.w") + 1 + enc_count = get_seq_len(state_dict, "encoders") enc_blk_nums = [1] * enc_count for i in range(enc_count): - enc_blk_nums[i] = ( - get_max_seq_index(state_dict, "encoders." + str(i) + ".{}.w") + 1 - ) + enc_blk_nums[i] = get_seq_len(state_dict, "encoders." + str(i)) - dec_count = get_max_seq_index(state_dict, "decoders.{}.0.w") + 1 + dec_count = get_seq_len(state_dict, "decoders") dec_blk_nums = [1] * dec_count for i in range(dec_count): - dec_blk_nums[i] = ( - get_max_seq_index(state_dict, "decoders." + str(i) + ".{}.w") + 1 - ) + dec_blk_nums[i] = get_seq_len(state_dict, "decoders." + str(i)) # in code: ffn_ch = int(c * ffn_scale) temp_c = state_dict["middle_blks.0.conv4.weight"].shape[1]