Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Small corrections.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 264694817
  • Loading branch information
Lukasz Kaiser authored and copybara-github committed Aug 21, 2019
1 parent 67ca605 commit 5386fb1
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 26 deletions.
4 changes: 2 additions & 2 deletions tensor2tensor/trax/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,12 @@ def __call__(self, x, params=(), state=(), **kwargs):
# JAX.

assert state is (), ( # pylint: disable=literal-comparison
'Custom gradients do not allow non-trivial start state.')
'Custom gradients require trivial start state. Got %s' % str(state))

def check_end_state(output_state):
output, state = output_state
assert state is (), ( # pylint: disable=literal-comparison
'Custom gradients do not allow non-trivial end state.')
'Custom gradients require trivial end state. Got %s' % str(state))
return output

# See this link for how custom transformations are defined in JAX:
Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/trax/layers/combinators.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def call(self, xs, params=(), state=(), **kwargs):
raise ValueError('number of params ({}) not equal to number of layers '
'({})'.format(len(params), n_layers))
if n_layers != 1 and len(state) != n_layers:
raise ValueError('number of params ({}) not equal to number of layers '
raise ValueError('length of state ({}) not equal to number of layers '
'({})'.format(len(state), n_layers))
for layer, p, s, rng in zip(self._sublayers, params, state, rngs):
is_stack_just_one_item = (_count_items(stack) == 1)
Expand Down
48 changes: 25 additions & 23 deletions tensor2tensor/trax/models/research/transformer_revnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,12 @@ def n_outputs(self):
"""Specifies how many data tensors this layer promises as output."""
return self._n_sections

def call(self, inputs, params=(), **kwargs):
def call(self, inputs, params=(), state=(), **kwargs):
rngs = _pop_rng_and_split(kwargs, len(inputs))
result = [self._layer(x, params=params, rng=r, **kwargs)
for x, r in zip(inputs, rngs)]
return tuple(result)
results = [self._layer(x, params=params, state=state, rng=r, **kwargs)
for x, r in zip(inputs, rngs)]
result_outputs, result_states = zip(*results)
return tuple(result_outputs), tuple(result_states)

def new_parameters(self, input_shape, input_dtype, rng):
first_shape = input_shape[0]
Expand Down Expand Up @@ -122,12 +123,13 @@ def __init__(self, n_sections=2, axis=-1):
self._n_sections = n_sections
self._axis = axis

def call(self, inputs, params=(), **kwargs):
def call(self, inputs, params=(), state=(), **kwargs):
del params, kwargs
return tuple(backend.numpy.split(inputs, self._n_sections, self._axis))
res = tuple(backend.numpy.split(inputs, self._n_sections, self._axis))
return res, state

def new_parameters(self, input_shapes, input_dtype, rng):
return ()
return (), ()

def n_inputs(self):
"""Specifies how many data tensors this layer expects as input."""
Expand Down Expand Up @@ -167,17 +169,17 @@ def n_outputs(self):
return self._n_sections

def new_parameters(self, input_shape, input_dtype, rng):
return ()
return (), ()

def call(self, inputs, params=(), **kwargs):
def call(self, inputs, params=(), state=(), **kwargs):
del params, kwargs
x1, x2 = inputs

x1_split = backend.numpy.split(x1, self._n_sections, self._axis)
x2_split = backend.numpy.split(x2, self._n_sections, self._axis)

res = [backend.numpy.concatenate(ys, -1) for ys in zip(x1_split, x2_split)]
return tuple(res)
return tuple(res), state

def reverse(self, output, params=(), **kwargs):
del params, kwargs
Expand Down Expand Up @@ -288,7 +290,7 @@ def __init__(self, n_heads=1, d_head=64,
# The lack of a bias term here is consistent with the tensor2tensor
# implementation, and shouldn't have an effect on modeling quality.

def call(self, x, params, **kwargs):
def call(self, x, params, state, **kwargs):
del kwargs
seqlen = x.shape[1]
res = np.dot(x, params)
Expand All @@ -300,13 +302,13 @@ def call(self, x, params, **kwargs):
# n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
res = np.reshape(res, (-1, seqlen, self._d_head))

return res
return res, state

def new_parameters(self, input_shape, input_dtype, rng):
del input_dtype
w = self._kernel_initializer(
(input_shape[-1], self._n_heads * self._d_head), rng)
return w
return w, ()


class ComputeAttentionOutput(tl.Layer):
Expand All @@ -321,7 +323,7 @@ def __init__(self, n_heads=1, d_model=1024,
# The lack of a bias term here is consistent with the tensor2tensor
# implementation, and shouldn't have an effect on modeling quality.

def call(self, x, params, **kwargs):
def call(self, x, params, state, **kwargs):
del kwargs
seqlen = x.shape[1]
d_head = x.shape[2]
Expand All @@ -330,13 +332,13 @@ def call(self, x, params, **kwargs):
x = np.transpose(x, (0, 2, 1, 3)) # -> n_batch, seqlen, n_heads, d_head
x = np.reshape(x, (-1, seqlen, self._n_heads * d_head))

return np.dot(x, params)
return np.dot(x, params), state

def new_parameters(self, input_shape, input_dtype, rng):
del input_dtype
w = self._kernel_initializer(
(input_shape[-1] * self._n_heads, self._d_model), rng)
return w
return w, ()


class ApplyAttentionWrapper(tl.Parallel):
Expand Down Expand Up @@ -374,14 +376,14 @@ def __init__(self, dropout, mode):
self._dropout = dropout
self._mode = mode

def call(self, inputs, params=(), rng=None, **kwargs):
def call(self, inputs, params=(), state=(), rng=None, **kwargs):
del params
q, k, v = inputs
mask_size = q.shape[-2]
mask = np.tril(np.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0)
res = tl.DotProductAttention(
q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=rng)
return res
return res, state

def forward_and_vjp(self, inputs, ct, params=(), **kwargs):
# Simultaneous forward pass and backprop through the attention mechanism.
Expand All @@ -391,7 +393,7 @@ def do_call(x):
return output, vjpfun(ct)[0]

def new_parameters(self, input_shapes, input_dtype, rng):
return ()
return (), ()

def n_inputs(self):
return 3
Expand All @@ -413,9 +415,9 @@ def __init__(self, loop_stride, dropout, mode):
else:
self.dropout = None

def call(self, inputs, params=(), **kwargs):
def call(self, inputs, params=(), state=(), **kwargs):
output, _ = self.forward_and_vjp(inputs, None, params=params, **kwargs)
return output
return output, state

def forward_and_vjp(self, inputs, ct, params=(), rng=None, **kwargs):
# This is the core of the memory-efficient attention implementation, where
Expand Down Expand Up @@ -547,9 +549,9 @@ def __init__(self, dropout, mode, n_bins=64):
super(DummyHashedAttention, self).__init__(dropout, mode)
self.n_bins = n_bins

def call(self, inputs, params=(), **kwargs):
def call(self, inputs, params=(), state=(), **kwargs):
output, _ = self.forward_and_vjp(inputs, None, params=params, **kwargs)
return output
return output, state

def forward_and_vjp(self, inputs, ct, params=(), **kwargs):
del params, kwargs
Expand Down

0 comments on commit 5386fb1

Please sign in to comment.