Skip to content

Commit

Permalink
Define the public API by what's documented
Browse files Browse the repository at this point in the history
  • Loading branch information
RunDevelopment committed Jul 8, 2024
1 parent ebf11ba commit d8c2256
Show file tree
Hide file tree
Showing 92 changed files with 1,605 additions and 1 deletion.
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/ATD/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[ATD]:
output_channels=in_chans,
size_requirements=SizeRequirements(minimum=8),
)


__all__ = ["ATDArch", "ATD"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/CRAFT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[CRAFT]:
output_channels=in_chans,
size_requirements=SizeRequirements(minimum=16, multiple_of=16),
)


__all__ = ["CRAFTArch", "CRAFT"]
Empty file.
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/DAT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[DAT]:
output_channels=in_chans,
size_requirements=SizeRequirements(minimum=16),
)


__all__ = ["DATArch", "DAT"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/DCTLSA/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[DCTLSA]:
output_channels=out_nc,
size_requirements=SizeRequirements(minimum=16),
)


__all__ = ["DCTLSAArch", "DCTLSA"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/DITN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[DITN]:
output_channels=3, # hard-coded in the architecture
size_requirements=SizeRequirements(multiple_of=patch_size),
)


__all__ = ["DITNArch", "DITN"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/DRCT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[DRCT]:
output_channels=in_chans,
size_requirements=SizeRequirements(multiple_of=16),
)


__all__ = ["DRCTArch", "DRCT"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/DRUNet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,6 @@ def call(model: DRUNet, image: torch.Tensor) -> torch.Tensor:
size_requirements=SizeRequirements(multiple_of=8),
call_fn=call,
)


__all__ = ["DRUNetArch", "DRUNet"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/DnCNN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,6 @@ def call(model: DnCNN, image: torch.Tensor) -> torch.Tensor:
size_requirements=SizeRequirements(),
call_fn=call,
)


__all__ = ["DnCNNArch", "DnCNN"]
Empty file.
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/FBCNN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[FBCNN]:
output_channels=out_nc,
call_fn=lambda model, image: model(image)[0],
)


__all__ = ["FBCNNArch", "FBCNN"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/FFTformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[FFTformer]:
output_channels=out_channels,
size_requirements=SizeRequirements(multiple_of=32),
)


__all__ = ["FFTformerArch", "FFTformer"]
Empty file.
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/GRL/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[GRL]:
input_channels=in_channels,
output_channels=out_channels,
)


__all__ = ["GRLArch", "GRL"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/HAT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[HAT]:
output_channels=in_chans,
size_requirements=SizeRequirements(minimum=16),
)


__all__ = ["HATArch", "HAT"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/HVICIDNet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[HVICIDNet]:
size_requirements=SizeRequirements(multiple_of=8),
tiling=ModelTiling.DISCOURAGED,
)


__all__ = ["HVICIDNetArch", "HVICIDNet"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/IPT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,6 @@ def call(model: IPT, x: torch.Tensor):
size_requirements=SizeRequirements(minimum=patch_size),
call_fn=call,
)


__all__ = ["IPTArch", "IPT"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/KBNet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_KBNet]:
return self._load_l(state_dict)
else:
return self._load_s(state_dict)


__all__ = ["KBNetArch", "KBNet_s", "KBNet_l"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/LaMa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,6 @@ def load(self, state_dict: StateDict) -> MaskedImageModelDescriptor[LaMa]:
output_channels=out_nc,
size_requirements=SizeRequirements(minimum=16),
)


__all__ = ["LaMaArch", "LaMa"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/MMRealSR/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[MMRealSR]:
size_requirements=SizeRequirements(minimum=16),
call_fn=lambda model, image: model(image)[0],
)


__all__ = ["MMRealSRArch", "MMRealSR"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/MixDehazeNet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[MixDehazeNet]:
tiling=ModelTiling.DISCOURAGED,
call_fn=lambda model, image: model(image) * 0.5 + 0.5,
)


__all__ = ["MixDehazeNetArch", "MixDehazeNet"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/NAFNet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[NAFNet]:
input_channels=img_channel,
output_channels=img_channel,
)


__all__ = ["NAFNetArch", "NAFNet"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/OmniSR/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[OmniSR]:
output_channels=num_out_ch,
size_requirements=SizeRequirements(minimum=16),
)


__all__ = ["OmniSRArch", "OmniSR"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/PLKSR/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_PLKSR]:
input_channels=3,
output_channels=3,
)


__all__ = ["PLKSRArch", "PLKSR", "RealPLKSR"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/RGT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[RGT]:
output_channels=in_chans,
size_requirements=SizeRequirements(minimum=16),
)


