Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for HVI-CIDNet #271

Merged
merged 3 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ Spandrel currently supports a limited amount of network architectures. If the ar
#### Low-light Enhancement

- [RetinexFormer](https://github.com/caiyuanhao1998/Retinexformer) | [Models](https://drive.google.com/drive/folders/1ynK5hfQachzc8y96ZumhkPPDXzHJwaQV?usp=drive_link)
- [HVI-CIDNet](https://github.com/Fediory/HVI-CIDNet) | [Models](https://github.com/Fediory/HVI-CIDNet/#weights-and-results-)

(All architectures marked with a `+` are only part of `spandrel_extra_arches`.)

Expand Down
2 changes: 2 additions & 0 deletions libs/spandrel/spandrel/__helpers/main_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
DnCNN,
DRUNet,
FFTformer,
HVICIDNet,
KBNet,
LaMa,
MixDehazeNet,
Expand Down Expand Up @@ -82,4 +83,5 @@
ArchSupport.from_architecture(ESRGAN.ESRGANArch()),
ArchSupport.from_architecture(PLKSR.PLKSRArch()),
ArchSupport.from_architecture(RetinexFormer.RetinexFormerArch()),
ArchSupport.from_architecture(HVICIDNet.HVICIDNetArch()),
)
94 changes: 94 additions & 0 deletions libs/spandrel/spandrel/architectures/HVICIDNet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from typing_extensions import override

from spandrel.util import KeyCondition

from ...__helpers.model_descriptor import (
Architecture,
ImageModelDescriptor,
ModelTiling,
SizeRequirements,
StateDict,
)
from .arch.cidnet import CIDNet as HVICIDNet


class HVICIDNetArch(Architecture[HVICIDNet]):
def __init__(self) -> None:
super().__init__(
id="HVICIDNet",
name="HVI-CIDNet",
detect=KeyCondition.has_all(
"HVE_block0.1.weight",
"HVE_block1.prelu.weight",
"HVE_block1.down.0.weight",
"HVE_block3.down.0.weight",
"HVD_block3.prelu.weight",
"HVD_block3.up_scale.0.weight",
"HVD_block3.up.weight",
"HVD_block1.up.weight",
"HVD_block0.1.weight",
"IE_block0.1.weight",
"IE_block1.prelu.weight",
"IE_block1.down.0.weight",
"ID_block1.up.weight",
"ID_block0.1.weight",
"HV_LCA1.gdfn.project_in.weight",
"HV_LCA1.gdfn.dwconv.weight",
"HV_LCA1.gdfn.dwconv1.weight",
"HV_LCA1.gdfn.dwconv2.weight",
"HV_LCA1.gdfn.project_out.weight",
"HV_LCA1.norm.weight",
"HV_LCA1.ffn.temperature",
"HV_LCA1.ffn.q.weight",
"HV_LCA1.ffn.q_dwconv.weight",
"HV_LCA1.ffn.project_out.weight",
"HV_LCA2.gdfn.project_in.weight",
"HV_LCA6.gdfn.project_in.weight",
"I_LCA1.gdfn.project_in.weight",
"I_LCA6.ffn.project_out.weight",
"trans.density_k",
),
)

@override
def load(self, state_dict: StateDict) -> ImageModelDescriptor[HVICIDNet]:
channels = [36, 36, 72, 144]
heads = [1, 2, 4, 8]
norm = False

channels = [
state_dict["HVE_block1.down.0.weight"].shape[1],
state_dict["HVE_block1.down.0.weight"].shape[0],
state_dict["HVE_block2.down.0.weight"].shape[0],
state_dict["HVE_block3.down.0.weight"].shape[0],
]

heads = [
1, # unused
state_dict["HV_LCA1.ffn.temperature"].shape[0],
state_dict["HV_LCA2.ffn.temperature"].shape[0],
state_dict["HV_LCA3.ffn.temperature"].shape[0],
]

norm = "HVE_block1.norm.weight" in state_dict

model = HVICIDNet(
channels=channels,
heads=heads,
norm=norm,
)

return ImageModelDescriptor(
model,
state_dict,
architecture=self,
purpose="Restoration",
tags=[],
supports_half=False, # TODO: verify
supports_bfloat16=True,
scale=1,
input_channels=3, # hard-coded
output_channels=3, # hard-coded
size_requirements=SizeRequirements(multiple_of=8),
tiling=ModelTiling.DISCOURAGED,
)
135 changes: 135 additions & 0 deletions libs/spandrel/spandrel/architectures/HVICIDNet/arch/HVI_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from math import pi

import torch
import torch.nn as nn


class RGB_HVI(nn.Module):
def __init__(self):
super().__init__()
self.density_k = torch.nn.Parameter(
torch.full([1], 0.2)
) # k is reciprocal to the paper mentioned
self.gated = False
self.gated2 = False
self.alpha = 1.0
self.this_k = 0

def HVIT(self, img):
eps = 1e-8
device = img.device
dtypes = img.dtype
hue = (
torch.Tensor(img.shape[0], img.shape[2], img.shape[3]).to(device).to(dtypes)
)
value = img.max(1)[0].to(dtypes)
img_min = img.min(1)[0].to(dtypes)
hue[img[:, 2] == value] = (
4.0
+ ((img[:, 0] - img[:, 1]) / (value - img_min + eps))[img[:, 2] == value]
)
hue[img[:, 1] == value] = (
2.0
+ ((img[:, 2] - img[:, 0]) / (value - img_min + eps))[img[:, 1] == value]
)
hue[img[:, 0] == value] = (
0.0
+ ((img[:, 1] - img[:, 2]) / (value - img_min + eps))[img[:, 0] == value]
) % 6

hue[img.min(1)[0] == value] = 0.0
hue = hue / 6.0

saturation = (value - img_min) / (value + eps)
saturation[value == 0] = 0

hue = hue.unsqueeze(1)
saturation = saturation.unsqueeze(1)
value = value.unsqueeze(1)

k = self.density_k
self.this_k = k.item()

color_sensitive = ((value * 0.5 * pi).sin() + eps).pow(k)
cx = (2.0 * pi * hue).cos()
cy = (2.0 * pi * hue).sin()
X = color_sensitive * saturation * cx
Y = color_sensitive * saturation * cy
Z = value
xyz = torch.cat([X, Y, Z], dim=1)
return xyz

def PHVIT(self, img):
eps = 1e-8
H, V, I = img[:, 0, :, :], img[:, 1, :, :], img[:, 2, :, :] # noqa: E741

# clip
H = torch.clamp(H, -1, 1)
V = torch.clamp(V, -1, 1)
I = torch.clamp(I, 0, 1) # noqa: E741

v = I
k = self.this_k
color_sensitive = ((v * 0.5 * pi).sin() + eps).pow(k)
H = (H) / (color_sensitive + eps)
V = (V) / (color_sensitive + eps)
H = torch.clamp(H, -1, 1)
V = torch.clamp(V, -1, 1)
h = torch.atan2(V, H) / (2 * pi)
h = h % 1
s = torch.sqrt(H**2 + V**2)

if self.gated:
s = s * 1.3

s = torch.clamp(s, 0, 1)
v = torch.clamp(v, 0, 1)

r = torch.zeros_like(h)
g = torch.zeros_like(h)
b = torch.zeros_like(h)

hi = torch.floor(h * 6.0)
f = h * 6.0 - hi
p = v * (1.0 - s)
q = v * (1.0 - (f * s))
t = v * (1.0 - ((1.0 - f) * s))

hi0 = hi == 0
hi1 = hi == 1
hi2 = hi == 2
hi3 = hi == 3
hi4 = hi == 4
hi5 = hi == 5

r[hi0] = v[hi0]
g[hi0] = t[hi0]
b[hi0] = p[hi0]

r[hi1] = q[hi1]
g[hi1] = v[hi1]
b[hi1] = p[hi1]

r[hi2] = p[hi2]
g[hi2] = v[hi2]
b[hi2] = t[hi2]

r[hi3] = p[hi3]
g[hi3] = q[hi3]
b[hi3] = v[hi3]

r[hi4] = t[hi4]
g[hi4] = p[hi4]
b[hi4] = v[hi4]

r[hi5] = v[hi5]
g[hi5] = p[hi5]
b[hi5] = q[hi5]

r = r.unsqueeze(1)
g = g.unsqueeze(1)
b = b.unsqueeze(1)
rgb = torch.cat([r, g, b], dim=1)
if self.gated2:
rgb = rgb * self.alpha
return rgb
133 changes: 133 additions & 0 deletions libs/spandrel/spandrel/architectures/HVICIDNet/arch/LCA.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import torch
import torch.nn as nn
from einops import rearrange

from .transformer_utils import LayerNorm


# Cross Attention Block
class CAB(nn.Module):
def __init__(self, dim, num_heads, bias):
super().__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

self.q = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
self.q_dwconv = nn.Conv2d(
dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias
)
self.kv = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=bias)
self.kv_dwconv = nn.Conv2d(
dim * 2,
dim * 2,
kernel_size=3,
stride=1,
padding=1,
groups=dim * 2,
bias=bias,
)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)

