-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathutil.py
38 lines (33 loc) · 1.29 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
from torch.autograd import Variable
def get_batch(source, *targets, batch_size, seq_len=10, cuda=False, evalu=False):
"""Generate batch from the raw data."""
nbatch = source.size(0) // batch_size
shuffle_mask = torch.randperm(batch_size)
# Trim extra elements doesn't fit well
source = source.narrow(0, 0, nbatch*batch_size)
# Make batch shape
source = source.view(batch_size, -1).t().contiguous()
# Shuffle
source = source[:, shuffle_mask]
if cuda:
source = source.cuda()
targets = list(targets)
for i in range(len(targets)):
targets[i] = targets[i].narrow(0, 0, nbatch*batch_size)
targets[i] = targets[i].view(batch_size, -1).t().contiguous()
targets[i] = targets[i][:, shuffle_mask]
if cuda:
targets[i] = targets[i].cuda()
for i in range(source.size(0) // seq_len):
ys = []
X = Variable(source[i*seq_len:(i+1)*seq_len], volatile=evalu)
for target in targets:
ys.append(Variable(target[i*seq_len:(i+1)*seq_len]))
yield X, ys
def repackage_hidden(h):
"""Wrap hidden in the new Variable to detach it from old history."""
if type(h) == Variable:
return Variable(h.data)
else:
return tuple(repackage_hidden(v) for v in h)