Skip to content

Commit

Permalink
Add SPAN norm parameter (#148)
Browse files Browse the repository at this point in the history
  • Loading branch information
RunDevelopment authored Feb 2, 2024
1 parent 5c0354a commit 7a1058d
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 2 deletions.
9 changes: 9 additions & 0 deletions src/spandrel/architectures/SPAN/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import torch

from ...__helpers.model_descriptor import ImageModelDescriptor, StateDict
from ..__arch_helpers.state import get_scale_and_output_channels
from .arch.span import SPAN
Expand All @@ -9,6 +11,7 @@ def load(state_dict: StateDict) -> ImageModelDescriptor[SPAN]:
feature_channels: int = 48
upscale: int = 4
bias = True # unused internally
norm = True
img_range = 255.0 # cannot be deduced from state_dict
rgb_mean = (0.4488, 0.4371, 0.4040) # cannot be deduced from state_dict

Expand All @@ -21,12 +24,18 @@ def load(state_dict: StateDict) -> ImageModelDescriptor[SPAN]:
num_in_ch,
)

# norm
if "no_norm" in state_dict:
norm = False
state_dict["no_norm"] = torch.zeros(1)

model = SPAN(
num_in_ch=num_in_ch,
num_out_ch=num_out_ch,
feature_channels=feature_channels,
upscale=upscale,
bias=bias,
norm=norm,
img_range=img_range,
rgb_mean=rgb_mean,
)
Expand Down
18 changes: 16 additions & 2 deletions src/spandrel/architectures/SPAN/arch/span.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from collections import OrderedDict
from typing import Literal

Expand Down Expand Up @@ -239,6 +241,7 @@ def __init__(
feature_channels=48,
upscale=4,
bias=True,
norm=True,
img_range=255.0,
rgb_mean=(0.4488, 0.4371, 0.4040),
):
Expand All @@ -249,6 +252,12 @@ def __init__(
self.img_range = img_range
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)

self.no_norm: torch.Tensor | None
if not norm:
self.register_buffer("no_norm", torch.zeros(1))
else:
self.no_norm = None

self.conv_1 = Conv3XC(self.in_channels, feature_channels, gain1=2, s=1)
self.block_1 = SPAB(feature_channels, bias=bias)
self.block_2 = SPAB(feature_channels, bias=bias)
Expand All @@ -266,9 +275,14 @@ def __init__(
feature_channels, self.out_channels, upscale_factor=upscale
)

@property
def is_norm(self):
return self.no_norm is None

def forward(self, x):
self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range
if self.is_norm:
self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range

out_feature = self.conv_1(x)

Expand Down
2 changes: 2 additions & 0 deletions tests/test_SPAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ def test_SPAN_load():
lambda: SPAN(num_in_ch=3, num_out_ch=3, upscale=2),
lambda: SPAN(num_in_ch=3, num_out_ch=3, upscale=4),
lambda: SPAN(num_in_ch=3, num_out_ch=3, upscale=8),
lambda: SPAN(num_in_ch=3, num_out_ch=3, norm=False),
condition=lambda a, b: (
a.in_channels == b.in_channels
and a.out_channels == b.out_channels
and a.img_range == b.img_range
and a.is_norm == b.is_norm
),
)

Expand Down

0 comments on commit 7a1058d

Please sign in to comment.