Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding a positive bias to the LSTM forget gate #840

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions blocks/bricks/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from blocks.bricks import Initializable, Logistic, Tanh, Linear
from blocks.bricks.base import Application, application, Brick, lazy
from blocks.initialization import NdarrayInitialization
from blocks.roles import add_role, WEIGHT, INITIAL_STATE
from blocks.roles import add_role, WEIGHT, BIAS, INITIAL_STATE
from blocks.utils import (pack, shared_floatx_nans, shared_floatx_zeros,
dict_union, dict_subset, is_shared_variable)
from blocks.bricks.parallel import Fork
Expand Down Expand Up @@ -354,6 +354,9 @@ class LSTM(BaseRecurrent, Initializable):
networks*, arXiv preprint arXiv:1308.0850 (2013).
.. [HS97] Sepp Hochreiter, and Jürgen Schmidhuber, *Long Short-Term
Memory*, Neural Computation 9(8) (1997), pp. 1735-1780.
.. [Jozefowicz15] Jozefowicz R., Zaremba W. and Sutskever I., *An
Empirical Exploration of Recurrent Network Architectures*, Journal
of Machine Learning Research 37 (2015).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sphynx doesn't like citations which were not referenced from somewhere. And actually, I see no reason of mentioning this paper, we cannot mention everyone who worked with LSTMs.

Can you whether add some documentation concerning this paper or delete the reference?


Parameters
----------
Expand Down Expand Up @@ -412,10 +415,20 @@ def _allocate(self):
self.W_state, self.W_cell_to_in, self.W_cell_to_forget,
self.W_cell_to_out, self.initial_state_, self.initial_cells]

if self.use_bias:
self.b_cell_to_forget = shared_floatx_nans((self.dim,),
name='b_cell_to_forget')
add_role(self.b_cell_to_forget, BIAS)
self.parameters.append(self.b_cell_to_forget)

def _initialize(self):
for weights in self.parameters[:4]:
self.weights_init.initialize(weights, self.rng)

if self.use_bias:
for biases in self.parameters[-1:]:
self.biases_init.initialize(biases, self.rng)

@recurrent(sequences=['inputs', 'mask'], states=['states', 'cells'],
contexts=[], outputs=['states', 'cells'])
def apply(self, inputs, states, cells, mask=None):
Expand Down Expand Up @@ -459,7 +472,8 @@ def slice_last(x, no):
in_gate = tensor.nnet.sigmoid(slice_last(activation, 0) +
cells * self.W_cell_to_in)
forget_gate = tensor.nnet.sigmoid(slice_last(activation, 1) +
cells * self.W_cell_to_forget)
cells * self.W_cell_to_forget +
self.b_cell_to_forget)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to fail if not self.use_bias.

next_cells = (forget_gate * cells +
in_gate * nonlinearity(slice_last(activation, 2)))
out_gate = tensor.nnet.sigmoid(slice_last(activation, 3) +
Expand Down
10 changes: 7 additions & 3 deletions tests/bricks/test_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_many_steps(self):
class TestLSTM(unittest.TestCase):
def setUp(self):
self.lstm = LSTM(dim=3, weights_init=Constant(2),
biases_init=Constant(0))
biases_init=Constant(1))
self.lstm.initialize()

def test_one_step(self):
Expand All @@ -166,6 +166,7 @@ def test_one_step(self):
W_cell_to_in = 2 * numpy.ones((3,), dtype=theano.config.floatX)
W_cell_to_out = 2 * numpy.ones((3,), dtype=theano.config.floatX)
W_cell_to_forget = 2 * numpy.ones((3,), dtype=theano.config.floatX)
b_cell_to_forget = 1 * numpy.ones((3,), dtype=theano.config.floatX)

# omitting biases because they are zero
activation = numpy.dot(h0_val, W_state_val) + x_val
Expand All @@ -174,7 +175,8 @@ def sigmoid(x):
return 1. / (1. + numpy.exp(-x))

i_t = sigmoid(activation[:, :3] + c0_val * W_cell_to_in)
f_t = sigmoid(activation[:, 3:6] + c0_val * W_cell_to_forget)
f_t = sigmoid(activation[:, 3:6] + c0_val * W_cell_to_forget +
b_cell_to_forget)
next_cells = f_t * c0_val + i_t * numpy.tanh(activation[:, 6:9])
o_t = sigmoid(activation[:, 9:12] +
next_cells * W_cell_to_out)
Expand All @@ -201,14 +203,16 @@ def test_many_steps(self):
W_cell_to_in = 2 * numpy.ones((3,), dtype=theano.config.floatX)
W_cell_to_out = 2 * numpy.ones((3,), dtype=theano.config.floatX)
W_cell_to_forget = 2 * numpy.ones((3,), dtype=theano.config.floatX)
b_cell_to_forget = 1 * numpy.ones((3,), dtype=theano.config.floatX)

def sigmoid(x):
return 1. / (1. + numpy.exp(-x))

for i in range(1, 25):
activation = numpy.dot(h_val[i-1], W_state_val) + x_val[i-1]
i_t = sigmoid(activation[:, :3] + c_val[i-1] * W_cell_to_in)
f_t = sigmoid(activation[:, 3:6] + c_val[i-1] * W_cell_to_forget)
f_t = sigmoid(activation[:, 3:6] + c_val[i-1] * W_cell_to_forget +
b_cell_to_forget)
c_val[i] = f_t * c_val[i-1] + i_t * numpy.tanh(activation[:, 6:9])
o_t = sigmoid(activation[:, 9:12] +
c_val[i] * W_cell_to_out)
Expand Down