From 00069a543d068e24296bccfdd3f4affb2bdc19f8 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Tue, 27 Mar 2018 14:24:52 +0200 Subject: [PATCH] Optimize build_future_mask implementation The triangular masking can be achieved using tf.matrix_band_part. This change results in a speedup of up to 10% during the training of Transformer models. --- CHANGELOG.md | 1 + opennmt/layers/transformer.py | 16 +++++++--------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bf5bba094..bfc26ad03 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ OpenNMT-tf follows [semantic versioning 2.0.0](https://semver.org/). The API cov * Fix the encoder state structure when RNN encoders are combined (e.g. in `SequentialEncoder`) * Fix `CharConvEmbedder` error on empty sequences * Fix `Adafactor` crash on sparse updates, automatically fallback to dense updates instead +* Improve the Transformer decoder mask construction (up to 10% speedup during training) ## [1.0.1](https://github.com/OpenNMT/OpenNMT-tf/releases/tag/v1.0.1) (2018-03-14) diff --git a/opennmt/layers/transformer.py b/opennmt/layers/transformer.py index 44c8566be..c2c561380 100644 --- a/opennmt/layers/transformer.py +++ b/opennmt/layers/transformer.py @@ -63,15 +63,13 @@ def build_future_mask(sequence_length, """ if num_heads is not None: sequence_length = tile_sequence_length(sequence_length, num_heads) - if maximum_length is None: - maximum_length = tf.reduce_max(sequence_length) - mask = tf.map_fn( - lambda x: tf.sequence_mask( - tf.minimum(tf.range(maximum_length) + 1, x), - maxlen=maximum_length, - dtype=dtype), - sequence_length, - dtype=dtype) + sequence_mask = tf.sequence_mask(sequence_length, maxlen=maximum_length, dtype=dtype) + shape = tf.shape(sequence_mask) + batch_size = shape[0] + max_time = shape[1] + mask = tf.ones([batch_size, max_time, max_time], dtype=dtype) + mask = tf.matrix_band_part(mask, -1, 0) + mask *= tf.expand_dims(sequence_mask, axis=1) if num_heads is not None: mask = tf.reshape(mask, [-1, num_heads, tf.shape(mask)[1], tf.shape(mask)[2]]) return mask