__all__ = ["RGTArch", "RGT"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/RealCUGAN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_RealCUGAN]:
output_channels=out_channels,
size_requirements=size_requirements,
)


__all__ = ["RealCUGANArch", "UpCunet2x", "UpCunet3x", "UpCunet4x", "UpCunet2x_fast"]
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,6 @@ def call(model: RestoreFormer, x: torch.Tensor) -> torch.Tensor:
size_requirements=SizeRequirements(multiple_of=32),
call_fn=call,
)


__all__ = ["RestoreFormerArch", "RestoreFormer"]
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[RetinexFormer]:
tiling=ModelTiling.DISCOURAGED,
call_fn=_call_fn,
)


__all__ = ["RetinexFormerArch", "RetinexFormer"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/SAFMN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SAFMN]:
output_channels=3, # hard-coded in the arch
size_requirements=SizeRequirements(multiple_of=8),
)


__all__ = ["SAFMNArch", "SAFMN"]
Empty file.
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/SCUNet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SCUNet]:
size_requirements=SizeRequirements(minimum=40),
tiling=ModelTiling.DISCOURAGED,
)


__all__ = ["SCUNetArch", "SCUNet"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/SPAN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SPAN]:
input_channels=num_in_ch,
output_channels=num_out_ch,
)


__all__ = ["SPANArch", "SPAN"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/SwiftSRGAN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SwiftSRGAN]:
input_channels=in_channels,
output_channels=in_channels,
)


__all__ = ["SwiftSRGANArch", "SwiftSRGAN"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/Swin2SR/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[Swin2SR]:
output_channels=in_chans,
size_requirements=SizeRequirements(minimum=16),
)


__all__ = ["Swin2SRArch", "Swin2SR"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/SwinIR/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SwinIR]:
output_channels=out_nc,
size_requirements=SizeRequirements(minimum=16),
)


__all__ = ["SwinIRArch", "SwinIR"]
Empty file.
3 changes: 3 additions & 0 deletions libs/spandrel/spandrel/architectures/Uformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]:
output_channels=dd_in,
size_requirements=SizeRequirements(multiple_of=128, square=True),
)


__all__ = ["UformerArch", "Uformer"]
5 changes: 5 additions & 0 deletions libs/spandrel_extra_arches/spandrel_extra_arches/__helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
)

EXTRA_REGISTRY = ArchRegistry()
"""
The registry of all architectures in this library.
Use ``MAIN_REGISTRY.add(*EXTRA_REGISTRY)`` to add all architectures to the main registry of `spandrel`.
"""

