-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathfilter.py
80 lines (57 loc) · 2.34 KB
/
filter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable
def diff_x(input, r):
assert input.dim() == 4
left = input[:, :, r:2 * r + 1]
middle = input[:, :, 2 * r + 1: ] - input[:, :, :-2 * r - 1]
right = input[:, :, -1: ] - input[:, :, -2 * r - 1: -r - 1]
output = torch.cat([left, middle, right], dim=2)
return output
def diff_y(input, r):
assert input.dim() == 4
left = input[:, :, :, r:2 * r + 1]
middle = input[:, :, :, 2 * r + 1: ] - input[:, :, :, :-2 * r - 1]
right = input[:, :, :, -1: ] - input[:, :, :, -2 * r - 1: -r - 1]
output = torch.cat([left, middle, right], dim=3)
return output
class BoxFilter(nn.Module):
def __init__(self, r):
super(BoxFilter, self).__init__()
self.r = r
def forward(self, x):
assert x.dim() == 4
return diff_y(diff_x(x.cumsum(dim=2), self.r).cumsum(dim=3), self.r)
class FastGuidedFilter(nn.Module):
def __init__(self, r, eps=1e-8):
super(FastGuidedFilter, self).__init__()
self.r = r
self.eps = eps
self.boxfilter = BoxFilter(r)
def forward(self, lr_x, lr_y, hr_x):
n_lrx, c_lrx, h_lrx, w_lrx = lr_x.size()
n_lry, c_lry, h_lry, w_lry = lr_y.size()
n_hrx, c_hrx, h_hrx, w_hrx = hr_x.size()
assert n_lrx == n_lry and n_lry == n_hrx
assert c_lrx == c_hrx and (c_lrx == 1 or c_lrx == c_lry)
assert h_lrx == h_lry and w_lrx == w_lry
assert h_lrx > 2*self.r+1 and w_lrx > 2*self.r+1
## N
N = self.boxfilter(Variable(lr_x.data.new().resize_((1, 1, h_lrx, w_lrx)).fill_(1.0)))
## mean_x
mean_x = self.boxfilter(lr_x) / N
## mean_y
mean_y = self.boxfilter(lr_y) / N
## cov_xy
cov_xy = self.boxfilter(lr_x * lr_y) / N - mean_x * mean_y
## var_x
var_x = self.boxfilter(lr_x * lr_x) / N - mean_x * mean_x
## A
A = cov_xy / (var_x + self.eps)
## b
b = mean_y - A * mean_x
## mean_A; mean_b
mean_A = F.interpolate(A, (h_hrx, w_hrx), mode='bilinear', align_corners=True)
mean_b = F.interpolate(b, (h_hrx, w_hrx), mode='bilinear', align_corners=True)
return mean_A*hr_x+mean_b