This repository has been archived by the owner on Sep 24, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 124
/
trainer.py
493 lines (400 loc) · 16.5 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
import os
import time
import shutil
import pickle
import torch
import torch.nn.functional as F
from tqdm import tqdm
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tensorboard_logger import configure, log_value
from model import RecurrentAttention
from utils import AverageMeter
class Trainer:
"""A Recurrent Attention Model trainer.
All hyperparameters are provided by the user in the
config file.
"""
def __init__(self, config, data_loader):
"""
Construct a new Trainer instance.
Args:
config: object containing command line arguments.
data_loader: A data iterator.
"""
self.config = config
if config.use_gpu and torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
# glimpse network params
self.patch_size = config.patch_size
self.glimpse_scale = config.glimpse_scale
self.num_patches = config.num_patches
self.loc_hidden = config.loc_hidden
self.glimpse_hidden = config.glimpse_hidden
# core network params
self.num_glimpses = config.num_glimpses
self.hidden_size = config.hidden_size
# reinforce params
self.std = config.std
self.M = config.M
# data params
if config.is_train:
self.train_loader = data_loader[0]
self.valid_loader = data_loader[1]
self.num_train = len(self.train_loader.sampler.indices)
self.num_valid = len(self.valid_loader.sampler.indices)
else:
self.test_loader = data_loader
self.num_test = len(self.test_loader.dataset)
self.num_classes = 10
self.num_channels = 1
# training params
self.epochs = config.epochs
self.start_epoch = 0
self.momentum = config.momentum
self.lr = config.init_lr
# misc params
self.best = config.best
self.ckpt_dir = config.ckpt_dir
self.logs_dir = config.logs_dir
self.best_valid_acc = 0.0
self.counter = 0
self.lr_patience = config.lr_patience
self.train_patience = config.train_patience
self.use_tensorboard = config.use_tensorboard
self.resume = config.resume
self.print_freq = config.print_freq
self.plot_freq = config.plot_freq
self.model_name = "ram_{}_{}x{}_{}".format(
config.num_glimpses,
config.patch_size,
config.patch_size,
config.glimpse_scale,
)
self.plot_dir = "./plots/" + self.model_name + "/"
if not os.path.exists(self.plot_dir):
os.makedirs(self.plot_dir)
# configure tensorboard logging
if self.use_tensorboard:
tensorboard_dir = self.logs_dir + self.model_name
print("[*] Saving tensorboard logs to {}".format(tensorboard_dir))
if not os.path.exists(tensorboard_dir):
os.makedirs(tensorboard_dir)
configure(tensorboard_dir)
# build RAM model
self.model = RecurrentAttention(
self.patch_size,
self.num_patches,
self.glimpse_scale,
self.num_channels,
self.loc_hidden,
self.glimpse_hidden,
self.std,
self.hidden_size,
self.num_classes,
)
self.model.to(self.device)
# initialize optimizer and scheduler
self.optimizer = torch.optim.Adam(
self.model.parameters(), lr=self.config.init_lr
)
self.scheduler = ReduceLROnPlateau(
self.optimizer, "min", patience=self.lr_patience
)
def reset(self):
h_t = torch.zeros(
self.batch_size,
self.hidden_size,
dtype=torch.float,
device=self.device,
requires_grad=True,
)
l_t = torch.FloatTensor(self.batch_size, 2).uniform_(-1, 1).to(self.device)
l_t.requires_grad = True
return h_t, l_t
def train(self):
"""Train the model on the training set.
A checkpoint of the model is saved after each epoch
and if the validation accuracy is improved upon,
a separate ckpt is created for use on the test set.
"""
# load the most recent checkpoint
if self.resume:
self.load_checkpoint(best=False)
print(
"\n[*] Train on {} samples, validate on {} samples".format(
self.num_train, self.num_valid
)
)
for epoch in range(self.start_epoch, self.epochs):
print(
"\nEpoch: {}/{} - LR: {:.6f}".format(
epoch + 1, self.epochs, self.optimizer.param_groups[0]["lr"]
)
)
# train for 1 epoch
train_loss, train_acc = self.train_one_epoch(epoch)
# evaluate on validation set
valid_loss, valid_acc = self.validate(epoch)
# # reduce lr if validation loss plateaus
self.scheduler.step(-valid_acc)
is_best = valid_acc > self.best_valid_acc
msg1 = "train loss: {:.3f} - train acc: {:.3f} "
msg2 = "- val loss: {:.3f} - val acc: {:.3f} - val err: {:.3f}"
if is_best:
self.counter = 0
msg2 += " [*]"
msg = msg1 + msg2
print(
msg.format(
train_loss, train_acc, valid_loss, valid_acc, 100 - valid_acc
)
)
# check for improvement
if not is_best:
self.counter += 1
if self.counter > self.train_patience:
print("[!] No improvement in a while, stopping training.")
return
self.best_valid_acc = max(valid_acc, self.best_valid_acc)
self.save_checkpoint(
{
"epoch": epoch + 1,
"model_state": self.model.state_dict(),
"optim_state": self.optimizer.state_dict(),
"best_valid_acc": self.best_valid_acc,
},
is_best,
)
def train_one_epoch(self, epoch):
"""
Train the model for 1 epoch of the training set.
An epoch corresponds to one full pass through the entire
training set in successive mini-batches.
This is used by train() and should not be called manually.
"""
self.model.train()
batch_time = AverageMeter()
losses = AverageMeter()
accs = AverageMeter()
tic = time.time()
with tqdm(total=self.num_train) as pbar:
for i, (x, y) in enumerate(self.train_loader):
self.optimizer.zero_grad()
x, y = x.to(self.device), y.to(self.device)
plot = False
if (epoch % self.plot_freq == 0) and (i == 0):
plot = True
# initialize location vector and hidden state
self.batch_size = x.shape[0]
h_t, l_t = self.reset()
# save images
imgs = []
imgs.append(x[0:9])
# extract the glimpses
locs = []
log_pi = []
baselines = []
for t in range(self.num_glimpses - 1):
# forward pass through model
h_t, l_t, b_t, p = self.model(x, l_t, h_t)
# store
locs.append(l_t[0:9])
baselines.append(b_t)
log_pi.append(p)
# last iteration
h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True)
log_pi.append(p)
baselines.append(b_t)
locs.append(l_t[0:9])
# convert list to tensors and reshape
baselines = torch.stack(baselines).transpose(1, 0)
log_pi = torch.stack(log_pi).transpose(1, 0)
# calculate reward
predicted = torch.max(log_probas, 1)[1]
R = (predicted.detach() == y).float()
R = R.unsqueeze(1).repeat(1, self.num_glimpses)
# compute losses for differentiable modules
loss_action = F.nll_loss(log_probas, y)
loss_baseline = F.mse_loss(baselines, R)
# compute reinforce loss
# summed over timesteps and averaged across batch
adjusted_reward = R - baselines.detach()
loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1)
loss_reinforce = torch.mean(loss_reinforce, dim=0)
# sum up into a hybrid loss
loss = loss_action + loss_baseline + loss_reinforce * 0.01
# compute accuracy
correct = (predicted == y).float()
acc = 100 * (correct.sum() / len(y))
# store
losses.update(loss.item(), x.size()[0])
accs.update(acc.item(), x.size()[0])
# compute gradients and update SGD
loss.backward()
self.optimizer.step()
# measure elapsed time
toc = time.time()
batch_time.update(toc - tic)
pbar.set_description(
(
"{:.1f}s - loss: {:.3f} - acc: {:.3f}".format(
(toc - tic), loss.item(), acc.item()
)
)
)
pbar.update(self.batch_size)
# dump the glimpses and locs
if plot:
imgs = [g.cpu().data.numpy().squeeze() for g in imgs]
locs = [l.cpu().data.numpy() for l in locs]
pickle.dump(
imgs, open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb")
)
pickle.dump(
locs, open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb")
)
# log to tensorboard
if self.use_tensorboard:
iteration = epoch * len(self.train_loader) + i
log_value("train_loss", losses.avg, iteration)
log_value("train_acc", accs.avg, iteration)
return losses.avg, accs.avg
@torch.no_grad()
def validate(self, epoch):
"""Evaluate the RAM model on the validation set.
"""
losses = AverageMeter()
accs = AverageMeter()
for i, (x, y) in enumerate(self.valid_loader):
x, y = x.to(self.device), y.to(self.device)
# duplicate M times
x = x.repeat(self.M, 1, 1, 1)
# initialize location vector and hidden state
self.batch_size = x.shape[0]
h_t, l_t = self.reset()
# extract the glimpses
log_pi = []
baselines = []
for t in range(self.num_glimpses - 1):
# forward pass through model
h_t, l_t, b_t, p = self.model(x, l_t, h_t)
# store
baselines.append(b_t)
log_pi.append(p)
# last iteration
h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True)
log_pi.append(p)
baselines.append(b_t)
# convert list to tensors and reshape
baselines = torch.stack(baselines).transpose(1, 0)
log_pi = torch.stack(log_pi).transpose(1, 0)
# average
log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
log_probas = torch.mean(log_probas, dim=0)
baselines = baselines.contiguous().view(self.M, -1, baselines.shape[-1])
baselines = torch.mean(baselines, dim=0)
log_pi = log_pi.contiguous().view(self.M, -1, log_pi.shape[-1])
log_pi = torch.mean(log_pi, dim=0)
# calculate reward
predicted = torch.max(log_probas, 1)[1]
R = (predicted.detach() == y).float()
R = R.unsqueeze(1).repeat(1, self.num_glimpses)
# compute losses for differentiable modules
loss_action = F.nll_loss(log_probas, y)
loss_baseline = F.mse_loss(baselines, R)
# compute reinforce loss
adjusted_reward = R - baselines.detach()
loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1)
loss_reinforce = torch.mean(loss_reinforce, dim=0)
# sum up into a hybrid loss
loss = loss_action + loss_baseline + loss_reinforce * 0.01
# compute accuracy
correct = (predicted == y).float()
acc = 100 * (correct.sum() / len(y))
# store
losses.update(loss.item(), x.size()[0])
accs.update(acc.item(), x.size()[0])
# log to tensorboard
if self.use_tensorboard:
iteration = epoch * len(self.valid_loader) + i
log_value("valid_loss", losses.avg, iteration)
log_value("valid_acc", accs.avg, iteration)
return losses.avg, accs.avg
@torch.no_grad()
def test(self):
"""Test the RAM model.
This function should only be called at the very
end once the model has finished training.
"""
correct = 0
# load the best checkpoint
self.load_checkpoint(best=self.best)
for i, (x, y) in enumerate(self.test_loader):
x, y = x.to(self.device), y.to(self.device)
# duplicate M times
x = x.repeat(self.M, 1, 1, 1)
# initialize location vector and hidden state
self.batch_size = x.shape[0]
h_t, l_t = self.reset()
# extract the glimpses
for t in range(self.num_glimpses - 1):
# forward pass through model
h_t, l_t, b_t, p = self.model(x, l_t, h_t)
# last iteration
h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True)
log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
log_probas = torch.mean(log_probas, dim=0)
pred = log_probas.data.max(1, keepdim=True)[1]
correct += pred.eq(y.data.view_as(pred)).cpu().sum()
perc = (100.0 * correct) / (self.num_test)
error = 100 - perc
print(
"[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%)".format(
correct, self.num_test, perc, error
)
)
def save_checkpoint(self, state, is_best):
"""Saves a checkpoint of the model.
If this model has reached the best validation accuracy thus
far, a seperate file with the suffix `best` is created.
"""
filename = self.model_name + "_ckpt.pth.tar"
ckpt_path = os.path.join(self.ckpt_dir, filename)
torch.save(state, ckpt_path)
if is_best:
filename = self.model_name + "_model_best.pth.tar"
shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename))
def load_checkpoint(self, best=False):
"""Load the best copy of a model.
This is useful for 2 cases:
- Resuming training with the most recent model checkpoint.
- Loading the best validation model to evaluate on the test data.
Args:
best: if set to True, loads the best model.
Use this if you want to evaluate your model
on the test data. Else, set to False in which
case the most recent version of the checkpoint
is used.
"""
print("[*] Loading model from {}".format(self.ckpt_dir))
filename = self.model_name + "_ckpt.pth.tar"
if best:
filename = self.model_name + "_model_best.pth.tar"
ckpt_path = os.path.join(self.ckpt_dir, filename)
ckpt = torch.load(ckpt_path)
# load variables from checkpoint
self.start_epoch = ckpt["epoch"]
self.best_valid_acc = ckpt["best_valid_acc"]
self.model.load_state_dict(ckpt["model_state"])
self.optimizer.load_state_dict(ckpt["optim_state"])
if best:
print(
"[*] Loaded {} checkpoint @ epoch {} "
"with best valid acc of {:.3f}".format(
filename, ckpt["epoch"], ckpt["best_valid_acc"]
)
)
else:
print("[*] Loaded {} checkpoint @ epoch {}".format(filename, ckpt["epoch"]))