Skip to content

Commit

Permalink
Improved OmniSR detection
Browse files Browse the repository at this point in the history
  • Loading branch information
RunDevelopment committed Nov 21, 2023
1 parent ebaa3cf commit 443439f
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 1,029 deletions.
948 changes: 0 additions & 948 deletions dump.yml

This file was deleted.

34 changes: 2 additions & 32 deletions src/spandrel/architectures/Compact/__init__.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,8 @@
from __future__ import annotations

import math

from ...__helpers.model_descriptor import SRModelDescriptor, StateDict
from ..__arch_helpers.state import get_max_seq_index
from ..__arch_helpers.state import get_max_seq_index, get_scale_and_output_channels
from .arch.SRVGG import SRVGGNetCompact


def _get_scale_and_output_channels(x: int, input_channels: int) -> tuple[int, int]:
# Unfortunately, we do not have enough information to determine both the scale and
# number output channels correctly *in general*. However, we can make some
# assumptions to make it good enough.
#
# What we know:
# - x = scale * scale * output_channels
# - output_channels is likely equal to input_channels
# - output_channels and input_channels is likely 1, 3, or 4
# - scale is likely 1, 2, 4, or 8

def is_square(n: int) -> bool:
return math.sqrt(n) == int(math.sqrt(n))

# just try out a few candidates and see which ones fulfill the requirements
candidates = [input_channels, 3, 4, 1]
for c in candidates:
if x % c == 0 and is_square(x // c):
return int(math.sqrt(x // c)), c

raise AssertionError(
f"Expected output channels to be either 1, 3, or 4."
f" Could not find a pair (scale, out_nc) such that `scale**2 * out_nc = {x}`"
)


def load(state_dict: StateDict) -> SRModelDescriptor[SRVGGNetCompact]:
state = state_dict

Expand All @@ -43,7 +13,7 @@ def load(state_dict: StateDict) -> SRModelDescriptor[SRVGGNetCompact]:
num_conv = (highest_num - 2) // 2

pixelshuffle_shape = state[f"body.{highest_num}.bias"].shape[0]
scale, out_nc = _get_scale_and_output_channels(pixelshuffle_shape, in_nc)
scale, out_nc = get_scale_and_output_channels(pixelshuffle_shape, in_nc)

model = SRVGGNetCompact(
num_in_ch=in_nc,
Expand Down
69 changes: 30 additions & 39 deletions src/spandrel/architectures/OmniSR/__init__.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,57 @@
import math

from ...__helpers.model_descriptor import SizeRequirements, SRModelDescriptor, StateDict
from ..__arch_helpers.state import (
get_scale_and_output_channels,
get_seq_len,
)
from .arch.OmniSR import OmniSR


def load(state_dict: StateDict) -> SRModelDescriptor[OmniSR]:
state = state_dict

block_num = 1 # Fine to assume this for now
ffn_bias = True
num_in_ch = 3
num_out_ch = 3
num_feat = 64
block_num = 1
pe = True
window_size = 8
res_num = 1
up_scale = 4
bias = True

num_feat = state_dict["input.weight"].shape[0] or 64
num_in_ch = state_dict["input.weight"].shape[1] or 3
num_out_ch = num_in_ch # we can just assume this for now. pixelshuffle smh
num_feat = state_dict["input.weight"].shape[0]
num_in_ch = state_dict["input.weight"].shape[1]
bias = "input.bias" in state_dict

pixelshuffle_shape = state_dict["up.0.weight"].shape[0]
up_scale = math.sqrt(pixelshuffle_shape / num_out_ch)
if up_scale - int(up_scale) > 0:
print(
"out_nc is probably different than in_nc, scale calculation might be wrong"
)
up_scale = int(up_scale)
res_num = 0
for key in state_dict.keys():
if "residual_layer" in key:
temp_res_num = int(key.split(".")[1])
if temp_res_num > res_num:
res_num = temp_res_num
res_num = res_num + 1 # zero-indexed
up_scale, num_out_ch = get_scale_and_output_channels(pixelshuffle_shape, num_in_ch)

res_num = res_num
res_num = get_seq_len(state_dict, "residual_layer")
block_num = get_seq_len(state_dict, "residual_layer.0.residual_layer") - 1

if (
rel_pos_bias_key = (
"residual_layer.0.residual_layer.0.layer.2.fn.rel_pos_bias.weight"
in state_dict.keys()
):
rel_pos_bias_weight = state_dict[
"residual_layer.0.residual_layer.0.layer.2.fn.rel_pos_bias.weight"
].shape[0]
)
if rel_pos_bias_key in state_dict:
pe = True
# rel_pos_bias_weight = (2 * window_size - 1) ** 2
rel_pos_bias_weight = state_dict[rel_pos_bias_key].shape[0]
window_size = int((math.sqrt(rel_pos_bias_weight) + 1) / 2)
else:
window_size = 8
pe = False

model = OmniSR(
num_in_ch=num_in_ch,
num_out_ch=num_out_ch,
num_feat=num_feat,
block_num=block_num,
ffn_bias=ffn_bias,
pe=pe,
window_size=window_size,
res_num=res_num,
up_scale=up_scale,
bias=True,
bias=bias,
)

in_nc = num_in_ch
out_nc = num_out_ch
num_feat = num_feat
scale = up_scale

tags = [
f"{num_feat}nf",
f"w{window_size}",
Expand All @@ -69,13 +60,13 @@ def load(state_dict: StateDict) -> SRModelDescriptor[OmniSR]:

return SRModelDescriptor(
model,
state,
state_dict,
architecture="OmniSR",
tags=tags,
supports_half=True, # TODO: Test this
supports_bfloat16=True,
scale=scale,
input_channels=in_nc,
output_channels=out_nc,
scale=up_scale,
input_channels=num_in_ch,
output_channels=num_out_ch,
size_requirements=SizeRequirements(minimum=16),
)
2 changes: 0 additions & 2 deletions src/spandrel/architectures/OmniSR/arch/OSA.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,8 +498,6 @@ class OSA_Block(nn.Module):
def __init__(
self,
channel_num=64,
bias=True,
ffn_bias=True,
window_size=8,
with_pe=False,
dropout=0.0,
Expand Down
3 changes: 0 additions & 3 deletions src/spandrel/architectures/OmniSR/arch/OSAG.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def __init__(
channel_num=64,
bias=True,
block_num=4,
ffn_bias=False,
window_size=0,
pe=False,
):
Expand All @@ -42,8 +41,6 @@ def __init__(
for _ in range(block_num):
temp_res = block_class(
channel_num,
bias,
ffn_bias=ffn_bias,
window_size=window_size,
with_pe=pe,
)
Expand Down
3 changes: 0 additions & 3 deletions src/spandrel/architectures/OmniSR/arch/OmniSR.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
# type: ignore
#############################################################
# File: OmniSR.py
# Created Date: Tuesday April 28th 2022
Expand All @@ -25,7 +24,6 @@ def __init__(
num_out_ch=3,
num_feat=64,
block_num=1,
ffn_bias=True,
pe=True,
window_size=8,
res_num=1,
Expand All @@ -45,7 +43,6 @@ def __init__(
channel_num=num_feat,
bias=bias,
block_num=block_num,
ffn_bias=ffn_bias,
window_size=self.window_size,
pe=pe,
)
Expand Down
60 changes: 60 additions & 0 deletions src/spandrel/architectures/__arch_helpers/state.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from __future__ import annotations

import math
from typing import Any


def get_max_seq_index(state: dict, key_pattern: str, start: int = 0) -> int:
"""
Returns the maximum number `i` such that `key_pattern.format(str(i))` is in `state`.
Expand All @@ -18,3 +24,57 @@ def get_max_seq_index(state: dict, key_pattern: str, start: int = 0) -> int:
if key not in state:
return i - 1
i += 1


def get_seq_len(state: dict[str, Any], seq_key: str) -> int:
"""
Returns the length of a sequence in the state dict.
The length is detected by finding the maximum index `i` such that
`{seq_key}.{i}.{suffix}` is in `state` for some suffix.
Example:
get_seq_len(state, "body") -> 5
"""
prefix = seq_key + "."

keys: set[int] = set()
for k in state.keys():
if k.startswith(prefix):
index = k[len(prefix) :].split(".", maxsplit=1)[0]
keys.add(int(index))

if len(keys) == 0:
return 0
return max(keys) + 1


def get_scale_and_output_channels(x: int, input_channels: int) -> tuple[int, int]:
"""
Returns a scale and number of output channels such that `scale**2 * out_nc = x`.
This is commonly used for pixelshuffel layers.
"""
# Unfortunately, we do not have enough information to determine both the scale and
# number output channels correctly *in general*. However, we can make some
# assumptions to make it good enough.
#
# What we know:
# - x = scale * scale * output_channels
# - output_channels is likely equal to input_channels
# - output_channels and input_channels is likely 1, 3, or 4
# - scale is likely 1, 2, 4, or 8

def is_square(n: int) -> bool:
return math.sqrt(n) == int(math.sqrt(n))

# just try out a few candidates and see which ones fulfill the requirements
candidates = [input_channels, 3, 4, 1]
for c in candidates:
if x % c == 0 and is_square(x // c):
return int(math.sqrt(x // c)), c

raise AssertionError(
f"Expected output channels to be either 1, 3, or 4."
f" Could not find a pair (scale, out_nc) such that `scale**2 * out_nc = {x}`"
)
33 changes: 31 additions & 2 deletions tests/test_OmniSR.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,36 @@
from spandrel import ModelLoader
from spandrel.architectures.OmniSR import OmniSR
from spandrel.architectures.OmniSR import OmniSR, load

from .util import ModelFile, TestImage, assert_image_inference, disallowed_props
from .util import (
ModelFile,
TestImage,
assert_image_inference,
assert_loads_correctly,
disallowed_props,
)


def test_OmniSR_load():
assert_loads_correctly(
load,
lambda: OmniSR(),
lambda: OmniSR(num_in_ch=1, num_out_ch=1),
lambda: OmniSR(num_in_ch=3, num_out_ch=3),
lambda: OmniSR(num_in_ch=4, num_out_ch=4),
lambda: OmniSR(num_in_ch=1, num_out_ch=3),
lambda: OmniSR(num_feat=32),
lambda: OmniSR(block_num=2),
lambda: OmniSR(pe=False),
lambda: OmniSR(bias=False),
lambda: OmniSR(window_size=5),
lambda: OmniSR(res_num=3),
lambda: OmniSR(up_scale=5),
condition=lambda a, b: (
a.res_num == b.res_num
and a.up_scale == b.up_scale
and a.window_size == b.window_size
),
)


def test_OmniSR_community1(snapshot):
Expand Down

0 comments on commit 443439f

Please sign in to comment.