-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathRNN_Model.py
75 lines (59 loc) · 2.71 KB
/
RNN_Model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch.nn as nn
from Attention_Module import TemporalAttention, FrequentialAttention, Dual_Attention_1
class BasicRNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super(BasicRNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.rnn = nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size, num_layers=self.num_layers,
batch_first=True)
def forward(self, x):
output, hidden = self.rnn(x)
return output, hidden
class AttentionRNN(nn.Module):
def __init__(self, input_size, hidden_size, seq_len, num_layers, attention_type):
super(AttentionRNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.seq_len = seq_len
self.num_layers = num_layers
self.attention_type = attention_type
self.AtRNN = nn.ModuleDict()
for i in range(self.num_layers):
if i == 0:
self.AtRNN.update({'RNN_%d' % (i+1): nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size,
num_layers=1, batch_first=True)})
else:
self.AtRNN.update({'RNN_%d' % (i+1): nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size,
num_layers=1, batch_first=True)})
if self.attention_type == 'TA':
self.TA = TemporalAttention()
elif self.attention_type == 'FA':
self.FA = FrequentialAttention(sequential_length=self.seq_len)
elif self.attention_type == 'DA1':
self.DA = Dual_Attention_1()
elif self.attention_type == 'DA2':
self.TA = TemporalAttention()
self.FA = FrequentialAttention(sequential_length=self.seq_len)
self.sig = nn.Sigmoid()
def forward(self, x):
for i in range(self.num_layers):
if i == 0:
output, hidden = self.AtRNN['RNN_%d' % (i+1)](x)
else:
output, hidden = self.AtRNN['RNN_%d' % (i+1)](output)
if self.attention_type == 'TA':
ta = self.TA(output)
output = output + self.sig(ta)
elif self.attention_type == 'FA':
fa = self.TA(output)
output = output + self.sig(fa)
elif self.attention_type == 'DA1':
da = self.DA(output)
output = output + self.sig(da)
elif self.attention_type == 'DA2':
ta = self.TA(output)
fa = self.FA(output)
output = output + self.sig(ta + fa)
return output, hidden