-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoctave convolution.py
81 lines (71 loc) · 4.28 KB
/
octave convolution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import torch.nn as nn
import math
class OctaveConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, alpha_in=0.5, alpha_out=0.5, stride=1, padding=0, dilation=1,
groups=1, bias=False):
super(OctaveConv, self).__init__()
self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
assert stride == 1 or stride == 2, "Stride should be 1 or 2."
self.stride = stride
self.is_dw = groups == in_channels
assert 0 <= alpha_in <= 1 and 0 <= alpha_out <= 1, "Alphas should be in the interval from 0 to 1."
self.alpha_in, self.alpha_out = alpha_in, alpha_out
self.conv_l2l = None if alpha_in == 0 or alpha_out == 0 else \
nn.Conv2d(int(alpha_in * in_channels), int(alpha_out * out_channels),
kernel_size, 1, padding, dilation, math.ceil(alpha_in * groups), bias)
self.conv_l2h = None if alpha_in == 0 or alpha_out == 1 or self.is_dw else \
nn.Conv2d(int(alpha_in * in_channels), out_channels - int(alpha_out * out_channels),
kernel_size, 1, padding, dilation, groups, bias)
self.conv_h2l = None if alpha_in == 1 or alpha_out == 0 or self.is_dw else \
nn.Conv2d(in_channels - int(alpha_in * in_channels), int(alpha_out * out_channels),
kernel_size, 1, padding, dilation, groups, bias)
self.conv_h2h = None if alpha_in == 1 or alpha_out == 1 else \
nn.Conv2d(in_channels - int(alpha_in * in_channels), out_channels - int(alpha_out * out_channels),
kernel_size, 1, padding, dilation, math.ceil(groups - alpha_in * groups), bias)
def forward(self, x):
x_h, x_l = x if type(x) is tuple else (x, None)
x_h = self.downsample(x_h) if self.stride == 2 else x_h
x_h2h = self.conv_h2h(x_h)
x_h2l = self.conv_h2l(self.downsample(x_h)) if self.alpha_out > 0 and not self.is_dw else None
if x_l is not None:
x_l2l = self.downsample(x_l) if self.stride == 2 else x_l
x_l2l = self.conv_l2l(x_l2l) if self.alpha_out > 0 else None
if self.is_dw:
return x_h2h, x_l2l
else:
x_l2h = self.conv_l2h(x_l)
x_l2h = self.upsample(x_l2h) if self.stride == 1 else x_l2h
x_h = x_l2h + x_h2h
x_l = x_h2l + x_l2l if x_h2l is not None and x_l2l is not None else None
return x_h, x_l
else:
return x_h2h, x_h2l
class Conv_BN(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, alpha_in=0.5, alpha_out=0.5, stride=1, padding=0, dilation=1,
groups=1, bias=False, norm_layer=nn.BatchNorm2d):
super(Conv_BN, self).__init__()
self.conv = OctaveConv(in_channels, out_channels, kernel_size, alpha_in, alpha_out, stride, padding, dilation,
groups, bias)
self.bn_h = None if alpha_out == 1 else norm_layer(int(out_channels * (1 - alpha_out)))
self.bn_l = None if alpha_out == 0 else norm_layer(int(out_channels * alpha_out))
def forward(self, x):
x_h, x_l = self.conv(x)
x_h = self.bn_h(x_h)
x_l = self.bn_l(x_l) if x_l is not None else None
return x_h, x_l
class Conv_BN_ACT(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, alpha_in=0.5, alpha_out=0.5, stride=1, padding=0, dilation=1,
groups=1, bias=False, norm_layer=nn.BatchNorm2d, activation_layer=nn.ReLU):
super(Conv_BN_ACT, self).__init__()
self.conv = OctaveConv(in_channels, out_channels, kernel_size, alpha_in, alpha_out, stride, padding, dilation,
groups, bias)
self.bn_h = None if alpha_out == 1 else norm_layer(int(out_channels * (1 - alpha_out)))
self.bn_l = None if alpha_out == 0 else norm_layer(int(out_channels * alpha_out))
self.act = activation_layer(inplace=True)
def forward(self, x):
x_h, x_l = self.conv(x)
x_h = self.act(self.bn_h(x_h))
x_l = self.act(self.bn_l(x_l)) if x_l is not None else None
return x_h, x_l