def forward(self, x, y):
_, _, h, w = x.shape

q = self.q_dwconv(self.q(x))
kv = self.kv_dwconv(self.kv(y))
k, v = kv.chunk(2, dim=1)

q = rearrange(q, "b (head c) h w -> b head c (h w)", head=self.num_heads)
k = rearrange(k, "b (head c) h w -> b head c (h w)", head=self.num_heads)
v = rearrange(v, "b (head c) h w -> b head c (h w)", head=self.num_heads)

q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)

attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)

out = attn @ v

out = rearrange(
out, "b head c (h w) -> b (head c) h w", head=self.num_heads, h=h, w=w
)

out = self.project_out(out)
return out


# Intensity Enhancement Layer
class IEL(nn.Module):
def __init__(self, dim, ffn_expansion_factor=2.66, bias=False):
super().__init__()

hidden_features = int(dim * ffn_expansion_factor)

self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)

self.dwconv = nn.Conv2d(
hidden_features * 2,
hidden_features * 2,
kernel_size=3,
stride=1,
padding=1,
groups=hidden_features * 2,
bias=bias,
)
self.dwconv1 = nn.Conv2d(
hidden_features,
hidden_features,
kernel_size=3,
stride=1,
padding=1,
groups=hidden_features,
bias=bias,
)
self.dwconv2 = nn.Conv2d(
hidden_features,
hidden_features,
kernel_size=3,
stride=1,
padding=1,
groups=hidden_features,
bias=bias,
)

self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)

self.Tanh = nn.Tanh()

def forward(self, x):
x = self.project_in(x)
x1, x2 = self.dwconv(x).chunk(2, dim=1)
x1 = self.Tanh(self.dwconv1(x1)) + x1
x2 = self.Tanh(self.dwconv2(x2)) + x2
x = x1 * x2
x = self.project_out(x)
return x


# Lightweight Cross Attention
class HV_LCA(nn.Module):
def __init__(self, dim, num_heads, bias=False):
super().__init__()
self.gdfn = IEL(dim) # IEL and CDL have same structure
self.norm = LayerNorm(dim)
self.ffn = CAB(dim, num_heads, bias)

def forward(self, x, y):
x = x + self.ffn(self.norm(x), self.norm(y))
x = self.gdfn(self.norm(x))
return x


class I_LCA(nn.Module):
def __init__(self, dim, num_heads, bias=False):
super().__init__()
self.norm = LayerNorm(dim)
self.gdfn = IEL(dim)
self.ffn = CAB(dim, num_heads, bias=bias)

def forward(self, x, y):
x = x + self.ffn(self.norm(x), self.norm(y))
x = x + self.gdfn(self.norm(x))
return x
Loading
Loading