Skip to content

Commit

Permalink
Make CUGAN input/output channels configurable for all versions
Browse files Browse the repository at this point in the history
  • Loading branch information
RunDevelopment committed Jul 8, 2024
1 parent 64d3c2b commit f3a3b7c
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def __init__(self, *, in_channels=3, out_channels=3, pro: bool = False):
self.unet1 = UNet1(in_channels, 64, deconv=True)
self.unet2 = UNet2(64, 64, deconv=False)
self.ps = nn.PixelShuffle(2)
self.conv_final = nn.Conv2d(64, 12, 3, 1, padding=0, bias=True)
self.conv_final = nn.Conv2d(64, 4 * out_channels, 3, 1, padding=0, bias=True)

@property
def is_pro(self):
Expand Down Expand Up @@ -427,10 +427,10 @@ class UpCunet2x_fast(nn.Module):

def __init__(self, *, in_channels=3, out_channels=3):
super().__init__()
self.unet1 = UNet1(12, 64, deconv=True)
self.unet1 = UNet1(4 * in_channels, 64, deconv=True)
self.unet2 = UNet2(64, 64, deconv=False)
self.ps = nn.PixelShuffle(2)
self.conv_final = nn.Conv2d(64, 12, 3, 1, padding=0, bias=True)
self.conv_final = nn.Conv2d(64, 4 * out_channels, 3, 1, padding=0, bias=True)
self.inv = nn.PixelUnshuffle(2)

def forward(self, x: Tensor):
Expand Down
36 changes: 21 additions & 15 deletions libs/spandrel/spandrel/architectures/RealCUGAN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,21 +61,27 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_RealCUGAN]:
in_channels = state_dict["unet1.conv1.conv.0.weight"].shape[1]
size_requirements = SizeRequirements(minimum=32)

if "conv_final.weight" in state_dict and in_channels == 12:
# UpCunet2x_fast
scale = 2
in_channels = 3 # hard coded in UpCunet2x_fast
out_channels = 3 # hard coded in UpCunet2x_fast
model = UpCunet2x_fast(in_channels=in_channels, out_channels=out_channels)
size_requirements = SizeRequirements(minimum=40, multiple_of=4)
tags.append("fast")
elif "conv_final.weight" in state_dict:
# UpCunet4x
scale = 4
out_channels = 3 # hard coded in UpCunet4x
model = UpCunet4x(
in_channels=in_channels, out_channels=out_channels, pro=pro
)
if "conv_final.weight" in state_dict:
# UpCunet4x or UpCunet2x_fast
# This is kinda wonky, because a UpCunet4x(in=4, out=k) and a
# UpCunet2x_fast(in=1, out=k) have the same state dict shapes,
# so we'll just assume that it's a UpCunet2x_fast
out_channels = state_dict["conv_final.weight"].shape[0] // 4
if out_channels * 4 == in_channels:
# UpCunet2x_fast
scale = 2
in_channels //= 4
model = UpCunet2x_fast(
in_channels=in_channels, out_channels=out_channels
)
size_requirements = SizeRequirements(minimum=40, multiple_of=4)
tags.append("fast")
else:
# UpCunet4x
scale = 4
model = UpCunet4x(
in_channels=in_channels, out_channels=out_channels, pro=pro
)
elif state_dict["unet1.conv_bottom.weight"].shape[2] == 5:
# UpCunet3x
scale = 3
Expand Down
3 changes: 2 additions & 1 deletion tests/test_RealCUGAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ def test_load():
lambda: UpCunet3x(in_channels=1, out_channels=4),
lambda: UpCunet3x(pro=True),
lambda: UpCunet4x(in_channels=3, out_channels=3),
lambda: UpCunet4x(in_channels=1, out_channels=3),
lambda: UpCunet4x(in_channels=1, out_channels=4),
lambda: UpCunet4x(pro=True),
lambda: UpCunet2x_fast(in_channels=3, out_channels=3),
lambda: UpCunet2x_fast(in_channels=1, out_channels=1),
)


Expand Down

0 comments on commit f3a3b7c

Please sign in to comment.