Skip to content

Commit

Permalink
Require keyword args for all architectures (#279)
Browse files Browse the repository at this point in the history
* KWOnly for all except ATD

* Better error message

* ATD
  • Loading branch information
RunDevelopment authored Jul 2, 2024
1 parent 067df7b commit a2e28a8
Show file tree
Hide file tree
Showing 45 changed files with 53 additions and 7 deletions.
1 change: 1 addition & 0 deletions libs/spandrel/spandrel/architectures/ATD/arch/atd_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,7 @@ class ATD(nn.Module):

def __init__(
self,
*,
img_size=64,
patch_size=1,
in_chans=3,
Expand Down
1 change: 1 addition & 0 deletions libs/spandrel/spandrel/architectures/CRAFT/arch/CRAFT.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,7 @@ class CRAFT(nn.Module):

def __init__(
self,
*,
in_chans=3,
# img_size=64,
window_size=16,
Expand Down
1 change: 1 addition & 0 deletions libs/spandrel/spandrel/architectures/Compact/arch/SRVGG.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class SRVGGNetCompact(nn.Module):

def __init__(
self,
*,
num_in_ch=3,
num_out_ch=3,
num_feat=64,
Expand Down
1 change: 1 addition & 0 deletions libs/spandrel/spandrel/architectures/DAT/arch/DAT.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,7 @@ class DAT(nn.Module):

def __init__(
self,
*,
img_size=64,
in_chans=3,
embed_dim=180,
Expand Down
1 change: 1 addition & 0 deletions libs/spandrel/spandrel/architectures/DCTLSA/arch/dctlsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ class DCTLSA(nn.Module):

def __init__(
self,
*,
in_nc=3,
nf=55,
num_modules=6,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ class DITN_Real(nn.Module):

def __init__(
self,
*,
inp_channels=3,
dim=60,
ITL_blocks=4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,7 @@ class DRCT(nn.Module):

def __init__(
self,
*,
img_size=64,
patch_size=1,
in_chans=3,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class DRUNet(nn.Module):

def __init__(
self,
*,
in_nc=1,
out_nc=1,
nc=[64, 128, 256, 512],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class DnCNN(nn.Module):

def __init__(
self,
*,
in_nc=1,
out_nc=1,
nc=64,
Expand Down
1 change: 1 addition & 0 deletions libs/spandrel/spandrel/architectures/ESRGAN/arch/RRDB.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class RRDBNet(nn.Module):

def __init__(
self,
*,
in_nc: int = 3,
out_nc: int = 3,
num_filters: int = 64,
Expand Down
1 change: 1 addition & 0 deletions libs/spandrel/spandrel/architectures/FBCNN/arch/FBCNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ class FBCNN(nn.Module):

def __init__(
self,
*,
in_nc=3,
out_nc=3,
nc=[64, 128, 256, 512],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ class FFTformer(nn.Module):

def __init__(
self,
*,
inp_channels=3,
out_channels=3,
dim=48,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ class GFPGANv1Clean(nn.Module):

def __init__(
self,
*,
out_size=512,
num_style_feat=512,
channel_multiplier=2,
Expand Down
1 change: 1 addition & 0 deletions libs/spandrel/spandrel/architectures/GRL/arch/grl.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ class GRL(nn.Module):

def __init__(
self,
*,
img_size=64,
in_channels: int = 3,
out_channels: int = 3,
Expand Down
1 change: 1 addition & 0 deletions libs/spandrel/spandrel/architectures/HAT/arch/HAT.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,7 @@ class HAT(nn.Module):

def __init__(
self,
*,
img_size=64,
patch_size=1,
in_chans=3,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class CIDNet(nn.Module):

def __init__(
self,
*,
channels=[36, 36, 72, 144],
heads=[1, 2, 4, 8],
norm=False,
Expand Down
1 change: 1 addition & 0 deletions libs/spandrel/spandrel/architectures/IPT/arch/ipt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class IPT(nn.Module):

def __init__(
self,
*,
patch_size: int = 48,
patch_dim: int = 3,
n_feats: int = 64,
Expand Down
1 change: 1 addition & 0 deletions libs/spandrel/spandrel/architectures/KBNet/arch/kbnet_l.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ class KBNet_l(nn.Module):

def __init__(
self,
*,
inp_channels=3,
out_channels=3,
dim=48,
Expand Down
1 change: 1 addition & 0 deletions libs/spandrel/spandrel/architectures/KBNet/arch/kbnet_s.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ class KBNet_s(nn.Module):

def __init__(
self,
*,
img_channel=3,
width=64,
middle_blk_num=12,
Expand Down
2 changes: 1 addition & 1 deletion libs/spandrel/spandrel/architectures/LaMa/arch/LaMa.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ def forward(self, image, mask):
class LaMa(nn.Module):
hyperparameters = {}

def __init__(self, in_nc=3, out_nc=3) -> None:
def __init__(self, *, in_nc=3, out_nc=3) -> None:
super().__init__()
self.model = FFCResNetGenerator(in_nc, out_nc)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,7 @@ class MMRRDBNet_test(nn.Module):

def __init__(
self,
*,
num_in_ch,
num_out_ch,
scale=4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ class MixDehazeNet(nn.Module):

def __init__(
self,
*,
in_chans=3,
out_chans=4,
embed_dims=[24, 48, 96, 48, 24],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class NAFNet(nn.Module):

def __init__(
self,
*,
img_channel: int = 3,
width: int = 16,
middle_blk_num: int = 1,
Expand Down
1 change: 1 addition & 0 deletions libs/spandrel/spandrel/architectures/OmniSR/arch/OmniSR.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class OmniSR(nn.Module):

def __init__(
self,
*,
num_in_ch=3,
num_out_ch=3,
num_feat=64,
Expand Down
1 change: 1 addition & 0 deletions libs/spandrel/spandrel/architectures/PLKSR/arch/PLKSR.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ class PLKSR(nn.Module):

def __init__(
self,
*,
dim: int = 64,
n_blocks: int = 28,
upscaling_factor: int = 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class RealPLKSR(nn.Module):

def __init__(
self,
*,
dim: int = 64,
n_blocks: int = 28,
upscaling_factor: int = 4,
Expand Down
1 change: 1 addition & 0 deletions libs/spandrel/spandrel/architectures/RGT/arch/rgt.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,7 @@ class RGT(nn.Module):

def __init__(
self,
*,
img_size=64,
in_chans=3,
embed_dim=180,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def forward_d(self, x1, x4): # conv234结尾有se
class UpCunet2x(nn.Module):
hyperparameters = {}

def __init__(self, in_channels=3, out_channels=3, pro: bool = False):
def __init__(self, *, in_channels=3, out_channels=3, pro: bool = False):
super().__init__()
self.pro: Tensor | None
if pro:
Expand Down Expand Up @@ -333,7 +333,7 @@ def forward(self, x: Tensor, alpha: float = 1):
class UpCunet3x(nn.Module):
hyperparameters = {}

def __init__(self, in_channels=3, out_channels=3, pro: bool = False):
def __init__(self, *, in_channels=3, out_channels=3, pro: bool = False):
super().__init__()
self.pro: Tensor | None
if pro:
Expand Down Expand Up @@ -375,7 +375,7 @@ def forward(self, x: Tensor, alpha: float = 1):
class UpCunet4x(nn.Module):
hyperparameters = {}

def __init__(self, in_channels=3, out_channels=3, pro: bool = False):
def __init__(self, *, in_channels=3, out_channels=3, pro: bool = False):
super().__init__()
self.pro: Tensor | None
if pro:
Expand Down Expand Up @@ -425,7 +425,7 @@ def forward(self, x: Tensor, alpha: float = 1):
class UpCunet2x_fast(nn.Module):
hyperparameters = {}

def __init__(self, in_channels=3, out_channels=3):
def __init__(self, *, in_channels=3, out_channels=3):
super().__init__()
self.unet1 = UNet1(12, 64, deconv=True)
self.unet2 = UNet2(64, 64, deconv=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,7 @@ class RestoreFormer(nn.Module):

def __init__(
self,
*,
n_embed=1024,
embed_dim=256,
ch=64,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ class RetinexFormer(nn.Module):

def __init__(
self,
*,
in_channels=3,
out_channels=3,
n_feat=40,
Expand Down
2 changes: 1 addition & 1 deletion libs/spandrel/spandrel/architectures/SAFMN/arch/safmn.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def forward(self, x):
class SAFMN(nn.Module):
hyperparameters = {}

def __init__(self, dim: int, n_blocks=8, ffn_scale=2.0, upscaling_factor=4):
def __init__(self, *, dim: int, n_blocks=8, ffn_scale=2.0, upscaling_factor=4):
super().__init__()
self.to_feat = nn.Conv2d(3, dim, 3, 1, 1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class SAFMN_BCIE(nn.Module):
hyperparameters = {}

def __init__(
self, dim: int, n_blocks=6, num_layers=6, ffn_scale=2.0, upscaling_factor=2
self, *, dim: int, n_blocks=6, num_layers=6, ffn_scale=2.0, upscaling_factor=2
):
super().__init__()

Expand Down
1 change: 1 addition & 0 deletions libs/spandrel/spandrel/architectures/SCUNet/arch/SCUNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ class SCUNet(nn.Module):

def __init__(
self,
*,
in_nc=3,
config=[4, 4, 4, 4, 4, 4, 4],
dim=64,
Expand Down
1 change: 1 addition & 0 deletions libs/spandrel/spandrel/architectures/SPAN/arch/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ class SPAN(nn.Module):

def __init__(
self,
*,
num_in_ch: int,
num_out_ch: int,
feature_channels=48,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class Generator(nn.Module):

def __init__(
self,
*,
in_channels: int = 3,
num_channels: int = 64,
num_blocks: int = 16,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,7 @@ class Swin2SR(nn.Module):

def __init__(
self,
*,
img_size=64,
patch_size=1,
in_chans=3,
Expand Down
1 change: 1 addition & 0 deletions libs/spandrel/spandrel/architectures/SwinIR/arch/SwinIR.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,7 @@ class SwinIR(nn.Module):

def __init__(
self,
*,
img_size=64,
patch_size=1,
in_chans=3,
Expand Down
6 changes: 6 additions & 0 deletions libs/spandrel/spandrel/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,12 @@ def inner(cls: type[C]) -> type[C]:
raise UserWarning(
"Class has **kwargs, which is not allowed in combination with @store_hyperparameters"
)
if spec.args != ["self"]:
raise UserWarning(
"@store_hyperparameters requires all arguments of `"
+ cls.__name__
+ ".__init__` after `self` to be keyword arguments. Use `def __init__(self, *, a, b, c):`."
)

@functools.wraps(old_init)
def new_init(self: C, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ class CodeFormer(VQAutoEncoder):

def __init__(
self,
*,
dim_embd=512,
n_head=8,
n_layers=9,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class DDColor(nn.Module):

def __init__(
self,
*,
encoder_name="convnext-l",
decoder_name="MultiScaleColorDecoder",
num_input_channels=3,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ class M3SNet(nn.Module):

def __init__(
self,
*,
img_channel=3,
width=32,
middle_blk_num=1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ class MIRNet_v2(nn.Module):

def __init__(
self,
*,
inp_channels=3,
out_channels=3,
n_feat=80,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ class MPRNet(nn.Module):

def __init__(
self,
*,
in_c: int = 3,
out_c: int = 3,
n_feat: int = 40,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ class Restormer(nn.Module):

def __init__(
self,
*,
inp_channels=3,
out_channels=3,
dim=48,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,7 @@ class SRFormer(nn.Module):

def __init__(
self,
*,
img_size=64,
patch_size=1,
in_chans=3,
Expand Down

0 comments on commit a2e28a8

Please sign in to comment.