diff --git a/model.py b/model.py index 158f6cc..695b556 100644 --- a/model.py +++ b/model.py @@ -39,7 +39,7 @@ def __init__( self.std = std self.sensor = modules.GlimpseNetwork(h_g, h_l, g, k, s, c) - self.rnn = modules.CoreNetwork(hidden_size, hidden_size) + self.rnn = modules.CoreNetwork(h_g + h_l, hidden_size) self.locator = modules.LocationNetwork(hidden_size, 2, std) self.classifier = modules.ActionNetwork(hidden_size, num_classes) self.baseliner = modules.BaselineNetwork(hidden_size, 1) diff --git a/modules.py b/modules.py index b0f9f61..19be8ce 100644 --- a/modules.py +++ b/modules.py @@ -1,3 +1,4 @@ +import math import torch import torch.nn as nn import torch.nn.functional as F @@ -84,7 +85,7 @@ def extract_patch(self, x, l, size): end = start + size # pad with zeros - x = F.pad(x, (size // 2, size // 2, size // 2, size // 2)) + x = F.pad(x, (math.ceil(size/2), math.ceil(size/2), math.ceil(size/2), math.ceil(size/2))) # loop through mini-batch and extract patches patch = [] diff --git a/trainer.py b/trainer.py index f94b27a..fa688d3 100644 --- a/trainer.py +++ b/trainer.py @@ -106,8 +106,8 @@ def __init__(self, config, data_loader): self.num_patches, self.glimpse_scale, self.num_channels, - self.loc_hidden, self.glimpse_hidden, + self.loc_hidden, self.std, self.hidden_size, self.num_classes,