-
Notifications
You must be signed in to change notification settings - Fork 124
Caveat in last commit #12
Comments
Won't This is my understanding: We want the weights of the location network to be trained using REINFORCE. Now the hidden state vector Doing the above, we've made it possible to backpropagate through |
Both |
Oh. I think you know what I meant. |
I think we shouldn't flow the information through |
@ipod825 I need to think about it some more. Empirically, I haven't seen a performance difference between the 2. I still reach ~1.3-1.4% error in about 30 epochs of training. What's bugging me right now is that I learned about the reparametrization trick this weekend, which essentially makes it possible to backprop through a sampled variable. So right now, I'm confused as to why we even need REINFORCE to train our network. We could just use the reparametrization trick like in VAEs to make the whole process differentiable and directly optimize for the weights of the location network. I'll give it some more thought tonight. |
Performance issue might not be related to all this formula issue. If you check this |
Also, I don't think re-parametrization trick applies to this scenario. |
recurrent-visual-attention/modules.py Line 350 in 99c4cbe
This line is related to this issue. You shouldn't apply tanh on l_t again. Say mu is 100, tanh(mu)=1.0. Even after adding a noinse, tanh(l_t) ~ tanh(1.0) = 0.76159.
A better idea is to use |
@ipod825 The PDF of a normal distribution is not bounded, so it is not guaranteed that I was against using |
mu = F.tanh(self.fc(h_t.detach()))
# reparametrization trick
noise = torch.zeros_like(mu)
noise.data.normal_(std=self.std)
l_t = mu + noise
# bound between [-1, 1]
l_t = F.tanh(l_t)
mu = F.clamp(self.fc(h_t.detach()), -1,1)
# reparametrization trick
noise = torch.zeros_like(mu)
noise.data.normal_(std=self.std)
l_t = mu + noise
# bound between [-1, 1]
l_t = F.clamp(l_t,-1,1) And do not detach the log_pi = Normal(mu, self.std).log_prob(l_t) You can check that the gradient in the location network is actually 0, as predicted by the discussion above. But if you use |
@ipod825 Have you tried your implementation using clamp and l_t.detach()? I tried that and got a very high performance on 6 glimpses, 8*8, 1 scale setting, around 0.58%. Paper reported 1.12%. |
I never got error lower than 1%. If you use only vanilla RNN (as already implemented by @kevinzakka), that would be an interesting result. If you consistently got similar results, it would be nice if you can share your code and let others figure out why it works so well. |
Why should we not train the weights of the RNN with REINFORCE ? |
99c4cbe#diff-40d9c2c37e955447b1175a32afab171fL353
This is not an unnecessary detach.
As it is used in
log_pi = Normal(mu, self.std).log_prob(l_t)
which is then used in
loss_reinforce = torch.sum(-log_pi*adjusted_reward, dim=1)
which means when minimizing reinforce loss, you are altering your location network through both mu and l_t (and yes, log_pi is differentiable w.r.t both mu and l_t). However, l_t is just mu+noise and we only want the gradient to flow through mu.
The text was updated successfully, but these errors were encountered: