Skip to content

Commit

Permalink
Merge branch 'dev' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
yyk-wew committed Apr 20, 2021
2 parents ff58514 + d1c2f03 commit f3b4da2
Showing 1 changed file with 60 additions and 22 deletions.
82 changes: 60 additions & 22 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,33 @@

# Filter Module
class Filter(nn.Module):
def __init__(self, size, band_start, band_end, use_learnable=True):
def __init__(self, size, band_start, band_end, use_learnable=True, norm=False):
super(Filter, self).__init__()
self.use_learnable = use_learnable

self.base = nn.Parameter(torch.tensor(generate_filter(band_start, band_end, size)), requires_grad=False)
if self.use_learnable:
self.learnable = nn.Parameter(torch.randn(size, size), requires_grad=True)
self.learnable.data.normal_(0., 0.1)
# Todo
# self.learnable = nn.Parameter(torch.rand((size, size)) * 0.2 - 0.1, requires_grad=True)

self.norm = norm
if norm:
self.ft_num = nn.Parameter(torch.sum(torch.tensor(generate_filter(band_start, band_end, size))), requires_grad=False)


def forward(self, x):
if self.use_learnable:
filt = self.base + norm_sigma(self.learnable)
else:
filt = self.base
return x * filt

if self.norm:
y = x * filt / self.ft_num
else:
y = x * filt
return y


# FAD Module
Expand Down Expand Up @@ -71,7 +83,7 @@ def __init__(self, size, window_size, M):
self.unfold = nn.Unfold(kernel_size=(window_size, window_size), stride=2, padding=4)

# init filters
self.filters = nn.ModuleList([Filter(window_size, window_size / M * i, window_size / M * (i+1)) for i in range(M)])
self.filters = nn.ModuleList([Filter(window_size, window_size * 2. / M * i, window_size * 2. / M * (i+1), norm=True) for i in range(M)])

def forward(self, x):
# turn RGB into Gray
Expand All @@ -96,10 +108,14 @@ def forward(self, x):
# M kernels filtering
y_list = []
for i in range(self._M):
y = self.filters[i](x_dct) # [N, L, C, S, S]
y = torch.abs(y)
y = torch.sum(y, dim=[2,3,4]) # [N, L]
# y = self.filters[i](x_dct) # [N, L, C, S, S]
# y = torch.abs(y)
# y = torch.sum(y, dim=[2,3,4]) # [N, L]
# y = torch.log10(y + 1e-15)
y = torch.abs(x_dct)
y = torch.log10(y + 1e-15)
y = self.filters[i](y)
y = torch.sum(y, dim=[2,3,4])
y = y.reshape(N, size_after, size_after).unsqueeze(dim=1) # [N, 1, 149, 149]
y_list.append(y)
out = torch.cat(y_list, dim=1) # [N, M, 149, 149]
Expand Down Expand Up @@ -298,21 +314,43 @@ def fea_8_12(self, x):
return x

class MixBlock(nn.Module):
def __init__(self, c_in):
def __init__(self, c_in, width, height):
super(MixBlock, self).__init__()
c_in = c_in * 2
self.conv1 = nn.Conv2d(c_in, c_in, (1,1))
self.conv2 = nn.Conv2d(c_in, 2, (3,3), 1, 1)
self.bn = nn.BatchNorm2d(c_in)
self.FAD_query = nn.Conv2d(c_in, c_in, (1,1))
self.LFS_query = nn.Conv2d(c_in, c_in, (1,1))

self.FAD_key = nn.Conv2d(c_in, c_in, (1,1))
self.LFS_key = nn.Conv2d(c_in, c_in, (1,1))

self.softmax = nn.Softmax(dim=-1)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()

def forward(self, x1, x2):
x = torch.cat([x1, x2], dim=1)
x = self.conv1(x)
x = self.bn(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmoid(x)
y1, y2 = torch.split(x, 1, dim=1)
return y1, y2

self.FAD_gamma = nn.Parameter(torch.zeros(1))
self.LFS_gamma = nn.Parameter(torch.zeros(1))

self.FAD_conv = nn.Conv2d(c_in, c_in, (1,1), groups=c_in)
self.FAD_bn = nn.BatchNorm2d(c_in)
self.LFS_conv = nn.Conv2d(c_in, c_in, (1,1), groups=c_in)
self.LFS_bn = nn.BatchNorm2d(c_in)

def forward(self, x_FAD, x_LFS):
B, C, W, H = x_FAD.size()
assert W == H

q_FAD = self.FAD_query(x_FAD).view(-1, W, H) # [BC, W, H]
q_LFS = self.LFS_query(x_LFS).view(-1, W, H)
M_query = torch.cat([q_FAD, q_LFS], dim=2) # [BC, W, 2H]

k_FAD = self.FAD_key(x_FAD).view(-1, W, H).transpose(1, 2) # [BC, H, W]
k_LFS = self.LFS_key(x_LFS).view(-1, W, H).transpose(1, 2)
M_key = torch.cat([k_FAD, k_LFS], dim=1) # [BC, 2H, W]

energy = torch.bmm(M_query, M_key) #[BC, W, W]
attention = self.softmax(energy).view(B, C, W, W)

att_LFS = x_LFS * attention * (torch.sigmoid(self.LFS_gamma) * 2.0 - 1.0)
y_FAD = x_FAD + self.FAD_bn(self.FAD_conv(att_LFS))

att_FAD = x_FAD * attention * (torch.sigmoid(self.FAD_gamma) * 2.0 - 1.0)
y_LFS = x_LFS + self.LFS_bn(self.LFS_conv(att_FAD))
return y_FAD, y_LFS

0 comments on commit f3b4da2

Please sign in to comment.