Skip to content

Commit

Permalink
Add parameter detection for RestoreFormer (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
RunDevelopment authored Nov 22, 2023
1 parent 71fc7e2 commit dcee7ed
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 16 deletions.
58 changes: 52 additions & 6 deletions src/spandrel/architectures/RestoreFormer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,62 @@
SizeRequirements,
StateDict,
)
from ..__arch_helpers.state import get_seq_len
from .arch.restoreformer_arch import RestoreFormer


def load(state_dict: StateDict) -> FaceSRModelDescriptor[RestoreFormer]:
in_nc = 3
out_nc = 3
n_embed = 1024
embed_dim = 256
ch = 64
out_ch = 3
ch_mult = (1, 2, 2, 4, 4, 8)
num_res_blocks = 2
attn_resolutions = (16,)
dropout = 0.0
in_channels = 3
resolution = 512
z_channels = 256
double_z = False
enable_mid = True
head_size = 8 # cannot be deduced from the shape of tensors in state_dict

n_embed = state_dict["quantize.embedding.weight"].shape[0]
embed_dim = state_dict["quantize.embedding.weight"].shape[1]
z_channels = state_dict["quant_conv.weight"].shape[1]
double_z = state_dict["encoder.conv_out.weight"].shape[0] == 2 * z_channels

enable_mid = "encoder.mid.block_1.norm1.weight" in state_dict

ch = state_dict["encoder.conv_in.weight"].shape[0]
in_channels = state_dict["encoder.conv_in.weight"].shape[1]
out_ch = state_dict["decoder.conv_out.weight"].shape[0]

num_res_blocks = get_seq_len(state_dict, "encoder.down.0.block")

ch_mult_len = get_seq_len(state_dict, "encoder.down")
ch_mult_list = [1] * ch_mult_len
for i in range(ch_mult_len):
ch_mult_list[i] = (
state_dict[f"encoder.down.{i}.block.0.conv2.weight"].shape[0] // ch
)
ch_mult = tuple(ch_mult_list)

model = RestoreFormer(
in_channels=in_nc,
out_ch=out_nc,
n_embed=n_embed,
embed_dim=embed_dim,
ch=ch,
out_ch=out_ch,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions,
dropout=dropout,
in_channels=in_channels,
resolution=resolution,
z_channels=z_channels,
double_z=double_z,
enable_mid=enable_mid,
head_size=head_size,
)

return FaceSRModelDescriptor(
Expand All @@ -23,7 +69,7 @@ def load(state_dict: StateDict) -> FaceSRModelDescriptor[RestoreFormer]:
supports_half=False,
supports_bfloat16=True,
scale=8,
input_channels=in_nc,
output_channels=out_nc,
input_channels=in_channels,
output_channels=out_ch,
size_requirements=SizeRequirements(minimum=16),
)
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,6 @@ def __init__(
double_z=True,
enable_mid=True,
head_size=1,
**ignore_kwargs,
):
super().__init__()
self.ch = ch
Expand Down Expand Up @@ -429,7 +428,6 @@ def __init__(
give_pre_end=False,
enable_mid=True,
head_size=1,
**ignorekwargs,
):
super().__init__()
self.ch = ch
Expand Down Expand Up @@ -556,7 +554,6 @@ def __init__(
give_pre_end=False,
enable_mid=True,
head_size=1,
**ignorekwargs,
):
super().__init__()
self.ch = ch
Expand All @@ -572,11 +569,11 @@ def __init__(
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print(
"Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)
)
)
# print(
# "Working with z of shape {} = {} dimensions.".format(
# self.z_shape, np.prod(self.z_shape)
# )
# )

# z to block_in
self.conv_in = torch.nn.Conv2d(
Expand Down
29 changes: 27 additions & 2 deletions tests/test_RestoreFormer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,32 @@
from spandrel import ModelLoader
from spandrel.architectures.RestoreFormer import RestoreFormer
from spandrel.architectures.RestoreFormer import RestoreFormer, load

from .util import ModelFile, disallowed_props
from .util import ModelFile, assert_loads_correctly, disallowed_props


def test_RestoreFormer_load():
assert_loads_correctly(
load,
lambda: RestoreFormer(),
lambda: RestoreFormer(n_embed=256, embed_dim=32),
lambda: RestoreFormer(ch=32),
lambda: RestoreFormer(in_channels=1, out_ch=1),
lambda: RestoreFormer(in_channels=1, out_ch=3),
lambda: RestoreFormer(in_channels=4, out_ch=4),
lambda: RestoreFormer(num_res_blocks=3),
lambda: RestoreFormer(ch_mult=(1, 3, 6)),
lambda: RestoreFormer(z_channels=64, double_z=True),
lambda: RestoreFormer(enable_mid=False),
condition=lambda a, b: (
a.encoder.ch == b.encoder.ch
and a.encoder.num_resolutions == b.encoder.num_resolutions
and a.encoder.num_res_blocks == b.encoder.num_res_blocks
and a.encoder.resolution == b.encoder.resolution
and a.encoder.in_channels == b.encoder.in_channels
and a.encoder.enable_mid == b.encoder.enable_mid
and a.decoder.z_shape == b.decoder.z_shape
),
)


def test_RestoreFormer(snapshot):
Expand Down

0 comments on commit dcee7ed

Please sign in to comment.