EXTRA_REGISTRY.add(
ArchSupport.from_architecture(SRFormer.SRFormerArch()),
Expand Down
6 changes: 6 additions & 0 deletions libs/spandrel_extra_arches/spandrel_extra_arches/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""
Spandrel extra arches contains more architectures for `spandrel`.
All architectures in this library are registered in the `EXTRA_REGISTRY` dictionary.
"""

from .__helper import EXTRA_REGISTRY

__version__ = "0.1.1"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import math

from typing_extensions import override

from spandrel.util import KeyCondition, get_seq_len

from ...__helpers.model_descriptor import (
Architecture,
ImageModelDescriptor,
SizeRequirements,
StateDict,
)
from .__arch.Uformer import Uformer


class UformerArch(Architecture[Uformer]):
def __init__(self) -> None:
super().__init__(
id="Uformer",
detect=KeyCondition.has_all(
"input_proj.proj.0.weight",
"output_proj.proj.0.weight",
"encoderlayer_0.blocks.0.norm1.weight",
"encoderlayer_2.blocks.0.norm1.weight",
"conv.blocks.0.norm1.weight",
"decoderlayer_0.blocks.0.norm1.weight",
"decoderlayer_2.blocks.0.norm1.weight",
),
)

@override
def load(self, state_dict: StateDict) -> ImageModelDescriptor[Uformer]:
img_size = 256 # cannot be deduced from state_dict
in_chans = 3
dd_in = 3
embed_dim = 32
depths = [2, 2, 2, 2, 2, 2, 2, 2, 2]
num_heads = [1, 2, 4, 8, 16, 16, 8, 4, 2]
win_size = 8
mlp_ratio = 4.0
qkv_bias = True
drop_rate = 0.0 # cannot be deduced from state_dict
attn_drop_rate = 0.0 # cannot be deduced from state_dict
drop_path_rate = 0.1 # cannot be deduced from state_dict
token_projection = "linear"
token_mlp = "leff"
shift_flag = True # cannot be deduced from state_dict
modulator = False
cross_modulator = False

embed_dim = state_dict["input_proj.proj.0.weight"].shape[0]
dd_in = state_dict["input_proj.proj.0.weight"].shape[1]
in_chans = state_dict["output_proj.proj.0.weight"].shape[0]

depths[0] = get_seq_len(state_dict, "encoderlayer_0.blocks")
depths[1] = get_seq_len(state_dict, "encoderlayer_1.blocks")
depths[2] = get_seq_len(state_dict, "encoderlayer_2.blocks")
depths[3] = get_seq_len(state_dict, "encoderlayer_3.blocks")
depths[4] = get_seq_len(state_dict, "conv.blocks")
depths[5] = get_seq_len(state_dict, "decoderlayer_0.blocks")
depths[6] = get_seq_len(state_dict, "decoderlayer_1.blocks")
depths[7] = get_seq_len(state_dict, "decoderlayer_2.blocks")
depths[8] = get_seq_len(state_dict, "decoderlayer_3.blocks")

num_heads_suffix = "blocks.0.attn.relative_position_bias_table"
num_heads[0] = state_dict[f"encoderlayer_0.{num_heads_suffix}"].shape[1]
num_heads[1] = state_dict[f"encoderlayer_1.{num_heads_suffix}"].shape[1]
num_heads[2] = state_dict[f"encoderlayer_2.{num_heads_suffix}"].shape[1]
num_heads[3] = state_dict[f"encoderlayer_3.{num_heads_suffix}"].shape[1]
num_heads[4] = state_dict[f"conv.{num_heads_suffix}"].shape[1]
num_heads[5] = state_dict[f"decoderlayer_0.{num_heads_suffix}"].shape[1]
num_heads[6] = state_dict[f"decoderlayer_1.{num_heads_suffix}"].shape[1]
num_heads[7] = state_dict[f"decoderlayer_2.{num_heads_suffix}"].shape[1]
num_heads[8] = state_dict[f"decoderlayer_3.{num_heads_suffix}"].shape[1]

if "encoderlayer_0.blocks.0.attn.qkv.to_q.depthwise.weight" in state_dict:
token_projection = "conv"
qkv_bias = True # cannot be deduced from state_dict
else:
token_projection = "linear"
qkv_bias = "encoderlayer_0.blocks.0.attn.qkv.to_q.bias" in state_dict

modulator = "decoderlayer_0.blocks.0.modulator.weight" in state_dict
cross_modulator = "decoderlayer_0.blocks.0.cross_modulator.weight" in state_dict

# size_temp = (2 * win_size - 1) ** 2
size_temp = state_dict[
"encoderlayer_0.blocks.0.attn.relative_position_bias_table"
].shape[0]
win_size = (int(math.sqrt(size_temp)) + 1) // 2

if "encoderlayer_0.blocks.0.mlp.fc1.weight" in state_dict:
token_mlp = "mlp" # or "ffn", doesn't matter
mlp_ratio = (
state_dict["encoderlayer_0.blocks.0.mlp.fc1.weight"].shape[0]
/ embed_dim
)
elif state_dict["encoderlayer_0.blocks.0.mlp.dwconv.0.weight"].shape[1] == 1:
token_mlp = "leff"
mlp_ratio = (
state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0]
/ embed_dim
)
else:
token_mlp = "fastleff"
mlp_ratio = (
state_dict["encoderlayer_0.blocks.0.mlp.linear1.0.weight"].shape[0]
/ embed_dim
)

model = Uformer(
img_size=img_size,
in_chans=in_chans,
dd_in=dd_in,
embed_dim=embed_dim,
depths=depths,
num_heads=num_heads,
win_size=win_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
token_projection=token_projection,
token_mlp=token_mlp,
shift_flag=shift_flag,
modulator=modulator,
cross_modulator=cross_modulator,
)

return ImageModelDescriptor(
model,
state_dict,
architecture=self,
purpose="Restoration",
tags=[],
supports_half=False, # Too much weirdness to support this at the moment
supports_bfloat16=True,
scale=1,
input_channels=dd_in,
output_channels=dd_in,
size_requirements=SizeRequirements(multiple_of=128, square=True),
)


__all__ = ["UformerArch", "Uformer"]
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[AdaCode]:
output_channels=in_channel,
size_requirements=SizeRequirements(multiple_of=multiple_of),
)


__all__ = ["AdaCodeArch", "AdaCode"]
Loading

0 comments on commit d8c2256

Please sign in to comment.