Skip to content

Commit

Permalink
Simplified detection code for sequences
Browse files Browse the repository at this point in the history
  • Loading branch information
RunDevelopment committed Nov 21, 2023
1 parent d53db13 commit cb12089
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 23 deletions.
4 changes: 2 additions & 2 deletions src/spandrel/architectures/Compact/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from ...__helpers.model_descriptor import SRModelDescriptor, StateDict
from ..__arch_helpers.state import get_max_seq_index, get_scale_and_output_channels
from ..__arch_helpers.state import get_scale_and_output_channels, get_seq_len
from .arch.SRVGG import SRVGGNetCompact


def load(state_dict: StateDict) -> SRModelDescriptor[SRVGGNetCompact]:
state = state_dict

highest_num = get_max_seq_index(state, "body.{}.weight")
highest_num = get_seq_len(state, "body") - 1

in_nc = state["body.0.weight"].shape[1]
num_feat = state["body.0.weight"].shape[0]
Expand Down
6 changes: 2 additions & 4 deletions src/spandrel/architectures/FBCNN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
RestorationModelDescriptor,
StateDict,
)
from ..__arch_helpers.state import get_seq_len
from .arch.FBCNN import FBCNN


Expand All @@ -17,10 +18,7 @@ def load(state_dict: StateDict) -> RestorationModelDescriptor[FBCNN]:
in_nc = state_dict["m_head.weight"].shape[1]
out_nc = state_dict["m_tail.weight"].shape[0]

for i in range(0, 20):
if f"m_down1.{i}.weight" in state_dict:
nb = i
break
nb = get_seq_len(state_dict, "m_body_encoder")

nc[0] = state_dict["m_head.weight"].shape[0]
nc[1] = state_dict[f"m_down1.{nb}.weight"].shape[0]
Expand Down
28 changes: 11 additions & 17 deletions src/spandrel/architectures/KBNet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
SizeRequirements,
StateDict,
)
from ..__arch_helpers.state import get_max_seq_index
from ..__arch_helpers.state import get_seq_len
from .arch.kbnet_l import KBNet_l
from .arch.kbnet_s import KBNet_s

Expand All @@ -27,14 +27,12 @@ def load_l(state_dict: StateDict) -> RestorationModelDescriptor[KBNet_l]:

dim = state_dict["patch_embed.proj.weight"].shape[0]

num_blocks[0] = get_max_seq_index(state_dict, "encoder_level1.{}.norm1.weight") + 1
num_blocks[1] = get_max_seq_index(state_dict, "encoder_level2.{}.norm1.weight") + 1
num_blocks[2] = get_max_seq_index(state_dict, "encoder_level3.{}.norm1.weight") + 1
num_blocks[3] = get_max_seq_index(state_dict, "latent.{}.norm1.weight") + 1
num_blocks[0] = get_seq_len(state_dict, "encoder_level1")
num_blocks[1] = get_seq_len(state_dict, "encoder_level2")
num_blocks[2] = get_seq_len(state_dict, "encoder_level3")
num_blocks[3] = get_seq_len(state_dict, "latent")

num_refinement_blocks = (
get_max_seq_index(state_dict, "refinement.{}.norm1.weight") + 1
)
num_refinement_blocks = get_seq_len(state_dict, "refinement")

heads[0] = state_dict["encoder_level1.0.ffn.temperature"].shape[0]
heads[1] = state_dict["encoder_level2.0.ffn.temperature"].shape[0]
Expand Down Expand Up @@ -87,21 +85,17 @@ def load_s(state_dict: StateDict) -> RestorationModelDescriptor[KBNet_s]:
img_channel = state_dict["intro.weight"].shape[1]
width = state_dict["intro.weight"].shape[0]

middle_blk_num = get_max_seq_index(state_dict, "middle_blks.{}.w") + 1
middle_blk_num = get_seq_len(state_dict, "middle_blks")

enc_count = get_max_seq_index(state_dict, "encoders.{}.0.w") + 1
enc_count = get_seq_len(state_dict, "encoders")
enc_blk_nums = [1] * enc_count
for i in range(enc_count):
enc_blk_nums[i] = (
get_max_seq_index(state_dict, "encoders." + str(i) + ".{}.w") + 1
)
enc_blk_nums[i] = get_seq_len(state_dict, "encoders." + str(i))

dec_count = get_max_seq_index(state_dict, "decoders.{}.0.w") + 1
dec_count = get_seq_len(state_dict, "decoders")
dec_blk_nums = [1] * dec_count
for i in range(dec_count):
dec_blk_nums[i] = (
get_max_seq_index(state_dict, "decoders." + str(i) + ".{}.w") + 1
)
dec_blk_nums[i] = get_seq_len(state_dict, "decoders." + str(i))

# in code: ffn_ch = int(c * ffn_scale)
temp_c = state_dict["middle_blks.0.conv4.weight"].shape[1]
Expand Down

0 comments on commit cb12089

Please sign in to comment.