Skip to content

Commit

Permalink
Unify logic for detecting pixelshuffle (#249)
Browse files Browse the repository at this point in the history
  • Loading branch information
RunDevelopment authored May 6, 2024
1 parent fc4bdc0 commit fe9ecc5
Show file tree
Hide file tree
Showing 10 changed files with 55 additions and 79 deletions.
12 changes: 2 additions & 10 deletions libs/spandrel/spandrel/architectures/ATD/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing_extensions import override

from spandrel.util import KeyCondition, get_seq_len
from spandrel.util import KeyCondition, get_pixelshuffle_params, get_seq_len

from ...__helpers.model_descriptor import (
Architecture,
Expand Down Expand Up @@ -114,15 +114,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[ATD]:
upscale = 4
elif "conv_before_upsample.0.weight" in state_dict:
upsampler = "pixelshuffle"
upscale = 1
for i in range(0, 10, 2):
if f"upsample.{i}.weight" not in state_dict:
break
num_feat = state_dict[f"upsample.{i}.weight"].shape[1]

upscale *= math.isqrt(
state_dict[f"upsample.{i}.weight"].shape[0] // num_feat
)
upscale, _ = get_pixelshuffle_params(state_dict, "upsample")
elif "conv_last.weight" in state_dict:
upsampler = ""
upscale = 1
Expand Down
8 changes: 2 additions & 6 deletions libs/spandrel/spandrel/architectures/DAT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing_extensions import override

from spandrel.util import KeyCondition, get_seq_len
from spandrel.util import KeyCondition, get_pixelshuffle_params, get_seq_len

from ...__helpers.model_descriptor import (
Architecture,
Expand Down Expand Up @@ -107,11 +107,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[DAT]:
resi_connection = "1conv" if "conv_after_body.weight" in state_dict else "3conv"

if upsampler == "pixelshuffle":
upscale = 1
for i in range(0, get_seq_len(state_dict, "upsample"), 2):
num_feat = state_dict[f"upsample.{i}.weight"].shape[1]
shape = state_dict[f"upsample.{i}.weight"].shape[0]
upscale *= int(math.sqrt(shape // num_feat))
upscale, num_feat = get_pixelshuffle_params(state_dict, "upsample")
elif upsampler == "pixelshuffledirect":
num_feat = state_dict["upsample.0.weight"].shape[1]
upscale = int(
Expand Down
21 changes: 2 additions & 19 deletions libs/spandrel/spandrel/architectures/DRCT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing_extensions import override

from spandrel.util import KeyCondition, get_seq_len
from spandrel.util import KeyCondition, get_pixelshuffle_params, get_seq_len

from ...__helpers.model_descriptor import (
Architecture,
Expand All @@ -13,23 +13,6 @@
from .arch.drct_arch import DRCT


def _get_upscale_pixelshuffle(
state_dict: StateDict, key_prefix: str = "upsample"
) -> int:
upscale = 1

for i in range(0, 10, 2):
key = f"{key_prefix}.{i}.weight"
if key not in state_dict:
break

shape = state_dict[key].shape
num_feat = shape[1]
upscale *= math.isqrt(shape[0] // num_feat)

return upscale


class DRCTArch(Architecture[DRCT]):
def __init__(self) -> None:
super().__init__(
Expand Down Expand Up @@ -105,7 +88,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[DRCT]:

if "conv_last.weight" in state_dict:
upsampler = "pixelshuffle"
upscale = _get_upscale_pixelshuffle(state_dict, "upsample")
upscale, _ = get_pixelshuffle_params(state_dict, "upsample")
else:
upsampler = ""
upscale = 1
Expand Down
13 changes: 7 additions & 6 deletions libs/spandrel/spandrel/architectures/GRL/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
import torch
from typing_extensions import override

from spandrel.util import KeyCondition, get_scale_and_output_channels, get_seq_len
from spandrel.util import (
KeyCondition,
get_pixelshuffle_params,
get_scale_and_output_channels,
get_seq_len,
)

from ...__helpers.canonicalize import remove_common_prefix
from ...__helpers.model_descriptor import Architecture, ImageModelDescriptor, StateDict
Expand Down Expand Up @@ -50,18 +55,14 @@ def _get_output_params(state_dict: StateDict, in_channels: int):
upsampler: str
upscale: int

num_out_feats = 64 # hard-coded
if (
"conv_before_upsample.0.weight" in state_dict
and "upsample.up.0.weight" in state_dict
):
upsampler = "pixelshuffle"
out_channels = state_dict["conv_last.weight"].shape[0]

upscale = 1
for i in range(0, get_seq_len(state_dict, "upsample.up"), 2):
shape = state_dict[f"upsample.up.{i}.weight"].shape[0]
upscale *= int(math.sqrt(shape // num_out_feats))
upscale, _ = get_pixelshuffle_params(state_dict, "upsample.up")
elif "upsample.up.0.weight" in state_dict:
upsampler = "pixelshuffledirect"
upscale, out_channels = get_scale_and_output_channels(
Expand Down
7 changes: 2 additions & 5 deletions libs/spandrel/spandrel/architectures/HAT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing_extensions import override

from spandrel.util import KeyCondition, get_seq_len
from spandrel.util import KeyCondition, get_pixelshuffle_params, get_seq_len

from ...__helpers.model_descriptor import (
Architecture,
Expand Down Expand Up @@ -109,10 +109,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[HAT]:
embed_dim = state_dict["conv_first.weight"].shape[0]

num_feat = state_dict["conv_last.weight"].shape[1]
upscale = 1
for i in range(0, get_seq_len(state_dict, "upsample"), 2):
shape = state_dict[f"upsample.{i}.weight"].shape[0]
upscale *= int(math.sqrt(shape // num_feat))
upscale, _ = get_pixelshuffle_params(state_dict, "upsample", num_feat)

window_size = int(math.sqrt(state_dict["relative_position_index_SA"].shape[0]))
overlap_ratio = _get_overlap_ratio(
Expand Down
10 changes: 2 additions & 8 deletions libs/spandrel/spandrel/architectures/RGT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from typing_extensions import override

from spandrel.util import KeyCondition, get_seq_len
from spandrel.util import KeyCondition, get_pixelshuffle_params, get_seq_len

from ...__helpers.model_descriptor import (
Architecture,
Expand Down Expand Up @@ -133,13 +133,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[RGT]:
)
break

upscale = 1
for i in range(0, 10, 2):
key = f"upsample.{i}.weight"
if key in state_dict:
shape = state_dict[key].shape
num_feat = shape[1]
upscale *= math.isqrt(shape[0] // num_feat)
upscale, _ = get_pixelshuffle_params(state_dict, "upsample")

split_size = _get_split_size(state_dict)

Expand Down
8 changes: 2 additions & 6 deletions libs/spandrel/spandrel/architectures/Swin2SR/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing_extensions import override

from spandrel.util import KeyCondition, get_seq_len
from spandrel.util import KeyCondition, get_pixelshuffle_params, get_seq_len

from ...__helpers.model_descriptor import (
Architecture,
Expand Down Expand Up @@ -102,11 +102,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[Swin2SR]:
math.sqrt(state_dict["upsample.0.weight"].shape[0] // in_chans)
)
else:
num_feat = 64 # hard-coded constant
upscale = 1
for i in range(0, get_seq_len(state_dict, "upsample"), 2):
shape = state_dict[f"upsample.{i}.weight"].shape[0]
upscale *= int(math.sqrt(shape // num_feat))
upscale, _ = get_pixelshuffle_params(state_dict, "upsample")

window_size = int(
math.sqrt(
Expand Down
12 changes: 2 additions & 10 deletions libs/spandrel/spandrel/architectures/SwinIR/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch import nn
from typing_extensions import override

from spandrel.util import KeyCondition, get_seq_len
from spandrel.util import KeyCondition, get_pixelshuffle_params, get_seq_len

from ...__helpers.model_descriptor import (
Architecture,
Expand Down Expand Up @@ -84,15 +84,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SwinIR]:
for _upsample_key in upsample_keys:
upscale *= 2
elif upsampler == "pixelshuffle":
upsample_keys = [
x
for x in state_dict
if "upsample" in x and "conv" not in x and "bias" not in x
]
for upsample_key in upsample_keys:
shape = state_dict[upsample_key].shape[0]
upscale *= math.sqrt(shape // num_feat)
upscale = int(upscale)
upscale, num_feat = get_pixelshuffle_params(state_dict, "upsample")
elif upsampler == "pixelshuffledirect":
upscale = int(
math.sqrt(state_dict["upsample.0.bias"].shape[0] // num_out_ch)
Expand Down
34 changes: 32 additions & 2 deletions libs/spandrel/spandrel/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,35 @@ def is_square(n: int) -> bool:
)


def get_pixelshuffle_params(
state_dict: Mapping[str, object],
upsample_key: str = "upsample",
default_nf: int = 64,
) -> tuple[int, int]:
"""
This will detect the upscale factor and number of features of a pixelshuffle module in the state dict.
A pixelshuffle module is a sequence of alternating up convolutions and pixelshuffle.
The class of this module is commonyl called `Upsample`.
Examples of such modules can be found in most SISR architectures, such as SwinIR, HAT, RGT, and many more.
"""
upscale = 1
num_feat = default_nf

for i in range(0, 10, 2):
key = f"{upsample_key}.{i}.weight"
if key not in state_dict:
break

tensor = state_dict[key]
# we'll assume that the state dict contains tensors
shape: tuple[int, ...] = tensor.shape # type: ignore
num_feat = shape[1]
upscale *= math.isqrt(shape[0] // num_feat)

return upscale, num_feat


def store_hyperparameters(*, extra_parameters: Mapping[str, object] = {}):
"""
Stores the hyperparameters of a class in a `hyperparameters` attribute.
Expand Down Expand Up @@ -170,9 +199,10 @@ def new_init(self: C, **kwargs):


__all__ = [
"KeyCondition",
"get_first_seq_index",
"get_seq_len",
"get_pixelshuffle_params",
"get_scale_and_output_channels",
"get_seq_len",
"KeyCondition",
"store_hyperparameters",
]
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
SizeRequirements,
StateDict,
)
from spandrel.util import KeyCondition, get_seq_len
from spandrel.util import KeyCondition, get_pixelshuffle_params, get_seq_len

from .arch.SRFormer import SRFormer

Expand Down Expand Up @@ -76,12 +76,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SRFormer]:
upscale = 4 # only supported scale
elif "conv_before_upsample.0.weight" in state_dict:
upsampler = "pixelshuffle"

num_feat = 64 # hard-coded constant
upscale = 1
for i in range(0, get_seq_len(state_dict, "upsample"), 2):
shape = state_dict[f"upsample.{i}.weight"].shape[0]
upscale *= int(math.sqrt(shape // num_feat))
upscale, _ = get_pixelshuffle_params(state_dict, "upsample")
elif "upsample.0.weight" in state_dict:
upsampler = "pixelshuffledirect"
upscale = int(
Expand Down

0 comments on commit fe9ecc5

Please sign in to comment.