From d8c22562296ca8b20802b0c1da73a3e004958cbc Mon Sep 17 00:00:00 2001 From: Michael Schmidt Date: Mon, 8 Jul 2024 17:38:44 +0200 Subject: [PATCH] Define the public API by what's documented --- .../architectures/ATD/__arch/__init__.py | 0 .../spandrel/architectures/ATD/__init__.py | 3 + .../architectures/CRAFT/__arch/__init__.py | 0 .../spandrel/architectures/CRAFT/__init__.py | 3 + .../architectures/Compact/__arch/__init__.py | 0 .../architectures/DAT/__arch/__init__.py | 0 .../spandrel/architectures/DAT/__init__.py | 3 + .../architectures/DCTLSA/__arch/__init__.py | 0 .../spandrel/architectures/DCTLSA/__init__.py | 3 + .../architectures/DITN/__arch/__init__.py | 0 .../spandrel/architectures/DITN/__init__.py | 3 + .../architectures/DRCT/__arch/__init__.py | 0 .../spandrel/architectures/DRCT/__init__.py | 3 + .../architectures/DRUNet/__arch/__init__.py | 0 .../spandrel/architectures/DRUNet/__init__.py | 3 + .../architectures/DnCNN/__arch/__init__.py | 0 .../spandrel/architectures/DnCNN/__init__.py | 3 + .../architectures/ESRGAN/__arch/__init__.py | 0 .../architectures/FBCNN/__arch/__init__.py | 0 .../spandrel/architectures/FBCNN/__init__.py | 3 + .../FFTformer/__arch/__init__.py | 0 .../architectures/FFTformer/__init__.py | 3 + .../architectures/GFPGAN/__arch/__init__.py | 0 .../architectures/GRL/__arch/__init__.py | 0 .../spandrel/architectures/GRL/__init__.py | 3 + .../architectures/HAT/__arch/__init__.py | 0 .../spandrel/architectures/HAT/__init__.py | 3 + .../HVICIDNet/__arch/__init__.py | 0 .../architectures/HVICIDNet/__init__.py | 3 + .../architectures/IPT/__arch/__init__.py | 0 .../spandrel/architectures/IPT/__init__.py | 3 + .../architectures/KBNet/__arch/__init__.py | 0 .../spandrel/architectures/KBNet/__init__.py | 3 + .../architectures/LaMa/__arch/__init__.py | 0 .../spandrel/architectures/LaMa/__init__.py | 3 + .../architectures/MMRealSR/__arch/__init__.py | 0 .../architectures/MMRealSR/__init__.py | 3 + .../MixDehazeNet/__arch/__init__.py | 0 .../architectures/MixDehazeNet/__init__.py | 3 + .../architectures/NAFNet/__arch/__init__.py | 0 .../spandrel/architectures/NAFNet/__init__.py | 3 + .../architectures/OmniSR/__arch/__init__.py | 0 .../spandrel/architectures/OmniSR/__init__.py | 3 + .../architectures/PLKSR/__arch/__init__.py | 0 .../spandrel/architectures/PLKSR/__init__.py | 3 + .../architectures/RGT/__arch/__init__.py | 0 .../spandrel/architectures/RGT/__init__.py | 3 + .../RealCUGAN/__arch/__init__.py | 0 .../architectures/RealCUGAN/__init__.py | 3 + .../RestoreFormer/__arch/__init__.py | 0 .../architectures/RestoreFormer/__init__.py | 3 + .../RetinexFormer/__arch/__init__.py | 0 .../architectures/RetinexFormer/__init__.py | 3 + .../architectures/SAFMN/__arch/__init__.py | 0 .../spandrel/architectures/SAFMN/__init__.py | 3 + .../SAFMNBCIE/__arch/__init__.py | 0 .../architectures/SCUNet/__arch/__init__.py | 0 .../spandrel/architectures/SCUNet/__init__.py | 3 + .../architectures/SPAN/__arch/__init__.py | 0 .../spandrel/architectures/SPAN/__init__.py | 3 + .../SwiftSRGAN/__arch/__init__.py | 0 .../architectures/SwiftSRGAN/__init__.py | 3 + .../architectures/Swin2SR/__arch/__init__.py | 0 .../architectures/Swin2SR/__init__.py | 3 + .../architectures/SwinIR/__arch/__init__.py | 0 .../spandrel/architectures/SwinIR/__init__.py | 3 + .../architectures/Uformer/__arch/__init__.py | 0 .../architectures/Uformer/__init__.py | 3 + .../spandrel_extra_arches/__helper.py | 5 + .../spandrel_extra_arches/__init__.py | 6 + .../architectures/AdaCode/__arch/__init__.py | 146 ++++++++++++++++++ .../architectures/AdaCode/__init__.py | 3 + .../CodeFormer/__arch/__init__.py | 146 ++++++++++++++++++ .../architectures/CodeFormer/__init__.py | 3 + .../architectures/DDColor/__arch/__init__.py | 146 ++++++++++++++++++ .../architectures/DDColor/__init__.py | 3 + .../architectures/FeMaSR/__arch/__init__.py | 146 ++++++++++++++++++ .../architectures/FeMaSR/__init__.py | 3 + .../architectures/M3SNet/__arch/__init__.py | 146 ++++++++++++++++++ .../architectures/M3SNet/__init__.py | 3 + .../architectures/MAT/__arch/__init__.py | 146 ++++++++++++++++++ .../architectures/MAT/__init__.py | 3 + .../architectures/MIRNet2/__arch/__init__.py | 146 ++++++++++++++++++ .../architectures/MIRNet2/__init__.py | 3 + .../architectures/MPRNet/__arch/__init__.py | 146 ++++++++++++++++++ .../architectures/MPRNet/__init__.py | 3 + .../Restormer/__arch/__init__.py | 146 ++++++++++++++++++ .../architectures/Restormer/__init__.py | 3 + .../architectures/SRFormer/__arch/__init__.py | 146 ++++++++++++++++++ .../architectures/SRFormer/__init__.py | 3 + .../architectures/__init__.py | 3 + pyproject.toml | 6 +- 92 files changed, 1605 insertions(+), 1 deletion(-) create mode 100644 libs/spandrel/spandrel/architectures/ATD/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/CRAFT/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/Compact/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/DAT/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/DCTLSA/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/DITN/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/DRCT/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/DRUNet/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/DnCNN/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/ESRGAN/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/FBCNN/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/FFTformer/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/GFPGAN/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/GRL/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/HAT/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/HVICIDNet/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/IPT/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/KBNet/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/LaMa/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/MMRealSR/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/MixDehazeNet/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/NAFNet/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/OmniSR/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/PLKSR/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/RGT/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/RealCUGAN/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/RestoreFormer/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/RetinexFormer/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/SAFMN/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/SAFMNBCIE/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/SCUNet/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/SPAN/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/SwiftSRGAN/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/Swin2SR/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/SwinIR/__arch/__init__.py create mode 100644 libs/spandrel/spandrel/architectures/Uformer/__arch/__init__.py create mode 100644 libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/__arch/__init__.py create mode 100644 libs/spandrel_extra_arches/spandrel_extra_arches/architectures/CodeFormer/__arch/__init__.py create mode 100644 libs/spandrel_extra_arches/spandrel_extra_arches/architectures/DDColor/__arch/__init__.py create mode 100644 libs/spandrel_extra_arches/spandrel_extra_arches/architectures/FeMaSR/__arch/__init__.py create mode 100644 libs/spandrel_extra_arches/spandrel_extra_arches/architectures/M3SNet/__arch/__init__.py create mode 100644 libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MAT/__arch/__init__.py create mode 100644 libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MIRNet2/__arch/__init__.py create mode 100644 libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/__arch/__init__.py create mode 100644 libs/spandrel_extra_arches/spandrel_extra_arches/architectures/Restormer/__arch/__init__.py create mode 100644 libs/spandrel_extra_arches/spandrel_extra_arches/architectures/SRFormer/__arch/__init__.py create mode 100644 libs/spandrel_extra_arches/spandrel_extra_arches/architectures/__init__.py diff --git a/libs/spandrel/spandrel/architectures/ATD/__arch/__init__.py b/libs/spandrel/spandrel/architectures/ATD/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/ATD/__init__.py b/libs/spandrel/spandrel/architectures/ATD/__init__.py index 1603368d..3ce04198 100644 --- a/libs/spandrel/spandrel/architectures/ATD/__init__.py +++ b/libs/spandrel/spandrel/architectures/ATD/__init__.py @@ -168,3 +168,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[ATD]: output_channels=in_chans, size_requirements=SizeRequirements(minimum=8), ) + + +__all__ = ["ATDArch", "ATD"] diff --git a/libs/spandrel/spandrel/architectures/CRAFT/__arch/__init__.py b/libs/spandrel/spandrel/architectures/CRAFT/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/CRAFT/__init__.py b/libs/spandrel/spandrel/architectures/CRAFT/__init__.py index afb2f061..51fef69e 100644 --- a/libs/spandrel/spandrel/architectures/CRAFT/__init__.py +++ b/libs/spandrel/spandrel/architectures/CRAFT/__init__.py @@ -126,3 +126,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[CRAFT]: output_channels=in_chans, size_requirements=SizeRequirements(minimum=16, multiple_of=16), ) + + +__all__ = ["CRAFTArch", "CRAFT"] diff --git a/libs/spandrel/spandrel/architectures/Compact/__arch/__init__.py b/libs/spandrel/spandrel/architectures/Compact/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/DAT/__arch/__init__.py b/libs/spandrel/spandrel/architectures/DAT/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/DAT/__init__.py b/libs/spandrel/spandrel/architectures/DAT/__init__.py index 6ed57566..2a545fec 100644 --- a/libs/spandrel/spandrel/architectures/DAT/__init__.py +++ b/libs/spandrel/spandrel/architectures/DAT/__init__.py @@ -177,3 +177,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[DAT]: output_channels=in_chans, size_requirements=SizeRequirements(minimum=16), ) + + +__all__ = ["DATArch", "DAT"] diff --git a/libs/spandrel/spandrel/architectures/DCTLSA/__arch/__init__.py b/libs/spandrel/spandrel/architectures/DCTLSA/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/DCTLSA/__init__.py b/libs/spandrel/spandrel/architectures/DCTLSA/__init__.py index b420227e..f6266c81 100644 --- a/libs/spandrel/spandrel/architectures/DCTLSA/__init__.py +++ b/libs/spandrel/spandrel/architectures/DCTLSA/__init__.py @@ -82,3 +82,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[DCTLSA]: output_channels=out_nc, size_requirements=SizeRequirements(minimum=16), ) + + +__all__ = ["DCTLSAArch", "DCTLSA"] diff --git a/libs/spandrel/spandrel/architectures/DITN/__arch/__init__.py b/libs/spandrel/spandrel/architectures/DITN/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/DITN/__init__.py b/libs/spandrel/spandrel/architectures/DITN/__init__.py index d03c1d03..b343eda6 100644 --- a/libs/spandrel/spandrel/architectures/DITN/__init__.py +++ b/libs/spandrel/spandrel/architectures/DITN/__init__.py @@ -85,3 +85,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[DITN]: output_channels=3, # hard-coded in the architecture size_requirements=SizeRequirements(multiple_of=patch_size), ) + + +__all__ = ["DITNArch", "DITN"] diff --git a/libs/spandrel/spandrel/architectures/DRCT/__arch/__init__.py b/libs/spandrel/spandrel/architectures/DRCT/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/DRCT/__init__.py b/libs/spandrel/spandrel/architectures/DRCT/__init__.py index 64d7bae1..7d59e489 100644 --- a/libs/spandrel/spandrel/architectures/DRCT/__init__.py +++ b/libs/spandrel/spandrel/architectures/DRCT/__init__.py @@ -155,3 +155,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[DRCT]: output_channels=in_chans, size_requirements=SizeRequirements(multiple_of=16), ) + + +__all__ = ["DRCTArch", "DRCT"] diff --git a/libs/spandrel/spandrel/architectures/DRUNet/__arch/__init__.py b/libs/spandrel/spandrel/architectures/DRUNet/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/DRUNet/__init__.py b/libs/spandrel/spandrel/architectures/DRUNet/__init__.py index 44d47c6e..5ee074ab 100644 --- a/libs/spandrel/spandrel/architectures/DRUNet/__init__.py +++ b/libs/spandrel/spandrel/architectures/DRUNet/__init__.py @@ -99,3 +99,6 @@ def call(model: DRUNet, image: torch.Tensor) -> torch.Tensor: size_requirements=SizeRequirements(multiple_of=8), call_fn=call, ) + + +__all__ = ["DRUNetArch", "DRUNet"] diff --git a/libs/spandrel/spandrel/architectures/DnCNN/__arch/__init__.py b/libs/spandrel/spandrel/architectures/DnCNN/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/DnCNN/__init__.py b/libs/spandrel/spandrel/architectures/DnCNN/__init__.py index fa62bd10..e383ba7d 100644 --- a/libs/spandrel/spandrel/architectures/DnCNN/__init__.py +++ b/libs/spandrel/spandrel/architectures/DnCNN/__init__.py @@ -125,3 +125,6 @@ def call(model: DnCNN, image: torch.Tensor) -> torch.Tensor: size_requirements=SizeRequirements(), call_fn=call, ) + + +__all__ = ["DnCNNArch", "DnCNN"] diff --git a/libs/spandrel/spandrel/architectures/ESRGAN/__arch/__init__.py b/libs/spandrel/spandrel/architectures/ESRGAN/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/FBCNN/__arch/__init__.py b/libs/spandrel/spandrel/architectures/FBCNN/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/FBCNN/__init__.py b/libs/spandrel/spandrel/architectures/FBCNN/__init__.py index 46f2cff4..dca55516 100644 --- a/libs/spandrel/spandrel/architectures/FBCNN/__init__.py +++ b/libs/spandrel/spandrel/architectures/FBCNN/__init__.py @@ -79,3 +79,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[FBCNN]: output_channels=out_nc, call_fn=lambda model, image: model(image)[0], ) + + +__all__ = ["FBCNNArch", "FBCNN"] diff --git a/libs/spandrel/spandrel/architectures/FFTformer/__arch/__init__.py b/libs/spandrel/spandrel/architectures/FFTformer/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/FFTformer/__init__.py b/libs/spandrel/spandrel/architectures/FFTformer/__init__.py index eb69bda7..b7109f73 100644 --- a/libs/spandrel/spandrel/architectures/FFTformer/__init__.py +++ b/libs/spandrel/spandrel/architectures/FFTformer/__init__.py @@ -98,3 +98,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[FFTformer]: output_channels=out_channels, size_requirements=SizeRequirements(multiple_of=32), ) + + +__all__ = ["FFTformerArch", "FFTformer"] diff --git a/libs/spandrel/spandrel/architectures/GFPGAN/__arch/__init__.py b/libs/spandrel/spandrel/architectures/GFPGAN/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/GRL/__arch/__init__.py b/libs/spandrel/spandrel/architectures/GRL/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/GRL/__init__.py b/libs/spandrel/spandrel/architectures/GRL/__init__.py index 851be55e..a115d610 100644 --- a/libs/spandrel/spandrel/architectures/GRL/__init__.py +++ b/libs/spandrel/spandrel/architectures/GRL/__init__.py @@ -359,3 +359,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[GRL]: input_channels=in_channels, output_channels=out_channels, ) + + +__all__ = ["GRLArch", "GRL"] diff --git a/libs/spandrel/spandrel/architectures/HAT/__arch/__init__.py b/libs/spandrel/spandrel/architectures/HAT/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/HAT/__init__.py b/libs/spandrel/spandrel/architectures/HAT/__init__.py index 2a219d26..686d91d6 100644 --- a/libs/spandrel/spandrel/architectures/HAT/__init__.py +++ b/libs/spandrel/spandrel/architectures/HAT/__init__.py @@ -225,3 +225,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[HAT]: output_channels=in_chans, size_requirements=SizeRequirements(minimum=16), ) + + +__all__ = ["HATArch", "HAT"] diff --git a/libs/spandrel/spandrel/architectures/HVICIDNet/__arch/__init__.py b/libs/spandrel/spandrel/architectures/HVICIDNet/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/HVICIDNet/__init__.py b/libs/spandrel/spandrel/architectures/HVICIDNet/__init__.py index 55abcd13..b30bb5e5 100644 --- a/libs/spandrel/spandrel/architectures/HVICIDNet/__init__.py +++ b/libs/spandrel/spandrel/architectures/HVICIDNet/__init__.py @@ -92,3 +92,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[HVICIDNet]: size_requirements=SizeRequirements(multiple_of=8), tiling=ModelTiling.DISCOURAGED, ) + + +__all__ = ["HVICIDNetArch", "HVICIDNet"] diff --git a/libs/spandrel/spandrel/architectures/IPT/__arch/__init__.py b/libs/spandrel/spandrel/architectures/IPT/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/IPT/__init__.py b/libs/spandrel/spandrel/architectures/IPT/__init__.py index aa511863..0e2f5eed 100644 --- a/libs/spandrel/spandrel/architectures/IPT/__init__.py +++ b/libs/spandrel/spandrel/architectures/IPT/__init__.py @@ -152,3 +152,6 @@ def call(model: IPT, x: torch.Tensor): size_requirements=SizeRequirements(minimum=patch_size), call_fn=call, ) + + +__all__ = ["IPTArch", "IPT"] diff --git a/libs/spandrel/spandrel/architectures/KBNet/__arch/__init__.py b/libs/spandrel/spandrel/architectures/KBNet/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/KBNet/__init__.py b/libs/spandrel/spandrel/architectures/KBNet/__init__.py index 9a574a5f..81f3b646 100644 --- a/libs/spandrel/spandrel/architectures/KBNet/__init__.py +++ b/libs/spandrel/spandrel/architectures/KBNet/__init__.py @@ -166,3 +166,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_KBNet]: return self._load_l(state_dict) else: return self._load_s(state_dict) + + +__all__ = ["KBNetArch", "KBNet_s", "KBNet_l"] diff --git a/libs/spandrel/spandrel/architectures/LaMa/__arch/__init__.py b/libs/spandrel/spandrel/architectures/LaMa/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/LaMa/__init__.py b/libs/spandrel/spandrel/architectures/LaMa/__init__.py index b3008936..c03ba393 100644 --- a/libs/spandrel/spandrel/architectures/LaMa/__init__.py +++ b/libs/spandrel/spandrel/architectures/LaMa/__init__.py @@ -53,3 +53,6 @@ def load(self, state_dict: StateDict) -> MaskedImageModelDescriptor[LaMa]: output_channels=out_nc, size_requirements=SizeRequirements(minimum=16), ) + + +__all__ = ["LaMaArch", "LaMa"] diff --git a/libs/spandrel/spandrel/architectures/MMRealSR/__arch/__init__.py b/libs/spandrel/spandrel/architectures/MMRealSR/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/MMRealSR/__init__.py b/libs/spandrel/spandrel/architectures/MMRealSR/__init__.py index ca519a75..161146f2 100644 --- a/libs/spandrel/spandrel/architectures/MMRealSR/__init__.py +++ b/libs/spandrel/spandrel/architectures/MMRealSR/__init__.py @@ -193,3 +193,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[MMRealSR]: size_requirements=SizeRequirements(minimum=16), call_fn=lambda model, image: model(image)[0], ) + + +__all__ = ["MMRealSRArch", "MMRealSR"] diff --git a/libs/spandrel/spandrel/architectures/MixDehazeNet/__arch/__init__.py b/libs/spandrel/spandrel/architectures/MixDehazeNet/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/MixDehazeNet/__init__.py b/libs/spandrel/spandrel/architectures/MixDehazeNet/__init__.py index 601314d3..d6b98858 100644 --- a/libs/spandrel/spandrel/architectures/MixDehazeNet/__init__.py +++ b/libs/spandrel/spandrel/architectures/MixDehazeNet/__init__.py @@ -112,3 +112,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[MixDehazeNet]: tiling=ModelTiling.DISCOURAGED, call_fn=lambda model, image: model(image) * 0.5 + 0.5, ) + + +__all__ = ["MixDehazeNetArch", "MixDehazeNet"] diff --git a/libs/spandrel/spandrel/architectures/NAFNet/__arch/__init__.py b/libs/spandrel/spandrel/architectures/NAFNet/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/NAFNet/__init__.py b/libs/spandrel/spandrel/architectures/NAFNet/__init__.py index c9aae804..7f1e3cbe 100644 --- a/libs/spandrel/spandrel/architectures/NAFNet/__init__.py +++ b/libs/spandrel/spandrel/architectures/NAFNet/__init__.py @@ -71,3 +71,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[NAFNet]: input_channels=img_channel, output_channels=img_channel, ) + + +__all__ = ["NAFNetArch", "NAFNet"] diff --git a/libs/spandrel/spandrel/architectures/OmniSR/__arch/__init__.py b/libs/spandrel/spandrel/architectures/OmniSR/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/OmniSR/__init__.py b/libs/spandrel/spandrel/architectures/OmniSR/__init__.py index 97904a86..808d02a7 100644 --- a/libs/spandrel/spandrel/architectures/OmniSR/__init__.py +++ b/libs/spandrel/spandrel/architectures/OmniSR/__init__.py @@ -96,3 +96,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[OmniSR]: output_channels=num_out_ch, size_requirements=SizeRequirements(minimum=16), ) + + +__all__ = ["OmniSRArch", "OmniSR"] diff --git a/libs/spandrel/spandrel/architectures/PLKSR/__arch/__init__.py b/libs/spandrel/spandrel/architectures/PLKSR/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/PLKSR/__init__.py b/libs/spandrel/spandrel/architectures/PLKSR/__init__.py index c8f3cd1b..5fd6f3ab 100644 --- a/libs/spandrel/spandrel/architectures/PLKSR/__init__.py +++ b/libs/spandrel/spandrel/architectures/PLKSR/__init__.py @@ -143,3 +143,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_PLKSR]: input_channels=3, output_channels=3, ) + + +__all__ = ["PLKSRArch", "PLKSR", "RealPLKSR"] diff --git a/libs/spandrel/spandrel/architectures/RGT/__arch/__init__.py b/libs/spandrel/spandrel/architectures/RGT/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/RGT/__init__.py b/libs/spandrel/spandrel/architectures/RGT/__init__.py index df6981be..3fc8a357 100644 --- a/libs/spandrel/spandrel/architectures/RGT/__init__.py +++ b/libs/spandrel/spandrel/architectures/RGT/__init__.py @@ -169,3 +169,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[RGT]: output_channels=in_chans, size_requirements=SizeRequirements(minimum=16), ) + + +__all__ = ["RGTArch", "RGT"] diff --git a/libs/spandrel/spandrel/architectures/RealCUGAN/__arch/__init__.py b/libs/spandrel/spandrel/architectures/RealCUGAN/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/RealCUGAN/__init__.py b/libs/spandrel/spandrel/architectures/RealCUGAN/__init__.py index 0255b81a..5d978482 100644 --- a/libs/spandrel/spandrel/architectures/RealCUGAN/__init__.py +++ b/libs/spandrel/spandrel/architectures/RealCUGAN/__init__.py @@ -104,3 +104,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_RealCUGAN]: output_channels=out_channels, size_requirements=size_requirements, ) + + +__all__ = ["RealCUGANArch", "UpCunet2x", "UpCunet3x", "UpCunet4x", "UpCunet2x_fast"] diff --git a/libs/spandrel/spandrel/architectures/RestoreFormer/__arch/__init__.py b/libs/spandrel/spandrel/architectures/RestoreFormer/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/RestoreFormer/__init__.py b/libs/spandrel/spandrel/architectures/RestoreFormer/__init__.py index ee572d38..5952a926 100644 --- a/libs/spandrel/spandrel/architectures/RestoreFormer/__init__.py +++ b/libs/spandrel/spandrel/architectures/RestoreFormer/__init__.py @@ -101,3 +101,6 @@ def call(model: RestoreFormer, x: torch.Tensor) -> torch.Tensor: size_requirements=SizeRequirements(multiple_of=32), call_fn=call, ) + + +__all__ = ["RestoreFormerArch", "RestoreFormer"] diff --git a/libs/spandrel/spandrel/architectures/RetinexFormer/__arch/__init__.py b/libs/spandrel/spandrel/architectures/RetinexFormer/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/RetinexFormer/__init__.py b/libs/spandrel/spandrel/architectures/RetinexFormer/__init__.py index afe22e18..048efda6 100644 --- a/libs/spandrel/spandrel/architectures/RetinexFormer/__init__.py +++ b/libs/spandrel/spandrel/architectures/RetinexFormer/__init__.py @@ -105,3 +105,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[RetinexFormer]: tiling=ModelTiling.DISCOURAGED, call_fn=_call_fn, ) + + +__all__ = ["RetinexFormerArch", "RetinexFormer"] diff --git a/libs/spandrel/spandrel/architectures/SAFMN/__arch/__init__.py b/libs/spandrel/spandrel/architectures/SAFMN/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/SAFMN/__init__.py b/libs/spandrel/spandrel/architectures/SAFMN/__init__.py index c629399a..77ee625f 100644 --- a/libs/spandrel/spandrel/architectures/SAFMN/__init__.py +++ b/libs/spandrel/spandrel/architectures/SAFMN/__init__.py @@ -67,3 +67,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SAFMN]: output_channels=3, # hard-coded in the arch size_requirements=SizeRequirements(multiple_of=8), ) + + +__all__ = ["SAFMNArch", "SAFMN"] diff --git a/libs/spandrel/spandrel/architectures/SAFMNBCIE/__arch/__init__.py b/libs/spandrel/spandrel/architectures/SAFMNBCIE/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/SCUNet/__arch/__init__.py b/libs/spandrel/spandrel/architectures/SCUNet/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/SCUNet/__init__.py b/libs/spandrel/spandrel/architectures/SCUNet/__init__.py index 3e9d3778..9753ddad 100644 --- a/libs/spandrel/spandrel/architectures/SCUNet/__init__.py +++ b/libs/spandrel/spandrel/architectures/SCUNet/__init__.py @@ -63,3 +63,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SCUNet]: size_requirements=SizeRequirements(minimum=40), tiling=ModelTiling.DISCOURAGED, ) + + +__all__ = ["SCUNetArch", "SCUNet"] diff --git a/libs/spandrel/spandrel/architectures/SPAN/__arch/__init__.py b/libs/spandrel/spandrel/architectures/SPAN/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/SPAN/__init__.py b/libs/spandrel/spandrel/architectures/SPAN/__init__.py index ccce5590..2eabec00 100644 --- a/libs/spandrel/spandrel/architectures/SPAN/__init__.py +++ b/libs/spandrel/spandrel/architectures/SPAN/__init__.py @@ -71,3 +71,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SPAN]: input_channels=num_in_ch, output_channels=num_out_ch, ) + + +__all__ = ["SPANArch", "SPAN"] diff --git a/libs/spandrel/spandrel/architectures/SwiftSRGAN/__arch/__init__.py b/libs/spandrel/spandrel/architectures/SwiftSRGAN/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/SwiftSRGAN/__init__.py b/libs/spandrel/spandrel/architectures/SwiftSRGAN/__init__.py index f0ae8255..c55618f0 100644 --- a/libs/spandrel/spandrel/architectures/SwiftSRGAN/__init__.py +++ b/libs/spandrel/spandrel/architectures/SwiftSRGAN/__init__.py @@ -52,3 +52,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SwiftSRGAN]: input_channels=in_channels, output_channels=in_channels, ) + + +__all__ = ["SwiftSRGANArch", "SwiftSRGAN"] diff --git a/libs/spandrel/spandrel/architectures/Swin2SR/__arch/__init__.py b/libs/spandrel/spandrel/architectures/Swin2SR/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/Swin2SR/__init__.py b/libs/spandrel/spandrel/architectures/Swin2SR/__init__.py index f755d938..cbe37dad 100644 --- a/libs/spandrel/spandrel/architectures/Swin2SR/__init__.py +++ b/libs/spandrel/spandrel/architectures/Swin2SR/__init__.py @@ -184,3 +184,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[Swin2SR]: output_channels=in_chans, size_requirements=SizeRequirements(minimum=16), ) + + +__all__ = ["Swin2SRArch", "Swin2SR"] diff --git a/libs/spandrel/spandrel/architectures/SwinIR/__arch/__init__.py b/libs/spandrel/spandrel/architectures/SwinIR/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/SwinIR/__init__.py b/libs/spandrel/spandrel/architectures/SwinIR/__init__.py index e60f0f10..94fc03f0 100644 --- a/libs/spandrel/spandrel/architectures/SwinIR/__init__.py +++ b/libs/spandrel/spandrel/architectures/SwinIR/__init__.py @@ -189,3 +189,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SwinIR]: output_channels=out_nc, size_requirements=SizeRequirements(minimum=16), ) + + +__all__ = ["SwinIRArch", "SwinIR"] diff --git a/libs/spandrel/spandrel/architectures/Uformer/__arch/__init__.py b/libs/spandrel/spandrel/architectures/Uformer/__arch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/spandrel/spandrel/architectures/Uformer/__init__.py b/libs/spandrel/spandrel/architectures/Uformer/__init__.py index 2b4b0c2a..6540556d 100644 --- a/libs/spandrel/spandrel/architectures/Uformer/__init__.py +++ b/libs/spandrel/spandrel/architectures/Uformer/__init__.py @@ -141,3 +141,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: output_channels=dd_in, size_requirements=SizeRequirements(multiple_of=128, square=True), ) + + +__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/__helper.py b/libs/spandrel_extra_arches/spandrel_extra_arches/__helper.py index d759ec6b..f56b814b 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/__helper.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/__helper.py @@ -14,6 +14,11 @@ ) EXTRA_REGISTRY = ArchRegistry() +""" +The registry of all architectures in this library. + +Use ``MAIN_REGISTRY.add(*EXTRA_REGISTRY)`` to add all architectures to the main registry of `spandrel`. +""" EXTRA_REGISTRY.add( ArchSupport.from_architecture(SRFormer.SRFormerArch()), diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/__init__.py index 2a639290..9ebb2503 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/__init__.py @@ -1,3 +1,9 @@ +""" +Spandrel extra arches contains more architectures for `spandrel`. + +All architectures in this library are registered in the `EXTRA_REGISTRY` dictionary. +""" + from .__helper import EXTRA_REGISTRY __version__ = "0.1.1" diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/__arch/__init__.py new file mode 100644 index 00000000..6540556d --- /dev/null +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/__arch/__init__.py @@ -0,0 +1,146 @@ +import math + +from typing_extensions import override + +from spandrel.util import KeyCondition, get_seq_len + +from ...__helpers.model_descriptor import ( + Architecture, + ImageModelDescriptor, + SizeRequirements, + StateDict, +) +from .__arch.Uformer import Uformer + + +class UformerArch(Architecture[Uformer]): + def __init__(self) -> None: + super().__init__( + id="Uformer", + detect=KeyCondition.has_all( + "input_proj.proj.0.weight", + "output_proj.proj.0.weight", + "encoderlayer_0.blocks.0.norm1.weight", + "encoderlayer_2.blocks.0.norm1.weight", + "conv.blocks.0.norm1.weight", + "decoderlayer_0.blocks.0.norm1.weight", + "decoderlayer_2.blocks.0.norm1.weight", + ), + ) + + @override + def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: + img_size = 256 # cannot be deduced from state_dict + in_chans = 3 + dd_in = 3 + embed_dim = 32 + depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] + num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] + win_size = 8 + mlp_ratio = 4.0 + qkv_bias = True + drop_rate = 0.0 # cannot be deduced from state_dict + attn_drop_rate = 0.0 # cannot be deduced from state_dict + drop_path_rate = 0.1 # cannot be deduced from state_dict + token_projection = "linear" + token_mlp = "leff" + shift_flag = True # cannot be deduced from state_dict + modulator = False + cross_modulator = False + + embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] + dd_in = state_dict["input_proj.proj.0.weight"].shape[1] + in_chans = state_dict["output_proj.proj.0.weight"].shape[0] + + depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") + depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") + depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") + depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") + depths[4] = get_seq_len(state_dict, "conv.blocks") + depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") + depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") + depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") + depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") + + num_heads_suffix = "blocks.0.attn.relative_position_bias_table" + num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] + num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] + num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] + + if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: + token_projection = "conv" + qkv_bias = True # cannot be deduced from state_dict + else: + token_projection = "linear" + qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict + + modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict + cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict + + # size_temp = (2 * win_size - 1) ** 2 + size_temp = state_dict[ + "encoderlayer_0.blocks.0.attn.relative_position_bias_table" + ].shape[0] + win_size = (int(math.sqrt(size_temp)) + 1) // 2 + + if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: + token_mlp = "mlp" # or "ffn", doesn't matter + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] + / embed_dim + ) + elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: + token_mlp = "leff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + else: + token_mlp = "fastleff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + + model = Uformer( + img_size=img_size, + in_chans=in_chans, + dd_in=dd_in, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + win_size=win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + modulator=modulator, + cross_modulator=cross_modulator, + ) + + return ImageModelDescriptor( + model, + state_dict, + architecture=self, + purpose="Restoration", + tags=[], + supports_half=False, # Too much weirdness to support this at the moment + supports_bfloat16=True, + scale=1, + input_channels=dd_in, + output_channels=dd_in, + size_requirements=SizeRequirements(multiple_of=128, square=True), + ) + + +__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/__init__.py index 557a2a59..91decbe1 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/__init__.py @@ -148,3 +148,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[AdaCode]: output_channels=in_channel, size_requirements=SizeRequirements(multiple_of=multiple_of), ) + + +__all__ = ["AdaCodeArch", "AdaCode"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/CodeFormer/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/CodeFormer/__arch/__init__.py new file mode 100644 index 00000000..6540556d --- /dev/null +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/CodeFormer/__arch/__init__.py @@ -0,0 +1,146 @@ +import math + +from typing_extensions import override + +from spandrel.util import KeyCondition, get_seq_len + +from ...__helpers.model_descriptor import ( + Architecture, + ImageModelDescriptor, + SizeRequirements, + StateDict, +) +from .__arch.Uformer import Uformer + + +class UformerArch(Architecture[Uformer]): + def __init__(self) -> None: + super().__init__( + id="Uformer", + detect=KeyCondition.has_all( + "input_proj.proj.0.weight", + "output_proj.proj.0.weight", + "encoderlayer_0.blocks.0.norm1.weight", + "encoderlayer_2.blocks.0.norm1.weight", + "conv.blocks.0.norm1.weight", + "decoderlayer_0.blocks.0.norm1.weight", + "decoderlayer_2.blocks.0.norm1.weight", + ), + ) + + @override + def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: + img_size = 256 # cannot be deduced from state_dict + in_chans = 3 + dd_in = 3 + embed_dim = 32 + depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] + num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] + win_size = 8 + mlp_ratio = 4.0 + qkv_bias = True + drop_rate = 0.0 # cannot be deduced from state_dict + attn_drop_rate = 0.0 # cannot be deduced from state_dict + drop_path_rate = 0.1 # cannot be deduced from state_dict + token_projection = "linear" + token_mlp = "leff" + shift_flag = True # cannot be deduced from state_dict + modulator = False + cross_modulator = False + + embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] + dd_in = state_dict["input_proj.proj.0.weight"].shape[1] + in_chans = state_dict["output_proj.proj.0.weight"].shape[0] + + depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") + depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") + depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") + depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") + depths[4] = get_seq_len(state_dict, "conv.blocks") + depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") + depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") + depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") + depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") + + num_heads_suffix = "blocks.0.attn.relative_position_bias_table" + num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] + num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] + num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] + + if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: + token_projection = "conv" + qkv_bias = True # cannot be deduced from state_dict + else: + token_projection = "linear" + qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict + + modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict + cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict + + # size_temp = (2 * win_size - 1) ** 2 + size_temp = state_dict[ + "encoderlayer_0.blocks.0.attn.relative_position_bias_table" + ].shape[0] + win_size = (int(math.sqrt(size_temp)) + 1) // 2 + + if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: + token_mlp = "mlp" # or "ffn", doesn't matter + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] + / embed_dim + ) + elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: + token_mlp = "leff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + else: + token_mlp = "fastleff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + + model = Uformer( + img_size=img_size, + in_chans=in_chans, + dd_in=dd_in, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + win_size=win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + modulator=modulator, + cross_modulator=cross_modulator, + ) + + return ImageModelDescriptor( + model, + state_dict, + architecture=self, + purpose="Restoration", + tags=[], + supports_half=False, # Too much weirdness to support this at the moment + supports_bfloat16=True, + scale=1, + input_channels=dd_in, + output_channels=dd_in, + size_requirements=SizeRequirements(multiple_of=128, square=True), + ) + + +__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/CodeFormer/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/CodeFormer/__init__.py index e7fbf21b..cfcbe067 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/CodeFormer/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/CodeFormer/__init__.py @@ -73,3 +73,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[CodeFormer]: size_requirements=SizeRequirements(multiple_of=512, square=True), call_fn=lambda model, image: model(image)[0], ) + + +__all__ = ["CodeFormerArch", "CodeFormer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/DDColor/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/DDColor/__arch/__init__.py new file mode 100644 index 00000000..6540556d --- /dev/null +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/DDColor/__arch/__init__.py @@ -0,0 +1,146 @@ +import math + +from typing_extensions import override + +from spandrel.util import KeyCondition, get_seq_len + +from ...__helpers.model_descriptor import ( + Architecture, + ImageModelDescriptor, + SizeRequirements, + StateDict, +) +from .__arch.Uformer import Uformer + + +class UformerArch(Architecture[Uformer]): + def __init__(self) -> None: + super().__init__( + id="Uformer", + detect=KeyCondition.has_all( + "input_proj.proj.0.weight", + "output_proj.proj.0.weight", + "encoderlayer_0.blocks.0.norm1.weight", + "encoderlayer_2.blocks.0.norm1.weight", + "conv.blocks.0.norm1.weight", + "decoderlayer_0.blocks.0.norm1.weight", + "decoderlayer_2.blocks.0.norm1.weight", + ), + ) + + @override + def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: + img_size = 256 # cannot be deduced from state_dict + in_chans = 3 + dd_in = 3 + embed_dim = 32 + depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] + num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] + win_size = 8 + mlp_ratio = 4.0 + qkv_bias = True + drop_rate = 0.0 # cannot be deduced from state_dict + attn_drop_rate = 0.0 # cannot be deduced from state_dict + drop_path_rate = 0.1 # cannot be deduced from state_dict + token_projection = "linear" + token_mlp = "leff" + shift_flag = True # cannot be deduced from state_dict + modulator = False + cross_modulator = False + + embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] + dd_in = state_dict["input_proj.proj.0.weight"].shape[1] + in_chans = state_dict["output_proj.proj.0.weight"].shape[0] + + depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") + depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") + depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") + depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") + depths[4] = get_seq_len(state_dict, "conv.blocks") + depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") + depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") + depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") + depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") + + num_heads_suffix = "blocks.0.attn.relative_position_bias_table" + num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] + num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] + num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] + + if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: + token_projection = "conv" + qkv_bias = True # cannot be deduced from state_dict + else: + token_projection = "linear" + qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict + + modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict + cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict + + # size_temp = (2 * win_size - 1) ** 2 + size_temp = state_dict[ + "encoderlayer_0.blocks.0.attn.relative_position_bias_table" + ].shape[0] + win_size = (int(math.sqrt(size_temp)) + 1) // 2 + + if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: + token_mlp = "mlp" # or "ffn", doesn't matter + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] + / embed_dim + ) + elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: + token_mlp = "leff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + else: + token_mlp = "fastleff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + + model = Uformer( + img_size=img_size, + in_chans=in_chans, + dd_in=dd_in, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + win_size=win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + modulator=modulator, + cross_modulator=cross_modulator, + ) + + return ImageModelDescriptor( + model, + state_dict, + architecture=self, + purpose="Restoration", + tags=[], + supports_half=False, # Too much weirdness to support this at the moment + supports_bfloat16=True, + scale=1, + input_channels=dd_in, + output_channels=dd_in, + size_requirements=SizeRequirements(multiple_of=128, square=True), + ) + + +__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/DDColor/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/DDColor/__init__.py index 90fb28f9..8f9aabef 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/DDColor/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/DDColor/__init__.py @@ -191,3 +191,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[DDColor]: tiling=ModelTiling.INTERNAL, call_fn=_call, ) + + +__all__ = ["DDColorArch", "DDColor"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/FeMaSR/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/FeMaSR/__arch/__init__.py new file mode 100644 index 00000000..6540556d --- /dev/null +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/FeMaSR/__arch/__init__.py @@ -0,0 +1,146 @@ +import math + +from typing_extensions import override + +from spandrel.util import KeyCondition, get_seq_len + +from ...__helpers.model_descriptor import ( + Architecture, + ImageModelDescriptor, + SizeRequirements, + StateDict, +) +from .__arch.Uformer import Uformer + + +class UformerArch(Architecture[Uformer]): + def __init__(self) -> None: + super().__init__( + id="Uformer", + detect=KeyCondition.has_all( + "input_proj.proj.0.weight", + "output_proj.proj.0.weight", + "encoderlayer_0.blocks.0.norm1.weight", + "encoderlayer_2.blocks.0.norm1.weight", + "conv.blocks.0.norm1.weight", + "decoderlayer_0.blocks.0.norm1.weight", + "decoderlayer_2.blocks.0.norm1.weight", + ), + ) + + @override + def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: + img_size = 256 # cannot be deduced from state_dict + in_chans = 3 + dd_in = 3 + embed_dim = 32 + depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] + num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] + win_size = 8 + mlp_ratio = 4.0 + qkv_bias = True + drop_rate = 0.0 # cannot be deduced from state_dict + attn_drop_rate = 0.0 # cannot be deduced from state_dict + drop_path_rate = 0.1 # cannot be deduced from state_dict + token_projection = "linear" + token_mlp = "leff" + shift_flag = True # cannot be deduced from state_dict + modulator = False + cross_modulator = False + + embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] + dd_in = state_dict["input_proj.proj.0.weight"].shape[1] + in_chans = state_dict["output_proj.proj.0.weight"].shape[0] + + depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") + depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") + depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") + depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") + depths[4] = get_seq_len(state_dict, "conv.blocks") + depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") + depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") + depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") + depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") + + num_heads_suffix = "blocks.0.attn.relative_position_bias_table" + num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] + num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] + num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] + + if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: + token_projection = "conv" + qkv_bias = True # cannot be deduced from state_dict + else: + token_projection = "linear" + qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict + + modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict + cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict + + # size_temp = (2 * win_size - 1) ** 2 + size_temp = state_dict[ + "encoderlayer_0.blocks.0.attn.relative_position_bias_table" + ].shape[0] + win_size = (int(math.sqrt(size_temp)) + 1) // 2 + + if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: + token_mlp = "mlp" # or "ffn", doesn't matter + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] + / embed_dim + ) + elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: + token_mlp = "leff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + else: + token_mlp = "fastleff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + + model = Uformer( + img_size=img_size, + in_chans=in_chans, + dd_in=dd_in, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + win_size=win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + modulator=modulator, + cross_modulator=cross_modulator, + ) + + return ImageModelDescriptor( + model, + state_dict, + architecture=self, + purpose="Restoration", + tags=[], + supports_half=False, # Too much weirdness to support this at the moment + supports_bfloat16=True, + scale=1, + input_channels=dd_in, + output_channels=dd_in, + size_requirements=SizeRequirements(multiple_of=128, square=True), + ) + + +__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/FeMaSR/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/FeMaSR/__init__.py index a51c82dd..75ce1950 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/FeMaSR/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/FeMaSR/__init__.py @@ -152,3 +152,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[FeMaSR]: output_channels=in_channel, size_requirements=SizeRequirements(multiple_of=multiple_of), ) + + +__all__ = ["FeMaSRArch", "FeMaSR"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/M3SNet/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/M3SNet/__arch/__init__.py new file mode 100644 index 00000000..6540556d --- /dev/null +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/M3SNet/__arch/__init__.py @@ -0,0 +1,146 @@ +import math + +from typing_extensions import override + +from spandrel.util import KeyCondition, get_seq_len + +from ...__helpers.model_descriptor import ( + Architecture, + ImageModelDescriptor, + SizeRequirements, + StateDict, +) +from .__arch.Uformer import Uformer + + +class UformerArch(Architecture[Uformer]): + def __init__(self) -> None: + super().__init__( + id="Uformer", + detect=KeyCondition.has_all( + "input_proj.proj.0.weight", + "output_proj.proj.0.weight", + "encoderlayer_0.blocks.0.norm1.weight", + "encoderlayer_2.blocks.0.norm1.weight", + "conv.blocks.0.norm1.weight", + "decoderlayer_0.blocks.0.norm1.weight", + "decoderlayer_2.blocks.0.norm1.weight", + ), + ) + + @override + def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: + img_size = 256 # cannot be deduced from state_dict + in_chans = 3 + dd_in = 3 + embed_dim = 32 + depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] + num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] + win_size = 8 + mlp_ratio = 4.0 + qkv_bias = True + drop_rate = 0.0 # cannot be deduced from state_dict + attn_drop_rate = 0.0 # cannot be deduced from state_dict + drop_path_rate = 0.1 # cannot be deduced from state_dict + token_projection = "linear" + token_mlp = "leff" + shift_flag = True # cannot be deduced from state_dict + modulator = False + cross_modulator = False + + embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] + dd_in = state_dict["input_proj.proj.0.weight"].shape[1] + in_chans = state_dict["output_proj.proj.0.weight"].shape[0] + + depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") + depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") + depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") + depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") + depths[4] = get_seq_len(state_dict, "conv.blocks") + depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") + depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") + depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") + depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") + + num_heads_suffix = "blocks.0.attn.relative_position_bias_table" + num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] + num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] + num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] + + if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: + token_projection = "conv" + qkv_bias = True # cannot be deduced from state_dict + else: + token_projection = "linear" + qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict + + modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict + cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict + + # size_temp = (2 * win_size - 1) ** 2 + size_temp = state_dict[ + "encoderlayer_0.blocks.0.attn.relative_position_bias_table" + ].shape[0] + win_size = (int(math.sqrt(size_temp)) + 1) // 2 + + if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: + token_mlp = "mlp" # or "ffn", doesn't matter + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] + / embed_dim + ) + elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: + token_mlp = "leff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + else: + token_mlp = "fastleff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + + model = Uformer( + img_size=img_size, + in_chans=in_chans, + dd_in=dd_in, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + win_size=win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + modulator=modulator, + cross_modulator=cross_modulator, + ) + + return ImageModelDescriptor( + model, + state_dict, + architecture=self, + purpose="Restoration", + tags=[], + supports_half=False, # Too much weirdness to support this at the moment + supports_bfloat16=True, + scale=1, + input_channels=dd_in, + output_channels=dd_in, + size_requirements=SizeRequirements(multiple_of=128, square=True), + ) + + +__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/M3SNet/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/M3SNet/__init__.py index fa2c37da..de5006ea 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/M3SNet/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/M3SNet/__init__.py @@ -97,3 +97,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[M3SNet]: output_channels=img_channel, size_requirements=SizeRequirements(multiple_of=16), ) + + +__all__ = ["M3SNetArch", "M3SNet"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MAT/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MAT/__arch/__init__.py new file mode 100644 index 00000000..6540556d --- /dev/null +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MAT/__arch/__init__.py @@ -0,0 +1,146 @@ +import math + +from typing_extensions import override + +from spandrel.util import KeyCondition, get_seq_len + +from ...__helpers.model_descriptor import ( + Architecture, + ImageModelDescriptor, + SizeRequirements, + StateDict, +) +from .__arch.Uformer import Uformer + + +class UformerArch(Architecture[Uformer]): + def __init__(self) -> None: + super().__init__( + id="Uformer", + detect=KeyCondition.has_all( + "input_proj.proj.0.weight", + "output_proj.proj.0.weight", + "encoderlayer_0.blocks.0.norm1.weight", + "encoderlayer_2.blocks.0.norm1.weight", + "conv.blocks.0.norm1.weight", + "decoderlayer_0.blocks.0.norm1.weight", + "decoderlayer_2.blocks.0.norm1.weight", + ), + ) + + @override + def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: + img_size = 256 # cannot be deduced from state_dict + in_chans = 3 + dd_in = 3 + embed_dim = 32 + depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] + num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] + win_size = 8 + mlp_ratio = 4.0 + qkv_bias = True + drop_rate = 0.0 # cannot be deduced from state_dict + attn_drop_rate = 0.0 # cannot be deduced from state_dict + drop_path_rate = 0.1 # cannot be deduced from state_dict + token_projection = "linear" + token_mlp = "leff" + shift_flag = True # cannot be deduced from state_dict + modulator = False + cross_modulator = False + + embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] + dd_in = state_dict["input_proj.proj.0.weight"].shape[1] + in_chans = state_dict["output_proj.proj.0.weight"].shape[0] + + depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") + depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") + depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") + depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") + depths[4] = get_seq_len(state_dict, "conv.blocks") + depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") + depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") + depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") + depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") + + num_heads_suffix = "blocks.0.attn.relative_position_bias_table" + num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] + num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] + num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] + + if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: + token_projection = "conv" + qkv_bias = True # cannot be deduced from state_dict + else: + token_projection = "linear" + qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict + + modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict + cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict + + # size_temp = (2 * win_size - 1) ** 2 + size_temp = state_dict[ + "encoderlayer_0.blocks.0.attn.relative_position_bias_table" + ].shape[0] + win_size = (int(math.sqrt(size_temp)) + 1) // 2 + + if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: + token_mlp = "mlp" # or "ffn", doesn't matter + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] + / embed_dim + ) + elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: + token_mlp = "leff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + else: + token_mlp = "fastleff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + + model = Uformer( + img_size=img_size, + in_chans=in_chans, + dd_in=dd_in, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + win_size=win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + modulator=modulator, + cross_modulator=cross_modulator, + ) + + return ImageModelDescriptor( + model, + state_dict, + architecture=self, + purpose="Restoration", + tags=[], + supports_half=False, # Too much weirdness to support this at the moment + supports_bfloat16=True, + scale=1, + input_channels=dd_in, + output_channels=dd_in, + size_requirements=SizeRequirements(multiple_of=128, square=True), + ) + + +__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MAT/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MAT/__init__.py index ca416fd6..bb4cd2b9 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MAT/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MAT/__init__.py @@ -48,3 +48,6 @@ def load(self, state_dict: StateDict) -> MaskedImageModelDescriptor[MAT]: minimum=512, multiple_of=512, square=True ), ) + + +__all__ = ["MATArch", "MAT"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MIRNet2/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MIRNet2/__arch/__init__.py new file mode 100644 index 00000000..6540556d --- /dev/null +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MIRNet2/__arch/__init__.py @@ -0,0 +1,146 @@ +import math + +from typing_extensions import override + +from spandrel.util import KeyCondition, get_seq_len + +from ...__helpers.model_descriptor import ( + Architecture, + ImageModelDescriptor, + SizeRequirements, + StateDict, +) +from .__arch.Uformer import Uformer + + +class UformerArch(Architecture[Uformer]): + def __init__(self) -> None: + super().__init__( + id="Uformer", + detect=KeyCondition.has_all( + "input_proj.proj.0.weight", + "output_proj.proj.0.weight", + "encoderlayer_0.blocks.0.norm1.weight", + "encoderlayer_2.blocks.0.norm1.weight", + "conv.blocks.0.norm1.weight", + "decoderlayer_0.blocks.0.norm1.weight", + "decoderlayer_2.blocks.0.norm1.weight", + ), + ) + + @override + def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: + img_size = 256 # cannot be deduced from state_dict + in_chans = 3 + dd_in = 3 + embed_dim = 32 + depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] + num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] + win_size = 8 + mlp_ratio = 4.0 + qkv_bias = True + drop_rate = 0.0 # cannot be deduced from state_dict + attn_drop_rate = 0.0 # cannot be deduced from state_dict + drop_path_rate = 0.1 # cannot be deduced from state_dict + token_projection = "linear" + token_mlp = "leff" + shift_flag = True # cannot be deduced from state_dict + modulator = False + cross_modulator = False + + embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] + dd_in = state_dict["input_proj.proj.0.weight"].shape[1] + in_chans = state_dict["output_proj.proj.0.weight"].shape[0] + + depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") + depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") + depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") + depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") + depths[4] = get_seq_len(state_dict, "conv.blocks") + depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") + depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") + depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") + depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") + + num_heads_suffix = "blocks.0.attn.relative_position_bias_table" + num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] + num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] + num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] + + if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: + token_projection = "conv" + qkv_bias = True # cannot be deduced from state_dict + else: + token_projection = "linear" + qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict + + modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict + cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict + + # size_temp = (2 * win_size - 1) ** 2 + size_temp = state_dict[ + "encoderlayer_0.blocks.0.attn.relative_position_bias_table" + ].shape[0] + win_size = (int(math.sqrt(size_temp)) + 1) // 2 + + if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: + token_mlp = "mlp" # or "ffn", doesn't matter + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] + / embed_dim + ) + elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: + token_mlp = "leff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + else: + token_mlp = "fastleff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + + model = Uformer( + img_size=img_size, + in_chans=in_chans, + dd_in=dd_in, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + win_size=win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + modulator=modulator, + cross_modulator=cross_modulator, + ) + + return ImageModelDescriptor( + model, + state_dict, + architecture=self, + purpose="Restoration", + tags=[], + supports_half=False, # Too much weirdness to support this at the moment + supports_bfloat16=True, + scale=1, + input_channels=dd_in, + output_channels=dd_in, + size_requirements=SizeRequirements(multiple_of=128, square=True), + ) + + +__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MIRNet2/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MIRNet2/__init__.py index e7ea5f42..6bcc8e30 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MIRNet2/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MIRNet2/__init__.py @@ -95,3 +95,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[MIRNet2]: output_channels=out_channels, size_requirements=SizeRequirements(multiple_of=4), ) + + +__all__ = ["MIRNet2Arch", "MIRNet2"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/__arch/__init__.py new file mode 100644 index 00000000..6540556d --- /dev/null +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/__arch/__init__.py @@ -0,0 +1,146 @@ +import math + +from typing_extensions import override + +from spandrel.util import KeyCondition, get_seq_len + +from ...__helpers.model_descriptor import ( + Architecture, + ImageModelDescriptor, + SizeRequirements, + StateDict, +) +from .__arch.Uformer import Uformer + + +class UformerArch(Architecture[Uformer]): + def __init__(self) -> None: + super().__init__( + id="Uformer", + detect=KeyCondition.has_all( + "input_proj.proj.0.weight", + "output_proj.proj.0.weight", + "encoderlayer_0.blocks.0.norm1.weight", + "encoderlayer_2.blocks.0.norm1.weight", + "conv.blocks.0.norm1.weight", + "decoderlayer_0.blocks.0.norm1.weight", + "decoderlayer_2.blocks.0.norm1.weight", + ), + ) + + @override + def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: + img_size = 256 # cannot be deduced from state_dict + in_chans = 3 + dd_in = 3 + embed_dim = 32 + depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] + num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] + win_size = 8 + mlp_ratio = 4.0 + qkv_bias = True + drop_rate = 0.0 # cannot be deduced from state_dict + attn_drop_rate = 0.0 # cannot be deduced from state_dict + drop_path_rate = 0.1 # cannot be deduced from state_dict + token_projection = "linear" + token_mlp = "leff" + shift_flag = True # cannot be deduced from state_dict + modulator = False + cross_modulator = False + + embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] + dd_in = state_dict["input_proj.proj.0.weight"].shape[1] + in_chans = state_dict["output_proj.proj.0.weight"].shape[0] + + depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") + depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") + depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") + depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") + depths[4] = get_seq_len(state_dict, "conv.blocks") + depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") + depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") + depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") + depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") + + num_heads_suffix = "blocks.0.attn.relative_position_bias_table" + num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] + num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] + num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] + + if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: + token_projection = "conv" + qkv_bias = True # cannot be deduced from state_dict + else: + token_projection = "linear" + qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict + + modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict + cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict + + # size_temp = (2 * win_size - 1) ** 2 + size_temp = state_dict[ + "encoderlayer_0.blocks.0.attn.relative_position_bias_table" + ].shape[0] + win_size = (int(math.sqrt(size_temp)) + 1) // 2 + + if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: + token_mlp = "mlp" # or "ffn", doesn't matter + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] + / embed_dim + ) + elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: + token_mlp = "leff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + else: + token_mlp = "fastleff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + + model = Uformer( + img_size=img_size, + in_chans=in_chans, + dd_in=dd_in, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + win_size=win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + modulator=modulator, + cross_modulator=cross_modulator, + ) + + return ImageModelDescriptor( + model, + state_dict, + architecture=self, + purpose="Restoration", + tags=[], + supports_half=False, # Too much weirdness to support this at the moment + supports_bfloat16=True, + scale=1, + input_channels=dd_in, + output_channels=dd_in, + size_requirements=SizeRequirements(multiple_of=128, square=True), + ) + + +__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/__init__.py index b9e01552..abefdf6e 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/MPRNet/__init__.py @@ -106,3 +106,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[MPRNet]: size_requirements=SizeRequirements(multiple_of=8), call_fn=lambda model, x: model(x)[0], ) + + +__all__ = ["MPRNetArch", "MPRNet"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/Restormer/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/Restormer/__arch/__init__.py new file mode 100644 index 00000000..6540556d --- /dev/null +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/Restormer/__arch/__init__.py @@ -0,0 +1,146 @@ +import math + +from typing_extensions import override + +from spandrel.util import KeyCondition, get_seq_len + +from ...__helpers.model_descriptor import ( + Architecture, + ImageModelDescriptor, + SizeRequirements, + StateDict, +) +from .__arch.Uformer import Uformer + + +class UformerArch(Architecture[Uformer]): + def __init__(self) -> None: + super().__init__( + id="Uformer", + detect=KeyCondition.has_all( + "input_proj.proj.0.weight", + "output_proj.proj.0.weight", + "encoderlayer_0.blocks.0.norm1.weight", + "encoderlayer_2.blocks.0.norm1.weight", + "conv.blocks.0.norm1.weight", + "decoderlayer_0.blocks.0.norm1.weight", + "decoderlayer_2.blocks.0.norm1.weight", + ), + ) + + @override + def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: + img_size = 256 # cannot be deduced from state_dict + in_chans = 3 + dd_in = 3 + embed_dim = 32 + depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] + num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] + win_size = 8 + mlp_ratio = 4.0 + qkv_bias = True + drop_rate = 0.0 # cannot be deduced from state_dict + attn_drop_rate = 0.0 # cannot be deduced from state_dict + drop_path_rate = 0.1 # cannot be deduced from state_dict + token_projection = "linear" + token_mlp = "leff" + shift_flag = True # cannot be deduced from state_dict + modulator = False + cross_modulator = False + + embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] + dd_in = state_dict["input_proj.proj.0.weight"].shape[1] + in_chans = state_dict["output_proj.proj.0.weight"].shape[0] + + depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") + depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") + depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") + depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") + depths[4] = get_seq_len(state_dict, "conv.blocks") + depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") + depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") + depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") + depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") + + num_heads_suffix = "blocks.0.attn.relative_position_bias_table" + num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] + num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] + num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] + + if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: + token_projection = "conv" + qkv_bias = True # cannot be deduced from state_dict + else: + token_projection = "linear" + qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict + + modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict + cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict + + # size_temp = (2 * win_size - 1) ** 2 + size_temp = state_dict[ + "encoderlayer_0.blocks.0.attn.relative_position_bias_table" + ].shape[0] + win_size = (int(math.sqrt(size_temp)) + 1) // 2 + + if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: + token_mlp = "mlp" # or "ffn", doesn't matter + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] + / embed_dim + ) + elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: + token_mlp = "leff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + else: + token_mlp = "fastleff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + + model = Uformer( + img_size=img_size, + in_chans=in_chans, + dd_in=dd_in, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + win_size=win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + modulator=modulator, + cross_modulator=cross_modulator, + ) + + return ImageModelDescriptor( + model, + state_dict, + architecture=self, + purpose="Restoration", + tags=[], + supports_half=False, # Too much weirdness to support this at the moment + supports_bfloat16=True, + scale=1, + input_channels=dd_in, + output_channels=dd_in, + size_requirements=SizeRequirements(multiple_of=128, square=True), + ) + + +__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/Restormer/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/Restormer/__init__.py index 8ccf0974..7eb38a49 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/Restormer/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/Restormer/__init__.py @@ -120,3 +120,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[Restormer]: output_channels=out_channels, size_requirements=SizeRequirements(multiple_of=8), ) + + +__all__ = ["RestormerArch", "Restormer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/SRFormer/__arch/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/SRFormer/__arch/__init__.py new file mode 100644 index 00000000..6540556d --- /dev/null +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/SRFormer/__arch/__init__.py @@ -0,0 +1,146 @@ +import math + +from typing_extensions import override + +from spandrel.util import KeyCondition, get_seq_len + +from ...__helpers.model_descriptor import ( + Architecture, + ImageModelDescriptor, + SizeRequirements, + StateDict, +) +from .__arch.Uformer import Uformer + + +class UformerArch(Architecture[Uformer]): + def __init__(self) -> None: + super().__init__( + id="Uformer", + detect=KeyCondition.has_all( + "input_proj.proj.0.weight", + "output_proj.proj.0.weight", + "encoderlayer_0.blocks.0.norm1.weight", + "encoderlayer_2.blocks.0.norm1.weight", + "conv.blocks.0.norm1.weight", + "decoderlayer_0.blocks.0.norm1.weight", + "decoderlayer_2.blocks.0.norm1.weight", + ), + ) + + @override + def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]: + img_size = 256 # cannot be deduced from state_dict + in_chans = 3 + dd_in = 3 + embed_dim = 32 + depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] + num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2] + win_size = 8 + mlp_ratio = 4.0 + qkv_bias = True + drop_rate = 0.0 # cannot be deduced from state_dict + attn_drop_rate = 0.0 # cannot be deduced from state_dict + drop_path_rate = 0.1 # cannot be deduced from state_dict + token_projection = "linear" + token_mlp = "leff" + shift_flag = True # cannot be deduced from state_dict + modulator = False + cross_modulator = False + + embed_dim = state_dict["input_proj.proj.0.weight"].shape[0] + dd_in = state_dict["input_proj.proj.0.weight"].shape[1] + in_chans = state_dict["output_proj.proj.0.weight"].shape[0] + + depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks") + depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks") + depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks") + depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks") + depths[4] = get_seq_len(state_dict, "conv.blocks") + depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks") + depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks") + depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks") + depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks") + + num_heads_suffix = "blocks.0.attn.relative_position_bias_table" + num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1] + num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1] + num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1] + num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1] + num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1] + num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1] + + if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict: + token_projection = "conv" + qkv_bias = True # cannot be deduced from state_dict + else: + token_projection = "linear" + qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict + + modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict + cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict + + # size_temp = (2 * win_size - 1) ** 2 + size_temp = state_dict[ + "encoderlayer_0.blocks.0.attn.relative_position_bias_table" + ].shape[0] + win_size = (int(math.sqrt(size_temp)) + 1) // 2 + + if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict: + token_mlp = "mlp" # or "ffn", doesn't matter + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0] + / embed_dim + ) + elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1: + token_mlp = "leff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + else: + token_mlp = "fastleff" + mlp_ratio = ( + state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0] + / embed_dim + ) + + model = Uformer( + img_size=img_size, + in_chans=in_chans, + dd_in=dd_in, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + win_size=win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + modulator=modulator, + cross_modulator=cross_modulator, + ) + + return ImageModelDescriptor( + model, + state_dict, + architecture=self, + purpose="Restoration", + tags=[], + supports_half=False, # Too much weirdness to support this at the moment + supports_bfloat16=True, + scale=1, + input_channels=dd_in, + output_channels=dd_in, + size_requirements=SizeRequirements(multiple_of=128, square=True), + ) + + +__all__ = ["UformerArch", "Uformer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/SRFormer/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/SRFormer/__init__.py index 2572da28..73e6f702 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/SRFormer/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/SRFormer/__init__.py @@ -182,3 +182,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SRFormer]: output_channels=in_chans, size_requirements=SizeRequirements(minimum=16), ) + + +__all__ = ["SRFormerArch", "SRFormer"] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/__init__.py new file mode 100644 index 00000000..7e6c7cbe --- /dev/null +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/__init__.py @@ -0,0 +1,3 @@ +""" +The package containing the implementations of all supported architectures. Not necessary for most user code. +""" diff --git a/pyproject.toml b/pyproject.toml index fc7b8b62..00766939 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ pythonpath = ["libs/spandrel", "libs/spandrel_extra_arches"] [tool.pydoctor] project-name = "spandrel" -add-package = ["libs/spandrel/spandrel"] +add-package = ["libs/spandrel/spandrel", "libs/spandrel_extra_arches/spandrel_extra_arches"] project-url = "https://github.com/chaiNNer-org/spandrel" docformat = "restructuredtext" warnings-as-errors = false @@ -51,5 +51,9 @@ theme = "readthedocs" privacy = [ "HIDDEN:spandrel.__version__", "HIDDEN:spandrel.__helpers", + "HIDDEN:spandrel.architectures.*.__arch", "PRIVATE:spandrel.canonicalize_state_dict", + "HIDDEN:spandrel_extra_arches.__version__", + "HIDDEN:spandrel_extra_arches.__helper", + "HIDDEN:spandrel_extra_arches.architectures.*.__arch", ]