This repository has been archived by the owner on Nov 2, 2018. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 85
/
Copy pathReinforceSampler.lua
70 lines (67 loc) · 2.66 KB
/
ReinforceSampler.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
--
-- Copyright (c) 2015, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the BSD-style license found in the
-- LICENSE file in the root directory of this source tree. An additional grant
-- of patent rights can be found in the PATENTS file in the same directory.
--
-- Author: Marc'Aurelio Ranzato <[email protected]>
-- Sumit Chopra <[email protected]>
-- Michael Auli <[email protected]>
-- Wojciech Zaremba <[email protected]>
--
local ReinforceSampler, parent = torch.class('nn.ReinforceSampler',
'nn.Module')
-- Module that takes a tensor storing log-probabilities (output of a LogSoftmax)
-- and samples from the corresponding multinomial distribtion.
-- Assumption: this receives input from a LogSoftMax and receives gradients from
-- a ReinforceCriterion.
function ReinforceSampler:__init(distribution)
parent.__init(self)
self.distribution = distribution
self.prob = torch.Tensor()
end
function ReinforceSampler:updateOutput(input)
if self.distribution == 'multinomial' then
self.prob:resizeAs(input)
self.prob:copy(input)
self.prob:exp()
self.output:resize(input:size(1), 1)
if torch.typename(self.output):find('torch%.Cuda.*Tensor') then
self.output = self.output:cudaLong()
else
self.output = self.output:long()
end
self.prob.multinomial(self.output, self.prob, 1)
if torch.typename(self.output):find('torch%.Cuda.*Tensor') then
self.output = self.output:cuda()
else
self.output = self.output:float()
end
else
error('we did not implement sampling from', self.distribution)
end
return self.output -- batch x 1
end
-- NOTE: in order for this to work, it has to be connected
-- to a ReinforceCriterion.
function ReinforceSampler:updateGradInput(input, gradOutput)
if self.distribution == 'multinomial' then
-- loop over mini-batches and build sparse vector of gradients
-- such that each sample has a vector of gradients that is all 0s
-- except for the component corresponding to the chosen word.
-- We assume that the gradients are provided by a ReinforceCriterion.
self.gradInput:resizeAs(input)
self.gradInput:zero()
for ss = 1, self.gradInput:size(1) do
-- adding round because sometimes multinomial returns a float 1e-6 far
-- from an integer.
self.gradInput[ss][torch.round(self.output[ss][1])] =
gradOutput[ss][1]
end
return self.gradInput
else
error('we did not implement sampling from', self.distribution)
end
end