diff --git a/src/spandrel/architectures/CodeFormer/__init__.py b/src/spandrel/architectures/CodeFormer/__init__.py index 752b8f89..3018fc2c 100644 --- a/src/spandrel/architectures/CodeFormer/__init__.py +++ b/src/spandrel/architectures/CodeFormer/__init__.py @@ -3,35 +3,28 @@ SizeRequirements, StateDict, ) +from ..__arch_helpers.state import get_seq_len from .arch.codeformer import CodeFormer def load(state_dict: StateDict) -> FaceSRModelDescriptor[CodeFormer]: dim_embd = 512 - n_head = 8 + n_head = 8 # cannot be deduced from state dict n_layers = 9 codebook_size = 1024 latent_size = 256 connect_list = ["32", "64", "128", "256"] fix_modules = ["quantize", "generator"] - # This is just a guess as I only have one model to look at - position_emb = state_dict["position_emb"] - dim_embd = position_emb.shape[1] - latent_size = position_emb.shape[0] + dim_embd = state_dict["position_emb"].shape[1] + latent_size = state_dict["position_emb"].shape[0] + codebook_size = state_dict["idx_pred_layer.1.weight"].shape[0] + n_layers = get_seq_len(state_dict, "ft_layers") - try: - n_layers = len( - set([x.split(".")[1] for x in state_dict.keys() if "ft_layers" in x]) - ) - except: # noqa: E722 - pass - - codebook_size = state_dict["quantize.embedding.weight"].shape[0] - - # This is also just another guess - n_head_exp = state_dict["ft_layers.0.self_attn.in_proj_weight"].shape[0] // dim_embd - n_head = 2**n_head_exp + keys = ["16", "32", "64", "128", "256", "512"] + connect_list = list( + filter(lambda k: f"fuse_convs_dict.{k}.scale.0.weight" in state_dict, keys) + ) in_nc = state_dict["encoder.blocks.0.weight"].shape[1] diff --git a/tests/test_CodeFormer.py b/tests/test_CodeFormer.py index 40d2a176..68abca52 100644 --- a/tests/test_CodeFormer.py +++ b/tests/test_CodeFormer.py @@ -1,7 +1,23 @@ from spandrel import ModelLoader -from spandrel.architectures.CodeFormer import CodeFormer +from spandrel.architectures.CodeFormer import CodeFormer, load -from .util import ModelFile, disallowed_props +from .util import ModelFile, assert_loads_correctly, disallowed_props + + +def test_CodeFormer_load(): + assert_loads_correctly( + load, + lambda: CodeFormer(), + lambda: CodeFormer(dim_embd=256, n_head=4), + lambda: CodeFormer(n_layers=5, codebook_size=512, latent_size=64), + lambda: CodeFormer(connect_list=["16", "32", "64"]), + condition=lambda a, b: ( + a.connect_list == b.connect_list + and a.dim_embd == b.dim_embd + and a.n_layers == b.n_layers + and a.codebook_size == b.codebook_size + ), + ) def test_CodeFormer(snapshot):