-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathiterative_normalization.py
336 lines (303 loc) · 15.1 KB
/
iterative_normalization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
"""
Reference: Iterative Normalization: Beyond Standardization towards Efficient Whitening, CVPR 2019
- Paper:
- Code: https://github.com/huangleiBuaa/IterNorm
"""
import torch.nn
import torch.nn.functional as F
from torch.nn import Parameter
# import extension._bcnn as bcnn
__all__ = ['iterative_normalization', 'IterNorm']
class iterative_normalization_py(torch.autograd.Function):
@staticmethod
def forward(ctx, *args, **kwargs):
X, running_mean, running_wmat, nc, ctx.T, eps, momentum, training = args
# change NxCxHxW to (G x D) x(NxHxW), i.e., g*d*m
ctx.g = X.size(1) // nc
x = X.transpose(0, 1).contiguous().view(ctx.g, nc, -1)
_, d, m = x.size()
saved = []
if training:
# calculate centered activation by subtracted mini-batch mean
mean = x.mean(-1, keepdim=True)
xc = x - mean
saved.append(xc)
# calculate covariance matrix
P = [None] * (ctx.T + 1)
P[0] = torch.eye(d).to(X).expand(ctx.g, d, d)
Sigma = torch.baddbmm(eps, P[0], 1. / m, xc, xc.transpose(1, 2))
# reciprocal of trace of Sigma: shape [g, 1, 1]
rTr = (Sigma * P[0]).sum((1, 2), keepdim=True).reciprocal_()
saved.append(rTr)
Sigma_N = Sigma * rTr
saved.append(Sigma_N)
for k in range(ctx.T):
P[k + 1] = torch.baddbmm(1.5, P[k], -0.5, torch.matrix_power(P[k], 3), Sigma_N)
saved.extend(P)
wm = P[ctx.T].mul_(rTr.sqrt()) # whiten matrix: the matrix inverse of Sigma, i.e., Sigma^{-1/2}
running_mean.copy_(momentum * mean + (1. - momentum) * running_mean)
running_wmat.copy_(momentum * wm + (1. - momentum) * running_wmat)
else:
xc = x - running_mean
wm = running_wmat
xn = wm.matmul(xc)
Xn = xn.view(X.size(1), X.size(0), *X.size()[2:]).transpose(0, 1).contiguous()
ctx.save_for_backward(*saved)
return Xn
@staticmethod
def backward(ctx, *grad_outputs):
grad, = grad_outputs
saved = ctx.saved_variables
xc = saved[0] # centered input
rTr = saved[1] # trace of Sigma
sn = saved[2].transpose(-2, -1) # normalized Sigma
P = saved[3:] # middle result matrix,
g, d, m = xc.size()
g_ = grad.transpose(0, 1).contiguous().view_as(xc)
g_wm = g_.matmul(xc.transpose(-2, -1))
g_P = g_wm * rTr.sqrt()
wm = P[ctx.T]
g_sn = 0
for k in range(ctx.T, 1, -1):
P[k - 1].transpose_(-2, -1)
P2 = P[k - 1].matmul(P[k - 1])
g_sn += P2.matmul(P[k - 1]).matmul(g_P)
g_tmp = g_P.matmul(sn)
g_P.baddbmm_(1.5, -0.5, g_tmp, P2)
g_P.baddbmm_(1, -0.5, P2, g_tmp)
g_P.baddbmm_(1, -0.5, P[k - 1].matmul(g_tmp), P[k - 1])
g_sn += g_P
# g_sn = g_sn * rTr.sqrt()
g_tr = ((-sn.matmul(g_sn) + g_wm.transpose(-2, -1).matmul(wm)) * P[0]).sum((1, 2), keepdim=True) * P[0]
g_sigma = (g_sn + g_sn.transpose(-2, -1) + 2. * g_tr) * (-0.5 / m * rTr)
# g_sigma = g_sigma + g_sigma.transpose(-2, -1)
g_x = torch.baddbmm(wm.matmul(g_ - g_.mean(-1, keepdim=True)), g_sigma, xc)
grad_input = g_x.view(grad.size(1), grad.size(0), *grad.size()[2:]).transpose(0, 1).contiguous()
return grad_input, None, None, None, None, None, None, None
class IterNorm(torch.nn.Module):
def __init__(self, num_features, num_groups=1, num_channels=None, T=5, dim=4, eps=1e-5, momentum=0.1, affine=True,
*args, **kwargs):
super(IterNorm, self).__init__()
# assert dim == 4, 'IterNorm is not support 2D'
self.T = T
self.eps = eps
self.momentum = momentum
self.num_features = num_features
self.affine = affine
self.dim = dim
if num_channels is None:
num_channels = (num_features - 1) // num_groups + 1
num_groups = num_features // num_channels
while num_features % num_channels != 0:
num_channels //= 2
num_groups = num_features // num_channels
assert num_groups > 0 and num_features % num_groups == 0, "num features={}, num groups={}".format(num_features,
num_groups)
self.num_groups = num_groups
self.num_channels = num_channels
shape = [1] * dim
shape[1] = self.num_features
if self.affine:
self.weight = Parameter(torch.Tensor(*shape))
self.bias = Parameter(torch.Tensor(*shape))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.register_buffer('running_mean', torch.zeros(num_groups, num_channels, 1))
# running whiten matrix
self.register_buffer('running_wm', torch.eye(num_channels).expand(num_groups, num_channels, num_channels))
self.reset_parameters()
def reset_parameters(self):
# self.reset_running_stats()
if self.affine:
torch.nn.init.ones_(self.weight)
torch.nn.init.zeros_(self.bias)
def forward(self, X: torch.Tensor):
X_hat = iterative_normalization_py.apply(X, self.running_mean, self.running_wm, self.num_channels, self.T,
self.eps, self.momentum, self.training)
# affine
if self.affine:
return X_hat * self.weight + self.bias
else:
return X_hat
def extra_repr(self):
return '{num_features}, num_channels={num_channels}, T={T}, eps={eps}, ' \
'momentum={momentum}, affine={affine}'.format(**self.__dict__)
class IterNormRotation(torch.nn.Module):
"""
Concept Whitening Module
The Whitening part is adapted from IterNorm. The core of CW module is learning
an extra rotation matrix R that align target concepts with the output feature
maps.
Because the concept activation is calculated based on a feature map, which
is a matrix, there are multiple ways to calculate the activation, denoted
by activation_mode.
"""
def __init__(self, num_features, num_groups = 1, num_channels=None, T=10, dim=4, eps=1e-5, momentum=0.05, affine=False,
mode = -1, activation_mode='pool_max', *args, **kwargs):
super(IterNormRotation, self).__init__()
assert dim == 4, 'IterNormRotation does not support 2D'
self.T = T
self.eps = eps
self.momentum = momentum
self.num_features = num_features
self.affine = affine
self.dim = dim
self.mode = mode
self.activation_mode = activation_mode
assert num_groups == 1, 'Please keep num_groups = 1. Current version does not support group whitening.'
if num_channels is None:
num_channels = (num_features - 1) // num_groups + 1
num_groups = num_features // num_channels
while num_features % num_channels != 0:
num_channels //= 2
num_groups = num_features // num_channels
assert num_groups > 0 and num_features % num_groups == 0, "num features={}, num groups={}".format(num_features,
num_groups)
self.num_groups = num_groups
self.num_channels = num_channels
shape = [1] * dim
shape[1] = self.num_features
#if self.affine:
self.weight = Parameter(torch.Tensor(*shape))
self.bias = Parameter(torch.Tensor(*shape))
#else:
# self.register_parameter('weight', None)
# self.register_parameter('bias', None)
#pooling and unpooling used in gradient computation
self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=3, return_indices=True)
self.maxunpool = torch.nn.MaxUnpool2d(kernel_size=3, stride=3)
# running mean
self.register_buffer('running_mean', torch.zeros(num_groups, num_channels, 1))
# running whiten matrix
self.register_buffer('running_wm', torch.eye(num_channels).expand(num_groups, num_channels, num_channels))
# running rotation matrix
self.register_buffer('running_rot', torch.eye(num_channels).expand(num_groups, num_channels, num_channels))
# sum Gradient, need to take average later
self.register_buffer('sum_G', torch.zeros(num_groups, num_channels, num_channels))
# counter, number of gradient for each concept
self.register_buffer("counter", torch.ones(num_channels)*0.001)
self.reset_parameters()
def reset_parameters(self):
if self.affine:
torch.nn.init.ones_(self.weight)
torch.nn.init.zeros_(self.bias)
def update_rotation_matrix(self):
"""
Update the rotation matrix R using the accumulated gradient G.
The update uses Cayley transform to make sure R is always orthonormal.
"""
size_R = self.running_rot.size()
with torch.no_grad():
G = self.sum_G/self.counter.reshape(-1,1)
R = self.running_rot.clone()
for i in range(2):
tau = 1000 # learning rate in Cayley transform
alpha = 0
beta = 100000000
c1 = 1e-4
c2 = 0.9
A = torch.einsum('gin,gjn->gij', G, R) - torch.einsum('gin,gjn->gij', R, G) # GR^T - RG^T
I = torch.eye(size_R[2]).expand(*size_R).cuda()
dF_0 = -0.5 * (A ** 2).sum()
# binary search for appropriate learning rate
cnt = 0
while True:
Q = torch.bmm((I + 0.5 * tau * A).inverse(), I - 0.5 * tau * A)
Y_tau = torch.bmm(Q, R)
F_X = (G[:,:,:] * R[:,:,:]).sum()
F_Y_tau = (G[:,:,:] * Y_tau[:,:,:]).sum()
dF_tau = -torch.bmm(torch.einsum('gni,gnj->gij', G, (I + 0.5 * tau * A).inverse()), torch.bmm(A,0.5*(R+Y_tau)))[0,:,:].trace()
if F_Y_tau > F_X + c1*tau*dF_0 + 1e-18:
beta = tau
tau = (beta+alpha)/2
elif dF_tau + 1e-18 < c2*dF_0:
alpha = tau
tau = (beta+alpha)/2
else:
break
cnt += 1
if cnt > 500:
print("--------------------update fail------------------------")
print(F_Y_tau, F_X + c1*tau*dF_0)
print(dF_tau, c2*dF_0)
print("-------------------------------------------------------")
break
print(tau, F_Y_tau)
Q = torch.bmm((I + 0.5 * tau * A).inverse(), I - 0.5 * tau * A)
R = torch.bmm(Q, R)
self.running_rot = R
self.counter = (torch.ones(size_R[-1]) * 0.001).cuda()
def forward(self, X: torch.Tensor):
X_hat = iterative_normalization_py.apply(X, self.running_mean, self.running_wm, self.num_channels, self.T,
self.eps, self.momentum, self.training)
# print(X_hat.shape, self.running_rot.shape)
# nchw
size_X = X_hat.size()
size_R = self.running_rot.size()
# ngchw
X_hat = X_hat.view(size_X[0], size_R[0], size_R[2], *size_X[2:])
# updating the gradient matrix, using the concept dataset
# the gradient is accumulated with momentum to stablize the training
with torch.no_grad():
# When 0<=mode, the jth column of gradient matrix is accumulated
if self.mode>=0:
if self.activation_mode=='mean':
self.sum_G[:,self.mode,:] = self.momentum * -X_hat.mean((0,3,4)) + (1. - self.momentum) * self.sum_G[:,self.mode,:]
self.counter[self.mode] += 1
elif self.activation_mode=='max':
X_test = torch.einsum('bgchw,gdc->bgdhw', X_hat, self.running_rot)
max_values = torch.max(torch.max(X_test, 3, keepdim=True)[0], 4, keepdim=True)[0]
max_bool = max_values==X_test
grad = -((X_hat * max_bool.to(X_hat)).sum((3,4))/max_bool.to(X_hat).sum((3,4))).mean((0,))
self.sum_G[:,self.mode,:] = self.momentum * grad + (1. - self.momentum) * self.sum_G[:,self.mode,:]
self.counter[self.mode] += 1
elif self.activation_mode=='pos_mean':
X_test = torch.einsum('bgchw,gdc->bgdhw', X_hat, self.running_rot)
pos_bool = X_test > 0
grad = -((X_hat * pos_bool.to(X_hat)).sum((3,4))/(pos_bool.to(X_hat).sum((3,4))+0.0001)).mean((0,))
self.sum_G[:,self.mode,:] = self.momentum * grad + (1. - self.momentum) * self.sum_G[:,self.mode,:]
self.counter[self.mode] += 1
elif self.activation_mode=='pool_max':
X_test = torch.einsum('bgchw,gdc->bgdhw', X_hat, self.running_rot)
X_test_nchw = X_test.view(size_X)
maxpool_value, maxpool_indices = self.maxpool(X_test_nchw)
X_test_unpool = self.maxunpool(maxpool_value, maxpool_indices, output_size = size_X).view(size_X[0], size_R[0], size_R[2], *size_X[2:])
maxpool_bool = X_test == X_test_unpool
grad = -((X_hat * maxpool_bool.to(X_hat)).sum((3,4))/(maxpool_bool.to(X_hat).sum((3,4)))).mean((0,))
self.sum_G[:,self.mode,:] = self.momentum * grad + (1. - self.momentum) * self.sum_G[:,self.mode,:]
self.counter[self.mode] += 1
# # When mode > k, this is not included in the paper
# elif self.mode>=0 and self.mode>=self.k:
# X_dot = torch.einsum('ngchw,gdc->ngdhw', X_hat, self.running_rot)
# X_dot = (X_dot == torch.max(X_dot, dim=2,keepdim=True)[0]).float().cuda()
# X_dot_unity = torch.clamp(torch.ceil(X_dot), 0.0, 1.0)
# X_G = torch.einsum('ngchw,ngdhw->gdchw', X_hat, X_dot_unity).mean((3,4))
# X_G[:,:self.k,:] = 0.0
# self.sum_G[:,:,:] += -X_G/size_X[0]
# self.counter[self.k:] += 1
# We set mode = -1 when we don't need to update G. For example, when we train for main objective
X_hat = torch.einsum('bgchw,gdc->bgdhw', X_hat, self.running_rot)
X_hat = X_hat.view(*size_X)
if self.affine:
return X_hat * self.weight + self.bias
else:
return X_hat
def extra_repr(self):
return '{num_features}, num_channels={num_channels}, T={T}, eps={eps}, ' \
'momentum={momentum}, affine={affine}'.format(**self.__dict__)
if __name__ == '__main__':
ItN = IterNormRotation(64, num_groups=2, T=10, momentum=1, affine=False)
print(ItN)
ItN.train()
x = torch.randn(16, 64, 14, 14)
x.requires_grad_()
y = ItN(x)
z = y.transpose(0, 1).contiguous().view(x.size(1), -1)
print(z.matmul(z.t()) / z.size(1))
y.sum().backward()
print('x grad', x.grad.size())
ItN.eval()
y = ItN(x)
z = y.transpose(0, 1).contiguous().view(x.size(1), -1)
print(z.matmul(z.t()) / z.size(1))