-
Notifications
You must be signed in to change notification settings - Fork 0
/
SpatialConvolutionLocal.lua
96 lines (83 loc) · 3.98 KB
/
SpatialConvolutionLocal.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
local C = ccn2.C
local SpatialConvolutionLocal, parent = torch.class('ccn2.SpatialConvolutionLocal', 'nn.Module')
function SpatialConvolutionLocal:__init(nInputPlane, nOutputPlane, iH, kH, dH, padding)
parent.__init(self)
dH = dH or 1 -- stride
padding = padding or 0
if not (nInputPlane >= 1 and (nInputPlane <= 3 or math.fmod(nInputPlane, 4) == 0)) then
error('Assertion failed: [(nInputPlane >= 1 and (nInputPlane <= 3 or math.fmod(nInputPlane, 4)))]. Number of input channels has to be 1, 2, 3 or a multiple of 4')
end
if math.fmod(nOutputPlane, 16) ~= 0 then
error('Assertion failed: [math.fmod(nOutputPlane, 16) == 0]. Number of output planes has to be a multiple of 16.')
end
self.nInputPlane = nInputPlane
self.nOutputPlane = nOutputPlane
self.kH = kH
self.dH = dH
self.padding = padding
self.oH = math.ceil((self.padding * 2 + iH - self.kH) / self.dH + 1)
local outputSize = self.oH*self.oH
local filterSize = self.kH*self.kH
self.weight = torch.Tensor(outputSize*nInputPlane*filterSize, nOutputPlane)
self.bias = torch.Tensor(outputSize*nOutputPlane)
self.gradWeight = torch.Tensor(outputSize*nInputPlane*filterSize, nOutputPlane)
self.gradBias = torch.Tensor(outputSize*nOutputPlane)
self.gradInput = torch.Tensor()
self.output = torch.Tensor()
self:reset()
end
function SpatialConvolutionLocal:reset(stdv)
if stdv then
stdv = stdv * math.sqrt(3)
else
stdv = 1/math.sqrt(self.kH*self.kH*self.nInputPlane)
end
self.weight:uniform(-stdv, stdv)
self.bias:uniform(-stdv, stdv)
end
function SpatialConvolutionLocal:updateOutput(input)
ccn2.typecheck(input)
ccn2.inputcheck(input)
local nBatch = input:size(4)
local oH = math.ceil((self.padding * 2 + input:size(2) - self.kH) / self.dH + 1)
local inputC = input:view(input:size(1) * input:size(2) * input:size(3),
input:size(4))
-- do convolution
C['localFilterActs'](cutorch.getState(), inputC:cdata(), self.weight:cdata(), self.output:cdata(),
input:size(2), oH, oH, -self.padding, self.dH, self.nInputPlane, 1);
-- add bias
self.output = self.output:view(self.nOutputPlane, oH*oH*nBatch)
C['addBias'](cutorch.getState(), self.output:cdata(), self.bias:cdata());
self.output = self.output:view(self.nOutputPlane, oH, oH, nBatch)
return self.output
end
function SpatialConvolutionLocal:updateGradInput(input, gradOutput)
ccn2.typecheck(input); ccn2.typecheck(gradOutput);
ccn2.inputcheck(input); ccn2.inputcheck(gradOutput);
local oH = gradOutput:size(2)
local iH = input:size(2)
local nBatch = input:size(4)
self.gradInput:resize(self.nInputPlane*iH*iH, nBatch);
local gradOutputC = gradOutput:view(
gradOutput:size(1) * gradOutput:size(2) * gradOutput:size(3), gradOutput:size(4)
)
C['localImgActs'](cutorch.getState(), gradOutputC:cdata(), self.weight:cdata(), self.gradInput:cdata(),
iH, iH, oH, -self.padding, self.dH, self.nInputPlane, 1);
self.gradInput = self.gradInput:view(self.nInputPlane, iH, iH, nBatch)
return self.gradInput
end
function SpatialConvolutionLocal:accGradParameters(input, gradOutput, scale)
scale = scale or 1
ccn2.typecheck(input); ccn2.typecheck(gradOutput);
ccn2.inputcheck(input); ccn2.inputcheck(gradOutput);
local oH = gradOutput:size(2);
local iH = input:size(2)
local nBatch = input:size(4)
local inputC = input:view(input:size(1) * input:size(2) * input:size(3), input:size(4))
local gradOutputC = gradOutput:view(gradOutput:size(1) * gradOutput:size(2) * gradOutput:size(3), gradOutput:size(4))
C['localWeightActsSt'](cutorch.getState(), inputC:cdata(), gradOutputC:cdata(), self.gradWeight:cdata(),
iH, oH, oH, self.kH,
-self.padding, self.dH, self.nInputPlane, 1, 0, scale);
gradOutputC = gradOutput:view(self.nOutputPlane, oH * oH * nBatch)
C['gradBias'](cutorch.getState(), gradOutputC:cdata(), self.gradBias:cdata(), scale);
end