Skip to content

Commit

Permalink
Improved CodeFormer parameter detection
Browse files Browse the repository at this point in the history
  • Loading branch information
RunDevelopment committed Nov 21, 2023
1 parent d53db13 commit 3bfb845
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 19 deletions.
27 changes: 10 additions & 17 deletions src/spandrel/architectures/CodeFormer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,28 @@
SizeRequirements,
StateDict,
)
from ..__arch_helpers.state import get_seq_len
from .arch.codeformer import CodeFormer


def load(state_dict: StateDict) -> FaceSRModelDescriptor[CodeFormer]:
dim_embd = 512
n_head = 8
n_head = 8 # cannot be deduced from state dict
n_layers = 9
codebook_size = 1024
latent_size = 256
connect_list = ["32", "64", "128", "256"]
fix_modules = ["quantize", "generator"]

# This is just a guess as I only have one model to look at
position_emb = state_dict["position_emb"]
dim_embd = position_emb.shape[1]
latent_size = position_emb.shape[0]
dim_embd = state_dict["position_emb"].shape[1]
latent_size = state_dict["position_emb"].shape[0]
codebook_size = state_dict["idx_pred_layer.1.weight"].shape[0]
n_layers = get_seq_len(state_dict, "ft_layers")

try:
n_layers = len(
set([x.split(".")[1] for x in state_dict.keys() if "ft_layers" in x])
)
except: # noqa: E722
pass

codebook_size = state_dict["quantize.embedding.weight"].shape[0]

# This is also just another guess
n_head_exp = state_dict["ft_layers.0.self_attn.in_proj_weight"].shape[0] // dim_embd
n_head = 2**n_head_exp
keys = ["16", "32", "64", "128", "256", "512"]
connect_list = list(
filter(lambda k: f"fuse_convs_dict.{k}.scale.0.weight" in state_dict, keys)
)

in_nc = state_dict["encoder.blocks.0.weight"].shape[1]

Expand Down
20 changes: 18 additions & 2 deletions tests/test_CodeFormer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@
from spandrel import ModelLoader
from spandrel.architectures.CodeFormer import CodeFormer
from spandrel.architectures.CodeFormer import CodeFormer, load

from .util import ModelFile, disallowed_props
from .util import ModelFile, assert_loads_correctly, disallowed_props


def test_CodeFormer_load():
assert_loads_correctly(
load,
lambda: CodeFormer(),
lambda: CodeFormer(dim_embd=256, n_head=4),
lambda: CodeFormer(n_layers=5, codebook_size=512, latent_size=64),
lambda: CodeFormer(connect_list=["16", "32", "64"]),
condition=lambda a, b: (
a.connect_list == b.connect_list
and a.dim_embd == b.dim_embd
and a.n_layers == b.n_layers
and a.codebook_size == b.codebook_size
),
)


def test_CodeFormer(snapshot):
Expand Down

0 comments on commit 3bfb845

Please sign in to comment.