-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathcondconv.py
111 lines (90 loc) · 4.32 KB
/
condconv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import functools
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.modules.conv import _ConvNd
from torch.nn.modules.utils import _pair
from torch.nn.parameter import Parameter
class _routing(nn.Module):
def __init__(self, in_channels, num_experts, dropout_rate):
super(_routing, self).__init__()
self.dropout = nn.Dropout(dropout_rate)
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels),
nn.LeakyReLU(0.1, True),
nn.Linear(in_channels, num_experts)
)
def forward(self, x):
x = torch.flatten(x)
x = self.dropout(x)
x = self.fc(x)
return F.sigmoid(x)
class DynamicCondConv2D(_ConvNd):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1,
bias=True, padding_mode='zeros', num_experts=3, dropout_rate=0.2,rooting_channels =512):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
super(DynamicCondConv2D, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), groups, bias, padding_mode)
#self._avg_pooling = functools.partial(F.adaptive_avg_pool2d, output_size=(1, 1))
self._routing_fn = _routing(rooting_channels, num_experts, dropout_rate)
self.weight = Parameter(torch.Tensor(
num_experts, out_channels, in_channels // groups, *kernel_size))
self.reset_parameters()
def _conv_forward(self, input, weight):
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
weight, self.bias, self.stride,
_pair(0), self.dilation, self.groups)
return F.conv2d(input, weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
def forward(self, inputs_q):
inputs = inputs_q[0]
kernel_conditions = inputs_q[1]
b, _, _, _ = inputs.size()
res = []
for i,input in enumerate(inputs):
input = input.unsqueeze(0)
routing_weights = self._routing_fn(kernel_conditions[i])
kernels = torch.sum(routing_weights[: ,None, None, None, None] * self.weight, 0)
out = self._conv_forward(input, kernels)
res.append(out)
return torch.cat(res, dim=0)
class CondConv2D(_ConvNd):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1,
bias=True, padding_mode='zeros', num_experts=3, dropout_rate=0.2):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
super(CondConv2D, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), groups, bias, padding_mode)
self._avg_pooling = functools.partial(F.adaptive_avg_pool2d, output_size=(1, 1))
self._routing_fn = _routing(in_channels, num_experts, dropout_rate)
self.weight = Parameter(torch.Tensor(
num_experts, out_channels, in_channels // groups, *kernel_size))
self.reset_parameters()
def _conv_forward(self, input, weight):
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
weight, self.bias, self.stride,
_pair(0), self.dilation, self.groups)
return F.conv2d(input, weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
def forward(self, inputs):
b, _, _, _ = inputs.size()
res = []
for input in inputs:
input = input.unsqueeze(0)
pooled_inputs = self._avg_pooling(input)
routing_weights = self._routing_fn(pooled_inputs)
kernels = torch.sum(routing_weights[: ,None, None, None, None] * self.weight, 0)
out = self._conv_forward(input, kernels)
res.append(out)
return torch.cat(res, dim=0)