-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.py
executable file
·130 lines (110 loc) · 4.97 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
## Only Generator Network here
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import initializers
from common.sn.sn_convolution_2d import SNConvolution2D
class DenseLayer(chainer.Chain):
def __init__(self, in_channel, growth_rate, bn_size, if_gen=False):
super(DenseLayer, self).__init__()
with self.init_scope():
initialW = initializers.HeNormal()
if if_gen:
self.conv1 = L.Convolution2D(None, bn_size * growth_rate, 1, 1, 0, initialW=initialW)
self.conv2 = L.Convolution2D(None, growth_rate, 3, 1, 1, initialW=initialW)
else:
self.conv1 = SNConvolution2D(None, bn_size * growth_rate, 1, 1, 0, initialW=initialW)
self.conv2 = SNConvolution2D(None, growth_rate, 3, 1, 1, initialW=initialW)
self.in_ch = in_channel
self.if_gen = if_gen
def __call__(self, x):
if self.if_gen:
h = F.relu(self.conv1(x))
h = self.conv2(h)
else:
h = F.leaky_relu(self.conv1(x))
h = self.conv2(h)
return F.concat((x, h))
class DenseBlock(chainer.Chain):
def __init__(self, n_layers, in_channel, bn_size, growth_rate, if_gen=False):
super(DenseBlock, self).__init__()
with self.init_scope():
for i in range(n_layers):
tmp_in_channel = in_channel + i * growth_rate
setattr(self, 'denselayer{}'.format(i + 1), DenseLayer(tmp_in_channel, growth_rate, bn_size, if_gen))
self.n_layers = n_layers
self.if_gen = if_gen
def __call__(self, x):
h = x
for i in range(1, self.n_layers + 1):
h = getattr(self, 'denselayer{}'.format(i))(h)
if self.if_gen:
h = h
else:
h = F.leaky_relu(h)
return h
# class Transition(chainer.Chain):
# def __init__(self, in_channel):
# super(Transition, self).__init__()
# with self.init_scope():
# initialW = initializers.HeNormal()
# self.conv1 = SNConvolution2D(in_channel, in_channel, 1, 1, 0, initialW=initialW)
# self.conv2 = SNConvolution2D(in_channel, in_channel, 2, 2, 0, initialW=initialW)
# def __call__(self, x):
# h = F.relu(x)
# h = self.conv1(h)
# h = self.conv2(h)
# return h
class PixelShuffler(chainer.Chain):
def __init__(self, r):
super(PixelShuffler, self).__init__()
self.r = r
def __call__(self, x):
batch_size, in_channel, height, width = x.shape
out_channel = int(in_channel / (self.r * self.r))
assert out_channel * self.r * self.r == in_channel
h = x.reshape((batch_size, self.r, self.r, out_channel, height, width))
h = h.transpose((0, 3, 4, 1, 5, 2))
return h.reshape((batch_size, out_channel, self.r * height, self.r * width))
class UpScale(chainer.Chain):
def __init__(self, r, in_channel=64, out_channel=256):
self.r = r
self.in_channel = in_channel
self.out_channel = out_channel
super(UpScale, self).__init__()
with self.init_scope():
self.conv = L.Convolution2D(self.in_channel, self.out_channel, 3, 1, 1)
self.pixel_shuffler = PixelShuffler(self.r)
def __call__(self, x):
return F.relu(self.pixel_shuffler(self.conv(x)))
class Generator(chainer.Chain):
def __init__(self, n_layers=(4, 4, 4, 4), init_features=64, bn_size=4, growth_rate=12):
super(Generator, self).__init__()
with self.init_scope():
initialW = initializers.HeNormal()
self.in_conv = L.Convolution2D(None, init_features, 9, 1, 4, initialW=initialW)
n_features = init_features
self.block1 = DenseBlock(n_layers[0], n_features, bn_size, growth_rate, if_gen=True)
n_features += n_layers[0] * growth_rate + n_features
self.block2 = DenseBlock(n_layers[1], n_features, bn_size, growth_rate, if_gen=True)
n_features += n_layers[1] * growth_rate + n_features
self.block3 = DenseBlock(n_layers[2], n_features, bn_size, growth_rate, if_gen=True)
n_features += n_layers[2] * growth_rate + n_features
self.block4 = DenseBlock(n_layers[3], n_features, bn_size, growth_rate, if_gen=True)
n_features += n_layers[3] * growth_rate + n_features
self.mid_conv = L.Convolution2D(None, 64, 3, 1, 1, initialW=initialW)
self.up1 = UpScale(2)
# self.up2 = UpScale(2)
self.out_conv = L.Convolution2D(None, 3, 9, 1, 4, initialW=initialW)
def __call__(self, x):
h = first = F.relu(self.in_conv(x))
h = F.concat((self.block1(h), h))
h = F.concat((self.block2(h), h))
h = F.concat((self.block3(h), h))
h = F.concat((self.block4(h), h))
h = self.mid_conv(h)
h = h + first
h = self.up1(h)
# h = self.up2(h)
h = self.out_conv(h)
return h