From 5386fb120a6ba2f1c635f08f0b22413349cc6caa Mon Sep 17 00:00:00 2001 From: Lukasz Kaiser Date: Wed, 21 Aug 2019 14:36:40 -0700 Subject: [PATCH] Small corrections. PiperOrigin-RevId: 264694817 --- tensor2tensor/trax/layers/base.py | 4 +- tensor2tensor/trax/layers/combinators.py | 2 +- .../models/research/transformer_revnet.py | 48 ++++++++++--------- 3 files changed, 28 insertions(+), 26 deletions(-) diff --git a/tensor2tensor/trax/layers/base.py b/tensor2tensor/trax/layers/base.py index 79568de0e..842ec3025 100644 --- a/tensor2tensor/trax/layers/base.py +++ b/tensor2tensor/trax/layers/base.py @@ -258,12 +258,12 @@ def __call__(self, x, params=(), state=(), **kwargs): # JAX. assert state is (), ( # pylint: disable=literal-comparison - 'Custom gradients do not allow non-trivial start state.') + 'Custom gradients require trivial start state. Got %s' % str(state)) def check_end_state(output_state): output, state = output_state assert state is (), ( # pylint: disable=literal-comparison - 'Custom gradients do not allow non-trivial end state.') + 'Custom gradients require trivial end state. Got %s' % str(state)) return output # See this link for how custom transformations are defined in JAX: diff --git a/tensor2tensor/trax/layers/combinators.py b/tensor2tensor/trax/layers/combinators.py index f51fb5ba8..800f48a7e 100644 --- a/tensor2tensor/trax/layers/combinators.py +++ b/tensor2tensor/trax/layers/combinators.py @@ -181,7 +181,7 @@ def call(self, xs, params=(), state=(), **kwargs): raise ValueError('number of params ({}) not equal to number of layers ' '({})'.format(len(params), n_layers)) if n_layers != 1 and len(state) != n_layers: - raise ValueError('number of params ({}) not equal to number of layers ' + raise ValueError('length of state ({}) not equal to number of layers ' '({})'.format(len(state), n_layers)) for layer, p, s, rng in zip(self._sublayers, params, state, rngs): is_stack_just_one_item = (_count_items(stack) == 1) diff --git a/tensor2tensor/trax/models/research/transformer_revnet.py b/tensor2tensor/trax/models/research/transformer_revnet.py index 0c92b7e68..ff5880b0f 100644 --- a/tensor2tensor/trax/models/research/transformer_revnet.py +++ b/tensor2tensor/trax/models/research/transformer_revnet.py @@ -65,11 +65,12 @@ def n_outputs(self): """Specifies how many data tensors this layer promises as output.""" return self._n_sections - def call(self, inputs, params=(), **kwargs): + def call(self, inputs, params=(), state=(), **kwargs): rngs = _pop_rng_and_split(kwargs, len(inputs)) - result = [self._layer(x, params=params, rng=r, **kwargs) - for x, r in zip(inputs, rngs)] - return tuple(result) + results = [self._layer(x, params=params, state=state, rng=r, **kwargs) + for x, r in zip(inputs, rngs)] + result_outputs, result_states = zip(*results) + return tuple(result_outputs), tuple(result_states) def new_parameters(self, input_shape, input_dtype, rng): first_shape = input_shape[0] @@ -122,12 +123,13 @@ def __init__(self, n_sections=2, axis=-1): self._n_sections = n_sections self._axis = axis - def call(self, inputs, params=(), **kwargs): + def call(self, inputs, params=(), state=(), **kwargs): del params, kwargs - return tuple(backend.numpy.split(inputs, self._n_sections, self._axis)) + res = tuple(backend.numpy.split(inputs, self._n_sections, self._axis)) + return res, state def new_parameters(self, input_shapes, input_dtype, rng): - return () + return (), () def n_inputs(self): """Specifies how many data tensors this layer expects as input.""" @@ -167,9 +169,9 @@ def n_outputs(self): return self._n_sections def new_parameters(self, input_shape, input_dtype, rng): - return () + return (), () - def call(self, inputs, params=(), **kwargs): + def call(self, inputs, params=(), state=(), **kwargs): del params, kwargs x1, x2 = inputs @@ -177,7 +179,7 @@ def call(self, inputs, params=(), **kwargs): x2_split = backend.numpy.split(x2, self._n_sections, self._axis) res = [backend.numpy.concatenate(ys, -1) for ys in zip(x1_split, x2_split)] - return tuple(res) + return tuple(res), state def reverse(self, output, params=(), **kwargs): del params, kwargs @@ -288,7 +290,7 @@ def __init__(self, n_heads=1, d_head=64, # The lack of a bias term here is consistent with the tensor2tensor # implementation, and shouldn't have an effect on modeling quality. - def call(self, x, params, **kwargs): + def call(self, x, params, state, **kwargs): del kwargs seqlen = x.shape[1] res = np.dot(x, params) @@ -300,13 +302,13 @@ def call(self, x, params, **kwargs): # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head res = np.reshape(res, (-1, seqlen, self._d_head)) - return res + return res, state def new_parameters(self, input_shape, input_dtype, rng): del input_dtype w = self._kernel_initializer( (input_shape[-1], self._n_heads * self._d_head), rng) - return w + return w, () class ComputeAttentionOutput(tl.Layer): @@ -321,7 +323,7 @@ def __init__(self, n_heads=1, d_model=1024, # The lack of a bias term here is consistent with the tensor2tensor # implementation, and shouldn't have an effect on modeling quality. - def call(self, x, params, **kwargs): + def call(self, x, params, state, **kwargs): del kwargs seqlen = x.shape[1] d_head = x.shape[2] @@ -330,13 +332,13 @@ def call(self, x, params, **kwargs): x = np.transpose(x, (0, 2, 1, 3)) # -> n_batch, seqlen, n_heads, d_head x = np.reshape(x, (-1, seqlen, self._n_heads * d_head)) - return np.dot(x, params) + return np.dot(x, params), state def new_parameters(self, input_shape, input_dtype, rng): del input_dtype w = self._kernel_initializer( (input_shape[-1] * self._n_heads, self._d_model), rng) - return w + return w, () class ApplyAttentionWrapper(tl.Parallel): @@ -374,14 +376,14 @@ def __init__(self, dropout, mode): self._dropout = dropout self._mode = mode - def call(self, inputs, params=(), rng=None, **kwargs): + def call(self, inputs, params=(), state=(), rng=None, **kwargs): del params q, k, v = inputs mask_size = q.shape[-2] mask = np.tril(np.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0) res = tl.DotProductAttention( q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=rng) - return res + return res, state def forward_and_vjp(self, inputs, ct, params=(), **kwargs): # Simultaneous forward pass and backprop through the attention mechanism. @@ -391,7 +393,7 @@ def do_call(x): return output, vjpfun(ct)[0] def new_parameters(self, input_shapes, input_dtype, rng): - return () + return (), () def n_inputs(self): return 3 @@ -413,9 +415,9 @@ def __init__(self, loop_stride, dropout, mode): else: self.dropout = None - def call(self, inputs, params=(), **kwargs): + def call(self, inputs, params=(), state=(), **kwargs): output, _ = self.forward_and_vjp(inputs, None, params=params, **kwargs) - return output + return output, state def forward_and_vjp(self, inputs, ct, params=(), rng=None, **kwargs): # This is the core of the memory-efficient attention implementation, where @@ -547,9 +549,9 @@ def __init__(self, dropout, mode, n_bins=64): super(DummyHashedAttention, self).__init__(dropout, mode) self.n_bins = n_bins - def call(self, inputs, params=(), **kwargs): + def call(self, inputs, params=(), state=(), **kwargs): output, _ = self.forward_and_vjp(inputs, None, params=params, **kwargs) - return output + return output, state def forward_and_vjp(self, inputs, ct, params=(), **kwargs): del params, kwargs