Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for RealPLKSR DySample #293

Merged
merged 7 commits into from
Jul 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions libs/spandrel/spandrel/architectures/PLKSR/__arch/RealPLKSR.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch import nn
from torch.nn.init import trunc_normal_

from spandrel.architectures.__arch_helpers.dysample import DySample
from spandrel.util import store_hyperparameters


Expand Down Expand Up @@ -112,14 +113,23 @@ def __init__(
use_ea: bool = True,
norm_groups: int = 4,
dropout: float = 0,
dysample: bool = False,
):
super().__init__()

# Perhaps some day in the future we can make these user-customizable,
# but for now I just want to leave them hardcoded and focus on dysample detection
in_ch: int = 3
out_ch: int = 3

self.dysample = dysample
self.upscaling_factor = upscaling_factor

if not self.training:
dropout = 0

self.feats = nn.Sequential(
*[nn.Conv2d(3, dim, 3, 1, 1)]
*[nn.Conv2d(in_ch, dim, 3, 1, 1)]
+ [
PLKBlock(dim, kernel_size, split_ratio, norm_groups, use_ea)
for _ in range(n_blocks)
Expand All @@ -134,8 +144,20 @@ def __init__(
torch.repeat_interleave, repeats=upscaling_factor**2, dim=1
)

self.to_img = nn.PixelShuffle(upscaling_factor)
if dysample and upscaling_factor != 1:
groups = out_ch if 3 * upscaling_factor**2 < 4 else 4
self.to_img = DySample(
in_ch * upscaling_factor**2,
out_ch,
upscaling_factor,
groups=groups,
end_convolution=upscaling_factor != 1,
)
RunDevelopment marked this conversation as resolved.
Show resolved Hide resolved
else:
self.to_img = nn.PixelShuffle(upscaling_factor)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.feats(x) + self.repeat_op(x)
return self.to_img(x)
if not self.dysample or (self.dysample and self.upscaling_factor != 1):
x = self.to_img(x)
return x
5 changes: 5 additions & 0 deletions libs/spandrel/spandrel/architectures/PLKSR/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_PLKSR]:
kernel_size = state_dict["feats.1.lk.conv.weight"].shape[2]
split_ratio = state_dict["feats.1.lk.conv.weight"].shape[0] / dim

use_dysample = "to_img.init_pos" in state_dict
if use_dysample:
more_tags.append("DySample")

model = RealPLKSR(
dim=dim,
upscaling_factor=scale,
Expand All @@ -127,6 +131,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_PLKSR]:
split_ratio=split_ratio,
use_ea=use_ea,
norm_groups=4, # un-detectable
dysample=use_dysample,
)
else:
raise ValueError("Unknown model type")
Expand Down
93 changes: 93 additions & 0 deletions libs/spandrel/spandrel/architectures/__arch_helpers/dysample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class DySample(nn.Module):
"""Adapted from 'Learning to Upsample by Learning to Sample':
https://arxiv.org/abs/2308.15085
https://github.com/tiny-smart/dysample
"""

def __init__(
self,
in_channels: int,
out_ch: int,
scale: int = 2,
groups: int = 4,
end_convolution: bool = True,
):
super().__init__()

try:
assert in_channels >= groups and in_channels % groups == 0
except: # noqa: E722
msg = "Incorrect in_channels and groups values."
raise ValueError(msg) # noqa: B904

out_channels = 2 * groups * scale**2
self.scale = scale
self.groups = groups
self.end_convolution = end_convolution
if end_convolution:
self.end_conv = nn.Conv2d(in_channels, out_ch, kernel_size=1)

self.offset = nn.Conv2d(in_channels, out_channels, 1)
self.scope = nn.Conv2d(in_channels, out_channels, 1, bias=False)
if self.training:
nn.init.trunc_normal_(self.offset.weight, std=0.02)
nn.init.constant_(self.scope.weight, val=0)

self.register_buffer("init_pos", self._init_pos())

def _init_pos(self):
h = torch.arange((-self.scale + 1) / 2, (self.scale - 1) / 2 + 1) / self.scale
return (
torch.stack(torch.meshgrid([h, h], indexing="ij"))
.transpose(1, 2)
.repeat(1, self.groups, 1)
.reshape(1, -1, 1, 1)
)

def forward(self, x):
offset = self.offset(x) * self.scope(x).sigmoid() * 0.5 + self.init_pos
B, _, H, W = offset.shape
offset = offset.view(B, 2, -1, H, W)
coords_h = torch.arange(H) + 0.5
coords_w = torch.arange(W) + 0.5

coords = (
torch.stack(torch.meshgrid([coords_w, coords_h], indexing="ij"))
.transpose(1, 2)
.unsqueeze(1)
.unsqueeze(0)
.type(x.dtype)
.to(x.device, non_blocking=True)
)
normalizer = torch.tensor(
[W, H],
dtype=x.dtype,
device=x.device,
pin_memory=False, # pin_memory was originally True
).view(1, 2, 1, 1, 1)
coords = 2 * (coords + offset) / normalizer - 1

coords = (
F.pixel_shuffle(coords.reshape(B, -1, H, W), self.scale)
.view(B, 2, -1, self.scale * H, self.scale * W)
.permute(0, 2, 3, 4, 1)
.contiguous()
.flatten(0, 1)
)
output = F.grid_sample(
x.reshape(B * self.groups, -1, H, W),
coords,
mode="bilinear",
align_corners=False,
padding_mode="border",
).view(B, -1, self.scale * H, self.scale * W)

if self.end_convolution:
output = self.end_conv(output)

return output
23 changes: 23 additions & 0 deletions tests/__snapshots__/test_PLKSR.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,26 @@
tiling=<ModelTiling.SUPPORTED: 1>,
)
# ---
# name: test_RealPLKSR_DySample
ImageModelDescriptor(
architecture=PLKSRArch(
id='PLKSR',
name='PLKSR',
),
input_channels=3,
output_channels=3,
purpose='SR',
scale=4,
size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=False,
tags=list([
'64dim',
'28nb',
'17ks',
'Real',
'DySample',
]),
tiling=<ModelTiling.SUPPORTED: 1>,
)
# ---
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 16 additions & 0 deletions tests/test_PLKSR.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def test_load():
lambda: RealPLKSR(split_ratio=0.5),
lambda: RealPLKSR(split_ratio=0.75),
lambda: RealPLKSR(use_ea=False),
lambda: RealPLKSR(dysample=True),
)


Expand Down Expand Up @@ -156,3 +157,18 @@ def test_RealPLKSR_2x(snapshot):
model,
[TestImage.SR_16, TestImage.SR_32, TestImage.SR_64],
)


def test_RealPLKSR_DySample(snapshot):
file = ModelFile.from_url(
"https://github.com/Phhofm/models/releases/download/4xHFA2k_ludvae_realplksr_dysample/4xHFA2k_ludvae_realplksr_dysample.pth",
name="4xHFA2k_ludvae_realplksr_dysample.pth",
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, RealPLKSR)
assert_image_inference(
file,
model,
[TestImage.SR_16, TestImage.SR_32, TestImage.SR_64],
)
Loading