-
Notifications
You must be signed in to change notification settings - Fork 105
/
rbpn.py
94 lines (80 loc) · 3.38 KB
/
rbpn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import torch.nn as nn
import torch.optim as optim
from base_networks import *
from torchvision.transforms import *
import torch.nn.functional as F
from dbpns import Net as DBPNS
class Net(nn.Module):
def __init__(self, num_channels, base_filter, feat, num_stages, n_resblock, nFrames, scale_factor):
super(Net, self).__init__()
#base_filter=256
#feat=64
self.nFrames = nFrames
if scale_factor == 2:
kernel = 6
stride = 2
padding = 2
elif scale_factor == 4:
kernel = 8
stride = 4
padding = 2
elif scale_factor == 8:
kernel = 12
stride = 8
padding = 2
#Initial Feature Extraction
self.feat0 = ConvBlock(num_channels, base_filter, 3, 1, 1, activation='prelu', norm=None)
self.feat1 = ConvBlock(8, base_filter, 3, 1, 1, activation='prelu', norm=None)
###DBPNS
self.DBPN = DBPNS(base_filter, feat, num_stages, scale_factor)
#Res-Block1
modules_body1 = [
ResnetBlock(base_filter, kernel_size=3, stride=1, padding=1, bias=True, activation='prelu', norm=None) \
for _ in range(n_resblock)]
modules_body1.append(DeconvBlock(base_filter, feat, kernel, stride, padding, activation='prelu', norm=None))
self.res_feat1 = nn.Sequential(*modules_body1)
#Res-Block2
modules_body2 = [
ResnetBlock(feat, kernel_size=3, stride=1, padding=1, bias=True, activation='prelu', norm=None) \
for _ in range(n_resblock)]
modules_body2.append(ConvBlock(feat, feat, 3, 1, 1, activation='prelu', norm=None))
self.res_feat2 = nn.Sequential(*modules_body2)
#Res-Block3
modules_body3 = [
ResnetBlock(feat, kernel_size=3, stride=1, padding=1, bias=True, activation='prelu', norm=None) \
for _ in range(n_resblock)]
modules_body3.append(ConvBlock(feat, base_filter, kernel, stride, padding, activation='prelu', norm=None))
self.res_feat3 = nn.Sequential(*modules_body3)
#Reconstruction
self.output = ConvBlock((nFrames-1)*feat, num_channels, 3, 1, 1, activation=None, norm=None)
for m in self.modules():
classname = m.__class__.__name__
if classname.find('Conv2d') != -1:
torch.nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('ConvTranspose2d') != -1:
torch.nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, neigbor, flow):
### initial feature extraction
feat_input = self.feat0(x)
feat_frame=[]
for j in range(len(neigbor)):
feat_frame.append(self.feat1(torch.cat((x, neigbor[j], flow[j]),1)))
####Projection
Ht = []
for j in range(len(neigbor)):
h0 = self.DBPN(feat_input)
h1 = self.res_feat1(feat_frame[j])
e = h0-h1
e = self.res_feat2(e)
h = h0+e
Ht.append(h)
feat_input = self.res_feat3(h)
####Reconstruction
out = torch.cat(Ht,1)
output = self.output(out)
return output