From a2e28a81fb96b3afbbaab3b7b7f0da60cd7fc64f Mon Sep 17 00:00:00 2001 From: Michael Schmidt Date: Tue, 2 Jul 2024 20:56:57 +0200 Subject: [PATCH] Require keyword args for all architectures (#279) * KWOnly for all except ATD * Better error message * ATD --- libs/spandrel/spandrel/architectures/ATD/arch/atd_arch.py | 1 + libs/spandrel/spandrel/architectures/CRAFT/arch/CRAFT.py | 1 + .../spandrel/spandrel/architectures/Compact/arch/SRVGG.py | 1 + libs/spandrel/spandrel/architectures/DAT/arch/DAT.py | 1 + .../spandrel/spandrel/architectures/DCTLSA/arch/dctlsa.py | 1 + .../spandrel/architectures/DITN/arch/DITN_Real.py | 1 + .../spandrel/architectures/DRCT/arch/drct_arch.py | 1 + .../spandrel/architectures/DRUNet/arch/network_unet.py | 1 + .../spandrel/architectures/DnCNN/arch/network_dncnn.py | 1 + libs/spandrel/spandrel/architectures/ESRGAN/arch/RRDB.py | 1 + libs/spandrel/spandrel/architectures/FBCNN/arch/FBCNN.py | 1 + .../architectures/FFTformer/arch/fftformer_arch.py | 1 + .../architectures/GFPGAN/arch/gfpganv1_clean_arch.py | 1 + libs/spandrel/spandrel/architectures/GRL/arch/grl.py | 1 + libs/spandrel/spandrel/architectures/HAT/arch/HAT.py | 1 + .../spandrel/architectures/HVICIDNet/arch/cidnet.py | 1 + libs/spandrel/spandrel/architectures/IPT/arch/ipt.py | 1 + .../spandrel/spandrel/architectures/KBNet/arch/kbnet_l.py | 1 + .../spandrel/spandrel/architectures/KBNet/arch/kbnet_s.py | 1 + libs/spandrel/spandrel/architectures/LaMa/arch/LaMa.py | 2 +- .../spandrel/architectures/MMRealSR/arch/mmrealsr_arch.py | 1 + .../architectures/MixDehazeNet/arch/MixDehazeNet.py | 1 + .../spandrel/architectures/NAFNet/arch/NAFNet_arch.py | 1 + .../spandrel/spandrel/architectures/OmniSR/arch/OmniSR.py | 1 + libs/spandrel/spandrel/architectures/PLKSR/arch/PLKSR.py | 1 + .../spandrel/architectures/PLKSR/arch/RealPLKSR.py | 1 + libs/spandrel/spandrel/architectures/RGT/arch/rgt.py | 1 + .../spandrel/architectures/RealCUGAN/arch/upcunet_v3.py | 8 ++++---- .../RestoreFormer/arch/restoreformer_arch.py | 1 + .../RetinexFormer/arch/retinexformer_arch.py | 1 + libs/spandrel/spandrel/architectures/SAFMN/arch/safmn.py | 2 +- .../spandrel/architectures/SAFMNBCIE/arch/safmn_bcie.py | 2 +- .../spandrel/spandrel/architectures/SCUNet/arch/SCUNet.py | 1 + libs/spandrel/spandrel/architectures/SPAN/arch/span.py | 1 + .../spandrel/architectures/SwiftSRGAN/arch/SwiftSRGAN.py | 1 + .../spandrel/architectures/Swin2SR/arch/Swin2SR.py | 1 + .../spandrel/spandrel/architectures/SwinIR/arch/SwinIR.py | 1 + libs/spandrel/spandrel/util/__init__.py | 6 ++++++ .../architectures/CodeFormer/arch/codeformer.py | 1 + .../architectures/DDColor/arch/ddcolor.py | 1 + .../architectures/M3SNet/arch/M3SNet.py | 1 + .../architectures/MIRNet2/arch/mirnet_v2_arch.py | 1 + .../architectures/MPRNet/arch/MPRNet.py | 1 + .../architectures/Restormer/arch/restormer_arch.py | 1 + .../architectures/SRFormer/arch/SRFormer.py | 1 + 45 files changed, 53 insertions(+), 7 deletions(-) diff --git a/libs/spandrel/spandrel/architectures/ATD/arch/atd_arch.py b/libs/spandrel/spandrel/architectures/ATD/arch/atd_arch.py index 438973bf..8628a9ec 100644 --- a/libs/spandrel/spandrel/architectures/ATD/arch/atd_arch.py +++ b/libs/spandrel/spandrel/architectures/ATD/arch/atd_arch.py @@ -909,6 +909,7 @@ class ATD(nn.Module): def __init__( self, + *, img_size=64, patch_size=1, in_chans=3, diff --git a/libs/spandrel/spandrel/architectures/CRAFT/arch/CRAFT.py b/libs/spandrel/spandrel/architectures/CRAFT/arch/CRAFT.py index 4a827398..94a23da4 100644 --- a/libs/spandrel/spandrel/architectures/CRAFT/arch/CRAFT.py +++ b/libs/spandrel/spandrel/architectures/CRAFT/arch/CRAFT.py @@ -748,6 +748,7 @@ class CRAFT(nn.Module): def __init__( self, + *, in_chans=3, # img_size=64, window_size=16, diff --git a/libs/spandrel/spandrel/architectures/Compact/arch/SRVGG.py b/libs/spandrel/spandrel/architectures/Compact/arch/SRVGG.py index bda65e1f..919ff99c 100644 --- a/libs/spandrel/spandrel/architectures/Compact/arch/SRVGG.py +++ b/libs/spandrel/spandrel/architectures/Compact/arch/SRVGG.py @@ -26,6 +26,7 @@ class SRVGGNetCompact(nn.Module): def __init__( self, + *, num_in_ch=3, num_out_ch=3, num_feat=64, diff --git a/libs/spandrel/spandrel/architectures/DAT/arch/DAT.py b/libs/spandrel/spandrel/architectures/DAT/arch/DAT.py index 5cff2f5b..5a78b23d 100644 --- a/libs/spandrel/spandrel/architectures/DAT/arch/DAT.py +++ b/libs/spandrel/spandrel/architectures/DAT/arch/DAT.py @@ -920,6 +920,7 @@ class DAT(nn.Module): def __init__( self, + *, img_size=64, in_chans=3, embed_dim=180, diff --git a/libs/spandrel/spandrel/architectures/DCTLSA/arch/dctlsa.py b/libs/spandrel/spandrel/architectures/DCTLSA/arch/dctlsa.py index d8d184de..807f026e 100644 --- a/libs/spandrel/spandrel/architectures/DCTLSA/arch/dctlsa.py +++ b/libs/spandrel/spandrel/architectures/DCTLSA/arch/dctlsa.py @@ -461,6 +461,7 @@ class DCTLSA(nn.Module): def __init__( self, + *, in_nc=3, nf=55, num_modules=6, diff --git a/libs/spandrel/spandrel/architectures/DITN/arch/DITN_Real.py b/libs/spandrel/spandrel/architectures/DITN/arch/DITN_Real.py index f43c99d5..b1940711 100644 --- a/libs/spandrel/spandrel/architectures/DITN/arch/DITN_Real.py +++ b/libs/spandrel/spandrel/architectures/DITN/arch/DITN_Real.py @@ -265,6 +265,7 @@ class DITN_Real(nn.Module): def __init__( self, + *, inp_channels=3, dim=60, ITL_blocks=4, diff --git a/libs/spandrel/spandrel/architectures/DRCT/arch/drct_arch.py b/libs/spandrel/spandrel/architectures/DRCT/arch/drct_arch.py index 151a818a..037d6219 100644 --- a/libs/spandrel/spandrel/architectures/DRCT/arch/drct_arch.py +++ b/libs/spandrel/spandrel/architectures/DRCT/arch/drct_arch.py @@ -688,6 +688,7 @@ class DRCT(nn.Module): def __init__( self, + *, img_size=64, patch_size=1, in_chans=3, diff --git a/libs/spandrel/spandrel/architectures/DRUNet/arch/network_unet.py b/libs/spandrel/spandrel/architectures/DRUNet/arch/network_unet.py index 7ec20bb7..fd3f6a41 100644 --- a/libs/spandrel/spandrel/architectures/DRUNet/arch/network_unet.py +++ b/libs/spandrel/spandrel/architectures/DRUNet/arch/network_unet.py @@ -21,6 +21,7 @@ class DRUNet(nn.Module): def __init__( self, + *, in_nc=1, out_nc=1, nc=[64, 128, 256, 512], diff --git a/libs/spandrel/spandrel/architectures/DnCNN/arch/network_dncnn.py b/libs/spandrel/spandrel/architectures/DnCNN/arch/network_dncnn.py index 3632830f..ffd754a5 100644 --- a/libs/spandrel/spandrel/architectures/DnCNN/arch/network_dncnn.py +++ b/libs/spandrel/spandrel/architectures/DnCNN/arch/network_dncnn.py @@ -45,6 +45,7 @@ class DnCNN(nn.Module): def __init__( self, + *, in_nc=1, out_nc=1, nc=64, diff --git a/libs/spandrel/spandrel/architectures/ESRGAN/arch/RRDB.py b/libs/spandrel/spandrel/architectures/ESRGAN/arch/RRDB.py index 79488db4..235c46c4 100644 --- a/libs/spandrel/spandrel/architectures/ESRGAN/arch/RRDB.py +++ b/libs/spandrel/spandrel/architectures/ESRGAN/arch/RRDB.py @@ -19,6 +19,7 @@ class RRDBNet(nn.Module): def __init__( self, + *, in_nc: int = 3, out_nc: int = 3, num_filters: int = 64, diff --git a/libs/spandrel/spandrel/architectures/FBCNN/arch/FBCNN.py b/libs/spandrel/spandrel/architectures/FBCNN/arch/FBCNN.py index f8911c07..4c83a2c8 100644 --- a/libs/spandrel/spandrel/architectures/FBCNN/arch/FBCNN.py +++ b/libs/spandrel/spandrel/architectures/FBCNN/arch/FBCNN.py @@ -411,6 +411,7 @@ class FBCNN(nn.Module): def __init__( self, + *, in_nc=3, out_nc=3, nc=[64, 128, 256, 512], diff --git a/libs/spandrel/spandrel/architectures/FFTformer/arch/fftformer_arch.py b/libs/spandrel/spandrel/architectures/FFTformer/arch/fftformer_arch.py index 67fd6287..08f9873a 100644 --- a/libs/spandrel/spandrel/architectures/FFTformer/arch/fftformer_arch.py +++ b/libs/spandrel/spandrel/architectures/FFTformer/arch/fftformer_arch.py @@ -275,6 +275,7 @@ class FFTformer(nn.Module): def __init__( self, + *, inp_channels=3, out_channels=3, dim=48, diff --git a/libs/spandrel/spandrel/architectures/GFPGAN/arch/gfpganv1_clean_arch.py b/libs/spandrel/spandrel/architectures/GFPGAN/arch/gfpganv1_clean_arch.py index 6abf4a93..3df7598a 100644 --- a/libs/spandrel/spandrel/architectures/GFPGAN/arch/gfpganv1_clean_arch.py +++ b/libs/spandrel/spandrel/architectures/GFPGAN/arch/gfpganv1_clean_arch.py @@ -196,6 +196,7 @@ class GFPGANv1Clean(nn.Module): def __init__( self, + *, out_size=512, num_style_feat=512, channel_multiplier=2, diff --git a/libs/spandrel/spandrel/architectures/GRL/arch/grl.py b/libs/spandrel/spandrel/architectures/GRL/arch/grl.py index 434cf1d9..c427f9ae 100644 --- a/libs/spandrel/spandrel/architectures/GRL/arch/grl.py +++ b/libs/spandrel/spandrel/architectures/GRL/arch/grl.py @@ -231,6 +231,7 @@ class GRL(nn.Module): def __init__( self, + *, img_size=64, in_channels: int = 3, out_channels: int = 3, diff --git a/libs/spandrel/spandrel/architectures/HAT/arch/HAT.py b/libs/spandrel/spandrel/architectures/HAT/arch/HAT.py index 51982b00..ad892cfb 100644 --- a/libs/spandrel/spandrel/architectures/HAT/arch/HAT.py +++ b/libs/spandrel/spandrel/architectures/HAT/arch/HAT.py @@ -856,6 +856,7 @@ class HAT(nn.Module): def __init__( self, + *, img_size=64, patch_size=1, in_chans=3, diff --git a/libs/spandrel/spandrel/architectures/HVICIDNet/arch/cidnet.py b/libs/spandrel/spandrel/architectures/HVICIDNet/arch/cidnet.py index f44a7fe8..f5d79d1e 100644 --- a/libs/spandrel/spandrel/architectures/HVICIDNet/arch/cidnet.py +++ b/libs/spandrel/spandrel/architectures/HVICIDNet/arch/cidnet.py @@ -14,6 +14,7 @@ class CIDNet(nn.Module): def __init__( self, + *, channels=[36, 36, 72, 144], heads=[1, 2, 4, 8], norm=False, diff --git a/libs/spandrel/spandrel/architectures/IPT/arch/ipt.py b/libs/spandrel/spandrel/architectures/IPT/arch/ipt.py index 7f6df282..a89d0d8b 100644 --- a/libs/spandrel/spandrel/architectures/IPT/arch/ipt.py +++ b/libs/spandrel/spandrel/architectures/IPT/arch/ipt.py @@ -22,6 +22,7 @@ class IPT(nn.Module): def __init__( self, + *, patch_size: int = 48, patch_dim: int = 3, n_feats: int = 64, diff --git a/libs/spandrel/spandrel/architectures/KBNet/arch/kbnet_l.py b/libs/spandrel/spandrel/architectures/KBNet/arch/kbnet_l.py index 8884d362..945ad4c7 100644 --- a/libs/spandrel/spandrel/architectures/KBNet/arch/kbnet_l.py +++ b/libs/spandrel/spandrel/architectures/KBNet/arch/kbnet_l.py @@ -223,6 +223,7 @@ class KBNet_l(nn.Module): def __init__( self, + *, inp_channels=3, out_channels=3, dim=48, diff --git a/libs/spandrel/spandrel/architectures/KBNet/arch/kbnet_s.py b/libs/spandrel/spandrel/architectures/KBNet/arch/kbnet_s.py index eac0fbb3..6134c7fe 100644 --- a/libs/spandrel/spandrel/architectures/KBNet/arch/kbnet_s.py +++ b/libs/spandrel/spandrel/architectures/KBNet/arch/kbnet_s.py @@ -205,6 +205,7 @@ class KBNet_s(nn.Module): def __init__( self, + *, img_channel=3, width=64, middle_blk_num=12, diff --git a/libs/spandrel/spandrel/architectures/LaMa/arch/LaMa.py b/libs/spandrel/spandrel/architectures/LaMa/arch/LaMa.py index 67613af5..b814d50f 100644 --- a/libs/spandrel/spandrel/architectures/LaMa/arch/LaMa.py +++ b/libs/spandrel/spandrel/architectures/LaMa/arch/LaMa.py @@ -667,7 +667,7 @@ def forward(self, image, mask): class LaMa(nn.Module): hyperparameters = {} - def __init__(self, in_nc=3, out_nc=3) -> None: + def __init__(self, *, in_nc=3, out_nc=3) -> None: super().__init__() self.model = FFCResNetGenerator(in_nc, out_nc) diff --git a/libs/spandrel/spandrel/architectures/MMRealSR/arch/mmrealsr_arch.py b/libs/spandrel/spandrel/architectures/MMRealSR/arch/mmrealsr_arch.py index b88b012f..b30216b3 100644 --- a/libs/spandrel/spandrel/architectures/MMRealSR/arch/mmrealsr_arch.py +++ b/libs/spandrel/spandrel/architectures/MMRealSR/arch/mmrealsr_arch.py @@ -554,6 +554,7 @@ class MMRRDBNet_test(nn.Module): def __init__( self, + *, num_in_ch, num_out_ch, scale=4, diff --git a/libs/spandrel/spandrel/architectures/MixDehazeNet/arch/MixDehazeNet.py b/libs/spandrel/spandrel/architectures/MixDehazeNet/arch/MixDehazeNet.py index b4705b8e..7db84a9e 100644 --- a/libs/spandrel/spandrel/architectures/MixDehazeNet/arch/MixDehazeNet.py +++ b/libs/spandrel/spandrel/architectures/MixDehazeNet/arch/MixDehazeNet.py @@ -208,6 +208,7 @@ class MixDehazeNet(nn.Module): def __init__( self, + *, in_chans=3, out_chans=4, embed_dims=[24, 48, 96, 48, 24], diff --git a/libs/spandrel/spandrel/architectures/NAFNet/arch/NAFNet_arch.py b/libs/spandrel/spandrel/architectures/NAFNet/arch/NAFNet_arch.py index 83cf2ef8..c688d03f 100644 --- a/libs/spandrel/spandrel/architectures/NAFNet/arch/NAFNet_arch.py +++ b/libs/spandrel/spandrel/architectures/NAFNet/arch/NAFNet_arch.py @@ -142,6 +142,7 @@ class NAFNet(nn.Module): def __init__( self, + *, img_channel: int = 3, width: int = 16, middle_blk_num: int = 1, diff --git a/libs/spandrel/spandrel/architectures/OmniSR/arch/OmniSR.py b/libs/spandrel/spandrel/architectures/OmniSR/arch/OmniSR.py index fa78de72..77092826 100644 --- a/libs/spandrel/spandrel/architectures/OmniSR/arch/OmniSR.py +++ b/libs/spandrel/spandrel/architectures/OmniSR/arch/OmniSR.py @@ -25,6 +25,7 @@ class OmniSR(nn.Module): def __init__( self, + *, num_in_ch=3, num_out_ch=3, num_feat=64, diff --git a/libs/spandrel/spandrel/architectures/PLKSR/arch/PLKSR.py b/libs/spandrel/spandrel/architectures/PLKSR/arch/PLKSR.py index 08094412..1cf3c8bc 100644 --- a/libs/spandrel/spandrel/architectures/PLKSR/arch/PLKSR.py +++ b/libs/spandrel/spandrel/architectures/PLKSR/arch/PLKSR.py @@ -291,6 +291,7 @@ class PLKSR(nn.Module): def __init__( self, + *, dim: int = 64, n_blocks: int = 28, upscaling_factor: int = 4, diff --git a/libs/spandrel/spandrel/architectures/PLKSR/arch/RealPLKSR.py b/libs/spandrel/spandrel/architectures/PLKSR/arch/RealPLKSR.py index 0501d535..86c4b170 100644 --- a/libs/spandrel/spandrel/architectures/PLKSR/arch/RealPLKSR.py +++ b/libs/spandrel/spandrel/architectures/PLKSR/arch/RealPLKSR.py @@ -103,6 +103,7 @@ class RealPLKSR(nn.Module): def __init__( self, + *, dim: int = 64, n_blocks: int = 28, upscaling_factor: int = 4, diff --git a/libs/spandrel/spandrel/architectures/RGT/arch/rgt.py b/libs/spandrel/spandrel/architectures/RGT/arch/rgt.py index 8353a11e..366a18b3 100644 --- a/libs/spandrel/spandrel/architectures/RGT/arch/rgt.py +++ b/libs/spandrel/spandrel/architectures/RGT/arch/rgt.py @@ -814,6 +814,7 @@ class RGT(nn.Module): def __init__( self, + *, img_size=64, in_chans=3, embed_dim=180, diff --git a/libs/spandrel/spandrel/architectures/RealCUGAN/arch/upcunet_v3.py b/libs/spandrel/spandrel/architectures/RealCUGAN/arch/upcunet_v3.py index a6859e00..1fc857d9 100644 --- a/libs/spandrel/spandrel/architectures/RealCUGAN/arch/upcunet_v3.py +++ b/libs/spandrel/spandrel/architectures/RealCUGAN/arch/upcunet_v3.py @@ -291,7 +291,7 @@ def forward_d(self, x1, x4): # conv234结尾有se class UpCunet2x(nn.Module): hyperparameters = {} - def __init__(self, in_channels=3, out_channels=3, pro: bool = False): + def __init__(self, *, in_channels=3, out_channels=3, pro: bool = False): super().__init__() self.pro: Tensor | None if pro: @@ -333,7 +333,7 @@ def forward(self, x: Tensor, alpha: float = 1): class UpCunet3x(nn.Module): hyperparameters = {} - def __init__(self, in_channels=3, out_channels=3, pro: bool = False): + def __init__(self, *, in_channels=3, out_channels=3, pro: bool = False): super().__init__() self.pro: Tensor | None if pro: @@ -375,7 +375,7 @@ def forward(self, x: Tensor, alpha: float = 1): class UpCunet4x(nn.Module): hyperparameters = {} - def __init__(self, in_channels=3, out_channels=3, pro: bool = False): + def __init__(self, *, in_channels=3, out_channels=3, pro: bool = False): super().__init__() self.pro: Tensor | None if pro: @@ -425,7 +425,7 @@ def forward(self, x: Tensor, alpha: float = 1): class UpCunet2x_fast(nn.Module): hyperparameters = {} - def __init__(self, in_channels=3, out_channels=3): + def __init__(self, *, in_channels=3, out_channels=3): super().__init__() self.unet1 = UNet1(12, 64, deconv=True) self.unet2 = UNet2(64, 64, deconv=False) diff --git a/libs/spandrel/spandrel/architectures/RestoreFormer/arch/restoreformer_arch.py b/libs/spandrel/spandrel/architectures/RestoreFormer/arch/restoreformer_arch.py index 13ca5a50..b1034297 100644 --- a/libs/spandrel/spandrel/architectures/RestoreFormer/arch/restoreformer_arch.py +++ b/libs/spandrel/spandrel/architectures/RestoreFormer/arch/restoreformer_arch.py @@ -674,6 +674,7 @@ class RestoreFormer(nn.Module): def __init__( self, + *, n_embed=1024, embed_dim=256, ch=64, diff --git a/libs/spandrel/spandrel/architectures/RetinexFormer/arch/retinexformer_arch.py b/libs/spandrel/spandrel/architectures/RetinexFormer/arch/retinexformer_arch.py index 5f23990a..9bb19043 100644 --- a/libs/spandrel/spandrel/architectures/RetinexFormer/arch/retinexformer_arch.py +++ b/libs/spandrel/spandrel/architectures/RetinexFormer/arch/retinexformer_arch.py @@ -330,6 +330,7 @@ class RetinexFormer(nn.Module): def __init__( self, + *, in_channels=3, out_channels=3, n_feat=40, diff --git a/libs/spandrel/spandrel/architectures/SAFMN/arch/safmn.py b/libs/spandrel/spandrel/architectures/SAFMN/arch/safmn.py index 3833eb97..76cabe7c 100644 --- a/libs/spandrel/spandrel/architectures/SAFMN/arch/safmn.py +++ b/libs/spandrel/spandrel/architectures/SAFMN/arch/safmn.py @@ -162,7 +162,7 @@ def forward(self, x): class SAFMN(nn.Module): hyperparameters = {} - def __init__(self, dim: int, n_blocks=8, ffn_scale=2.0, upscaling_factor=4): + def __init__(self, *, dim: int, n_blocks=8, ffn_scale=2.0, upscaling_factor=4): super().__init__() self.to_feat = nn.Conv2d(3, dim, 3, 1, 1) diff --git a/libs/spandrel/spandrel/architectures/SAFMNBCIE/arch/safmn_bcie.py b/libs/spandrel/spandrel/architectures/SAFMNBCIE/arch/safmn_bcie.py index affeeb01..7ed9abca 100644 --- a/libs/spandrel/spandrel/architectures/SAFMNBCIE/arch/safmn_bcie.py +++ b/libs/spandrel/spandrel/architectures/SAFMNBCIE/arch/safmn_bcie.py @@ -123,7 +123,7 @@ class SAFMN_BCIE(nn.Module): hyperparameters = {} def __init__( - self, dim: int, n_blocks=6, num_layers=6, ffn_scale=2.0, upscaling_factor=2 + self, *, dim: int, n_blocks=6, num_layers=6, ffn_scale=2.0, upscaling_factor=2 ): super().__init__() diff --git a/libs/spandrel/spandrel/architectures/SCUNet/arch/SCUNet.py b/libs/spandrel/spandrel/architectures/SCUNet/arch/SCUNet.py index 8523fc49..4092ffb1 100644 --- a/libs/spandrel/spandrel/architectures/SCUNet/arch/SCUNet.py +++ b/libs/spandrel/spandrel/architectures/SCUNet/arch/SCUNet.py @@ -279,6 +279,7 @@ class SCUNet(nn.Module): def __init__( self, + *, in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64, diff --git a/libs/spandrel/spandrel/architectures/SPAN/arch/span.py b/libs/spandrel/spandrel/architectures/SPAN/arch/span.py index 503201b1..8291f8bc 100644 --- a/libs/spandrel/spandrel/architectures/SPAN/arch/span.py +++ b/libs/spandrel/spandrel/architectures/SPAN/arch/span.py @@ -243,6 +243,7 @@ class SPAN(nn.Module): def __init__( self, + *, num_in_ch: int, num_out_ch: int, feature_channels=48, diff --git a/libs/spandrel/spandrel/architectures/SwiftSRGAN/arch/SwiftSRGAN.py b/libs/spandrel/spandrel/architectures/SwiftSRGAN/arch/SwiftSRGAN.py index d87329e1..dacaa3f3 100644 --- a/libs/spandrel/spandrel/architectures/SwiftSRGAN/arch/SwiftSRGAN.py +++ b/libs/spandrel/spandrel/architectures/SwiftSRGAN/arch/SwiftSRGAN.py @@ -107,6 +107,7 @@ class Generator(nn.Module): def __init__( self, + *, in_channels: int = 3, num_channels: int = 64, num_blocks: int = 16, diff --git a/libs/spandrel/spandrel/architectures/Swin2SR/arch/Swin2SR.py b/libs/spandrel/spandrel/architectures/Swin2SR/arch/Swin2SR.py index 03abb091..f2099dbc 100644 --- a/libs/spandrel/spandrel/architectures/Swin2SR/arch/Swin2SR.py +++ b/libs/spandrel/spandrel/architectures/Swin2SR/arch/Swin2SR.py @@ -908,6 +908,7 @@ class Swin2SR(nn.Module): def __init__( self, + *, img_size=64, patch_size=1, in_chans=3, diff --git a/libs/spandrel/spandrel/architectures/SwinIR/arch/SwinIR.py b/libs/spandrel/spandrel/architectures/SwinIR/arch/SwinIR.py index 2acd92f4..0153e347 100644 --- a/libs/spandrel/spandrel/architectures/SwinIR/arch/SwinIR.py +++ b/libs/spandrel/spandrel/architectures/SwinIR/arch/SwinIR.py @@ -816,6 +816,7 @@ class SwinIR(nn.Module): def __init__( self, + *, img_size=64, patch_size=1, in_chans=3, diff --git a/libs/spandrel/spandrel/util/__init__.py b/libs/spandrel/spandrel/util/__init__.py index 6527fb5e..34805e18 100644 --- a/libs/spandrel/spandrel/util/__init__.py +++ b/libs/spandrel/spandrel/util/__init__.py @@ -177,6 +177,12 @@ def inner(cls: type[C]) -> type[C]: raise UserWarning( "Class has **kwargs, which is not allowed in combination with @store_hyperparameters" ) + if spec.args != ["self"]: + raise UserWarning( + "@store_hyperparameters requires all arguments of `" + + cls.__name__ + + ".__init__` after `self` to be keyword arguments. Use `def __init__(self, *, a, b, c):`." + ) @functools.wraps(old_init) def new_init(self: C, **kwargs): diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/CodeFormer/arch/codeformer.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/CodeFormer/arch/codeformer.py index 4a544577..fbc5da73 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/CodeFormer/arch/codeformer.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/CodeFormer/arch/codeformer.py @@ -610,6 +610,7 @@ class CodeFormer(VQAutoEncoder): def __init__( self, + *, dim_embd=512, n_head=8, n_layers=9, diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/DDColor/arch/ddcolor.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/DDColor/arch/ddcolor.py index 07a0b0ca..e0ddd713 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/DDColor/arch/ddcolor.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/DDColor/arch/ddcolor.py @@ -31,6 +31,7 @@ class DDColor(nn.Module): def __init__( self, + *, encoder_name="convnext-l", decoder_name="MultiScaleColorDecoder", num_input_channels=3, diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/M3SNet/arch/M3SNet.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/M3SNet/arch/M3SNet.py index 3a0a32ad..71a7ef99 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/M3SNet/arch/M3SNet.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/M3SNet/arch/M3SNet.py @@ -171,6 +171,7 @@ class M3SNet(nn.Module): def __init__( self, + *, img_channel=3, width=32, middle_blk_num=1, diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MIRNet2/arch/mirnet_v2_arch.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MIRNet2/arch/mirnet_v2_arch.py index 8bda96f4..a58ca6a9 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MIRNet2/arch/mirnet_v2_arch.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MIRNet2/arch/mirnet_v2_arch.py @@ -307,6 +307,7 @@ class MIRNet_v2(nn.Module): def __init__( self, + *, inp_channels=3, out_channels=3, n_feat=80, diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/arch/MPRNet.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/arch/MPRNet.py index 2ef3e000..6bdd95fd 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/arch/MPRNet.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/arch/MPRNet.py @@ -398,6 +398,7 @@ class MPRNet(nn.Module): def __init__( self, + *, in_c: int = 3, out_c: int = 3, n_feat: int = 40, diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/Restormer/arch/restormer_arch.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/Restormer/arch/restormer_arch.py index 13716549..84a279b4 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/Restormer/arch/restormer_arch.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/Restormer/arch/restormer_arch.py @@ -222,6 +222,7 @@ class Restormer(nn.Module): def __init__( self, + *, inp_channels=3, out_channels=3, dim=48, diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/SRFormer/arch/SRFormer.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/SRFormer/arch/SRFormer.py index 0a4afb37..b3688880 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/SRFormer/arch/SRFormer.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/SRFormer/arch/SRFormer.py @@ -1024,6 +1024,7 @@ class SRFormer(nn.Module): def __init__( self, + *, img_size=64, patch_size=1, in_chans=3,