Skip to content

Commit

Permalink
Optimize build_future_mask implementation
Browse files Browse the repository at this point in the history
The triangular masking can be achieved using tf.matrix_band_part. This
change results in a speedup of up to 10% during the training of
Transformer models.
  • Loading branch information
guillaumekln committed Mar 27, 2018
1 parent b2df462 commit 00069a5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ OpenNMT-tf follows [semantic versioning 2.0.0](https://semver.org/). The API cov
* Fix the encoder state structure when RNN encoders are combined (e.g. in `SequentialEncoder`)
* Fix `CharConvEmbedder` error on empty sequences
* Fix `Adafactor` crash on sparse updates, automatically fallback to dense updates instead
* Improve the Transformer decoder mask construction (up to 10% speedup during training)

## [1.0.1](https://github.com/OpenNMT/OpenNMT-tf/releases/tag/v1.0.1) (2018-03-14)

Expand Down
16 changes: 7 additions & 9 deletions opennmt/layers/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,13 @@ def build_future_mask(sequence_length,
"""
if num_heads is not None:
sequence_length = tile_sequence_length(sequence_length, num_heads)
if maximum_length is None:
maximum_length = tf.reduce_max(sequence_length)
mask = tf.map_fn(
lambda x: tf.sequence_mask(
tf.minimum(tf.range(maximum_length) + 1, x),
maxlen=maximum_length,
dtype=dtype),
sequence_length,
dtype=dtype)
sequence_mask = tf.sequence_mask(sequence_length, maxlen=maximum_length, dtype=dtype)
shape = tf.shape(sequence_mask)
batch_size = shape[0]
max_time = shape[1]
mask = tf.ones([batch_size, max_time, max_time], dtype=dtype)
mask = tf.matrix_band_part(mask, -1, 0)
mask *= tf.expand_dims(sequence_mask, axis=1)
if num_heads is not None:
mask = tf.reshape(mask, [-1, num_heads, tf.shape(mask)[1], tf.shape(mask)[2]])
return mask
Expand Down

0 comments on commit 00069a5

Please sign in to comment.