This repository has been archived by the owner on Sep 24, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 124
/
model.py
89 lines (74 loc) · 3.57 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch.nn as nn
import modules
class RecurrentAttention(nn.Module):
"""A Recurrent Model of Visual Attention (RAM) [1].
RAM is a recurrent neural network that processes
inputs sequentially, attending to different locations
within the image one at a time, and incrementally
combining information from these fixations to build
up a dynamic internal representation of the image.
References:
[1]: Minh et. al., https://arxiv.org/abs/1406.6247
"""
def __init__(
self, g, k, s, c, h_g, h_l, std, hidden_size, num_classes,
):
"""Constructor.
Args:
g: size of the square patches in the glimpses extracted by the retina.
k: number of patches to extract per glimpse.
s: scaling factor that controls the size of successive patches.
c: number of channels in each image.
h_g: hidden layer size of the fc layer for `phi`.
h_l: hidden layer size of the fc layer for `l`.
std: standard deviation of the Gaussian policy.
hidden_size: hidden size of the rnn.
num_classes: number of classes in the dataset.
num_glimpses: number of glimpses to take per image,
i.e. number of BPTT steps.
"""
super().__init__()
self.std = std
self.sensor = modules.GlimpseNetwork(h_g, h_l, g, k, s, c)
self.rnn = modules.CoreNetwork(hidden_size, hidden_size)
self.locator = modules.LocationNetwork(hidden_size, 2, std)
self.classifier = modules.ActionNetwork(hidden_size, num_classes)
self.baseliner = modules.BaselineNetwork(hidden_size, 1)
def forward(self, x, l_t_prev, h_t_prev, last=False):
"""Run RAM for one timestep on a minibatch of images.
Args:
x: a 4D Tensor of shape (B, H, W, C). The minibatch
of images.
l_t_prev: a 2D tensor of shape (B, 2). The location vector
containing the glimpse coordinates [x, y] for the previous
timestep `t-1`.
h_t_prev: a 2D tensor of shape (B, hidden_size). The hidden
state vector for the previous timestep `t-1`.
last: a bool indicating whether this is the last timestep.
If True, the action network returns an output probability
vector over the classes and the baseline `b_t` for the
current timestep `t`. Else, the core network returns the
hidden state vector for the next timestep `t+1` and the
location vector for the next timestep `t+1`.
Returns:
h_t: a 2D tensor of shape (B, hidden_size). The hidden
state vector for the current timestep `t`.
mu: a 2D tensor of shape (B, 2). The mean that parametrizes
the Gaussian policy.
l_t: a 2D tensor of shape (B, 2). The location vector
containing the glimpse coordinates [x, y] for the
current timestep `t`.
b_t: a vector of length (B,). The baseline for the
current time step `t`.
log_probas: a 2D tensor of shape (B, num_classes). The
output log probability vector over the classes.
log_pi: a vector of length (B,).
"""
g_t = self.sensor(x, l_t_prev)
h_t = self.rnn(g_t, h_t_prev)
log_pi, l_t = self.locator(h_t)
b_t = self.baseliner(h_t).squeeze()
if last:
log_probas = self.classifier(h_t)
return h_t, l_t, b_t, log_probas, log_pi
return h_t, l_t, b_t, log_pi