-
Notifications
You must be signed in to change notification settings - Fork 5
/
deform_cnn.py
124 lines (96 loc) · 3.46 KB
/
deform_cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from __future__ import absolute_import, division
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_deform_conv.deform_conv_layer import ConvOffset2D
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
# conv11
self.conv11 = nn.Conv2d(1, 32, 3, padding=1)
self.bn11 = nn.BatchNorm2d(32)
# conv12
self.conv12 = nn.Conv2d(32, 64, 3, padding=1, stride=2)
self.bn12 = nn.BatchNorm2d(64)
# conv21
self.conv21 = nn.Conv2d(64, 128, 3, padding= 1)
self.bn21 = nn.BatchNorm2d(128)
# conv22
self.conv22 = nn.Conv2d(128, 128, 3, padding=1, stride=2)
self.bn22 = nn.BatchNorm2d(128)
# out
self.fc = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.conv11(x))
x = self.bn11(x)
x = F.relu(self.conv12(x))
x = self.bn12(x)
x = F.relu(self.conv21(x))
x = self.bn21(x)
x = F.relu(self.conv22(x))
x = self.bn22(x)
x = F.avg_pool2d(x, kernel_size=[x.size(2), x.size(3)])
x = self.fc(x.view(x.size()[:2]))#
x = F.softmax(x)
return x
class DeformConvNet(nn.Module):
def __init__(self):
super(DeformConvNet, self).__init__()
# conv11
self.conv11 = nn.Conv2d(1, 32, 3, padding=1)
self.bn11 = nn.BatchNorm2d(32)
# conv12
self.offset12 = ConvOffset2D(32)
self.conv12 = nn.Conv2d(32, 64, 3, padding=1, stride=2)
self.bn12 = nn.BatchNorm2d(64)
# conv21
self.offset21 = ConvOffset2D(64)
self.conv21 = nn.Conv2d(64, 128, 3, padding= 1)
self.bn21 = nn.BatchNorm2d(128)
# conv22
self.offset22 = ConvOffset2D(128)
self.conv22 = nn.Conv2d(128, 128, 3, padding=1, stride=2)
self.bn22 = nn.BatchNorm2d(128)
# out
self.fc = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.conv11(x))
x = self.bn11(x)
x = self.offset12(x)
x = F.relu(self.conv12(x))
x = self.bn12(x)
x = self.offset21(x)
x = F.relu(self.conv21(x))
x = self.bn21(x)
x = self.offset22(x)
x = F.relu(self.conv22(x))
x = self.bn22(x)
x = F.avg_pool2d(x, kernel_size=[x.size(2), x.size(3)])
x = self.fc(x.view(x.size()[:2]))
x = F.softmax(x)
return x
def freeze(self, module_classes):
'''
freeze modules for finetuning
'''
for k, m in self._modules.items():
if any([type(m) == mc for mc in module_classes]):
for param in m.parameters():
param.requires_grad = False
def unfreeze(self, module_classes):
'''
unfreeze modules
'''
for k, m in self._modules.items():
if any([isinstance(m, mc) for mc in module_classes]):
for param in m.parameters():
param.requires_grad = True
def parameters(self):
return filter(lambda p: p.requires_grad, super(DeformConvNet, self).parameters())
def get_cnn():
return ConvNet()
def get_deform_cnn(trainable=True, freeze_filter=[nn.Conv2d, nn.Linear]):
model = DeformConvNet()
if not trainable:
model.freeze(freeze_filter)
return model