diff --git a/libs/spandrel/spandrel/__helpers/main_registry.py b/libs/spandrel/spandrel/__helpers/main_registry.py index 2e4f29e8..60afd944 100644 --- a/libs/spandrel/spandrel/__helpers/main_registry.py +++ b/libs/spandrel/spandrel/__helpers/main_registry.py @@ -35,6 +35,7 @@ Swin2SR, SwinIR, Uformer, + USRNet, ) from .registry import ArchRegistry, ArchSupport @@ -80,4 +81,5 @@ ArchSupport.from_architecture(DRCT.DRCTArch()), ArchSupport.from_architecture(ESRGAN.ESRGANArch()), ArchSupport.from_architecture(PLKSR.PLKSRArch()), + ArchSupport.from_architecture(USRNet.USRNetArch()), ) diff --git a/libs/spandrel/spandrel/architectures/USRNet/__init__.py b/libs/spandrel/spandrel/architectures/USRNet/__init__.py new file mode 100644 index 00000000..17cabde1 --- /dev/null +++ b/libs/spandrel/spandrel/architectures/USRNet/__init__.py @@ -0,0 +1,104 @@ +from typing_extensions import override + +from spandrel.util import KeyCondition, get_seq_len + +from ...__helpers.model_descriptor import ( + Architecture, + ImageModelDescriptor, + SizeRequirements, + StateDict, +) +from .arch.usrnet import USRNet + + +class USRNetArch(Architecture[USRNet]): + def __init__(self) -> None: + super().__init__( + id="USRNet", + detect=KeyCondition.has_all( + "p.m_head.weight", + "p.m_down1.0.res.0.weight", + "p.m_down1.0.res.2.weight", + "p.m_down2.0.res.0.weight", + "p.m_down2.0.res.2.weight", + "p.m_down3.0.res.0.weight", + "p.m_down3.0.res.2.weight", + "p.m_body.0.res.0.weight", + "p.m_body.0.res.2.weight", + "p.m_tail.weight", + "h.mlp.0.weight", + "h.mlp.0.bias", + "h.mlp.2.weight", + "h.mlp.2.bias", + "h.mlp.4.weight", + "h.mlp.4.bias", + ), + ) + + @override + def load(self, state_dict: StateDict) -> ImageModelDescriptor[USRNet]: + # n_iter = 8 + # h_nc = 64 + # in_nc = 4 + # out_nc = 3 + # nc = [64, 128, 256, 512] + # nb = 2 + act_mode = "R" + # downsample_mode = "strideconv" + # upsample_mode = "convtranspose" + + # detect parameters + n_iter = state_dict["h.mlp.4.weight"].shape[0] // 2 + h_nc = state_dict["h.mlp.0.weight"].shape[0] + + in_nc = state_dict["p.m_head.weight"].shape[1] + out_nc = state_dict["p.m_tail.weight"].shape[0] + + nc = [ + state_dict["p.m_down1.0.res.0.weight"].shape[0], + state_dict["p.m_down2.0.res.0.weight"].shape[0], + state_dict["p.m_down3.0.res.0.weight"].shape[0], + state_dict["p.m_body.0.res.0.weight"].shape[0], + ] + nb = get_seq_len(state_dict, "p.m_body") + + if f"p.m_down1.{nb}.weight" in state_dict: + downsample_mode = "strideconv" + else: + # we cannot distinguish between avgpool and maxpool + downsample_mode = "maxpool" + + if "p.m_up3.1.res.0.weight" in state_dict: + upsample_mode = "convtranspose" + elif "p.m_up3.0.weight" in state_dict: + upsample_mode = "pixelshuffle" + elif "p.m_up3.1.weight" in state_dict: + upsample_mode = "upconv" + else: + raise ValueError("Unknown upsample mode") + + model = USRNet( + n_iter=n_iter, + h_nc=h_nc, + in_nc=in_nc, + out_nc=out_nc, + nc=nc, + nb=nb, + act_mode=act_mode, + downsample_mode=downsample_mode, + upsample_mode=upsample_mode, + ) + + return ImageModelDescriptor( + model, + state_dict, + architecture=self, + purpose="Restoration", + tags=[], + supports_half=False, # TODO: check + supports_bfloat16=True, + scale=1, + input_channels=in_nc, + output_channels=out_nc, + size_requirements=SizeRequirements(multiple_of=128, square=True), + ) diff --git a/libs/spandrel/spandrel/architectures/USRNet/arch/LICENSE b/libs/spandrel/spandrel/architectures/USRNet/arch/LICENSE new file mode 100644 index 00000000..ddd784fe --- /dev/null +++ b/libs/spandrel/spandrel/architectures/USRNet/arch/LICENSE @@ -0,0 +1,9 @@ +MIT License + +Copyright (c) 2019 Kai Zhang + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/libs/spandrel/spandrel/architectures/USRNet/arch/usrnet.py b/libs/spandrel/spandrel/architectures/USRNet/arch/usrnet.py new file mode 100644 index 00000000..e41f7570 --- /dev/null +++ b/libs/spandrel/spandrel/architectures/USRNet/arch/usrnet.py @@ -0,0 +1,399 @@ +import numpy as np +import torch +import torch.nn as nn + +from spandrel.util import store_hyperparameters + +from ...__arch_helpers import dpir_basic_block as B + +""" +# -------------------------------------------- +# Kai Zhang (cskaizhang@gmail.com) +@inproceedings{zhang2020deep, + title={Deep unfolding network for image super-resolution}, + author={Zhang, Kai and Van Gool, Luc and Timofte, Radu}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={0--0}, + year={2020} +} +# -------------------------------------------- +""" + + +""" +# -------------------------------------------- +# basic functions +# -------------------------------------------- +""" + + +def splits(a, sf): + """split a into sfxsf distinct blocks + + Args: + a: NxCxWxHx2 + sf: split factor + + Returns: + b: NxCx(W/sf)x(H/sf)x2x(sf^2) + """ + b = torch.stack(torch.chunk(a, sf, dim=2), dim=5) + b = torch.cat(torch.chunk(b, sf, dim=3), dim=5) + return b + + +def r2c(x): + # convert real to complex + return torch.stack([x, torch.zeros_like(x)], -1) + + +def cdiv(x, y): + # complex division + a, b = x[..., 0], x[..., 1] + c, d = y[..., 0], y[..., 1] + cd2 = c**2 + d**2 + return torch.stack([(a * c + b * d) / cd2, (b * c - a * d) / cd2], -1) + + +def csum(x, y): + # complex + real + return torch.stack([x[..., 0] + y, x[..., 1]], -1) + + +def cabs2(x): + return x[..., 0] ** 2 + x[..., 1] ** 2 + + +def cmul(t1, t2): + """complex multiplication + + Args: + t1: NxCxHxWx2, complex tensor + t2: NxCxHxWx2 + + Returns: + output: NxCxHxWx2 + """ + real1, imag1 = t1[..., 0], t1[..., 1] + real2, imag2 = t2[..., 0], t2[..., 1] + return torch.stack( + [real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1 + ) + + +def cconj(t, inplace=False): + """complex's conjugation + + Args: + t: NxCxHxWx2 + + Returns: + output: NxCxHxWx2 + """ + c = t.clone() if not inplace else t + c[..., 1] *= -1 + return c + + +def p2o(psf, shape): + """ + Convert point-spread function to optical transfer function. + otf = p2o(psf) computes the Fast Fourier Transform (FFT) of the + point-spread function (PSF) array and creates the optical transfer + function (OTF) array that is not influenced by the PSF off-centering. + + Args: + psf: NxCxhxw + shape: [H, W] + + Returns: + otf: NxCxHxWx2 + """ + otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf) + otf[..., : psf.shape[2], : psf.shape[3]].copy_(psf) + for axis, axis_size in enumerate(psf.shape[2:]): + otf = torch.roll(otf, -int(axis_size / 2), dims=axis + 2) + otf = torch.fft.rfft(otf, 2, onesided=False) + n_ops = torch.sum( + torch.tensor(psf.shape).type_as(psf) + * torch.log2(torch.tensor(psf.shape).type_as(psf)) + ) + otf[..., 1][torch.abs(otf[..., 1]) < n_ops * 2.22e-16] = torch.tensor(0).type_as( + psf + ) + return otf + + +def upsample(x, sf=3): + """s-fold upsampler + + Upsampling the spatial size by filling the new entries with zeros + + x: tensor image, NxCxWxH + """ + st = 0 + z = torch.zeros((x.shape[0], x.shape[1], x.shape[2] * sf, x.shape[3] * sf)).type_as( + x + ) + z[..., st::sf, st::sf].copy_(x) + return z + + +def downsample(x, sf=3): + """s-fold downsampler + + Keeping the upper-left pixel for each distinct sfxsf patch and discarding the others + + x: tensor image, NxCxWxH + """ + st = 0 + return x[..., st::sf, st::sf] + + +""" +# -------------------------------------------- +# (1) Prior module; ResUNet: act as a non-blind denoiser +# x_k = P(z_k, beta_k) +# -------------------------------------------- +""" + + +class ResUNet(nn.Module): + def __init__( + self, + in_nc=4, + out_nc=3, + nc=[64, 128, 256, 512], + nb=2, + act_mode="R", + downsample_mode="strideconv", + upsample_mode="convtranspose", + ): + super().__init__() + + self.m_head = B.conv(in_nc, nc[0], bias=False, mode="C") + + # downsample + if downsample_mode == "avgpool": + downsample_block = B.downsample_avgpool + elif downsample_mode == "maxpool": + downsample_block = B.downsample_maxpool + elif downsample_mode == "strideconv": + downsample_block = B.downsample_strideconv + else: + raise NotImplementedError( + f"downsample mode [{downsample_mode:s}] is not found" + ) + + self.m_down1 = B.sequential( + *[ + B.ResBlock(nc[0], nc[0], bias=False, mode="C" + act_mode + "C") + for _ in range(nb) + ], + downsample_block(nc[0], nc[1], bias=False, mode="2"), + ) + self.m_down2 = B.sequential( + *[ + B.ResBlock(nc[1], nc[1], bias=False, mode="C" + act_mode + "C") + for _ in range(nb) + ], + downsample_block(nc[1], nc[2], bias=False, mode="2"), + ) + self.m_down3 = B.sequential( + *[ + B.ResBlock(nc[2], nc[2], bias=False, mode="C" + act_mode + "C") + for _ in range(nb) + ], + downsample_block(nc[2], nc[3], bias=False, mode="2"), + ) + + self.m_body = B.sequential( + *[ + B.ResBlock(nc[3], nc[3], bias=False, mode="C" + act_mode + "C") + for _ in range(nb) + ] + ) + + # upsample + if upsample_mode == "upconv": + upsample_block = B.upsample_upconv + elif upsample_mode == "pixelshuffle": + upsample_block = B.upsample_pixelshuffle + elif upsample_mode == "convtranspose": + upsample_block = B.upsample_convtranspose + else: + raise NotImplementedError(f"upsample mode [{upsample_mode:s}] is not found") + + self.m_up3 = B.sequential( + upsample_block(nc[3], nc[2], bias=False, mode="2"), + *[ + B.ResBlock(nc[2], nc[2], bias=False, mode="C" + act_mode + "C") + for _ in range(nb) + ], + ) + self.m_up2 = B.sequential( + upsample_block(nc[2], nc[1], bias=False, mode="2"), + *[ + B.ResBlock(nc[1], nc[1], bias=False, mode="C" + act_mode + "C") + for _ in range(nb) + ], + ) + self.m_up1 = B.sequential( + upsample_block(nc[1], nc[0], bias=False, mode="2"), + *[ + B.ResBlock(nc[0], nc[0], bias=False, mode="C" + act_mode + "C") + for _ in range(nb) + ], + ) + + self.m_tail = B.conv(nc[0], out_nc, bias=False, mode="C") + + def forward(self, x): + h, w = x.size()[-2:] + paddingBottom = int(np.ceil(h / 8) * 8 - h) + paddingRight = int(np.ceil(w / 8) * 8 - w) + x = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x) + + x1 = self.m_head(x) + x2 = self.m_down1(x1) + x3 = self.m_down2(x2) + x4 = self.m_down3(x3) + x = self.m_body(x4) + x = self.m_up3(x + x4) + x = self.m_up2(x + x3) + x = self.m_up1(x + x2) + x = self.m_tail(x + x1) + + x = x[..., :h, :w] + + return x + + +""" +# -------------------------------------------- +# (2) Data module, closed-form solution +# It is a trainable-parameter-free module ^_^ +# z_k = D(x_{k-1}, s, k, y, alpha_k) +# some can be pre-calculated +# -------------------------------------------- +""" + + +class DataNet(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, FB, FBC, F2B, FBFy, alpha, sf): + FR = FBFy + torch.fft.rfft(alpha * x, 2, onesided=False) + x1 = cmul(FB, FR) + FBR = torch.mean(splits(x1, sf), dim=-1, keepdim=False) + invW = torch.mean(splits(F2B, sf), dim=-1, keepdim=False) + invWBR = cdiv(FBR, csum(invW, alpha)) + FCBinvWBR = cmul(FBC, invWBR.repeat(1, 1, sf, sf, 1)) + FX = (FR - FCBinvWBR) / alpha.unsqueeze(-1) + Xest = torch.fft.irfft(FX, 2, onesided=False) + + return Xest + + +""" +# -------------------------------------------- +# (3) Hyper-parameter module +# -------------------------------------------- +""" + + +class HyPaNet(nn.Module): + def __init__(self, in_nc=2, out_nc=8, channel=64): + super().__init__() + self.mlp = nn.Sequential( + nn.Conv2d(in_nc, channel, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel, channel, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel, out_nc, 1, padding=0, bias=True), + nn.Softplus(), + ) + + def forward(self, x): + x = self.mlp(x) + 1e-6 + return x + + +""" +# -------------------------------------------- +# main USRNet +# deep unfolding super-resolution network +# -------------------------------------------- +""" + + +@store_hyperparameters() +class USRNet(nn.Module): + hyperparameters = {} + + def __init__( + self, + n_iter=8, + h_nc=64, + in_nc=4, + out_nc=3, + nc=[64, 128, 256, 512], + nb=2, + act_mode="R", + downsample_mode="strideconv", + upsample_mode="convtranspose", + ): + super().__init__() + + self.d = DataNet() + self.p = ResUNet( + in_nc=in_nc, + out_nc=out_nc, + nc=nc, + nb=nb, + act_mode=act_mode, + downsample_mode=downsample_mode, + upsample_mode=upsample_mode, + ) + self.h = HyPaNet(in_nc=2, out_nc=n_iter * 2, channel=h_nc) + self.n = n_iter + + def forward(self, x, k, sf, sigma): + """ + x: tensor, NxCxWxH + k: tensor, Nx(1,3)xwxh + sf: integer, 1 + sigma: tensor, Nx1x1x1 + """ + + # initialization & pre-calculation + w, h = x.shape[-2:] + FB = p2o(k, (w * sf, h * sf)) + FBC = cconj(FB, inplace=False) + F2B = r2c(cabs2(FB)) + STy = upsample(x, sf=sf) + FBFy = cmul(FBC, torch.fft.rfft(STy, 2, onesided=False)) + x = nn.functional.interpolate(x, scale_factor=sf, mode="nearest") + + # hyper-parameter, alpha & beta + ab = self.h( + torch.cat((sigma, torch.tensor(sf).type_as(sigma).expand_as(sigma)), dim=1) + ) + + # unfolding + for i in range(self.n): + x = self.d(x, FB, FBC, F2B, FBFy, ab[:, i : i + 1, ...], sf) + x = self.p( + torch.cat( + ( + x, + ab[:, i + self.n : i + self.n + 1, ...].repeat( + 1, 1, x.size(2), x.size(3) + ), + ), + dim=1, + ) + ) + + return x diff --git a/tests/test_USRNet.py b/tests/test_USRNet.py new file mode 100644 index 00000000..b45df18d --- /dev/null +++ b/tests/test_USRNet.py @@ -0,0 +1,58 @@ +from spandrel.architectures.USRNet import USRNet, USRNetArch + +from .util import ( + ModelFile, + TestImage, + assert_image_inference, + assert_loads_correctly, + assert_size_requirements, + disallowed_props, + skip_if_unchanged, +) + +skip_if_unchanged(__file__) + + +def test_load(): + assert_loads_correctly( + USRNetArch(), + lambda: USRNet(), + lambda: USRNet(in_nc=2, out_nc=1), + lambda: USRNet(nb=5), + lambda: USRNet(nc=[16, 32, 32, 8]), + lambda: USRNet(h_nc=16), + lambda: USRNet(n_iter=4), + lambda: USRNet(downsample_mode="maxpool"), + lambda: USRNet(downsample_mode="strideconv"), + lambda: USRNet(upsample_mode="upconv"), + lambda: USRNet(upsample_mode="pixelshuffle"), + lambda: USRNet(upsample_mode="convtranspose"), + ) + + +def test_size_requirements(): + return + file = ModelFile.from_url( + "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/001_classicalSR_DF2K_s64w8_SwinIR-M_x2.pth" + ) + assert_size_requirements(file.load_model()) + + file = ModelFile.from_url( + "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/002_lightweightSR_DIV2K_s64w8_SwinIR-S_x2.pth" + ) + assert_size_requirements(file.load_model()) + + +def test_SwinIR_M_s64w8_2x(snapshot): + return + file = ModelFile.from_url( + "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/001_classicalSR_DF2K_s64w8_SwinIR-M_x2.pth" + ) + model = file.load_model() + assert model == snapshot(exclude=disallowed_props) + assert isinstance(model.model, USRNet) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + )