-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2f71a10
commit 1fa4dcd
Showing
11 changed files
with
663 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
94 changes: 94 additions & 0 deletions
94
libs/spandrel/spandrel/architectures/HVICIDNet/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
from typing_extensions import override | ||
|
||
from spandrel.util import KeyCondition | ||
|
||
from ...__helpers.model_descriptor import ( | ||
Architecture, | ||
ImageModelDescriptor, | ||
ModelTiling, | ||
SizeRequirements, | ||
StateDict, | ||
) | ||
from .arch.cidnet import CIDNet as HVICIDNet | ||
|
||
|
||
class HVICIDNetArch(Architecture[HVICIDNet]): | ||
def __init__(self) -> None: | ||
super().__init__( | ||
id="HVICIDNet", | ||
name="HVI-CIDNet", | ||
detect=KeyCondition.has_all( | ||
"HVE_block0.1.weight", | ||
"HVE_block1.prelu.weight", | ||
"HVE_block1.down.0.weight", | ||
"HVE_block3.down.0.weight", | ||
"HVD_block3.prelu.weight", | ||
"HVD_block3.up_scale.0.weight", | ||
"HVD_block3.up.weight", | ||
"HVD_block1.up.weight", | ||
"HVD_block0.1.weight", | ||
"IE_block0.1.weight", | ||
"IE_block1.prelu.weight", | ||
"IE_block1.down.0.weight", | ||
"ID_block1.up.weight", | ||
"ID_block0.1.weight", | ||
"HV_LCA1.gdfn.project_in.weight", | ||
"HV_LCA1.gdfn.dwconv.weight", | ||
"HV_LCA1.gdfn.dwconv1.weight", | ||
"HV_LCA1.gdfn.dwconv2.weight", | ||
"HV_LCA1.gdfn.project_out.weight", | ||
"HV_LCA1.norm.weight", | ||
"HV_LCA1.ffn.temperature", | ||
"HV_LCA1.ffn.q.weight", | ||
"HV_LCA1.ffn.q_dwconv.weight", | ||
"HV_LCA1.ffn.project_out.weight", | ||
"HV_LCA2.gdfn.project_in.weight", | ||
"HV_LCA6.gdfn.project_in.weight", | ||
"I_LCA1.gdfn.project_in.weight", | ||
"I_LCA6.ffn.project_out.weight", | ||
"trans.density_k", | ||
), | ||
) | ||
|
||
@override | ||
def load(self, state_dict: StateDict) -> ImageModelDescriptor[HVICIDNet]: | ||
channels = [36, 36, 72, 144] | ||
heads = [1, 2, 4, 8] | ||
norm = False | ||
|
||
channels = [ | ||
state_dict["HVE_block1.down.0.weight"].shape[1], | ||
state_dict["HVE_block1.down.0.weight"].shape[0], | ||
state_dict["HVE_block2.down.0.weight"].shape[0], | ||
state_dict["HVE_block3.down.0.weight"].shape[0], | ||
] | ||
|
||
heads = [ | ||
1, # unused | ||
state_dict["HV_LCA1.ffn.temperature"].shape[0], | ||
state_dict["HV_LCA2.ffn.temperature"].shape[0], | ||
state_dict["HV_LCA3.ffn.temperature"].shape[0], | ||
] | ||
|
||
norm = "HVE_block1.norm.weight" in state_dict | ||
|
||
model = HVICIDNet( | ||
channels=channels, | ||
heads=heads, | ||
norm=norm, | ||
) | ||
|
||
return ImageModelDescriptor( | ||
model, | ||
state_dict, | ||
architecture=self, | ||
purpose="Restoration", | ||
tags=[], | ||
supports_half=False, # TODO: verify | ||
supports_bfloat16=True, | ||
scale=1, | ||
input_channels=3, # hard-coded | ||
output_channels=3, # hard-coded | ||
size_requirements=SizeRequirements(multiple_of=8), | ||
tiling=ModelTiling.DISCOURAGED, | ||
) |
135 changes: 135 additions & 0 deletions
135
libs/spandrel/spandrel/architectures/HVICIDNet/arch/HVI_transform.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
from math import pi | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class RGB_HVI(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.density_k = torch.nn.Parameter( | ||
torch.full([1], 0.2) | ||
) # k is reciprocal to the paper mentioned | ||
self.gated = False | ||
self.gated2 = False | ||
self.alpha = 1.0 | ||
self.this_k = 0 | ||
|
||
def HVIT(self, img): | ||
eps = 1e-8 | ||
device = img.device | ||
dtypes = img.dtype | ||
hue = ( | ||
torch.Tensor(img.shape[0], img.shape[2], img.shape[3]).to(device).to(dtypes) | ||
) | ||
value = img.max(1)[0].to(dtypes) | ||
img_min = img.min(1)[0].to(dtypes) | ||
hue[img[:, 2] == value] = ( | ||
4.0 | ||
+ ((img[:, 0] - img[:, 1]) / (value - img_min + eps))[img[:, 2] == value] | ||
) | ||
hue[img[:, 1] == value] = ( | ||
2.0 | ||
+ ((img[:, 2] - img[:, 0]) / (value - img_min + eps))[img[:, 1] == value] | ||
) | ||
hue[img[:, 0] == value] = ( | ||
0.0 | ||
+ ((img[:, 1] - img[:, 2]) / (value - img_min + eps))[img[:, 0] == value] | ||
) % 6 | ||
|
||
hue[img.min(1)[0] == value] = 0.0 | ||
hue = hue / 6.0 | ||
|
||
saturation = (value - img_min) / (value + eps) | ||
saturation[value == 0] = 0 | ||
|
||
hue = hue.unsqueeze(1) | ||
saturation = saturation.unsqueeze(1) | ||
value = value.unsqueeze(1) | ||
|
||
k = self.density_k | ||
self.this_k = k.item() | ||
|
||
color_sensitive = ((value * 0.5 * pi).sin() + eps).pow(k) | ||
cx = (2.0 * pi * hue).cos() | ||
cy = (2.0 * pi * hue).sin() | ||
X = color_sensitive * saturation * cx | ||
Y = color_sensitive * saturation * cy | ||
Z = value | ||
xyz = torch.cat([X, Y, Z], dim=1) | ||
return xyz | ||
|
||
def PHVIT(self, img): | ||
eps = 1e-8 | ||
H, V, I = img[:, 0, :, :], img[:, 1, :, :], img[:, 2, :, :] # noqa: E741 | ||
|
||
# clip | ||
H = torch.clamp(H, -1, 1) | ||
V = torch.clamp(V, -1, 1) | ||
I = torch.clamp(I, 0, 1) # noqa: E741 | ||
|
||
v = I | ||
k = self.this_k | ||
color_sensitive = ((v * 0.5 * pi).sin() + eps).pow(k) | ||
H = (H) / (color_sensitive + eps) | ||
V = (V) / (color_sensitive + eps) | ||
H = torch.clamp(H, -1, 1) | ||
V = torch.clamp(V, -1, 1) | ||
h = torch.atan2(V, H) / (2 * pi) | ||
h = h % 1 | ||
s = torch.sqrt(H**2 + V**2) | ||
|
||
if self.gated: | ||
s = s * 1.3 | ||
|
||
s = torch.clamp(s, 0, 1) | ||
v = torch.clamp(v, 0, 1) | ||
|
||
r = torch.zeros_like(h) | ||
g = torch.zeros_like(h) | ||
b = torch.zeros_like(h) | ||
|
||
hi = torch.floor(h * 6.0) | ||
f = h * 6.0 - hi | ||
p = v * (1.0 - s) | ||
q = v * (1.0 - (f * s)) | ||
t = v * (1.0 - ((1.0 - f) * s)) | ||
|
||
hi0 = hi == 0 | ||
hi1 = hi == 1 | ||
hi2 = hi == 2 | ||
hi3 = hi == 3 | ||
hi4 = hi == 4 | ||
hi5 = hi == 5 | ||
|
||
r[hi0] = v[hi0] | ||
g[hi0] = t[hi0] | ||
b[hi0] = p[hi0] | ||
|
||
r[hi1] = q[hi1] | ||
g[hi1] = v[hi1] | ||
b[hi1] = p[hi1] | ||
|
||
r[hi2] = p[hi2] | ||
g[hi2] = v[hi2] | ||
b[hi2] = t[hi2] | ||
|
||
r[hi3] = p[hi3] | ||
g[hi3] = q[hi3] | ||
b[hi3] = v[hi3] | ||
|
||
r[hi4] = t[hi4] | ||
g[hi4] = p[hi4] | ||
b[hi4] = v[hi4] | ||
|
||
r[hi5] = v[hi5] | ||
g[hi5] = p[hi5] | ||
b[hi5] = q[hi5] | ||
|
||
r = r.unsqueeze(1) | ||
g = g.unsqueeze(1) | ||
b = b.unsqueeze(1) | ||
rgb = torch.cat([r, g, b], dim=1) | ||
if self.gated2: | ||
rgb = rgb * self.alpha | ||
return rgb |
133 changes: 133 additions & 0 deletions
133
libs/spandrel/spandrel/architectures/HVICIDNet/arch/LCA.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import torch | ||
import torch.nn as nn | ||
from einops import rearrange | ||
|
||
from .transformer_utils import LayerNorm | ||
|
||
|
||
# Cross Attention Block | ||
class CAB(nn.Module): | ||
def __init__(self, dim, num_heads, bias): | ||
super().__init__() | ||
self.num_heads = num_heads | ||
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) | ||
|
||
self.q = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) | ||
self.q_dwconv = nn.Conv2d( | ||
dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias | ||
) | ||
self.kv = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=bias) | ||
self.kv_dwconv = nn.Conv2d( | ||
dim * 2, | ||
dim * 2, | ||
kernel_size=3, | ||
stride=1, | ||
padding=1, | ||
groups=dim * 2, | ||
bias=bias, | ||
) | ||
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) | ||
|
||
def forward(self, x, y): | ||
_, _, h, w = x.shape | ||
|
||
q = self.q_dwconv(self.q(x)) | ||
kv = self.kv_dwconv(self.kv(y)) | ||
k, v = kv.chunk(2, dim=1) | ||
|
||
q = rearrange(q, "b (head c) h w -> b head c (h w)", head=self.num_heads) | ||
k = rearrange(k, "b (head c) h w -> b head c (h w)", head=self.num_heads) | ||
v = rearrange(v, "b (head c) h w -> b head c (h w)", head=self.num_heads) | ||
|
||
q = torch.nn.functional.normalize(q, dim=-1) | ||
k = torch.nn.functional.normalize(k, dim=-1) | ||
|
||
attn = (q @ k.transpose(-2, -1)) * self.temperature | ||
attn = attn.softmax(dim=-1) | ||
|
||
out = attn @ v | ||
|
||
out = rearrange( | ||
out, "b head c (h w) -> b (head c) h w", head=self.num_heads, h=h, w=w | ||
) | ||
|
||
out = self.project_out(out) | ||
return out | ||
|
||
|
||
# Intensity Enhancement Layer | ||
class IEL(nn.Module): | ||
def __init__(self, dim, ffn_expansion_factor=2.66, bias=False): | ||
super().__init__() | ||
|
||
hidden_features = int(dim * ffn_expansion_factor) | ||
|
||
self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias) | ||
|
||
self.dwconv = nn.Conv2d( | ||
hidden_features * 2, | ||
hidden_features * 2, | ||
kernel_size=3, | ||
stride=1, | ||
padding=1, | ||
groups=hidden_features * 2, | ||
bias=bias, | ||
) | ||
self.dwconv1 = nn.Conv2d( | ||
hidden_features, | ||
hidden_features, | ||
kernel_size=3, | ||
stride=1, | ||
padding=1, | ||
groups=hidden_features, | ||
bias=bias, | ||
) | ||
self.dwconv2 = nn.Conv2d( | ||
hidden_features, | ||
hidden_features, | ||
kernel_size=3, | ||
stride=1, | ||
padding=1, | ||
groups=hidden_features, | ||
bias=bias, | ||
) | ||
|
||
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) | ||
|
||
self.Tanh = nn.Tanh() | ||
|
||
def forward(self, x): | ||
x = self.project_in(x) | ||
x1, x2 = self.dwconv(x).chunk(2, dim=1) | ||
x1 = self.Tanh(self.dwconv1(x1)) + x1 | ||
x2 = self.Tanh(self.dwconv2(x2)) + x2 | ||
x = x1 * x2 | ||
x = self.project_out(x) | ||
return x | ||
|
||
|
||
# Lightweight Cross Attention | ||
class HV_LCA(nn.Module): | ||
def __init__(self, dim, num_heads, bias=False): | ||
super().__init__() | ||
self.gdfn = IEL(dim) # IEL and CDL have same structure | ||
self.norm = LayerNorm(dim) | ||
self.ffn = CAB(dim, num_heads, bias) | ||
|
||
def forward(self, x, y): | ||
x = x + self.ffn(self.norm(x), self.norm(y)) | ||
x = self.gdfn(self.norm(x)) | ||
return x | ||
|
||
|
||
class I_LCA(nn.Module): | ||
def __init__(self, dim, num_heads, bias=False): | ||
super().__init__() | ||
self.norm = LayerNorm(dim) | ||
self.gdfn = IEL(dim) | ||
self.ffn = CAB(dim, num_heads, bias=bias) | ||
|
||
def forward(self, x, y): | ||
x = x + self.ffn(self.norm(x), self.norm(y)) | ||
x = x + self.gdfn(self.norm(x)) | ||
return x |
Oops, something went wrong.