From b5fd4685e2f5399ec9ac01feb3827278d85b27ff Mon Sep 17 00:00:00 2001 From: Jordi Mas Date: Thu, 22 Aug 2024 17:49:55 +0200 Subject: [PATCH] Remove tensorflow.addons dependency and bring needed classes --- .github/workflows/ci.yml | 1 + docs/generate-apidoc.py | 1 + opennmt/decoders/rnn_decoder.py | 3 +- opennmt/encoders/rnn_encoder.py | 3 +- opennmt/models/catalog.py | 3 +- opennmt/models/sequence_tagger.py | 3 +- opennmt/models/sequence_to_sequence.py | 3 +- opennmt/optimizers/utils.py | 7 +- opennmt/tests/optimizer_test.py | 6 +- opennmt/tfa/__init__.py | 2 + opennmt/tfa/conftest.py | 20 + opennmt/tfa/optimizers/__init__.py | 6 + opennmt/tfa/optimizers/lazy_adam.py | 156 ++ .../tfa/optimizers/tests/lazy_adam_test.py | 262 +++ .../tests/weight_decay_optimizers_test.py | 464 +++++ opennmt/tfa/optimizers/utils.py | 41 + .../tfa/optimizers/weight_decay_optimizers.py | 554 ++++++ opennmt/tfa/rnn/__init__.py | 2 + opennmt/tfa/rnn/abstract_rnn_cell.py | 133 ++ opennmt/tfa/rnn/layer_norm_lstm_cell.py | 210 ++ .../rnn/tests/layer_norm_lstm_cell_test.py | 213 ++ opennmt/tfa/seq2seq/__init__.py | 30 + opennmt/tfa/seq2seq/attention_wrapper.py | 1739 +++++++++++++++++ opennmt/tfa/seq2seq/beam_search_decoder.py | 72 + opennmt/tfa/seq2seq/decoder.py | 583 ++++++ opennmt/tfa/seq2seq/loss.py | 211 ++ opennmt/tfa/seq2seq/sampler.py | 96 + .../seq2seq/tests/attention_wrapper_test.py | 365 ++++ opennmt/tfa/seq2seq/tests/loss_test.py | 314 +++ opennmt/tfa/text/__init__.py | 1 + opennmt/tfa/text/crf.py | 567 ++++++ opennmt/tfa/text/tests/crf_test.py | 404 ++++ opennmt/tfa/utils/__init__.py | 0 opennmt/tfa/utils/keras_utils.py | 66 + opennmt/tfa/utils/test_utils.py | 275 +++ opennmt/tfa/utils/types.py | 82 + opennmt/utils/decoding.py | 3 +- setup.py | 2 +- 38 files changed, 6888 insertions(+), 15 deletions(-) create mode 100644 opennmt/tfa/__init__.py create mode 100644 opennmt/tfa/conftest.py create mode 100644 opennmt/tfa/optimizers/__init__.py create mode 100644 opennmt/tfa/optimizers/lazy_adam.py create mode 100644 opennmt/tfa/optimizers/tests/lazy_adam_test.py create mode 100644 opennmt/tfa/optimizers/tests/weight_decay_optimizers_test.py create mode 100644 opennmt/tfa/optimizers/utils.py create mode 100644 opennmt/tfa/optimizers/weight_decay_optimizers.py create mode 100644 opennmt/tfa/rnn/__init__.py create mode 100644 opennmt/tfa/rnn/abstract_rnn_cell.py create mode 100644 opennmt/tfa/rnn/layer_norm_lstm_cell.py create mode 100644 opennmt/tfa/rnn/tests/layer_norm_lstm_cell_test.py create mode 100644 opennmt/tfa/seq2seq/__init__.py create mode 100644 opennmt/tfa/seq2seq/attention_wrapper.py create mode 100644 opennmt/tfa/seq2seq/beam_search_decoder.py create mode 100644 opennmt/tfa/seq2seq/decoder.py create mode 100644 opennmt/tfa/seq2seq/loss.py create mode 100644 opennmt/tfa/seq2seq/sampler.py create mode 100644 opennmt/tfa/seq2seq/tests/attention_wrapper_test.py create mode 100644 opennmt/tfa/seq2seq/tests/loss_test.py create mode 100644 opennmt/tfa/text/__init__.py create mode 100644 opennmt/tfa/text/crf.py create mode 100644 opennmt/tfa/text/tests/crf_test.py create mode 100644 opennmt/tfa/utils/__init__.py create mode 100644 opennmt/tfa/utils/keras_utils.py create mode 100644 opennmt/tfa/utils/test_utils.py create mode 100644 opennmt/tfa/utils/types.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 049edab31..d0b66ccee 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -69,6 +69,7 @@ jobs: - name: Run tests run: | pytest --cov=opennmt --cov-report xml opennmt/tests + pytest opennmt/tfa/ - name: Upload coverage report if: matrix.tensorflow == '2.13' diff --git a/docs/generate-apidoc.py b/docs/generate-apidoc.py index 0f9eba4c6..d30b65a2c 100644 --- a/docs/generate-apidoc.py +++ b/docs/generate-apidoc.py @@ -35,6 +35,7 @@ def document_function(output_dir, function_path): def module_is_public(module): return ( module.__name__.startswith("opennmt") + and not module.__name__.startswith("opennmt.tfa") and hasattr(module, "__file__") and module.__file__.endswith("__init__.py") ) diff --git a/opennmt/decoders/rnn_decoder.py b/opennmt/decoders/rnn_decoder.py index 1ef9caf1e..b46f1dee8 100644 --- a/opennmt/decoders/rnn_decoder.py +++ b/opennmt/decoders/rnn_decoder.py @@ -1,7 +1,8 @@ """Define RNN-based decoders.""" import tensorflow as tf -import tensorflow_addons as tfa + +import opennmt.tfa as tfa from opennmt.decoders import decoder from opennmt.layers import bridge, common, rnn, transformer diff --git a/opennmt/encoders/rnn_encoder.py b/opennmt/encoders/rnn_encoder.py index 979dbd362..251993fd9 100644 --- a/opennmt/encoders/rnn_encoder.py +++ b/opennmt/encoders/rnn_encoder.py @@ -1,7 +1,8 @@ """Define RNN-based encoders.""" import tensorflow as tf -import tensorflow_addons as tfa + +import opennmt.tfa as tfa from opennmt.encoders.encoder import Encoder, SequentialEncoder from opennmt.layers import common, rnn diff --git a/opennmt/models/catalog.py b/opennmt/models/catalog.py index b013d0b24..bee663006 100644 --- a/opennmt/models/catalog.py +++ b/opennmt/models/catalog.py @@ -1,7 +1,8 @@ """Catalog of predefined models.""" import tensorflow as tf -import tensorflow_addons as tfa + +import opennmt.tfa as tfa from opennmt import config as config_util from opennmt import decoders, encoders, inputters, layers diff --git a/opennmt/models/sequence_tagger.py b/opennmt/models/sequence_tagger.py index 2b2b7780d..53e5f32f4 100644 --- a/opennmt/models/sequence_tagger.py +++ b/opennmt/models/sequence_tagger.py @@ -2,7 +2,8 @@ import numpy as np import tensorflow as tf -import tensorflow_addons as tfa + +import opennmt.tfa as tfa from opennmt import inputters from opennmt.models.model import Model diff --git a/opennmt/models/sequence_to_sequence.py b/opennmt/models/sequence_to_sequence.py index a7f9296b8..e37999f5f 100644 --- a/opennmt/models/sequence_to_sequence.py +++ b/opennmt/models/sequence_to_sequence.py @@ -1,7 +1,8 @@ """Standard sequence-to-sequence model.""" import tensorflow as tf -import tensorflow_addons as tfa + +import opennmt.tfa as tfa from opennmt import config as config_util from opennmt import constants, inputters diff --git a/opennmt/optimizers/utils.py b/opennmt/optimizers/utils.py index 4dd613aef..431fc463c 100644 --- a/opennmt/optimizers/utils.py +++ b/opennmt/optimizers/utils.py @@ -3,13 +3,12 @@ import inspect import tensorflow as tf -import tensorflow_addons as tfa from packaging.version import Version -from tensorflow_addons.optimizers.weight_decay_optimizers import ( - DecoupledWeightDecayExtension, -) +import opennmt.tfa as tfa + +from opennmt.tfa.optimizers.weight_decay_optimizers import DecoupledWeightDecayExtension from opennmt.utils import misc if Version(tf.__version__) >= Version("2.11.0"): diff --git a/opennmt/tests/optimizer_test.py b/opennmt/tests/optimizer_test.py index 2808a557a..092fb6cf6 100644 --- a/opennmt/tests/optimizer_test.py +++ b/opennmt/tests/optimizer_test.py @@ -1,12 +1,10 @@ import tensorflow as tf -import tensorflow_addons as tfa -from tensorflow_addons.optimizers.weight_decay_optimizers import ( - DecoupledWeightDecayExtension, -) +import opennmt.tfa as tfa from opennmt.optimizers import utils from opennmt.tests import test_util +from opennmt.tfa.optimizers.weight_decay_optimizers import DecoupledWeightDecayExtension class OptimizerTest(tf.test.TestCase): diff --git a/opennmt/tfa/__init__.py b/opennmt/tfa/__init__.py new file mode 100644 index 000000000..da573406b --- /dev/null +++ b/opennmt/tfa/__init__.py @@ -0,0 +1,2 @@ +from opennmt.tfa import optimizers, rnn, seq2seq, text +from opennmt.tfa.utils import types diff --git a/opennmt/tfa/conftest.py b/opennmt/tfa/conftest.py new file mode 100644 index 000000000..3e7f2950c --- /dev/null +++ b/opennmt/tfa/conftest.py @@ -0,0 +1,20 @@ +import numpy as np +import pytest +import tensorflow as tf + +import opennmt.tfa as tfa + +from opennmt.tfa.utils.test_utils import ( # noqa: F401 + data_format, + device, + maybe_run_functions_eagerly, + only_run_functions_eagerly, + pytest_addoption, + pytest_collection_modifyitems, + pytest_configure, + pytest_generate_tests, + pytest_make_parametrize_id, + run_with_mixed_precision_policy, + set_global_variables, + set_seeds, +) diff --git a/opennmt/tfa/optimizers/__init__.py b/opennmt/tfa/optimizers/__init__.py new file mode 100644 index 000000000..211abe1e6 --- /dev/null +++ b/opennmt/tfa/optimizers/__init__.py @@ -0,0 +1,6 @@ +from opennmt.tfa.optimizers.lazy_adam import LazyAdam +from opennmt.tfa.optimizers.weight_decay_optimizers import ( + AdamW, + DecoupledWeightDecayExtension, + extend_with_decoupled_weight_decay, +) diff --git a/opennmt/tfa/optimizers/lazy_adam.py b/opennmt/tfa/optimizers/lazy_adam.py new file mode 100644 index 000000000..4cc18d1c6 --- /dev/null +++ b/opennmt/tfa/optimizers/lazy_adam.py @@ -0,0 +1,156 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Variant of the Adam optimizer that handles sparse updates more efficiently. + +Compared with the original Adam optimizer, the one in this file can +provide a large improvement in model training throughput for some +applications. However, it provides slightly different semantics than the +original Adam algorithm, and may lead to different empirical results. +""" + +import importlib + +from typing import Callable, Union + +import tensorflow as tf + +from typeguard import typechecked + +from opennmt.tfa.utils.types import FloatTensorLike + +if importlib.util.find_spec("tensorflow.keras.optimizers.legacy") is not None: + adam_optimizer_class = tf.keras.optimizers.legacy.Adam +else: + adam_optimizer_class = tf.keras.optimizers.Adam + + +@tf.keras.utils.register_keras_serializable(package="Addons") +class LazyAdam(adam_optimizer_class): + """Variant of the Adam optimizer that handles sparse updates more + efficiently. + + The original Adam algorithm maintains two moving-average accumulators for + each trainable variable; the accumulators are updated at every step. + This class provides lazier handling of gradient updates for sparse + variables. It only updates moving-average accumulators for sparse variable + indices that appear in the current batch, rather than updating the + accumulators for all indices. Compared with the original Adam optimizer, + it can provide large improvements in model training throughput for some + applications. However, it provides slightly different semantics than the + original Adam algorithm, and may lead to different empirical results. + + Note, amsgrad is currently not supported and the argument can only be + False. + """ + + @typechecked + def __init__( + self, + learning_rate: Union[FloatTensorLike, Callable] = 0.001, + beta_1: FloatTensorLike = 0.9, + beta_2: FloatTensorLike = 0.999, + epsilon: FloatTensorLike = 1e-7, + amsgrad: bool = False, + name: str = "LazyAdam", + **kwargs, + ): + """Constructs a new LazyAdam optimizer. + + Args: + learning_rate: A `Tensor` or a floating point value. or a schedule + that is a `tf.keras.optimizers.schedules.LearningRateSchedule` + The learning rate. + beta_1: A `float` value or a constant `float` tensor. + The exponential decay rate for the 1st moment estimates. + beta_2: A `float` value or a constant `float` tensor. + The exponential decay rate for the 2nd moment estimates. + epsilon: A small constant for numerical stability. + This epsilon is "epsilon hat" in + [Adam: A Method for Stochastic Optimization. Kingma et al., 2014] + (http://arxiv.org/abs/1412.6980) (in the formula just + before Section 2.1), not the epsilon in Algorithm 1 of the paper. + amsgrad: `boolean`. Whether to apply AMSGrad variant of this + algorithm from the paper "On the Convergence of Adam and beyond". + Note that this argument is currently not supported and the + argument can only be `False`. + name: Optional name for the operations created when applying + gradients. Defaults to "LazyAdam". + **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, + `lr`, `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` + is clip gradients by value, `decay` is included for backward + compatibility to allow time inverse decay of learning rate. `lr` + is included for backward compatibility, recommended to use + `learning_rate` instead. + """ + super().__init__( + learning_rate=learning_rate, + beta_1=beta_1, + beta_2=beta_2, + epsilon=epsilon, + amsgrad=amsgrad, + name=name, + **kwargs, + ) + + def _resource_apply_sparse(self, grad, var, indices): + var_dtype = var.dtype.base_dtype + lr_t = self._decayed_lr(var_dtype) + beta_1_t = self._get_hyper("beta_1", var_dtype) + beta_2_t = self._get_hyper("beta_2", var_dtype) + local_step = tf.cast(self.iterations + 1, var_dtype) + beta_1_power = tf.math.pow(beta_1_t, local_step) + beta_2_power = tf.math.pow(beta_2_t, local_step) + epsilon_t = tf.convert_to_tensor(self.epsilon, var_dtype) + lr = lr_t * tf.math.sqrt(1 - beta_2_power) / (1 - beta_1_power) + + # \\(m := beta1 * m + (1 - beta1) * g_t\\) + m = self.get_slot(var, "m") + m_t_slice = beta_1_t * tf.gather(m, indices) + (1 - beta_1_t) * grad + m_update_op = self._resource_scatter_update(m, indices, m_t_slice) + + # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) + v = self.get_slot(var, "v") + v_t_slice = beta_2_t * tf.gather(v, indices) + (1 - beta_2_t) * tf.math.square( + grad + ) + v_update_op = self._resource_scatter_update(v, indices, v_t_slice) + + # \\(variable += -learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) + var_slice = lr * m_t_slice / (tf.math.sqrt(v_t_slice) + epsilon_t) + var_update_op = self._resource_scatter_sub(var, indices, var_slice) + + return tf.group(*[var_update_op, m_update_op, v_update_op]) + + def _resource_scatter_update(self, resource, indices, update): + return self._resource_scatter_operate( + resource, indices, update, tf.raw_ops.ResourceScatterUpdate + ) + + def _resource_scatter_sub(self, resource, indices, update): + return self._resource_scatter_operate( + resource, indices, update, tf.raw_ops.ResourceScatterSub + ) + + def _resource_scatter_operate(self, resource, indices, update, resource_scatter_op): + resource_update_kwargs = { + "resource": resource.handle, + "indices": indices, + "updates": update, + } + + return resource_scatter_op(**resource_update_kwargs) + + def get_config(self): + return super().get_config() diff --git a/opennmt/tfa/optimizers/tests/lazy_adam_test.py b/opennmt/tfa/optimizers/tests/lazy_adam_test.py new file mode 100644 index 000000000..35dd80867 --- /dev/null +++ b/opennmt/tfa/optimizers/tests/lazy_adam_test.py @@ -0,0 +1,262 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for LazyAdam.""" + + +import numpy as np +import pytest +import tensorflow as tf + +from opennmt.tfa.optimizers import lazy_adam +from opennmt.tfa.utils import test_utils + + +def adam_update_numpy( + param, g_t, t, m, v, lr=0.001, beta1=0.9, beta2=0.999, epsilon=1e-7 +): + lr_t = lr * np.sqrt(1 - beta2 ** (t + 1)) / (1 - beta1 ** (t + 1)) + + m_t = beta1 * m + (1 - beta1) * g_t + v_t = beta2 * v + (1 - beta2) * g_t * g_t + + param_t = param - lr_t * m_t / (np.sqrt(v_t) + epsilon) + return param_t, m_t, v_t + + +def get_beta_accumulators(opt, dtype): + local_step = tf.cast(opt.iterations + 1, dtype) + beta_1_t = tf.cast(opt._get_hyper("beta_1"), dtype) + beta_1_power = tf.math.pow(beta_1_t, local_step) + beta_2_t = tf.cast(opt._get_hyper("beta_2"), dtype) + beta_2_power = tf.math.pow(beta_2_t, local_step) + return (beta_1_power, beta_2_power) + + +@pytest.mark.parametrize("dtype", [tf.half, tf.float32, tf.float64]) +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_sparse(dtype): + # TODO: remove the with tf.device when the execution on cpu is enforced + # See #1682 to track it. + with tf.device("CPU:0"): + _test_sparse(dtype) + + +def _test_sparse(dtype): + # Initialize tf for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.0, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.0, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = tf.Variable(var0_np) + var1 = tf.Variable(var1_np) + grads0_np_indices = np.array([0, 2], dtype=np.int32) + grads0 = tf.IndexedSlices( + tf.constant(grads0_np[grads0_np_indices]), + tf.constant(grads0_np_indices), + tf.constant([3]), + ) + grads1_np_indices = np.array([0, 2], dtype=np.int32) + grads1 = tf.IndexedSlices( + tf.constant(grads1_np[grads1_np_indices]), + tf.constant(grads1_np_indices), + tf.constant([3]), + ) + opt = lazy_adam.LazyAdam() + + # Fetch params to validate initial values + np.testing.assert_allclose([1.0, 1.0, 2.0], var0.numpy(), 1e-6, 1e-6) + np.testing.assert_allclose([3.0, 3.0, 4.0], var1.numpy(), 1e-6, 1e-6) + + # Run 3 steps of Adam + for t in range(3): + beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype) + test_utils.assert_allclose_according_to_type(0.9 ** (t + 1), beta_1_power) + test_utils.assert_allclose_according_to_type(0.999 ** (t + 1), beta_2_power) + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + test_utils.assert_allclose_according_to_type(var0_np, var0.numpy()) + test_utils.assert_allclose_according_to_type(var1_np, var1.numpy()) + + +@pytest.mark.parametrize("dtype", [tf.int32, tf.int64]) +@pytest.mark.with_device(["cpu", "gpu"]) +def test_sparse_device_placement(dtype): + + # If a GPU is available, tests that all optimizer ops can be placed on + # it (i.e. they have GPU kernels). + var = tf.Variable([[1.0], [2.0]]) + indices = tf.constant([0, 1], dtype=dtype) + + def g_sum(): + return tf.math.reduce_sum(tf.gather(var, indices)) + + optimizer = lazy_adam.LazyAdam(3.0) + optimizer.minimize(g_sum, var_list=[var]) + + +@pytest.mark.parametrize("dtype", [tf.half, tf.float32, tf.float64]) +def test_sparse_repeated_indices(dtype): + # todo: remove the with tf.device once the placement on cpu is enforced. + with tf.device("CPU:0"): + repeated_index_update_var = tf.Variable([[1], [2]], dtype=dtype) + aggregated_update_var = tf.Variable([[1], [2]], dtype=dtype) + grad_repeated_index = tf.IndexedSlices( + tf.constant([0.1, 0.1], shape=[2, 1], dtype=dtype), + tf.constant([1, 1]), + tf.constant([2, 1]), + ) + grad_aggregated = tf.IndexedSlices( + tf.constant([0.2], shape=[1, 1], dtype=dtype), + tf.constant([1]), + tf.constant([2, 1]), + ) + repeated_update_opt = lazy_adam.LazyAdam() + aggregated_update_opt = lazy_adam.LazyAdam() + for _ in range(3): + repeated_update_opt.apply_gradients( + [(grad_repeated_index, repeated_index_update_var)] + ) + aggregated_update_opt.apply_gradients( + [(grad_aggregated, aggregated_update_var)] + ) + np.testing.assert_allclose( + aggregated_update_var.numpy(), repeated_index_update_var.numpy() + ) + + +@pytest.mark.parametrize("use_callable_params", [True, False]) +@pytest.mark.parametrize("dtype", [tf.half, tf.float32, tf.float64]) +def test_basic(use_callable_params, dtype): + # Initialize tf for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = tf.Variable(var0_np) + var1 = tf.Variable(var1_np) + grads0 = tf.constant(grads0_np) + grads1 = tf.constant(grads1_np) + + def learning_rate(): + return 0.001 + + if not use_callable_params: + learning_rate = learning_rate() + + opt = lazy_adam.LazyAdam(learning_rate=learning_rate) + + # Run 3 steps of Adam + for t in range(3): + beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype) + test_utils.assert_allclose_according_to_type(0.9 ** (t + 1), beta_1_power) + test_utils.assert_allclose_according_to_type(0.999 ** (t + 1), beta_2_power) + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + test_utils.assert_allclose_according_to_type(var0_np, var0.numpy()) + test_utils.assert_allclose_according_to_type(var1_np, var1.numpy()) + + +@pytest.mark.parametrize("dtype", [tf.half, tf.float32, tf.float64]) +def test_tensor_learning_rate(dtype): + # Initialize tf for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = tf.Variable(var0_np) + var1 = tf.Variable(var1_np) + grads0 = tf.constant(grads0_np) + grads1 = tf.constant(grads1_np) + opt = lazy_adam.LazyAdam(tf.constant(0.001)) + + # Run 3 steps of Adam + for t in range(3): + beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype) + test_utils.assert_allclose_according_to_type(0.9 ** (t + 1), beta_1_power) + test_utils.assert_allclose_according_to_type(0.999 ** (t + 1), beta_2_power) + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + test_utils.assert_allclose_according_to_type(var0_np, var0.numpy()) + test_utils.assert_allclose_according_to_type(var1_np, var1.numpy()) + + +@pytest.mark.parametrize("dtype", [tf.half, tf.float32, tf.float64]) +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_sharing(dtype): + # Initialize tf for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = tf.Variable(var0_np) + var1 = tf.Variable(var1_np) + grads0 = tf.constant(grads0_np) + grads1 = tf.constant(grads1_np) + opt = lazy_adam.LazyAdam() + + # Fetch params to validate initial values + np.testing.assert_allclose([1.0, 2.0], var0.numpy()) + np.testing.assert_allclose([3.0, 4.0], var1.numpy()) + + # Run 3 steps of intertwined Adam1 and Adam2. + for t in range(3): + beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype) + test_utils.assert_allclose_according_to_type(0.9 ** (t + 1), beta_1_power) + test_utils.assert_allclose_according_to_type(0.999 ** (t + 1), beta_2_power) + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + test_utils.assert_allclose_according_to_type(var0_np, var0.numpy()) + test_utils.assert_allclose_according_to_type(var1_np, var1.numpy()) + + +def test_slots_unique_eager(): + v1 = tf.Variable(1.0) + v2 = tf.Variable(1.0) + opt = lazy_adam.LazyAdam(1.0) + opt.minimize(lambda: v1 + v2, var_list=[v1, v2]) + # There should be iteration, and two unique slot variables for v1 and v2. + assert 5 == len(opt.variables()) + assert opt.variables()[0] == opt.iterations + + +def test_serialization(): + optimizer = lazy_adam.LazyAdam() + config = tf.keras.optimizers.serialize(optimizer) + new_optimizer = tf.keras.optimizers.deserialize(config) + assert new_optimizer.get_config() == optimizer.get_config() diff --git a/opennmt/tfa/optimizers/tests/weight_decay_optimizers_test.py b/opennmt/tfa/optimizers/tests/weight_decay_optimizers_test.py new file mode 100644 index 000000000..be9d5de85 --- /dev/null +++ b/opennmt/tfa/optimizers/tests/weight_decay_optimizers_test.py @@ -0,0 +1,464 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for optimizers with weight decay.""" + +import importlib + +import numpy as np +import pytest +import tensorflow as tf + +from opennmt.tfa.optimizers import weight_decay_optimizers +from opennmt.tfa.utils import test_utils + +WEIGHT_DECAY = 0.01 + + +def do_test( + dtype, + optimizer, + update_fn, + do_sparse=False, + do_decay_var_list=False, + **optimizer_kwargs, +): + """The major test function. + + Args: + optimizer: The tensorflow optimizer class to be tested. + update_fn: The numpy update function of the optimizer, the function + signature must be + update_fn(var: np.array, + grad_t: np.array, + slot_vars: dict, + **kwargs) -> (updated_var, updated_slot_vars) + Note that slot_vars will be initialized to an empty dictionary + for each variable, initial values should be handled in the + update_fn. + do_sparse: If True, test sparse update. Defaults to False, i.e., + dense update. + do_decay_var_list: If True, test by passing a list of vars to + ensure hashing is handled correctly + **optimizer_kwargs:The parameters to pass to the construcor of the + optimizer. Either a constant or a callable. This also passed to + the optimizer_params in the update_fn. + """ + # TODO: Fix #347 issue + if do_sparse and test_utils.is_gpu_available(): + pytest.skip("Wait #347 to be fixed") + + # Initialize variables for numpy implementation. + np_slot_vars0, np_slot_vars1 = {}, {} + var0_np = np.array([1.0, 2.0], dtype=dtype[0].as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype[0].as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype[0].as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype[0].as_numpy_dtype) + # Create Tensorflow variables. + var0 = tf.Variable(var0_np, name="var0_%d" % dtype[1]) + var1 = tf.Variable(var1_np, name="var1_%d" % dtype[1]) + if do_sparse: + grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0 = tf.IndexedSlices( + tf.constant(grads0_np), tf.constant(grads0_np_indices), tf.constant([2]) + ) + grads1_np_indices = np.array([0, 1], dtype=np.int32) + grads1 = tf.IndexedSlices( + tf.constant(grads1_np), tf.constant(grads1_np_indices), tf.constant([2]) + ) + else: + grads0 = tf.constant(grads0_np) + grads1 = tf.constant(grads1_np) + opt = optimizer(**optimizer_kwargs) + # Create the update op. + # Run 3 steps of the optimizer + optimizer_kwargs.pop("exclude_from_weight_decay", None) + for _ in range(3): + if do_decay_var_list: + opt.apply_gradients( + zip([grads0, grads1], [var0, var1]), decay_var_list=[var0, var1] + ) + else: + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + var0_np, np_slot_vars0 = update_fn( + var0_np, grads0_np, np_slot_vars0, **optimizer_kwargs + ) + var1_np, np_slot_vars1 = update_fn( + var1_np, grads1_np, np_slot_vars1, **optimizer_kwargs + ) + # Validate updated params + test_utils.assert_allclose_according_to_type(var0_np, var0.numpy()) + test_utils.assert_allclose_according_to_type(var1_np, var1.numpy()) + + +def do_test_sparse_repeated_indices(dtype, optimizer, **optimizer_kwargs): + """Test for repeated indices in sparse updates. + + This test verifies that an update with repeated indices is the same as + an update with two times the gradient. + + Args: + optimizer: The tensorflow optimizer class to be tested. + **optimizer_kwargs: The parameters to pass to the construcor of the + optimizer. Either a constant or a callable. This also passed to + the optimizer_params in the update_fn. + """ + # TODO: Fix #347 issue + if test_utils.is_gpu_available(): + pytest.skip("Wait #347 to be fixed") + + repeated_index_update_var = tf.Variable([[1.0], [2.0]], dtype=dtype) + aggregated_update_var = tf.Variable([[1.0], [2.0]], dtype=dtype) + grad_repeated_index = tf.IndexedSlices( + tf.constant([0.1, 0.1], shape=[2, 1], dtype=dtype), + tf.constant([1, 1]), + tf.constant([2, 1]), + ) + grad_aggregated = tf.IndexedSlices( + tf.constant([0.2], shape=[1, 1], dtype=dtype), + tf.constant([1]), + tf.constant([2, 1]), + ) + opt_repeated = optimizer(**optimizer_kwargs) + _ = opt_repeated.apply_gradients([(grad_repeated_index, repeated_index_update_var)]) + opt_aggregated = optimizer(**optimizer_kwargs) + _ = opt_aggregated.apply_gradients([(grad_aggregated, aggregated_update_var)]) + np.testing.assert_allclose( + aggregated_update_var.numpy(), repeated_index_update_var.numpy() + ) + for _ in range(3): + opt_repeated.apply_gradients([(grad_repeated_index, repeated_index_update_var)]) + opt_aggregated.apply_gradients([(grad_aggregated, aggregated_update_var)]) + np.testing.assert_allclose( + aggregated_update_var.numpy(), repeated_index_update_var.numpy() + ) + + +def adamw_update_numpy( + param, grad_t, slot_vars, learning_rate, beta_1, beta_2, epsilon, weight_decay +): + """Numpy update function for AdamW.""" + lr, beta1, beta2, eps, wd = ( + v() if callable(v) else v + for v in (learning_rate, beta_1, beta_2, epsilon, weight_decay) + ) + t = slot_vars.get("t", 0) + 1 + lr_t = lr * np.sqrt(1 - beta2**t) / (1 - beta1**t) + slot_vars["m"] = beta1 * slot_vars.get("m", 0) + (1 - beta1) * grad_t + slot_vars["v"] = beta2 * slot_vars.get("v", 0) + (1 - beta2) * grad_t**2 + param_t = param * (1 - wd) - lr_t * slot_vars["m"] / (np.sqrt(slot_vars["v"]) + eps) + slot_vars["t"] = t + return param_t, slot_vars + + +def sgdw_update_numpy(param, grad_t, slot_vars, learning_rate, momentum, weight_decay): + """Numpy update function for SGDW.""" + m = slot_vars.get("m", 0) + lr, momentum, wd = ( + v() if callable(v) else v for v in (learning_rate, momentum, weight_decay) + ) + slot_vars["m"] = momentum * m + grad_t + param_t = param * (1 - wd) - lr * slot_vars["m"] + return param_t, slot_vars + + +@pytest.mark.parametrize("dtype", [(tf.half, 0), (tf.float32, 1), (tf.float64, 2)]) +def test_sparse_adamw(dtype): + do_test( + dtype, + weight_decay_optimizers.AdamW, + adamw_update_numpy, + do_sparse=True, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-8, + weight_decay=WEIGHT_DECAY, + ) + + +@pytest.mark.parametrize("dtype", [tf.half, tf.float32, tf.float64]) +def test_sparse_repeated_indices_adamw(dtype): + do_test_sparse_repeated_indices( + dtype, + weight_decay_optimizers.AdamW, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-8, + weight_decay=WEIGHT_DECAY, + ) + + +@pytest.mark.parametrize("dtype", [(tf.half, 0), (tf.float32, 1), (tf.float64, 2)]) +def test_basic_adamw(dtype): + do_test( + dtype, + weight_decay_optimizers.AdamW, + adamw_update_numpy, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-8, + weight_decay=WEIGHT_DECAY, + ) + + +@pytest.mark.parametrize("dtype", [(tf.half, 0), (tf.float32, 1), (tf.float64, 2)]) +def test_basic_callable_params_adamw(dtype): + do_test( + dtype, + weight_decay_optimizers.AdamW, + adamw_update_numpy, + learning_rate=lambda: 0.001, + beta_1=lambda: 0.9, + beta_2=lambda: 0.999, + epsilon=1e-8, + weight_decay=lambda: WEIGHT_DECAY, + ) + + +@pytest.mark.parametrize("dtype", [(tf.half, 0), (tf.float32, 1), (tf.float64, 2)]) +def test_basic_decay_var_list_adamw(dtype): + do_test( + dtype, + weight_decay_optimizers.AdamW, + adamw_update_numpy, + do_decay_var_list=True, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-8, + weight_decay=WEIGHT_DECAY, + ) + + +def test_exclude_weight_decay_adamw(): + optimizer = weight_decay_optimizers.AdamW( + learning_rate=1e-4, weight_decay=1e-4, exclude_from_weight_decay=["var1"] + ) + var0 = tf.Variable([], name="var0") + var1 = tf.Variable([], name="var1") + var1_weight = tf.Variable([], name="var1_weight") + + optimizer._set_decay_var_list([var0, var1, var1_weight]) + assert optimizer._do_use_weight_decay(var0) + assert not optimizer._do_use_weight_decay(var1) + assert not optimizer._do_use_weight_decay(var1_weight) + + +@pytest.mark.parametrize("dtype", [(tf.half, 0), (tf.float32, 1), (tf.float64, 2)]) +def test_var_list_with_exclude_list_adamw(dtype): + do_test( + dtype, + weight_decay_optimizers.AdamW, + adamw_update_numpy, + do_decay_var_list=True, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-8, + weight_decay=WEIGHT_DECAY, + exclude_from_weight_decay=["var0_*", "var1_*"], + ) + + +def test_keras_fit(): + """Check if calling model.fit works.""" + model = tf.keras.models.Sequential([tf.keras.layers.Dense(2)]) + loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + optimizer = weight_decay_optimizers.AdamW(learning_rate=1e-4, weight_decay=1e-4) + model.compile(optimizer=optimizer, loss=loss, metrics=["accuracy"]) + x, y = np.random.uniform(size=(2, 4, 1)) + model.fit(x, y, epochs=1) + + +def test_keras_fit_with_schedule(): + """Check if calling model.fit works with wd schedule.""" + model = tf.keras.models.Sequential([tf.keras.layers.Dense(2)]) + loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + wd_schedule = tf.keras.optimizers.schedules.ExponentialDecay( + WEIGHT_DECAY, decay_steps=10, decay_rate=0.9 + ) + optimizer = weight_decay_optimizers.AdamW( + learning_rate=1e-4, weight_decay=wd_schedule + ) + model.compile(optimizer=optimizer, loss=loss, metrics=["accuracy"]) + x, y = np.random.uniform(size=(2, 4, 1)) + model.fit(x, y, epochs=1) + + +@pytest.mark.with_device(["cpu", "gpu"]) +def test_weight_decay_with_piecewise_constant_decay_schedule(): + model = tf.keras.models.Sequential([tf.keras.layers.Dense(2)]) + loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + wd_schedule = tf.optimizers.schedules.PiecewiseConstantDecay([2], [1e-4, 1e-5]) + optimizer = weight_decay_optimizers.SGDW( + learning_rate=1e-2, weight_decay=wd_schedule + ) + model.compile(optimizer=optimizer, loss=loss, metrics=["accuracy"]) + x, y = np.random.uniform(size=(2, 4, 1)) + model.fit(x, y, batch_size=1, epochs=1) + + +@pytest.mark.parametrize("dtype", [(tf.half, 0), (tf.float32, 1), (tf.float64, 2)]) +def test_sparse_sgdw(dtype): + do_test( + dtype, + weight_decay_optimizers.SGDW, + sgdw_update_numpy, + do_sparse=True, + learning_rate=0.001, + momentum=0.9, + weight_decay=WEIGHT_DECAY, + ) + + +@pytest.mark.parametrize("dtype", [tf.half, tf.float32, tf.float64]) +def test_sparse_repeated_indices_sgdw(dtype): + do_test_sparse_repeated_indices( + dtype, + weight_decay_optimizers.SGDW, + learning_rate=0.001, + momentum=0.9, + weight_decay=WEIGHT_DECAY, + ) + + +@pytest.mark.parametrize("dtype", [(tf.half, 0), (tf.float32, 1), (tf.float64, 2)]) +def test_basic_sgdw(dtype): + do_test( + dtype, + weight_decay_optimizers.SGDW, + sgdw_update_numpy, + learning_rate=0.001, + momentum=0.9, + weight_decay=WEIGHT_DECAY, + ) + + +@pytest.mark.parametrize("dtype", [(tf.half, 0), (tf.float32, 1), (tf.float64, 2)]) +def test_basic_callable_params_sgdw(dtype): + do_test( + dtype, + weight_decay_optimizers.SGDW, + sgdw_update_numpy, + learning_rate=lambda: 0.001, + momentum=lambda: 0.9, + weight_decay=lambda: WEIGHT_DECAY, + ) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +@pytest.mark.parametrize("dtype", [(tf.half, 0), (tf.float32, 1), (tf.float64, 2)]) +def test_basic_decay_var_list_sgdw(dtype): + do_test( + dtype, + weight_decay_optimizers.SGDW, + sgdw_update_numpy, + do_decay_var_list=True, + learning_rate=0.001, + momentum=0.9, + weight_decay=WEIGHT_DECAY, + ) + + +def test_exclude_weight_decay_sgdw(): + optimizer = weight_decay_optimizers.SGDW( + learning_rate=0.01, weight_decay=1e-4, exclude_from_weight_decay=["var1"] + ) + var0 = tf.Variable([], name="var0") + var1 = tf.Variable([], name="var1") + var1_weight = tf.Variable([], name="var1_weight") + + optimizer._set_decay_var_list([var0, var1, var1_weight]) + assert optimizer._do_use_weight_decay(var0) + assert not optimizer._do_use_weight_decay(var1) + assert not optimizer._do_use_weight_decay(var1_weight) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +@pytest.mark.parametrize("dtype", [(tf.half, 0), (tf.float32, 1), (tf.float64, 2)]) +def test_var_list_with_exclude_list_sgdw(dtype): + do_test( + dtype, + weight_decay_optimizers.SGDW, + sgdw_update_numpy, + do_decay_var_list=True, + learning_rate=0.001, + momentum=0.9, + weight_decay=WEIGHT_DECAY, + exclude_from_weight_decay=["var0_*", "var1_*"], + ) + + +if importlib.util.find_spec("tensorflow.keras.optimizers.legacy") is not None: + optimizer_class = tf.keras.optimizers.legacy.SGD +else: + optimizer_class = tf.keras.optimizers.SGD + + +@pytest.mark.parametrize( + "optimizer", + [ + weight_decay_optimizers.SGDW, + weight_decay_optimizers.extend_with_decoupled_weight_decay(optimizer_class), + ], +) +@pytest.mark.parametrize("dtype", [(tf.half, 0), (tf.float32, 1), (tf.float64, 2)]) +def test_optimizer_basic(dtype, optimizer): + do_test( + dtype, + optimizer, + sgdw_update_numpy, + learning_rate=0.001, + momentum=0.9, + weight_decay=WEIGHT_DECAY, + ) + + +@pytest.mark.parametrize( + "optimizer", + [ + weight_decay_optimizers.SGDW, + weight_decay_optimizers.extend_with_decoupled_weight_decay(optimizer_class), + ], +) +@pytest.mark.parametrize("dtype", [tf.half, tf.float32, tf.float64]) +def test_optimizer_sparse(dtype, optimizer): + do_test_sparse_repeated_indices( + dtype, optimizer, learning_rate=0.001, momentum=0.9, weight_decay=WEIGHT_DECAY + ) + + +def test_serialization(): + optimizer = weight_decay_optimizers.AdamW( + learning_rate=1e-4, weight_decay=1e-4, exclude_from_weight_decay=["var1"] + ) + config = tf.keras.optimizers.serialize(optimizer) + new_optimizer = tf.keras.optimizers.deserialize(config) + assert new_optimizer.get_config() == optimizer.get_config() + + +def test_serialization_with_wd_schedule(): + wd_schedule = tf.keras.optimizers.schedules.ExponentialDecay( + WEIGHT_DECAY, decay_steps=10, decay_rate=0.9 + ) + optimizer = weight_decay_optimizers.AdamW( + learning_rate=1e-4, weight_decay=wd_schedule + ) + config = tf.keras.optimizers.serialize(optimizer) + new_optimizer = tf.keras.optimizers.deserialize(config) + assert new_optimizer.get_config() == optimizer.get_config() diff --git a/opennmt/tfa/optimizers/utils.py b/opennmt/tfa/optimizers/utils.py new file mode 100644 index 000000000..af1b43a27 --- /dev/null +++ b/opennmt/tfa/optimizers/utils.py @@ -0,0 +1,41 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Additional Utilities used for tfa.optimizers.""" + +import re + +from typing import List + +import tensorflow as tf + + +def get_variable_name(variable) -> str: + """Get the variable name from the variable tensor.""" + param_name = variable.name + m = re.match("^(.*):\\d+$", param_name) + if m is not None: + param_name = m.group(1) + return param_name + + +def is_variable_matched_by_regexes(variable, regexes: List[str]) -> bool: + """Whether variable is matched in regexes list by its name.""" + if regexes: + # var_name = get_variable_name(variable) + var_name = variable.name + for r in regexes: + if re.search(r, var_name): + return True + return False diff --git a/opennmt/tfa/optimizers/weight_decay_optimizers.py b/opennmt/tfa/optimizers/weight_decay_optimizers.py new file mode 100644 index 000000000..047ac5329 --- /dev/null +++ b/opennmt/tfa/optimizers/weight_decay_optimizers.py @@ -0,0 +1,554 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class to make optimizers weight decay ready.""" + +import importlib + +from typing import Callable, List, Optional, Type, Union + +import tensorflow as tf + +from typeguard import typechecked + +from opennmt.tfa.optimizers.utils import is_variable_matched_by_regexes +from opennmt.tfa.utils.types import FloatTensorLike + + +class DecoupledWeightDecayExtension: + """This class allows to extend optimizers with decoupled weight decay. + + It implements the decoupled weight decay described by [Loshchilov & Hutter] + (https://arxiv.org/pdf/1711.05101.pdf), in which the weight decay is + decoupled from the optimization steps w.r.t. to the loss function. + For SGD variants, this simplifies hyperparameter search since it decouples + the settings of weight decay and learning rate. + For adaptive gradient algorithms, it regularizes variables with large + gradients more than L2 regularization would, which was shown to yield + better training loss and generalization error in the paper above. + + This class alone is not an optimizer but rather extends existing + optimizers with decoupled weight decay. We explicitly define the two + examples used in the above paper (SGDW and AdamW), but in general this can + extend any OptimizerX class by using + `ExtendedCls = extend_with_decoupled_weight_decay(OptimizerX)`. + Weight decay can then be set when instantiating the optimizer: + `optimizerX = ExtendedCls(weight_decay=0.001, learning_rate=0.001)`. + In order for it to work, it must be the first class the Optimizer with + weight decay inherits from, e.g. + + ```python + class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam): + def __init__(self, weight_decay, *args, **kwargs): + super(AdamW, self).__init__(weight_decay, *args, **kwargs). + ``` + + Note: this extension decays weights BEFORE applying the update based + on the gradient, i.e. this extension only has the desired behaviour for + optimizers which do not depend on the value of'var' in the update step! + + Note: when applying a decay to the learning rate, be sure to manually apply + the decay to the `weight_decay` as well. For example: + + ```python + step = tf.Variable(0, trainable=False) + schedule = tf.optimizers.schedules.PiecewiseConstantDecay( + [10000, 15000], [1e-0, 1e-1, 1e-2]) + # lr and wd can be a function or a tensor + lr = 1e-1 * schedule(step) + wd = lambda: 1e-4 * schedule(step) + + # ... + + optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd) + ``` + """ + + @typechecked + def __init__( + self, + weight_decay: Union[FloatTensorLike, Callable], + exclude_from_weight_decay: Optional[List[str]] = None, + **kwargs, + ): + """Extension class that adds weight decay to an optimizer. + + Args: + weight_decay: A `Tensor`, a floating point value, or a schedule + that is a `tf.keras.optimizers.schedules.LearningRateSchedule` + to decay the variable by, in the update step. + exclude_from_weight_decay: List of regex patterns of + variables excluded from weight decay. Variables whose name + contain a substring matching the pattern will be excluded. + Note `decay_var_list` in `minimize` or `apply_gradients` takes + priority over `exclude_from_weight_decay` if specified. + **kwargs: Optional list or tuple or set of `Variable` objects to + decay. + """ + wd = kwargs.pop("weight_decay", weight_decay) + super().__init__(**kwargs) + self._decay_var_list = None # is set in minimize or apply_gradients + self._set_hyper("weight_decay", wd) + self.exclude_from_weight_decay = exclude_from_weight_decay + + def get_config(self): + config = super().get_config() + config.update( + { + "weight_decay": self._serialize_hyperparameter("weight_decay"), + "exclude_from_weight_decay": self.exclude_from_weight_decay, + } + ) + return config + + @classmethod + def from_config(cls, config, custom_objects=None): + # LR handling copied from optimizer_v2.OptimizerV2 + if "learning_rate" in config: + if isinstance(config["learning_rate"], dict): + config["learning_rate"] = tf.keras.optimizers.schedules.deserialize( + config["learning_rate"], custom_objects=custom_objects + ) + + if "weight_decay" in config: + if isinstance(config["weight_decay"], dict): + config["weight_decay"] = tf.keras.optimizers.schedules.deserialize( + config["weight_decay"], custom_objects=custom_objects + ) + + return cls(**config) + + def minimize( + self, + loss, + var_list, + grad_loss=None, + name=None, + decay_var_list=None, + tape=None, + ): + """Minimize `loss` by updating `var_list`. + + This method simply computes gradient using `tf.GradientTape` and calls + `apply_gradients()`. If you want to process the gradient before + applying then call `tf.GradientTape` and `apply_gradients()` explicitly + instead of using this function. + + Args: + loss: `Tensor` or callable. If a callable, `loss` should take no + arguments and return the value to minimize. If a `Tensor`, the + `tape` argument must be passed. + var_list: list or tuple of `Variable` objects to update to + minimize `loss`, or a callable returning the list or tuple of + `Variable` objects. Use callable when the variable list would + otherwise be incomplete before `minimize` since the variables + are created at the first time `loss` is called. + grad_loss: Optional. A `Tensor` holding the gradient computed for + `loss`. + decay_var_list: Optional list of variables to be decayed. Defaults + to all variables in var_list. Note `decay_var_list` takes + priority over `exclude_from_weight_decay` if specified. + name: Optional name for the returned operation. + tape: (Optional) `tf.GradientTape`. If `loss` is provided as a + `Tensor`, the tape that computed the `loss` must be provided. + Returns: + An Operation that updates the variables in `var_list`. + Raises: + ValueError: If some of the variables are not `Variable` objects. + """ + self._set_decay_var_list(var_list, decay_var_list) + return super().minimize( + loss, var_list=var_list, grad_loss=grad_loss, name=name, tape=tape + ) + + def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None, **kwargs): + """Apply gradients to variables. + + This is the second part of `minimize()`. It returns an `Operation` that + applies gradients. + + Args: + grads_and_vars: List of (gradient, variable) pairs. + name: Optional name for the returned operation. Default to the + name passed to the `Optimizer` constructor. + decay_var_list: Optional list of variables to be decayed. Defaults + to all variables in var_list. Note `decay_var_list` takes + priority over `exclude_from_weight_decay` if specified. + **kwargs: Additional arguments to pass to the base optimizer's + apply_gradient method, e.g., TF2.2 added an argument + `experimental_aggregate_gradients`. + Returns: + An `Operation` that applies the specified gradients. + Raises: + TypeError: If `grads_and_vars` is malformed. + ValueError: If none of the variables have gradients. + """ + grads_and_vars = list(grads_and_vars) + self._set_decay_var_list((v for _, v in grads_and_vars), decay_var_list) + return super().apply_gradients(grads_and_vars, name=name, **kwargs) + + def _decay_weights_op(self, var, apply_state=None): + if self._do_use_weight_decay(var): + var_device, var_dtype = var.device, var.dtype.base_dtype + coefficients = (apply_state or {}).get( + (var_device, var_dtype) + ) or self._fallback_apply_state(var_device, var_dtype) + + return var.assign_sub(coefficients["wd_t"] * var, self._use_locking) + return tf.no_op() + + def _decay_weights_sparse_op(self, var, indices, apply_state=None): + if self._do_use_weight_decay(var): + var_device, var_dtype = var.device, var.dtype.base_dtype + coefficients = (apply_state or {}).get( + (var_device, var_dtype) + ) or self._fallback_apply_state(var_device, var_dtype) + + update = -coefficients["wd_t"] * tf.gather(var, indices) + return self._resource_scatter_add(var, indices, update) + return tf.no_op() + + def _prepare_local(self, var_device, var_dtype, apply_state): + super(DecoupledWeightDecayExtension, self)._prepare_local( + var_device, var_dtype, apply_state + ) + + if "weight_decay" in self._hyper: + wd_t = tf.identity(self._decayed_wd(var_dtype)) + apply_state[(var_device, var_dtype)]["wd_t"] = wd_t + + def _decayed_wd(self, var_dtype): + wd_t = self._get_hyper("weight_decay", var_dtype) + + if isinstance(wd_t, tf.keras.optimizers.schedules.LearningRateSchedule): + wd_t = tf.cast(wd_t(self.iterations), var_dtype) + + return wd_t + + # Here, we overwrite the apply functions that the base optimizer calls. + # super().apply_x resolves to the apply_x function of the BaseOptimizer. + + def _resource_apply_dense(self, grad, var, apply_state=None): + with tf.control_dependencies( + [self._decay_weights_op(var, apply_state=apply_state)] + ): + return super()._resource_apply_dense(grad, var, apply_state=apply_state) + + def _resource_apply_sparse(self, grad, var, indices, apply_state=None): + decay_op = self._decay_weights_sparse_op(var, indices, apply_state=apply_state) + with tf.control_dependencies([decay_op]): + return super()._resource_apply_sparse( + grad, var, indices, apply_state=apply_state + ) + + def _set_decay_var_list(self, var_list, decay_var_list=None): + if decay_var_list: + self._decay_var_list = set(v.ref() for v in decay_var_list) + elif self.exclude_from_weight_decay: + self._decay_var_list = set( + v.ref() + for v in var_list + if not is_variable_matched_by_regexes(v, self.exclude_from_weight_decay) + ) + else: + self._decay_var_list = None + + def _do_use_weight_decay(self, var): + """Whether to use L2 weight decay for `var`.""" + if self._decay_var_list is None: + return True + return var.ref() in self._decay_var_list + + +if importlib.util.find_spec("tensorflow.keras.optimizers.legacy") is not None: + keras_legacy_optimizer = Union[ + tf.keras.optimizers.legacy.Optimizer, tf.keras.optimizers.Optimizer + ] +else: + keras_legacy_optimizer = tf.keras.optimizers.Optimizer + + +@typechecked +def extend_with_decoupled_weight_decay( + base_optimizer: Type[keras_legacy_optimizer], +) -> Type[keras_legacy_optimizer]: + """Factory function returning an optimizer class with decoupled weight + decay. + + Returns an optimizer class. An instance of the returned class computes the + update step of `base_optimizer` and additionally decays the weights. + E.g., the class returned by + `extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam)` is + equivalent to `tfa.optimizers.AdamW`. + + The API of the new optimizer class slightly differs from the API of the + base optimizer: + - The first argument to the constructor is the weight decay rate. + - Optional keyword argument `exclude_from_weight_decay` accepts list of + regex patterns of variables excluded from weight decay. Variables whose + name contain a substring matching the pattern will be excluded. + - `minimize` and `apply_gradients` accept the optional keyword argument + `decay_var_list`, which specifies the variables that should be decayed. + Note this takes priority over `exclude_from_weight_decay` if specified. + If both `None`, all variables that are optimized are decayed. + + Usage example: + ```python + # MyAdamW is a new class + MyAdamW = extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam) + # Create a MyAdamW object + optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001) + # update var1, var2 but only decay var1 + optimizer.minimize(loss, var_list=[var1, var2], decay_variables=[var1]) + + Note: this extension decays weights BEFORE applying the update based + on the gradient, i.e. this extension only has the desired behaviour for + optimizers which do not depend on the value of 'var' in the update step! + + Note: when applying a decay to the learning rate, be sure to manually apply + the decay to the `weight_decay` as well. For example: + + ```python + step = tf.Variable(0, trainable=False) + schedule = tf.optimizers.schedules.PiecewiseConstantDecay( + [10000, 15000], [1e-0, 1e-1, 1e-2]) + # lr and wd can be a function or a tensor + lr = 1e-1 * schedule(step) + wd = lambda: 1e-4 * schedule(step) + + # ... + + optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd) + ``` + + Note: you might want to register your own custom optimizer using + `tf.keras.utils.get_custom_objects()`. + + Args: + base_optimizer: An optimizer class that inherits from + tf.optimizers.Optimizer. + + Returns: + A new optimizer class that inherits from DecoupledWeightDecayExtension + and base_optimizer. + """ + + class OptimizerWithDecoupledWeightDecay( + DecoupledWeightDecayExtension, base_optimizer + ): + """Base_optimizer with decoupled weight decay. + + This class computes the update step of `base_optimizer` and + additionally decays the variable with the weight decay being + decoupled from the optimization steps w.r.t. to the loss + function, as described by [Loshchilov & Hutter] + (https://arxiv.org/pdf/1711.05101.pdf). For SGD variants, this + simplifies hyperparameter search since it decouples the settings + of weight decay and learning rate. For adaptive gradient + algorithms, it regularizes variables with large gradients more + than L2 regularization would, which was shown to yield better + training loss and generalization error in the paper above. + """ + + @typechecked + def __init__( + self, + weight_decay: Union[FloatTensorLike, Callable], + *args, + **kwargs, + ): + # super delegation is necessary here + super().__init__(weight_decay, *args, **kwargs) + + return OptimizerWithDecoupledWeightDecay + + +if hasattr(tf.keras.optimizers, "legacy"): + ADAM_CLASS = tf.keras.optimizers.legacy.Adam + SGD_CLASS = tf.keras.optimizers.legacy.SGD +else: + ADAM_CLASS = tf.keras.optimizers.Adam + SGD_CLASS = tf.keras.optimizers.SGD + + +@tf.keras.utils.register_keras_serializable(package="Addons") +class SGDW(DecoupledWeightDecayExtension, SGD_CLASS): + """Optimizer that implements the Momentum algorithm with weight_decay. + + This is an implementation of the SGDW optimizer described in "Decoupled + Weight Decay Regularization" by [Loshchilov & Hutter] + (https://arxiv.org/pdf/1711.05101.pdf). + It computes the update step of `tf.keras.optimizers.SGD` and additionally + decays the variable. Note that this is different from adding + L2 regularization on the variables to the loss. Decoupling the weight decay + from other hyperparameters (in particular the learning rate) simplifies + hyperparameter search. + + For further information see the documentation of the SGD Optimizer. + + This optimizer can also be instantiated as + ```python + extend_with_decoupled_weight_decay(tf.keras.optimizers.SGD, + weight_decay=weight_decay) + ``` + + Note: when applying a decay to the learning rate, be sure to manually apply + the decay to the `weight_decay` as well. For example: + + ```python + step = tf.Variable(0, trainable=False) + schedule = tf.optimizers.schedules.PiecewiseConstantDecay( + [10000, 15000], [1e-0, 1e-1, 1e-2]) + # lr and wd can be a function or a tensor + lr = 1e-1 * schedule(step) + wd = lambda: 1e-4 * schedule(step) + + # ... + + optimizer = tfa.optimizers.SGDW( + learning_rate=lr, weight_decay=wd, momentum=0.9) + ``` + """ + + @typechecked + def __init__( + self, + weight_decay: Union[FloatTensorLike, Callable], + learning_rate: Union[FloatTensorLike, Callable] = 0.001, + momentum: Union[FloatTensorLike, Callable] = 0.0, + nesterov: bool = False, + name: str = "SGDW", + **kwargs, + ): + """Construct a new SGDW optimizer. + + For further information see the documentation of the SGD Optimizer. + + Args: + learning_rate: float hyperparameter >= 0. Learning rate. + momentum: float hyperparameter >= 0 that accelerates SGD in the + relevant direction and dampens oscillations. + nesterov: boolean. Whether to apply Nesterov momentum. + name: Optional name prefix for the operations created when applying + gradients. Defaults to 'SGD'. + **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, + `lr`, `decay`, `exclude_from_weight_decay`}. `clipnorm` is clip + gradients by norm; `clipvalue` is clip gradients by value. + `decay` is included for backward compatibility to allow time + inverse decay of learning rate. `lr` is included for backward + compatibility, recommended to use `learning_rate` instead. + `exclude_from_weight_decay` accepts list of regex patterns of + variables excluded from weight decay. + """ + super().__init__( + weight_decay, + learning_rate=learning_rate, + momentum=momentum, + nesterov=nesterov, + name=name, + **kwargs, + ) + + +@tf.keras.utils.register_keras_serializable(package="Addons") +class AdamW(DecoupledWeightDecayExtension, ADAM_CLASS): + """Optimizer that implements the Adam algorithm with weight decay. + + This is an implementation of the AdamW optimizer described in "Decoupled + Weight Decay Regularization" by [Loshchilov & Hutter] + (https://arxiv.org/pdf/1711.05101.pdf). + + It computes the update step of `tf.keras.optimizers.Adam` and additionally + decays the variable. Note that this is different from adding L2 + regularization on the variables to the loss: it regularizes variables with + large gradients more than L2 regularization would, which was shown to yield + better training loss and generalization error in the paper above. + + For further information see the documentation of the Adam Optimizer. + + This optimizer can also be instantiated as + ```python + extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam, + weight_decay=weight_decay) + ``` + + Note: when applying a decay to the learning rate, be sure to manually apply + the decay to the `weight_decay` as well. For example: + + ```python + step = tf.Variable(0, trainable=False) + schedule = tf.optimizers.schedules.PiecewiseConstantDecay( + [10000, 15000], [1e-0, 1e-1, 1e-2]) + # lr and wd can be a function or a tensor + lr = 1e-1 * schedule(step) + wd = lambda: 1e-4 * schedule(step) + + # ... + + optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd) + ``` + """ + + @typechecked + def __init__( + self, + weight_decay: Union[FloatTensorLike, Callable], + learning_rate: Union[FloatTensorLike, Callable] = 0.001, + beta_1: Union[FloatTensorLike, Callable] = 0.9, + beta_2: Union[FloatTensorLike, Callable] = 0.999, + epsilon: FloatTensorLike = 1e-07, + amsgrad: bool = False, + name: str = "AdamW", + **kwargs, + ): + """Construct a new AdamW optimizer. + + For further information see the documentation of the Adam Optimizer. + + Args: + weight_decay: A Tensor or a floating point value. The weight decay. + learning_rate: A Tensor or a floating point value. The learning + rate. + beta_1: A float value or a constant float tensor. The exponential + decay rate for the 1st moment estimates. + beta_2: A float value or a constant float tensor. The exponential + decay rate for the 2nd moment estimates. + epsilon: A small constant for numerical stability. This epsilon is + "epsilon hat" in the Kingma and Ba paper (in the formula just + before Section 2.1), not the epsilon in Algorithm 1 of the + paper. + amsgrad: boolean. Whether to apply AMSGrad variant of this + algorithm from the paper "On the Convergence of Adam and + beyond". + name: Optional name for the operations created when applying + gradients. Defaults to "AdamW". + **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, + `lr`, `decay`, `exclude_from_weight_decay`}. `clipnorm` is clip + gradients by norm; `clipvalue` is clip gradients by value. + `decay` is included for backward compatibility to allow time + inverse decay of learning rate. `lr` is included for backward + compatibility, recommended to use `learning_rate` instead. + `exclude_from_weight_decay` accepts list of regex patterns of + variables excluded from weight decay. + """ + super().__init__( + weight_decay, + learning_rate=learning_rate, + beta_1=beta_1, + beta_2=beta_2, + epsilon=epsilon, + amsgrad=amsgrad, + name=name, + **kwargs, + ) diff --git a/opennmt/tfa/rnn/__init__.py b/opennmt/tfa/rnn/__init__.py new file mode 100644 index 000000000..0bde4b0d6 --- /dev/null +++ b/opennmt/tfa/rnn/__init__.py @@ -0,0 +1,2 @@ +from opennmt.tfa.rnn.abstract_rnn_cell import AbstractRNNCell +from opennmt.tfa.rnn.layer_norm_lstm_cell import LayerNormLSTMCell diff --git a/opennmt/tfa/rnn/abstract_rnn_cell.py b/opennmt/tfa/rnn/abstract_rnn_cell.py new file mode 100644 index 000000000..de5225bf7 --- /dev/null +++ b/opennmt/tfa/rnn/abstract_rnn_cell.py @@ -0,0 +1,133 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for RNN cells. + +Adapted from legacy github.com/keras-team/tf-keras. +""" + +import tensorflow as tf + + +def _generate_zero_filled_state_for_cell(cell, inputs, batch_size, dtype): + if inputs is not None: + batch_size = tf.shape(inputs)[0] + dtype = inputs.dtype + return _generate_zero_filled_state(batch_size, cell.state_size, dtype) + + +def _generate_zero_filled_state(batch_size_tensor, state_size, dtype): + """Generate a zero filled tensor with shape [batch_size, state_size].""" + if batch_size_tensor is None or dtype is None: + raise ValueError( + "batch_size and dtype cannot be None while constructing initial state: " + "batch_size={}, dtype={}".format(batch_size_tensor, dtype) + ) + + def create_zeros(unnested_state_size): + flat_dims = tf.TensorShape(unnested_state_size).as_list() + init_state_size = [batch_size_tensor] + flat_dims + return tf.zeros(init_state_size, dtype=dtype) + + if tf.nest.is_nested(state_size): + return tf.nest.map_structure(create_zeros, state_size) + else: + return create_zeros(state_size) + + +class AbstractRNNCell(tf.keras.layers.Layer): + """Abstract object representing an RNN cell. + + This is a base class for implementing RNN cells with custom behavior. + + Every `RNNCell` must have the properties below and implement `call` with + the signature `(output, next_state) = call(input, state)`. + + Examples: + + ```python + class MinimalRNNCell(AbstractRNNCell): + + def __init__(self, units, **kwargs): + self.units = units + super(MinimalRNNCell, self).__init__(**kwargs) + + @property + def state_size(self): + return self.units + + def build(self, input_shape): + self.kernel = self.add_weight(shape=(input_shape[-1], self.units), + initializer='uniform', + name='kernel') + self.recurrent_kernel = self.add_weight( + shape=(self.units, self.units), + initializer='uniform', + name='recurrent_kernel') + self.built = True + + def call(self, inputs, states): + prev_output = states[0] + h = backend.dot(inputs, self.kernel) + output = h + backend.dot(prev_output, self.recurrent_kernel) + return output, output + ``` + + This definition of cell differs from the definition used in the literature. + In the literature, 'cell' refers to an object with a single scalar output. + This definition refers to a horizontal array of such units. + + An RNN cell, in the most abstract setting, is anything that has + a state and performs some operation that takes a matrix of inputs. + This operation results in an output matrix with `self.output_size` columns. + If `self.state_size` is an integer, this operation also results in a new + state matrix with `self.state_size` columns. If `self.state_size` is a + (possibly nested tuple of) TensorShape object(s), then it should return a + matching structure of Tensors having shape `[batch_size].concatenate(s)` + for each `s` in `self.batch_size`. + """ + + def call(self, inputs, states): + """The function that contains the logic for one RNN step calculation. + + Args: + inputs: the input tensor, which is a slide from the overall RNN input by + the time dimension (usually the second dimension). + states: the state tensor from previous step, which has the same shape + as `(batch, state_size)`. In the case of timestep 0, it will be the + initial state user specified, or zero filled tensor otherwise. + + Returns: + A tuple of two tensors: + 1. output tensor for the current timestep, with size `output_size`. + 2. state tensor for next step, which has the shape of `state_size`. + """ + raise NotImplementedError("Abstract method") + + @property + def state_size(self): + """size(s) of state(s) used by this cell. + + It can be represented by an Integer, a TensorShape or a tuple of Integers + or TensorShapes. + """ + raise NotImplementedError("Abstract method") + + @property + def output_size(self): + """Integer or TensorShape: size of outputs produced by this cell.""" + raise NotImplementedError("Abstract method") + + def get_initial_state(self, inputs=None, batch_size=None, dtype=None): + return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype) diff --git a/opennmt/tfa/rnn/layer_norm_lstm_cell.py b/opennmt/tfa/rnn/layer_norm_lstm_cell.py new file mode 100644 index 000000000..5b876f554 --- /dev/null +++ b/opennmt/tfa/rnn/layer_norm_lstm_cell.py @@ -0,0 +1,210 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implements LayerNormLSTM Cell.""" + +import tensorflow as tf +import tensorflow.keras as keras + +from typeguard import typechecked + +from opennmt.tfa.utils.types import ( + Activation, + Constraint, + FloatTensorLike, + Initializer, + Regularizer, + TensorLike, +) + + +@tf.keras.utils.register_keras_serializable(package="Addons") +class LayerNormLSTMCell(keras.layers.LSTMCell): + """LSTM cell with layer normalization and recurrent dropout. + + This class adds layer normalization and recurrent dropout to a LSTM unit. + Layer normalization implementation is based on: + + https://arxiv.org/abs/1607.06450. + + "Layer Normalization" Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton + + and is applied before the internal nonlinearities. + Recurrent dropout is based on: + + https://arxiv.org/abs/1603.05118 + + "Recurrent Dropout without Memory Loss" + Stanislau Semeniuta, Aliaksei Severyn, Erhardt Barth. + + Example: + + >>> inputs = np.random.random([30,23,9]).astype(np.float32) + >>> lnLSTMCell = tfa.rnn.LayerNormLSTMCell(4) + >>> rnn = tf.keras.layers.RNN(lnLSTMCell, return_sequences=True, return_state=True) + >>> outputs, memory_state, carry_state = rnn(inputs) + >>> outputs.shape + TensorShape([30, 23, 4]) + >>> memory_state.shape + TensorShape([30, 4]) + >>> carry_state.shape + TensorShape([30, 4]) + """ + + @typechecked + def __init__( + self, + units: TensorLike, + activation: Activation = "tanh", + recurrent_activation: Activation = "sigmoid", + use_bias: bool = True, + kernel_initializer: Initializer = "glorot_uniform", + recurrent_initializer: Initializer = "orthogonal", + bias_initializer: Initializer = "zeros", + unit_forget_bias: bool = True, + kernel_regularizer: Regularizer = None, + recurrent_regularizer: Regularizer = None, + bias_regularizer: Regularizer = None, + kernel_constraint: Constraint = None, + recurrent_constraint: Constraint = None, + bias_constraint: Constraint = None, + dropout: FloatTensorLike = 0.0, + recurrent_dropout: FloatTensorLike = 0.0, + norm_gamma_initializer: Initializer = "ones", + norm_beta_initializer: Initializer = "zeros", + norm_epsilon: FloatTensorLike = 1e-3, + **kwargs, + ): + """Initializes the LSTM cell. + + Args: + units: Positive integer, dimensionality of the output space. + activation: Activation function to use. Default: hyperbolic tangent + (`tanh`). If you pass `None`, no activation is applied (ie. + "linear" activation: `a(x) = x`). + recurrent_activation: Activation function to use for the recurrent + step. Default: sigmoid (`sigmoid`). If you pass `None`, no + activation is applied (ie. "linear" activation: `a(x) = x`). + use_bias: Boolean, whether the layer uses a bias vector. + kernel_initializer: Initializer for the `kernel` weights matrix, used + for the linear transformation of the inputs. + recurrent_initializer: Initializer for the `recurrent_kernel` weights + matrix, used for the linear transformation of the recurrent state. + bias_initializer: Initializer for the bias vector. + unit_forget_bias: Boolean. If True, add 1 to the bias of the forget + gate at initialization. Setting it to true will also force + `bias_initializer="zeros"`. This is recommended in [Jozefowicz et + al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) + kernel_regularizer: Regularizer function applied to the `kernel` + weights matrix. + recurrent_regularizer: Regularizer function applied to + the `recurrent_kernel` weights matrix. + bias_regularizer: Regularizer function applied to the bias vector. + kernel_constraint: Constraint function applied to the `kernel` + weights matrix. + recurrent_constraint: Constraint function applied to the + `recurrent_kernel` weights matrix. + bias_constraint: Constraint function applied to the bias vector. + dropout: Float between 0 and 1. Fraction of the units to drop for the + linear transformation of the inputs. + recurrent_dropout: Float between 0 and 1. Fraction of the units to + drop for the linear transformation of the recurrent state. + norm_gamma_initializer: Initializer for the layer normalization gain + initial value. + norm_beta_initializer: Initializer for the layer normalization shift + initial value. + norm_epsilon: Float, the epsilon value for normalization layers. + **kwargs: Dict, the other keyword arguments for layer creation. + """ + super().__init__( + units, + activation=activation, + recurrent_activation=recurrent_activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + recurrent_initializer=recurrent_initializer, + bias_initializer=bias_initializer, + unit_forget_bias=unit_forget_bias, + kernel_regularizer=kernel_regularizer, + recurrent_regularizer=recurrent_regularizer, + bias_regularizer=bias_regularizer, + kernel_constraint=kernel_constraint, + recurrent_constraint=recurrent_constraint, + bias_constraint=bias_constraint, + dropout=dropout, + recurrent_dropout=recurrent_dropout, + **kwargs, + ) + self.norm_gamma_initializer = keras.initializers.get(norm_gamma_initializer) + self.norm_beta_initializer = keras.initializers.get(norm_beta_initializer) + self.norm_epsilon = norm_epsilon + self.kernel_norm = self._create_norm_layer("kernel_norm") + self.recurrent_norm = self._create_norm_layer("recurrent_norm") + self.state_norm = self._create_norm_layer("state_norm") + + def build(self, input_shape): + super().build(input_shape) + + def maybe_build_sublayer(sublayer, build_shape): + if not sublayer.built: + with tf.keras.backend.name_scope(sublayer.name): + sublayer.build(build_shape) + sublayer.built = True + + maybe_build_sublayer(self.kernel_norm, [input_shape[0], self.units * 4]) + maybe_build_sublayer(self.recurrent_norm, [input_shape[0], self.units * 4]) + maybe_build_sublayer(self.state_norm, [input_shape[0], self.units]) + + def call(self, inputs, states, training=None): + h_tm1 = states[0] # previous memory state + c_tm1 = states[1] # previous carry state + + dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4) + rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(h_tm1, training, count=4) + if 0.0 < self.dropout < 1.0: + inputs *= dp_mask[0] + z = self.kernel_norm(keras.backend.dot(inputs, self.kernel)) + + if 0.0 < self.recurrent_dropout < 1.0: + h_tm1 *= rec_dp_mask[0] + z += self.recurrent_norm(keras.backend.dot(h_tm1, self.recurrent_kernel)) + if self.use_bias: + z = keras.backend.bias_add(z, self.bias) + + z = tf.split(z, num_or_size_splits=4, axis=1) + c, o = self._compute_carry_and_output_fused(z, c_tm1) + c = self.state_norm(c) + h = o * self.activation(c) + return h, [h, c] + + def get_config(self): + config = { + "norm_gamma_initializer": keras.initializers.serialize( + self.norm_gamma_initializer + ), + "norm_beta_initializer": keras.initializers.serialize( + self.norm_beta_initializer + ), + "norm_epsilon": self.norm_epsilon, + } + base_config = super().get_config() + return {**base_config, **config} + + def _create_norm_layer(self, name): + return keras.layers.LayerNormalization( + beta_initializer=self.norm_beta_initializer, + gamma_initializer=self.norm_gamma_initializer, + epsilon=self.norm_epsilon, + name=name, + ) diff --git a/opennmt/tfa/rnn/tests/layer_norm_lstm_cell_test.py b/opennmt/tfa/rnn/tests/layer_norm_lstm_cell_test.py new file mode 100644 index 000000000..4fe7dde2c --- /dev/null +++ b/opennmt/tfa/rnn/tests/layer_norm_lstm_cell_test.py @@ -0,0 +1,213 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for LayerNormLSTM Cell.""" + +import numpy as np +import pytest +import tensorflow as tf +import tensorflow.keras as keras + +from packaging.version import Version + +from opennmt.tfa.rnn import LayerNormLSTMCell + + +def test_cell_output(): + x = tf.ones([1, 2], dtype=tf.float32) + c0 = tf.constant(0.1 * np.asarray([[0, 1]]), dtype=tf.float32) + h0 = tf.constant(0.1 * np.asarray([[2, 3]]), dtype=tf.float32) + state0 = [h0, c0] + c1 = tf.constant(0.1 * np.asarray([[4, 5]]), dtype=tf.float32) + h1 = tf.constant(0.1 * np.asarray([[6, 7]]), dtype=tf.float32) + state1 = [h1, c1] + state = (state0, state1) + const_initializer = tf.constant_initializer(0.5) + + def single_cell(): + return LayerNormLSTMCell( + units=2, + kernel_initializer=const_initializer, + recurrent_initializer=const_initializer, + bias_initializer=const_initializer, + norm_epsilon=1e-12, + ) + + cell = keras.layers.StackedRNNCells([single_cell() for _ in range(2)]) + output_v, output_states_v = cell(x, state) + + expected_output = np.array([[-0.47406167, 0.47406143]]) + expected_state0_c = np.array([[-1.0, 1.0]]) + expected_state0_h = np.array([[-0.47406167, 0.47406143]]) + expected_state1_c = np.array([[-1.0, 1.0]]) + expected_state1_h = np.array([[-0.47406167, 0.47406143]]) + + actual_state0_h = output_states_v[0][0] + actual_state0_c = output_states_v[0][1] + actual_state1_h = output_states_v[1][0] + actual_state1_c = output_states_v[1][1] + + np.testing.assert_allclose(output_v, expected_output, 1e-5) + np.testing.assert_allclose(expected_state0_c, actual_state0_c, 1e-5) + np.testing.assert_allclose(expected_state0_h, actual_state0_h, 1e-5) + np.testing.assert_allclose(expected_state1_c, actual_state1_c, 1e-5) + np.testing.assert_allclose(expected_state1_h, actual_state1_h, 1e-5) + + # Test BasicLSTMCell with input_size != num_units. + x = tf.ones([1, 3], dtype=tf.float32) + c = tf.constant(0.1 * np.asarray([[0, 1]]), dtype=tf.float32) + h = tf.constant(0.1 * np.asarray([[2, 3]]), dtype=tf.float32) + state = [h, c] + cell = LayerNormLSTMCell( + units=2, + kernel_initializer=const_initializer, + recurrent_initializer=const_initializer, + bias_initializer=const_initializer, + norm_epsilon=1e-12, + ) + output_v, output_states_v = cell(x, state) + expected_h = np.array([[-0.47406167, 0.47406143]]) + expected_c = np.array([[-1.0, 1.0]]) + np.testing.assert_allclose(output_v, expected_h, 1e-5) + np.testing.assert_allclose(output_states_v[0], expected_h, 1e-5) + np.testing.assert_allclose(output_states_v[1], expected_c, 1e-5) + + +@pytest.mark.skipif( + Version(tf.__version__) < Version("2.13"), + reason="TF2.13 Serialization method doesn't support legacy method on parent class", +) +def test_config_layer_norm_legacy(): + cell = LayerNormLSTMCell(10, name="layer_norm_lstm_cell_3") + + expected_config = { + "dtype": "float32", + "name": "layer_norm_lstm_cell_3", + "trainable": True, + "units": 10, + "activation": "tanh", + "recurrent_activation": "sigmoid", + "use_bias": True, + "kernel_initializer": { + "class_name": "GlorotUniform", + "config": {"seed": None}, + "module": "keras.initializers", + "registered_name": None, + }, + "recurrent_initializer": { + "class_name": "Orthogonal", + "config": {"seed": None, "gain": 1.0}, + "module": "keras.initializers", + "registered_name": None, + }, + "bias_initializer": { + "class_name": "Zeros", + "config": {}, + "module": "keras.initializers", + "registered_name": None, + }, + "unit_forget_bias": True, + "kernel_regularizer": None, + "recurrent_regularizer": None, + "bias_regularizer": None, + "kernel_constraint": None, + "recurrent_constraint": None, + "bias_constraint": None, + "dropout": 0.0, + "recurrent_dropout": 0.0, + "implementation": 2, + "norm_gamma_initializer": { + "class_name": "Ones", + "config": {}, + "module": "keras.initializers", + "registered_name": None, + }, + "norm_beta_initializer": { + "class_name": "Zeros", + "config": {}, + "module": "keras.initializers", + "registered_name": None, + }, + "norm_epsilon": 1e-3, + } + config = cell.get_config() + assert config == expected_config + + restored_cell = LayerNormLSTMCell.from_config(config) + restored_config = restored_cell.get_config() + assert config == restored_config + + +@pytest.mark.skipif( + Version(tf.__version__) >= Version("2.13"), + reason="TF2.13 Serialization method doesn't support legacy method on parent class", +) +def test_config_layer_norm(): + cell = LayerNormLSTMCell(10, name="layer_norm_lstm_cell_3") + + expected_config = { + "dtype": "float32", + "name": "layer_norm_lstm_cell_3", + "trainable": True, + "units": 10, + "activation": "tanh", + "recurrent_activation": "sigmoid", + "use_bias": True, + "kernel_initializer": { + "class_name": "GlorotUniform", + "config": {"seed": None}, + }, + "recurrent_initializer": { + "class_name": "Orthogonal", + "config": {"seed": None, "gain": 1.0}, + }, + "bias_initializer": {"class_name": "Zeros", "config": {}}, + "unit_forget_bias": True, + "kernel_regularizer": None, + "recurrent_regularizer": None, + "bias_regularizer": None, + "kernel_constraint": None, + "recurrent_constraint": None, + "bias_constraint": None, + "dropout": 0.0, + "recurrent_dropout": 0.0, + "implementation": 2, + "norm_gamma_initializer": {"class_name": "Ones", "config": {}}, + "norm_beta_initializer": {"class_name": "Zeros", "config": {}}, + "norm_epsilon": 1e-3, + } + config = cell.get_config() + assert config == expected_config + + restored_cell = LayerNormLSTMCell.from_config(config) + restored_config = restored_cell.get_config() + assert config == restored_config + + +def test_build(): + cell = LayerNormLSTMCell(10, name="layer_norm_lstm_cell") + cell( + inputs=tf.ones((12, 20)), + states=cell.get_initial_state(batch_size=12, dtype=tf.float32), + ) + assert len(cell.weights) == 9 + assert cell.weights[0].name == "layer_norm_lstm_cell/kernel:0" + assert cell.weights[1].name == "layer_norm_lstm_cell/recurrent_kernel:0" + assert cell.weights[2].name == "layer_norm_lstm_cell/bias:0" + assert cell.weights[3].name == "layer_norm_lstm_cell/kernel_norm/gamma:0" + assert cell.weights[4].name == "layer_norm_lstm_cell/kernel_norm/beta:0" + assert cell.weights[5].name == "layer_norm_lstm_cell/recurrent_norm/gamma:0" + assert cell.weights[6].name == "layer_norm_lstm_cell/recurrent_norm/beta:0" + assert cell.weights[7].name == "layer_norm_lstm_cell/state_norm/gamma:0" + assert cell.weights[8].name == "layer_norm_lstm_cell/state_norm/beta:0" diff --git a/opennmt/tfa/seq2seq/__init__.py b/opennmt/tfa/seq2seq/__init__.py new file mode 100644 index 000000000..6a46fbbfd --- /dev/null +++ b/opennmt/tfa/seq2seq/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Additional layers for sequence to sequence models.""" + +from opennmt.tfa.seq2seq.attention_wrapper import ( + AttentionMechanism, + AttentionWrapper, + AttentionWrapperState, + LuongAttention, + LuongMonotonicAttention, + hardmax, + monotonic_attention, + safe_cumprod, +) +from opennmt.tfa.seq2seq.beam_search_decoder import tile_batch +from opennmt.tfa.seq2seq.decoder import BaseDecoder, Decoder, dynamic_decode +from opennmt.tfa.seq2seq.loss import SequenceLoss, sequence_loss +from opennmt.tfa.seq2seq.sampler import Sampler diff --git a/opennmt/tfa/seq2seq/attention_wrapper.py b/opennmt/tfa/seq2seq/attention_wrapper.py new file mode 100644 index 000000000..8c344a434 --- /dev/null +++ b/opennmt/tfa/seq2seq/attention_wrapper.py @@ -0,0 +1,1739 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A dynamic attention wrapper for RNN cells.""" + +import collections +import functools +import math + +from typing import Callable, List, Optional, Union + +import numpy as np +import tensorflow as tf + +from packaging.version import Version +from typeguard import typechecked + +from opennmt.tfa.rnn.abstract_rnn_cell import AbstractRNNCell +from opennmt.tfa.utils import keras_utils +from opennmt.tfa.utils.types import ( + AcceptableDTypes, + FloatTensorLike, + Initializer, + Number, + TensorLike, +) + +if Version(tf.__version__) < Version("2.13"): + SERIALIZATION_ARGS = {} +else: + SERIALIZATION_ARGS = {"use_legacy_format": True} + + +class AttentionMechanism(tf.keras.layers.Layer): + """Base class for attention mechanisms. + + Common functionality includes: + 1. Storing the query and memory layers. + 2. Preprocessing and storing the memory. + + Note that this layer takes memory as its init parameter, which is an + anti-pattern of Keras API, we have to keep the memory as init parameter for + performance and dependency reason. Under the hood, during `__init__()`, it + will invoke `base_layer.__call__(memory, setup_memory=True)`. This will let + keras to keep track of the memory tensor as the input of this layer. Once + the `__init__()` is done, then user can query the attention by + `score = att_obj([query, state])`, and use it as a normal keras layer. + + Special attention is needed when adding using this class as the base layer + for new attention: + 1. Build() could be invoked at least twice. So please make sure weights + are not duplicated. + 2. Layer.get_weights() might return different set of weights if the + instance has `query_layer`. The query_layer weights is not initialized + until the memory is configured. + + Also note that this layer does not work with Keras model when + `model.compile(run_eagerly=True)` due to the fact that this layer is + stateful. The support for that will be added in a future version. + """ + + @typechecked + def __init__( + self, + memory: Union[TensorLike, None], + probability_fn: callable, + query_layer: Optional[tf.keras.layers.Layer] = None, + memory_layer: Optional[tf.keras.layers.Layer] = None, + memory_sequence_length: Optional[TensorLike] = None, + **kwargs, + ): + """Construct base AttentionMechanism class. + + Args: + memory: The memory to query; usually the output of an RNN encoder. + This tensor should be shaped `[batch_size, max_time, ...]`. + probability_fn: A `callable`. Converts the score and previous + alignments to probabilities. Its signature should be: + `probabilities = probability_fn(score, state)`. + query_layer: Optional `tf.keras.layers.Layer` instance. The layer's + depth must match the depth of `memory_layer`. If `query_layer` is + not provided, the shape of `query` must match that of + `memory_layer`. + memory_layer: Optional `tf.keras.layers.Layer` instance. The layer's + depth must match the depth of `query_layer`. + If `memory_layer` is not provided, the shape of `memory` must match + that of `query_layer`. + memory_sequence_length: (optional) Sequence lengths for the batch + entries in memory. If provided, the memory tensor rows are masked + with zeros for values past the respective sequence lengths. + **kwargs: Dictionary that contains other common arguments for layer + creation. + """ + self.query_layer = query_layer + self.memory_layer = memory_layer + super().__init__(**kwargs) + self.default_probability_fn = probability_fn + self.probability_fn = probability_fn + + self.keys = None + self.values = None + self.batch_size = None + self._memory_initialized = False + self._check_inner_dims_defined = True + self.supports_masking = True + + if memory is not None: + # Setup the memory by self.__call__() with memory and + # memory_seq_length. This will make the attention follow the keras + # convention which takes all the tensor inputs via __call__(). + if memory_sequence_length is None: + inputs = memory + else: + inputs = [memory, memory_sequence_length] + + self.values = super().__call__(inputs, setup_memory=True) + + @property + def memory_initialized(self): + """Returns `True` if this attention mechanism has been initialized with + a memory.""" + return self._memory_initialized + + def build(self, input_shape): + if not self._memory_initialized: + # This is for setting up the memory, which contains memory and + # optional memory_sequence_length. Build the memory_layer with + # memory shape. + if self.memory_layer is not None and not self.memory_layer.built: + if isinstance(input_shape, list): + self.memory_layer.build(input_shape[0]) + else: + self.memory_layer.build(input_shape) + else: + # The input_shape should be query.shape and state.shape. Use the + # query to init the query layer. + if self.query_layer is not None and not self.query_layer.built: + self.query_layer.build(input_shape[0]) + + def __call__(self, inputs, **kwargs): + """Preprocess the inputs before calling `base_layer.__call__()`. + + Note that there are situation here, one for setup memory, and one with + actual query and state. + 1. When the memory has not been configured, we just pass all the param + to `base_layer.__call__()`, which will then invoke `self.call()` with + proper inputs, which allows this class to setup memory. + 2. When the memory has already been setup, the input should contain + query and state, and optionally processed memory. If the processed + memory is not included in the input, we will have to append it to + the inputs and give it to the `base_layer.__call__()`. The processed + memory is the output of first invocation of `self.__call__()`. If we + don't add it here, then from keras perspective, the graph is + disconnected since the output from previous call is never used. + + Args: + inputs: the inputs tensors. + **kwargs: dict, other keyeword arguments for the `__call__()` + """ + # Allow manual memory reset + if kwargs.get("setup_memory", False): + self._memory_initialized = False + + if self._memory_initialized: + if len(inputs) not in (2, 3): + raise ValueError( + "Expect the inputs to have 2 or 3 tensors, got %d" % len(inputs) + ) + if len(inputs) == 2: + # We append the calculated memory here so that the graph will be + # connected. + inputs.append(self.values) + + return super().__call__(inputs, **kwargs) + + def call(self, inputs, mask=None, setup_memory=False, **kwargs): + """Setup the memory or query the attention. + + There are two case here, one for setup memory, and the second is query + the attention score. `setup_memory` is the flag to indicate which mode + it is. The input list will be treated differently based on that flag. + + Args: + inputs: a list of tensor that could either be `query` and `state`, or + `memory` and `memory_sequence_length`. + `query` is the tensor of dtype matching `memory` and shape + `[batch_size, query_depth]`. + `state` is the tensor of dtype matching `memory` and shape + `[batch_size, alignments_size]`. (`alignments_size` is memory's + `max_time`). + `memory` is the memory to query; usually the output of an RNN + encoder. The tensor should be shaped `[batch_size, max_time, ...]`. + `memory_sequence_length` (optional) is the sequence lengths for the + batch entries in memory. If provided, the memory tensor rows are + masked with zeros for values past the respective sequence lengths. + mask: optional bool tensor with shape `[batch, max_time]` for the + mask of memory. If it is not None, the corresponding item of the + memory should be filtered out during calculation. + setup_memory: boolean, whether the input is for setting up memory, or + query attention. + **kwargs: Dict, other keyword arguments for the call method. + Returns: + Either processed memory or attention score, based on `setup_memory`. + """ + if setup_memory: + if isinstance(inputs, list): + if len(inputs) not in (1, 2): + raise ValueError( + "Expect inputs to have 1 or 2 tensors, got %d" % len(inputs) + ) + memory = inputs[0] + memory_sequence_length = inputs[1] if len(inputs) == 2 else None + memory_mask = mask + else: + memory, memory_sequence_length = inputs, None + memory_mask = mask + self.setup_memory(memory, memory_sequence_length, memory_mask) + # We force the self.built to false here since only memory is, + # initialized but the real query/state has not been call() yet. The + # layer should be build and call again. + self.built = False + # Return the processed memory in order to create the Keras + # connectivity data for it. + return self.values + else: + if not self._memory_initialized: + raise ValueError( + "Cannot query the attention before the setup of memory" + ) + if len(inputs) not in (2, 3): + raise ValueError( + "Expect the inputs to have query, state, and optional " + "processed memory, got %d items" % len(inputs) + ) + # Ignore the rest of the inputs and only care about the query and + # state + query, state = inputs[0], inputs[1] + return self._calculate_attention(query, state) + + def setup_memory(self, memory, memory_sequence_length=None, memory_mask=None): + """Pre-process the memory before actually query the memory. + + This should only be called once at the first invocation of `call()`. + + Args: + memory: The memory to query; usually the output of an RNN encoder. + This tensor should be shaped `[batch_size, max_time, ...]`. + memory_sequence_length (optional): Sequence lengths for the batch + entries in memory. If provided, the memory tensor rows are masked + with zeros for values past the respective sequence lengths. + memory_mask: (Optional) The boolean tensor with shape `[batch_size, + max_time]`. For any value equal to False, the corresponding value + in memory should be ignored. + """ + if memory_sequence_length is not None and memory_mask is not None: + raise ValueError( + "memory_sequence_length and memory_mask cannot be " + "used at same time for attention." + ) + with tf.name_scope(self.name or "BaseAttentionMechanismInit"): + self.values = _prepare_memory( + memory, + memory_sequence_length=memory_sequence_length, + memory_mask=memory_mask, + check_inner_dims_defined=self._check_inner_dims_defined, + ) + # Mark the value as check since the memory and memory mask might not + # passed from __call__(), which does not have proper keras metadata. + # TODO(omalleyt12): Remove this hack once the mask the has proper + # keras history. + + def _mark_checked(tensor): + tensor._keras_history_checked = True # pylint: disable=protected-access + + tf.nest.map_structure(_mark_checked, self.values) + if self.memory_layer is not None: + self.keys = self.memory_layer(self.values) + else: + self.keys = self.values + self.batch_size = self.keys.shape[0] or tf.shape(self.keys)[0] + self._alignments_size = self.keys.shape[1] or tf.shape(self.keys)[1] + if memory_mask is not None or memory_sequence_length is not None: + unwrapped_probability_fn = self.default_probability_fn + + def _mask_probability_fn(score, prev): + return unwrapped_probability_fn( + _maybe_mask_score( + score, + memory_mask=memory_mask, + memory_sequence_length=memory_sequence_length, + score_mask_value=score.dtype.min, + ), + prev, + ) + + self.probability_fn = _mask_probability_fn + self._memory_initialized = True + + def _calculate_attention(self, query, state): + raise NotImplementedError( + "_calculate_attention need to be implemented by subclasses." + ) + + def compute_mask(self, inputs, mask=None): + # There real input of the attention is query and state, and the memory + # layer mask shouldn't be pass down. Returning None for all output mask + # here. + return None, None + + def get_config(self): + config = {} + # Since the probability_fn is likely to be a wrapped function, the child + # class should preserve the original function and how its wrapped. + + if self.query_layer is not None: + config["query_layer"] = { + "class_name": self.query_layer.__class__.__name__, + "config": self.query_layer.get_config(), + } + if self.memory_layer is not None: + config["memory_layer"] = { + "class_name": self.memory_layer.__class__.__name__, + "config": self.memory_layer.get_config(), + } + # memory is a required init parameter and its a tensor. It cannot be + # serialized to config, so we put a placeholder for it. + config["memory"] = None + base_config = super().get_config() + return {**base_config, **config} + + def _process_probability_fn(self, func_name): + """Helper method to retrieve the probably function by string input.""" + valid_probability_fns = { + "softmax": tf.nn.softmax, + "hardmax": hardmax, + } + if func_name not in valid_probability_fns.keys(): + raise ValueError( + "Invalid probability function: %s, options are %s" + % (func_name, valid_probability_fns.keys()) + ) + return valid_probability_fns[func_name] + + @classmethod + def deserialize_inner_layer_from_config(cls, config, custom_objects): + """Helper method that reconstruct the query and memory from the config. + + In the get_config() method, the query and memory layer configs are + serialized into dict for persistence, this method perform the reverse + action to reconstruct the layer from the config. + + Args: + config: dict, the configs that will be used to reconstruct the + object. + custom_objects: dict mapping class names (or function names) of + custom (non-Keras) objects to class/functions. + Returns: + config: dict, the config with layer instance created, which is ready + to be used as init parameters. + """ + # Reconstruct the query and memory layer for parent class. + # Instead of updating the input, create a copy and use that. + config = config.copy() + query_layer_config = config.pop("query_layer", None) + if query_layer_config: + query_layer = tf.keras.layers.deserialize( + query_layer_config, + custom_objects=custom_objects, + **SERIALIZATION_ARGS, + ) + config["query_layer"] = query_layer + memory_layer_config = config.pop("memory_layer", None) + if memory_layer_config: + memory_layer = tf.keras.layers.deserialize( + memory_layer_config, + custom_objects=custom_objects, + **SERIALIZATION_ARGS, + ) + config["memory_layer"] = memory_layer + return config + + @property + def alignments_size(self): + if isinstance(self._alignments_size, int): + return self._alignments_size + else: + return tf.TensorShape([None]) + + @property + def state_size(self): + return self.alignments_size + + def initial_alignments(self, batch_size, dtype): + """Creates the initial alignment values for the `tfa.seq2seq.AttentionWrapper` + class. + + This is important for attention mechanisms that use the previous + alignment to calculate the alignment at the next time step + (e.g. monotonic attention). + + The default behavior is to return a tensor of all zeros. + + Args: + batch_size: `int32` scalar, the batch_size. + dtype: The `dtype`. + + Returns: + A `dtype` tensor shaped `[batch_size, alignments_size]` + (`alignments_size` is the values' `max_time`). + """ + return tf.zeros([batch_size, self._alignments_size], dtype=dtype) + + def initial_state(self, batch_size, dtype): + """Creates the initial state values for the `tfa.seq2seq.AttentionWrapper` class. + + This is important for attention mechanisms that use the previous + alignment to calculate the alignment at the next time step + (e.g. monotonic attention). + + The default behavior is to return the same output as + `initial_alignments`. + + Args: + batch_size: `int32` scalar, the batch_size. + dtype: The `dtype`. + + Returns: + A structure of all-zero tensors with shapes as described by + `state_size`. + """ + return self.initial_alignments(batch_size, dtype) + + +def _luong_score(query, keys, scale): + """Implements Luong-style (multiplicative) scoring function. + + This attention has two forms. The first is standard Luong attention, + as described in: + + Minh-Thang Luong, Hieu Pham, Christopher D. Manning. + "Effective Approaches to Attention-based Neural Machine Translation." + EMNLP 2015. https://arxiv.org/abs/1508.04025 + + The second is the scaled form inspired partly by the normalized form of + Bahdanau attention. + + To enable the second form, call this function with `scale=True`. + + Args: + query: Tensor, shape `[batch_size, num_units]` to compare to keys. + keys: Processed memory, shape `[batch_size, max_time, num_units]`. + scale: the optional tensor to scale the attention score. + + Returns: + A `[batch_size, max_time]` tensor of unnormalized score values. + + Raises: + ValueError: If `key` and `query` depths do not match. + """ + depth = query.shape[-1] + key_units = keys.shape[-1] + if depth != key_units: + raise ValueError( + "Incompatible or unknown inner dimensions between query and keys. " + "Query (%s) has units: %s. Keys (%s) have units: %s. " + "Perhaps you need to set num_units to the keys' dimension (%s)?" + % (query, depth, keys, key_units, key_units) + ) + + # Reshape from [batch_size, depth] to [batch_size, 1, depth] + # for matmul. + query = tf.expand_dims(query, 1) + + # Inner product along the query units dimension. + # matmul shapes: query is [batch_size, 1, depth] and + # keys is [batch_size, max_time, depth]. + # the inner product is asked to **transpose keys' inner shape** to get a + # batched matmul on: + # [batch_size, 1, depth] . [batch_size, depth, max_time] + # resulting in an output shape of: + # [batch_size, 1, max_time]. + # we then squeeze out the center singleton dimension. + score = tf.matmul(query, keys, transpose_b=True) + score = tf.squeeze(score, [1]) + + if scale is not None: + score = scale * score + return score + + +class LuongAttention(AttentionMechanism): + """Implements Luong-style (multiplicative) attention scoring. + + This attention has two forms. The first is standard Luong attention, + as described in: + + Minh-Thang Luong, Hieu Pham, Christopher D. Manning. + [Effective Approaches to Attention-based Neural Machine Translation. + EMNLP 2015.](https://arxiv.org/abs/1508.04025) + + The second is the scaled form inspired partly by the normalized form of + Bahdanau attention. + + To enable the second form, construct the object with parameter + `scale=True`. + """ + + @typechecked + def __init__( + self, + units: TensorLike, + memory: Optional[TensorLike] = None, + memory_sequence_length: Optional[TensorLike] = None, + scale: bool = False, + probability_fn: str = "softmax", + dtype: AcceptableDTypes = None, + name: str = "LuongAttention", + **kwargs, + ): + """Construct the AttentionMechanism mechanism. + + Args: + units: The depth of the attention mechanism. + memory: The memory to query; usually the output of an RNN encoder. + This tensor should be shaped `[batch_size, max_time, ...]`. + memory_sequence_length: (optional): Sequence lengths for the batch + entries in memory. If provided, the memory tensor rows are masked + with zeros for values past the respective sequence lengths. + scale: Python boolean. Whether to scale the energy term. + probability_fn: (optional) string, the name of function to convert + the attention score to probabilities. The default is `softmax` + which is `tf.nn.softmax`. Other options is `hardmax`, which is + hardmax() within this module. Any other value will result + intovalidation error. Default to use `softmax`. + dtype: The data type for the memory layer of the attention mechanism. + name: Name to use when creating ops. + **kwargs: Dictionary that contains other common arguments for layer + creation. + """ + # For LuongAttention, we only transform the memory layer; thus + # num_units **must** match expected the query depth. + self.probability_fn_name = probability_fn + probability_fn = self._process_probability_fn(self.probability_fn_name) + + def wrapped_probability_fn(score, _): + return probability_fn(score) + + memory_layer = kwargs.pop("memory_layer", None) + if not memory_layer: + memory_layer = tf.keras.layers.Dense( + units, name="memory_layer", use_bias=False, dtype=dtype + ) + self.units = units + self.scale = scale + self.scale_weight = None + super().__init__( + memory=memory, + memory_sequence_length=memory_sequence_length, + query_layer=None, + memory_layer=memory_layer, + probability_fn=wrapped_probability_fn, + name=name, + dtype=dtype, + **kwargs, + ) + + def build(self, input_shape): + super().build(input_shape) + if self.scale and self.scale_weight is None: + self.scale_weight = self.add_weight( + "attention_g", initializer=tf.ones_initializer, shape=() + ) + self.built = True + + def _calculate_attention(self, query, state): + """Score the query based on the keys and values. + + Args: + query: Tensor of dtype matching `self.values` and shape + `[batch_size, query_depth]`. + state: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` + (`alignments_size` is memory's `max_time`). + + Returns: + alignments: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` (`alignments_size` is memory's + `max_time`). + next_state: Same as the alignments. + """ + score = _luong_score(query, self.keys, self.scale_weight) + alignments = self.probability_fn(score, state) + next_state = alignments + return alignments, next_state + + def get_config(self): + config = { + "units": self.units, + "scale": self.scale, + "probability_fn": self.probability_fn_name, + } + base_config = super().get_config() + return {**base_config, **config} + + @classmethod + def from_config(cls, config, custom_objects=None): + config = AttentionMechanism.deserialize_inner_layer_from_config( + config, custom_objects=custom_objects + ) + return cls(**config) + + +def _bahdanau_score( + processed_query, keys, attention_v, attention_g=None, attention_b=None +): + """Implements Bahdanau-style (additive) scoring function. + + This attention has two forms. The first is Bahdanau attention, + as described in: + + Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio. + "Neural Machine Translation by Jointly Learning to Align and Translate." + ICLR 2015. https://arxiv.org/abs/1409.0473 + + The second is the normalized form. This form is inspired by the + weight normalization article: + + Tim Salimans, Diederik P. Kingma. + "Weight Normalization: A Simple Reparameterization to Accelerate + Training of Deep Neural Networks." + https://arxiv.org/abs/1602.07868 + + To enable the second form, set please pass in attention_g and attention_b. + + Args: + processed_query: Tensor, shape `[batch_size, num_units]` to compare to + keys. + keys: Processed memory, shape `[batch_size, max_time, num_units]`. + attention_v: Tensor, shape `[num_units]`. + attention_g: Optional scalar tensor for normalization. + attention_b: Optional tensor with shape `[num_units]` for normalization. + + Returns: + A `[batch_size, max_time]` tensor of unnormalized score values. + """ + # Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting. + processed_query = tf.expand_dims(processed_query, 1) + if attention_g is not None and attention_b is not None: + normed_v = ( + attention_g + * attention_v + * tf.math.rsqrt(tf.reduce_sum(tf.square(attention_v))) + ) + return tf.reduce_sum( + normed_v * tf.tanh(keys + processed_query + attention_b), [2] + ) + else: + return tf.reduce_sum(attention_v * tf.tanh(keys + processed_query), [2]) + + +def safe_cumprod(x: TensorLike, *args, **kwargs) -> tf.Tensor: + """Computes cumprod of x in logspace using cumsum to avoid underflow. + + The cumprod function and its gradient can result in numerical instabilities + when its argument has very small and/or zero values. As long as the + argument is all positive, we can instead compute the cumulative product as + exp(cumsum(log(x))). This function can be called identically to + tf.cumprod. + + Args: + x: Tensor to take the cumulative product of. + *args: Passed on to cumsum; these are identical to those in cumprod. + **kwargs: Passed on to cumsum; these are identical to those in cumprod. + Returns: + Cumulative product of x. + """ + with tf.name_scope("SafeCumprod"): + x = tf.convert_to_tensor(x, name="x") + tiny = np.finfo(x.dtype.as_numpy_dtype).tiny + return tf.exp( + tf.cumsum(tf.math.log(tf.clip_by_value(x, tiny, 1)), *args, **kwargs) + ) + + +def monotonic_attention( + p_choose_i: FloatTensorLike, previous_attention: FloatTensorLike, mode: str +) -> tf.Tensor: + """Computes monotonic attention distribution from choosing probabilities. + + Monotonic attention implies that the input sequence is processed in an + explicitly left-to-right manner when generating the output sequence. In + addition, once an input sequence element is attended to at a given output + timestep, elements occurring before it cannot be attended to at subsequent + output timesteps. This function generates attention distributions + according to these assumptions. For more information, see `Online and + Linear-Time Attention by Enforcing Monotonic Alignments`. + + Args: + p_choose_i: Probability of choosing input sequence/memory element i. + Should be of shape (batch_size, input_sequence_length), and should all + be in the range [0, 1]. + previous_attention: The attention distribution from the previous output + timestep. Should be of shape (batch_size, input_sequence_length). For + the first output timestep, preevious_attention[n] should be + [1, 0, 0, ..., 0] for all n in [0, ... batch_size - 1]. + mode: How to compute the attention distribution. Must be one of + 'recursive', 'parallel', or 'hard'. + * 'recursive' uses tf.scan to recursively compute the distribution. + This is slowest but is exact, general, and does not suffer from + numerical instabilities. + * 'parallel' uses parallelized cumulative-sum and cumulative-product + operations to compute a closed-form solution to the recurrence + relation defining the attention distribution. This makes it more + efficient than 'recursive', but it requires numerical checks which + make the distribution non-exact. This can be a problem in + particular when input_sequence_length is long and/or p_choose_i has + entries very close to 0 or 1. + * 'hard' requires that the probabilities in p_choose_i are all either + 0 or 1, and subsequently uses a more efficient and exact solution. + + Returns: + A tensor of shape (batch_size, input_sequence_length) representing the + attention distributions for each sequence in the batch. + + Raises: + ValueError: mode is not one of 'recursive', 'parallel', 'hard'. + """ + # Force things to be tensors + p_choose_i = tf.convert_to_tensor(p_choose_i, name="p_choose_i") + previous_attention = tf.convert_to_tensor( + previous_attention, name="previous_attention" + ) + if mode == "recursive": + # Use .shape[0] when it's not None, or fall back on symbolic shape + batch_size = p_choose_i.shape[0] or tf.shape(p_choose_i)[0] + # Compute [1, 1 - p_choose_i[0], 1 - p_choose_i[1], ..., 1 - p_choose_ + # i[-2]] + shifted_1mp_choose_i = tf.concat( + [tf.ones((batch_size, 1)), 1 - p_choose_i[:, :-1]], 1 + ) + # Compute attention distribution recursively as + # q[i] = (1 - p_choose_i[i - 1])*q[i - 1] + previous_attention[i] + # attention[i] = p_choose_i[i]*q[i] + attention = p_choose_i * tf.transpose( + tf.scan( + # Need to use reshape to remind TF of the shape between loop + # iterations + lambda x, yz: tf.reshape(yz[0] * x + yz[1], (batch_size,)), + # Loop variables yz[0] and yz[1] + [tf.transpose(shifted_1mp_choose_i), tf.transpose(previous_attention)], + # Initial value of x is just zeros + tf.zeros((batch_size,)), + ) + ) + elif mode == "parallel": + # safe_cumprod computes cumprod in logspace with numeric checks + cumprod_1mp_choose_i = safe_cumprod(1 - p_choose_i, axis=1, exclusive=True) + # Compute recurrence relation solution + attention = ( + p_choose_i + * cumprod_1mp_choose_i + * tf.cumsum( + previous_attention, + # Clip cumprod_1mp to avoid divide-by-zero + tf.clip_by_value(cumprod_1mp_choose_i, 1e-10, 1.0), + axis=1, + ) + ) + elif mode == "hard": + # Remove any probabilities before the index chosen last time step + p_choose_i *= tf.cumsum(previous_attention, axis=1) + # Now, use exclusive cumprod to remove probabilities after the first + # chosen index, like so: + # p_choose_i = [0, 0, 0, 1, 1, 0, 1, 1] + # cumprod(1 - p_choose_i, exclusive=True) = [1, 1, 1, 1, 0, 0, 0, 0] + # Product of above: [0, 0, 0, 1, 0, 0, 0, 0] + attention = p_choose_i * tf.math.cumprod(1 - p_choose_i, axis=1, exclusive=True) + else: + raise ValueError("mode must be 'recursive', 'parallel', or 'hard'.") + return attention + + +def _monotonic_probability_fn( + score, previous_alignments, sigmoid_noise, mode, seed=None +): + """Attention probability function for monotonic attention. + + Takes in unnormalized attention scores, adds pre-sigmoid noise to encourage + the model to make discrete attention decisions, passes them through a + sigmoid to obtain "choosing" probabilities, and then calls + monotonic_attention to obtain the attention distribution. For more + information, see + + Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck, + "Online and Linear-Time Attention by Enforcing Monotonic Alignments." + ICML 2017. https://arxiv.org/abs/1704.00784 + + Args: + score: Unnormalized attention scores, shape + `[batch_size, alignments_size]` + previous_alignments: Previous attention distribution, shape + `[batch_size, alignments_size]` + sigmoid_noise: Standard deviation of pre-sigmoid noise. Setting this + larger than 0 will encourage the model to produce large attention + scores, effectively making the choosing probabilities discrete and the + resulting attention distribution one-hot. It should be set to 0 at + test-time, and when hard attention is not desired. + mode: How to compute the attention distribution. Must be one of + 'recursive', 'parallel', or 'hard'. See the docstring for + `tfa.seq2seq.monotonic_attention` for more information. + seed: (optional) Random seed for pre-sigmoid noise. + + Returns: + A `[batch_size, alignments_size]`-shape tensor corresponding to the + resulting attention distribution. + """ + # Optionally add pre-sigmoid noise to the scores + if sigmoid_noise > 0: + noise = tf.random.normal(tf.shape(score), dtype=score.dtype, seed=seed) + score += sigmoid_noise * noise + # Compute "choosing" probabilities from the attention scores + if mode == "hard": + # When mode is hard, use a hard sigmoid + p_choose_i = tf.cast(score > 0, score.dtype) + else: + p_choose_i = tf.sigmoid(score) + # Convert from choosing probabilities to attention distribution + return monotonic_attention(p_choose_i, previous_alignments, mode) + + +class _BaseMonotonicAttentionMechanism(AttentionMechanism): + """Base attention mechanism for monotonic attention. + + Simply overrides the initial_alignments function to provide a dirac + distribution, which is needed in order for the monotonic attention + distributions to have the correct behavior. + """ + + def initial_alignments(self, batch_size, dtype): + """Creates the initial alignment values for the monotonic attentions. + + Initializes to dirac distributions, i.e. + [1, 0, 0, ...memory length..., 0] for all entries in the batch. + + Args: + batch_size: `int32` scalar, the batch_size. + dtype: The `dtype`. + + Returns: + A `dtype` tensor shaped `[batch_size, alignments_size]` + (`alignments_size` is the values' `max_time`). + """ + max_time = self._alignments_size + return tf.one_hot( + tf.zeros((batch_size,), dtype=tf.int32), max_time, dtype=dtype + ) + + +class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism): + """Monotonic attention mechanism with Luong-style energy function. + + This type of attention enforces a monotonic constraint on the attention + distributions; that is once the model attends to a given point in the + memory it can't attend to any prior points at subsequence output timesteps. + It achieves this by using the `_monotonic_probability_fn` instead of `softmax` + to construct its attention distributions. Otherwise, it is equivalent to + `tfa.seq2seq.LuongAttention`. This approach is proposed in + + [Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck, + "Online and Linear-Time Attention by Enforcing Monotonic Alignments." + ICML 2017.](https://arxiv.org/abs/1704.00784) + """ + + @typechecked + def __init__( + self, + units: TensorLike, + memory: Optional[TensorLike] = None, + memory_sequence_length: Optional[TensorLike] = None, + scale: bool = False, + sigmoid_noise: FloatTensorLike = 0.0, + sigmoid_noise_seed: Optional[FloatTensorLike] = None, + score_bias_init: FloatTensorLike = 0.0, + mode: str = "parallel", + dtype: AcceptableDTypes = None, + name: str = "LuongMonotonicAttention", + **kwargs, + ): + """Construct the attention mechanism. + + Args: + units: The depth of the query mechanism. + memory: The memory to query; usually the output of an RNN encoder. + This tensor should be shaped `[batch_size, max_time, ...]`. + memory_sequence_length: (optional): Sequence lengths for the batch + entries in memory. If provided, the memory tensor rows are masked + with zeros for values past the respective sequence lengths. + scale: Python boolean. Whether to scale the energy term. + sigmoid_noise: Standard deviation of pre-sigmoid noise. See the + docstring for `_monotonic_probability_fn` for more information. + sigmoid_noise_seed: (optional) Random seed for pre-sigmoid noise. + score_bias_init: Initial value for score bias scalar. It's + recommended to initialize this to a negative value when the length + of the memory is large. + mode: How to compute the attention distribution. Must be one of + 'recursive', 'parallel', or 'hard'. See the docstring for + `tfa.seq2seq.monotonic_attention` for more information. + dtype: The data type for the query and memory layers of the attention + mechanism. + name: Name to use when creating ops. + **kwargs: Dictionary that contains other common arguments for layer + creation. + """ + # Set up the monotonic probability fn with supplied parameters + wrapped_probability_fn = functools.partial( + _monotonic_probability_fn, + sigmoid_noise=sigmoid_noise, + mode=mode, + seed=sigmoid_noise_seed, + ) + memory_layer = kwargs.pop("memory_layer", None) + if not memory_layer: + memory_layer = tf.keras.layers.Dense( + units, name="memory_layer", use_bias=False, dtype=dtype + ) + self.units = units + self.scale = scale + self.sigmoid_noise = sigmoid_noise + self.sigmoid_noise_seed = sigmoid_noise_seed + self.score_bias_init = score_bias_init + self.mode = mode + self.attention_g = None + self.attention_score_bias = None + super().__init__( + memory=memory, + memory_sequence_length=memory_sequence_length, + query_layer=None, + memory_layer=memory_layer, + probability_fn=wrapped_probability_fn, + name=name, + dtype=dtype, + **kwargs, + ) + + def build(self, input_shape): + super().build(input_shape) + if self.scale and self.attention_g is None: + self.attention_g = self.add_weight( + "attention_g", initializer=tf.ones_initializer, shape=() + ) + if self.attention_score_bias is None: + self.attention_score_bias = self.add_weight( + "attention_score_bias", + shape=(), + initializer=tf.constant_initializer(self.score_bias_init), + ) + self.built = True + + def _calculate_attention(self, query, state): + """Score the query based on the keys and values. + + Args: + query: Tensor of dtype matching `self.values` and shape + `[batch_size, query_depth]`. + state: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` + (`alignments_size` is memory's `max_time`). + + Returns: + alignments: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` (`alignments_size` is memory's + `max_time`). + next_state: Same as alignments + """ + score = _luong_score(query, self.keys, self.attention_g) + score += self.attention_score_bias + alignments = self.probability_fn(score, state) + next_state = alignments + return alignments, next_state + + def get_config(self): + config = { + "units": self.units, + "scale": self.scale, + "sigmoid_noise": self.sigmoid_noise, + "sigmoid_noise_seed": self.sigmoid_noise_seed, + "score_bias_init": self.score_bias_init, + "mode": self.mode, + } + base_config = super().get_config() + return {**base_config, **config} + + @classmethod + def from_config(cls, config, custom_objects=None): + config = AttentionMechanism.deserialize_inner_layer_from_config( + config, custom_objects=custom_objects + ) + return cls(**config) + + +class AttentionWrapperState( + collections.namedtuple( + "AttentionWrapperState", + ( + "cell_state", + "attention", + "alignments", + "alignment_history", + "attention_state", + ), + ) +): + """State of a `tfa.seq2seq.AttentionWrapper`. + + Attributes: + cell_state: The state of the wrapped RNN cell at the previous time + step. + attention: The attention emitted at the previous time step. + alignments: A single or tuple of `Tensor`(s) containing the + alignments emitted at the previous time step for each attention + mechanism. + alignment_history: (if enabled) a single or tuple of `TensorArray`(s) + containing alignment matrices from all time steps for each attention + mechanism. Call `stack()` on each to convert to a `Tensor`. + attention_state: A single or tuple of nested objects + containing attention mechanism state for each attention mechanism. + The objects may contain Tensors or TensorArrays. + """ + + def clone(self, **kwargs): + """Clone this object, overriding components provided by kwargs. + + The new state fields' shape must match original state fields' shape. + This will be validated, and original fields' shape will be propagated + to new fields. + + Example: + + >>> batch_size = 1 + >>> memory = tf.random.normal(shape=[batch_size, 3, 100]) + >>> encoder_state = [tf.zeros((batch_size, 100)), tf.zeros((batch_size, 100))] + >>> attention_mechanism = tfa.seq2seq.LuongAttention(100, memory=memory, + memory_sequence_length=[3] * batch_size) + >>> attention_cell = tfa.seq2seq.AttentionWrapper(tf.keras.layers.LSTMCell(100), + attention_mechanism, attention_layer_size=10) + >>> decoder_initial_state = + attention_cell.get_initial_state(batch_size=batch_size, dtype=tf.float32) + >>> decoder_initial_state = decoder_initial_state.clone(cell_state=encoder_state) + + Args: + **kwargs: Any properties of the state object to replace in the + returned `AttentionWrapperState`. + + Returns: + A new `AttentionWrapperState` whose properties are the same as + this one, except any overridden properties as provided in `kwargs`. + """ + + def with_same_shape(old, new): + """Check and set new tensor's shape.""" + if isinstance(old, tf.Tensor) and isinstance(new, tf.Tensor): + if not tf.executing_eagerly(): + new_shape = tf.shape(new) + old_shape = tf.shape(old) + assert_equal = tf.debugging.assert_equal(new_shape, old_shape) + with tf.control_dependencies([assert_equal]): + # Add an identity op so that control deps can kick in. + return tf.identity(new) + else: + if old.shape.as_list() != new.shape.as_list(): + raise ValueError( + "The shape of the AttentionWrapperState is " + "expected to be same as the one to clone. " + "self.shape: %s, input.shape: %s" % (old.shape, new.shape) + ) + return new + return new + + return tf.nest.map_structure(with_same_shape, self, super()._replace(**kwargs)) + + +def _prepare_memory( + memory, memory_sequence_length=None, memory_mask=None, check_inner_dims_defined=True +): + """Convert to tensor and possibly mask `memory`. + + Args: + memory: `Tensor`, shaped `[batch_size, max_time, ...]`. + memory_sequence_length: `int32` `Tensor`, shaped `[batch_size]`. + memory_mask: `boolean` tensor with shape [batch_size, max_time]. The + memory should be skipped when the corresponding mask is False. + check_inner_dims_defined: Python boolean. If `True`, the `memory` + argument's shape is checked to ensure all but the two outermost + dimensions are fully defined. + + Returns: + A (possibly masked), checked, new `memory`. + + Raises: + ValueError: If `check_inner_dims_defined` is `True` and not + `memory.shape[2:].is_fully_defined()`. + """ + memory = tf.nest.map_structure( + lambda m: tf.convert_to_tensor(m, name="memory"), memory + ) + if memory_sequence_length is not None and memory_mask is not None: + raise ValueError( + "memory_sequence_length and memory_mask can't be provided at same time." + ) + if memory_sequence_length is not None: + memory_sequence_length = tf.convert_to_tensor( + memory_sequence_length, name="memory_sequence_length" + ) + if check_inner_dims_defined: + + def _check_dims(m): + if not m.shape[2:].is_fully_defined(): + raise ValueError( + "Expected memory %s to have fully defined inner dims, " + "but saw shape: %s" % (m.name, m.shape) + ) + + tf.nest.map_structure(_check_dims, memory) + if memory_sequence_length is None and memory_mask is None: + return memory + elif memory_sequence_length is not None: + seq_len_mask = tf.sequence_mask( + memory_sequence_length, + maxlen=tf.shape(tf.nest.flatten(memory)[0])[1], + dtype=tf.nest.flatten(memory)[0].dtype, + ) + else: + # For memory_mask is not None + seq_len_mask = tf.cast(memory_mask, dtype=tf.nest.flatten(memory)[0].dtype) + + def _maybe_mask(m, seq_len_mask): + """Mask the memory based on the memory mask.""" + rank = m.shape.ndims + rank = rank if rank is not None else tf.rank(m) + extra_ones = tf.ones(rank - 2, dtype=tf.int32) + seq_len_mask = tf.reshape( + seq_len_mask, tf.concat((tf.shape(seq_len_mask), extra_ones), 0) + ) + return m * seq_len_mask + + return tf.nest.map_structure(lambda m: _maybe_mask(m, seq_len_mask), memory) + + +def _maybe_mask_score( + score, memory_sequence_length=None, memory_mask=None, score_mask_value=None +): + """Mask the attention score based on the masks.""" + if memory_sequence_length is None and memory_mask is None: + return score + if memory_sequence_length is not None and memory_mask is not None: + raise ValueError( + "memory_sequence_length and memory_mask can't be provided at same time." + ) + if memory_sequence_length is not None: + message = "All values in memory_sequence_length must greater than zero." + with tf.control_dependencies( + [ + tf.debugging.assert_positive( # pylint: disable=bad-continuation + memory_sequence_length, message=message + ) + ] + ): + memory_mask = tf.sequence_mask( + memory_sequence_length, maxlen=tf.shape(score)[1] + ) + score_mask_values = score_mask_value * tf.ones_like(score) + return tf.where(memory_mask, score, score_mask_values) + + +def hardmax(logits: TensorLike, name: Optional[str] = None) -> tf.Tensor: + """Returns batched one-hot vectors. + + The depth index containing the `1` is that of the maximum logit value. + + Args: + logits: A batch tensor of logit values. + name: Name to use when creating ops. + Returns: + A batched one-hot tensor. + """ + with tf.name_scope(name or "Hardmax"): + logits = tf.convert_to_tensor(logits, name="logits") + depth = logits.shape[-1] or tf.shape(logits)[-1] + return tf.one_hot(tf.argmax(logits, -1), depth, dtype=logits.dtype) + + +def _compute_attention( + attention_mechanism, cell_output, attention_state, attention_layer +): + """Computes the attention and alignments for a given + attention_mechanism.""" + alignments, next_attention_state = attention_mechanism( + [cell_output, attention_state] + ) + + # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time] + expanded_alignments = tf.expand_dims(alignments, 1) + # Context is the inner product of alignments and values along the + # memory time dimension. + # alignments shape is + # [batch_size, 1, memory_time] + # attention_mechanism.values shape is + # [batch_size, memory_time, memory_size] + # the batched matmul is over memory_time, so the output shape is + # [batch_size, 1, memory_size]. + # we then squeeze out the singleton dim. + context_ = tf.matmul(expanded_alignments, attention_mechanism.values) + context_ = tf.squeeze(context_, [1]) + + if attention_layer is not None: + attention = attention_layer(tf.concat([cell_output, context_], 1)) + else: + attention = context_ + + return attention, alignments, next_attention_state + + +class AttentionWrapper(AbstractRNNCell): + """Wraps another RNN cell with attention. + + Example: + + >>> batch_size = 4 + >>> max_time = 7 + >>> hidden_size = 32 + >>> + >>> memory = tf.random.uniform([batch_size, max_time, hidden_size]) + >>> memory_sequence_length = tf.fill([batch_size], max_time) + >>> + >>> attention_mechanism = tfa.seq2seq.LuongAttention(hidden_size) + >>> attention_mechanism.setup_memory(memory, memory_sequence_length) + >>> + >>> cell = tf.keras.layers.LSTMCell(hidden_size) + >>> cell = tfa.seq2seq.AttentionWrapper( + ... cell, attention_mechanism, attention_layer_size=hidden_size) + >>> + >>> inputs = tf.random.uniform([batch_size, hidden_size]) + >>> state = cell.get_initial_state(inputs) + >>> + >>> outputs, state = cell(inputs, state) + >>> outputs.shape + TensorShape([4, 32]) + """ + + @typechecked + def __init__( + self, + cell: tf.keras.layers.Layer, + attention_mechanism: Union[AttentionMechanism, List[AttentionMechanism]], + attention_layer_size: Optional[Union[Number, List[Number]]] = None, + alignment_history: bool = False, + cell_input_fn: Optional[Callable] = None, + output_attention: bool = True, + initial_cell_state: Optional[TensorLike] = None, + name: Optional[str] = None, + attention_layer: Optional[ + Union[tf.keras.layers.Layer, List[tf.keras.layers.Layer]] + ] = None, + attention_fn: Optional[Callable] = None, + **kwargs, + ): + """Construct the `AttentionWrapper`. + + **NOTE** If you are using the `tfa.seq2seq.BeamSearchDecoder` with a cell wrapped + in `AttentionWrapper`, then you must ensure that: + + - The encoder output has been tiled to `beam_width` via + `tfa.seq2seq.tile_batch` (NOT `tf.tile`). + - The `batch_size` argument passed to the `get_initial_state` method of + this wrapper is equal to `true_batch_size * beam_width`. + - The initial state created with `get_initial_state` above contains a + `cell_state` value containing properly tiled final state from the + encoder. + + An example: + + >>> batch_size = 1 + >>> beam_width = 5 + >>> sequence_length = tf.convert_to_tensor([5]) + >>> encoder_outputs = tf.random.uniform(shape=(batch_size, 5, 10)) + >>> encoder_final_state = [tf.zeros((batch_size, 10)), tf.zeros((batch_size, 10))] + >>> tiled_encoder_outputs = tfa.seq2seq.tile_batch(encoder_outputs, + multiplier=beam_width) + >>> tiled_encoder_final_state = tfa.seq2seq.tile_batch(encoder_final_state, + multiplier=beam_width) + >>> tiled_sequence_length = tfa.seq2seq.tile_batch(sequence_length, + multiplier=beam_width) + >>> attention_mechanism = tfa.seq2seq.BahdanauAttention(10, + memory=tiled_encoder_outputs, memory_sequence_length=tiled_sequence_length) + >>> attention_cell = tfa.seq2seq.AttentionWrapper(tf.keras.layers.LSTMCell(10), + attention_mechanism) + >>> decoder_initial_state = + attention_cell.get_initial_state(batch_size=batch_size * beam_width, dtype=tf.float32) + >>> decoder_initial_state = + decoder_initial_state.clone(cell_state=tiled_encoder_final_state) + + Args: + cell: A layer that implements the `tf.keras.layers.AbstractRNNCell` + interface. + attention_mechanism: A list of `tfa.seq2seq.AttentionMechanism` + instances single instance. + attention_layer_size: A list of Python integers or a single Python + integer, the depth of the attention (output) layer(s). If `None` + (default), use the context as attention at each time step. + Otherwise, feed the context and cell output into the attention + layer to generate attention at each time step. If + `attention_mechanism` is a list, `attention_layer_size` must be a list + of the same length. If `attention_layer` is set, this must be `None`. + If `attention_fn` is set, it must guaranteed that the outputs of + `attention_fn` also meet the above requirements. + alignment_history: Python boolean, whether to store alignment history + from all time steps in the final output state (currently stored as + a time major `TensorArray` on which you must call `stack()`). + cell_input_fn: (optional) A `callable`. The default is: + `lambda inputs, attention: + tf.concat([inputs, attention], -1)`. + output_attention: Python bool. If `True` (default), the output at + each time step is the attention value. This is the behavior of + Luong-style attention mechanisms. If `False`, the output at each + time step is the output of `cell`. This is the behavior of + Bahdanau-style attention mechanisms. In both cases, the + `attention` tensor is propagated to the next time step via the + state and is used there. This flag only controls whether the + attention mechanism is propagated up to the next cell in an RNN + stack or to the top RNN output. + initial_cell_state: The initial state value to use for the cell when + the user calls `get_initial_state()`. Note that if this value is + provided now, and the user uses a `batch_size` argument of + `get_initial_state` which does not match the batch size of + `initial_cell_state`, proper behavior is not guaranteed. + name: Name to use when creating ops. + attention_layer: A list of `tf.keras.layers.Layer` instances or a + single `tf.keras.layers.Layer` instance taking the context + and cell output as inputs to generate attention at each time step. + If `None` (default), use the context as attention at each time step. + If `attention_mechanism` is a list, `attention_layer` must be a list of + the same length. If `attention_layer_size` is set, this must be + `None`. + attention_fn: An optional callable function that allows users to + provide their own customized attention function, which takes input + `(attention_mechanism, cell_output, attention_state, + attention_layer)` and outputs `(attention, alignments, + next_attention_state)`. If provided, the `attention_layer_size` should + be the size of the outputs of `attention_fn`. + **kwargs: Other keyword arguments for layer creation. + + Raises: + TypeError: `attention_layer_size` is not `None` and + (`attention_mechanism` is a list but `attention_layer_size` is not; + or vice versa). + ValueError: if `attention_layer_size` is not `None`, + `attention_mechanism` is a list, and its length does not match that + of `attention_layer_size`; if `attention_layer_size` and + `attention_layer` are set simultaneously. + """ + super().__init__(name=name, **kwargs) + keras_utils.assert_like_rnncell("cell", cell) + if isinstance(attention_mechanism, (list, tuple)): + self._is_multi = True + attention_mechanisms = list(attention_mechanism) + else: + self._is_multi = False + attention_mechanisms = [attention_mechanism] + + if cell_input_fn is None: + + def cell_input_fn(inputs, attention): + return tf.concat([inputs, attention], -1) + + if attention_layer_size is not None and attention_layer is not None: + raise ValueError( + "Only one of attention_layer_size and attention_layer should be set" + ) + + if attention_layer_size is not None: + attention_layer_sizes = tuple( + attention_layer_size + if isinstance(attention_layer_size, (list, tuple)) + else (attention_layer_size,) + ) + if len(attention_layer_sizes) != len(attention_mechanisms): + raise ValueError( + "If provided, attention_layer_size must contain exactly " + "one integer per attention_mechanism, saw: %d vs %d" + % (len(attention_layer_sizes), len(attention_mechanisms)) + ) + dtype = kwargs.get("dtype", None) + self._attention_layers = list( + tf.keras.layers.Dense( + attention_layer_size, + name="attention_layer", + use_bias=False, + dtype=dtype, + ) + for i, attention_layer_size in enumerate(attention_layer_sizes) + ) + elif attention_layer is not None: + self._attention_layers = list( + attention_layer + if isinstance(attention_layer, (list, tuple)) + else (attention_layer,) + ) + if len(self._attention_layers) != len(attention_mechanisms): + raise ValueError( + "If provided, attention_layer must contain exactly one " + "layer per attention_mechanism, saw: %d vs %d" + % (len(self._attention_layers), len(attention_mechanisms)) + ) + else: + self._attention_layers = None + + if attention_fn is None: + attention_fn = _compute_attention + self._attention_fn = attention_fn + self._attention_layer_size = None + + self._cell = cell + self._attention_mechanisms = attention_mechanisms + self._cell_input_fn = cell_input_fn + self._output_attention = output_attention + self._alignment_history = alignment_history + with tf.name_scope(name or "AttentionWrapperInit"): + if initial_cell_state is None: + self._initial_cell_state = None + else: + final_state_tensor = tf.nest.flatten(initial_cell_state)[-1] + state_batch_size = ( + final_state_tensor.shape[0] or tf.shape(final_state_tensor)[0] + ) + error_message = ( + "When constructing AttentionWrapper %s: " % self.name + + "Non-matching batch sizes between the memory " + "(encoder output) and initial_cell_state. Are you using " + "the BeamSearchDecoder? You may need to tile your " + "initial state via the tfa.seq2seq.tile_batch " + "function with argument multiple=beam_width." + ) + with tf.control_dependencies( + self._batch_size_checks( # pylint: disable=bad-continuation + state_batch_size, error_message + ) + ): + self._initial_cell_state = tf.nest.map_structure( + lambda s: tf.identity(s, name="check_initial_cell_state"), + initial_cell_state, + ) + + def _attention_mechanisms_checks(self): + for attention_mechanism in self._attention_mechanisms: + if not attention_mechanism.memory_initialized: + raise ValueError( + "The AttentionMechanism instances passed to " + "this AttentionWrapper should be initialized " + "with a memory first, either by passing it " + "to the AttentionMechanism constructor or " + "calling attention_mechanism.setup_memory()" + ) + + def _batch_size_checks(self, batch_size, error_message): + self._attention_mechanisms_checks() + return [ + tf.debugging.assert_equal( + batch_size, attention_mechanism.batch_size, message=error_message + ) + for attention_mechanism in self._attention_mechanisms + ] + + def _get_attention_layer_size(self): + if self._attention_layer_size is not None: + return self._attention_layer_size + self._attention_mechanisms_checks() + attention_output_sizes = ( + attention_mechanism.values.shape[-1] + for attention_mechanism in self._attention_mechanisms + ) + if self._attention_layers is None: + self._attention_layer_size = sum(attention_output_sizes) + else: + # Compute the layer output size from its input which is the + # concatenation of the cell output and the attention mechanism + # output. + self._attention_layer_size = sum( + layer.compute_output_shape( + [None, self._cell.output_size + attention_output_size] + )[-1] + for layer, attention_output_size in zip( + self._attention_layers, attention_output_sizes + ) + ) + return self._attention_layer_size + + def _item_or_tuple(self, seq): + """Returns `seq` as tuple or the singular element. + + Which is returned is determined by how the AttentionMechanism(s) were + passed to the constructor. + + Args: + seq: A non-empty sequence of items or generator. + + Returns: + Either the values in the sequence as a tuple if + AttentionMechanism(s) were passed to the constructor as a sequence + or the singular element. + """ + t = tuple(seq) + if self._is_multi: + return t + else: + return t[0] + + @property + def output_size(self): + if self._output_attention: + return self._get_attention_layer_size() + else: + return self._cell.output_size + + @property + def state_size(self): + """The `state_size` property of `tfa.seq2seq.AttentionWrapper`. + + Returns: + A `tfa.seq2seq.AttentionWrapperState` tuple containing shapes used + by this object. + """ + return AttentionWrapperState( + cell_state=self._cell.state_size, + attention=self._get_attention_layer_size(), + alignments=self._item_or_tuple( + a.alignments_size for a in self._attention_mechanisms + ), + attention_state=self._item_or_tuple( + a.state_size for a in self._attention_mechanisms + ), + alignment_history=self._item_or_tuple( + a.alignments_size if self._alignment_history else () + for a in self._attention_mechanisms + ), + ) # sometimes a TensorArray + + def get_initial_state(self, inputs=None, batch_size=None, dtype=None): + """Return an initial (zero) state tuple for this `tfa.seq2seq.AttentionWrapper`. + + **NOTE** Please see the initializer documentation for details of how + to call `get_initial_state` if using a `tfa.seq2seq.AttentionWrapper` + with a `tfa.seq2seq.BeamSearchDecoder`. + + Args: + inputs: The inputs that will be fed to this cell. + batch_size: `0D` integer tensor: the batch size. + dtype: The internal state data type. + + Returns: + An `tfa.seq2seq.AttentionWrapperState` tuple containing zeroed out tensors and, + possibly, empty `TensorArray` objects. + + Raises: + ValueError: (or, possibly at runtime, `InvalidArgument`), if + `batch_size` does not match the output size of the encoder passed + to the wrapper object at initialization time. + """ + if inputs is not None: + batch_size = tf.shape(inputs)[0] + dtype = inputs.dtype + with tf.name_scope( + type(self).__name__ + "ZeroState" + ): # pylint: disable=bad-continuation + if self._initial_cell_state is not None: + cell_state = self._initial_cell_state + else: + cell_state = self._cell.get_initial_state( + batch_size=batch_size, dtype=dtype + ) + error_message = ( + "When calling get_initial_state of AttentionWrapper %s: " % self.name + + "Non-matching batch sizes between the memory " + "(encoder output) and the requested batch size. Are you using " + "the BeamSearchDecoder? If so, make sure your encoder output " + "has been tiled to beam_width via " + "tfa.seq2seq.tile_batch, and the batch_size= argument " + "passed to get_initial_state is batch_size * beam_width." + ) + with tf.control_dependencies( + self._batch_size_checks(batch_size, error_message) + ): # pylint: disable=bad-continuation + cell_state = tf.nest.map_structure( + lambda s: tf.identity(s, name="checked_cell_state"), cell_state + ) + initial_alignments = [ + attention_mechanism.initial_alignments(batch_size, dtype) + for attention_mechanism in self._attention_mechanisms + ] + return AttentionWrapperState( + cell_state=cell_state, + attention=tf.zeros( + [batch_size, self._get_attention_layer_size()], dtype=dtype + ), + alignments=self._item_or_tuple(initial_alignments), + attention_state=self._item_or_tuple( + attention_mechanism.initial_state(batch_size, dtype) + for attention_mechanism in self._attention_mechanisms + ), + alignment_history=self._item_or_tuple( + tf.TensorArray( + dtype, size=0, dynamic_size=True, element_shape=alignment.shape + ) + if self._alignment_history + else () + for alignment in initial_alignments + ), + ) + + def call(self, inputs, state, **kwargs): + """Perform a step of attention-wrapped RNN. + + - Step 1: Mix the `inputs` and previous step's `attention` output via + `cell_input_fn`. + - Step 2: Call the wrapped `cell` with this input and its previous + state. + - Step 3: Score the cell's output with `attention_mechanism`. + - Step 4: Calculate the alignments by passing the score through the + `normalizer`. + - Step 5: Calculate the context vector as the inner product between the + alignments and the attention_mechanism's values (memory). + - Step 6: Calculate the attention output by concatenating the cell + output and context through the attention layer (a linear layer with + `attention_layer_size` outputs). + + Args: + inputs: (Possibly nested tuple of) Tensor, the input at this time + step. + state: An instance of `tfa.seq2seq.AttentionWrapperState` containing + tensors from the previous time step. + **kwargs: Dict, other keyword arguments for the cell call method. + + Returns: + A tuple `(attention_or_cell_output, next_state)`, where: + + - `attention_or_cell_output` depending on `output_attention`. + - `next_state` is an instance of `tfa.seq2seq.AttentionWrapperState` + containing the state calculated at this time step. + + Raises: + TypeError: If `state` is not an instance of `tfa.seq2seq.AttentionWrapperState`. + """ + if not isinstance(state, AttentionWrapperState): + try: + state = AttentionWrapperState(*state) + except TypeError: + raise TypeError( + "Expected state to be instance of AttentionWrapperState or " + "values that can construct AttentionWrapperState. " + "Received type %s instead." % type(state) + ) + + # Step 1: Calculate the true inputs to the cell based on the + # previous attention value. + cell_inputs = self._cell_input_fn(inputs, state.attention) + cell_state = state.cell_state + cell_output, next_cell_state = self._cell(cell_inputs, cell_state, **kwargs) + next_cell_state = tf.nest.pack_sequence_as( + cell_state, tf.nest.flatten(next_cell_state) + ) + + cell_batch_size = cell_output.shape[0] or tf.shape(cell_output)[0] + error_message = ( + "When applying AttentionWrapper %s: " % self.name + + "Non-matching batch sizes between the memory " + "(encoder output) and the query (decoder output). Are you using " + "the BeamSearchDecoder? You may need to tile your memory input " + "via the tfa.seq2seq.tile_batch function with argument " + "multiple=beam_width." + ) + with tf.control_dependencies( + self._batch_size_checks(cell_batch_size, error_message) + ): # pylint: disable=bad-continuation + cell_output = tf.identity(cell_output, name="checked_cell_output") + + if self._is_multi: + previous_attention_state = state.attention_state + previous_alignment_history = state.alignment_history + else: + previous_attention_state = [state.attention_state] + previous_alignment_history = [state.alignment_history] + + all_alignments = [] + all_attentions = [] + all_attention_states = [] + maybe_all_histories = [] + for i, attention_mechanism in enumerate(self._attention_mechanisms): + attention, alignments, next_attention_state = self._attention_fn( + attention_mechanism, + cell_output, + previous_attention_state[i], + self._attention_layers[i] if self._attention_layers else None, + ) + alignment_history = ( + previous_alignment_history[i].write( + previous_alignment_history[i].size(), alignments + ) + if self._alignment_history + else () + ) + + all_attention_states.append(next_attention_state) + all_alignments.append(alignments) + all_attentions.append(attention) + maybe_all_histories.append(alignment_history) + + attention = tf.concat(all_attentions, 1) + next_state = AttentionWrapperState( + cell_state=next_cell_state, + attention=attention, + attention_state=self._item_or_tuple(all_attention_states), + alignments=self._item_or_tuple(all_alignments), + alignment_history=self._item_or_tuple(maybe_all_histories), + ) + + if self._output_attention: + return attention, next_state + else: + return cell_output, next_state diff --git a/opennmt/tfa/seq2seq/beam_search_decoder.py b/opennmt/tfa/seq2seq/beam_search_decoder.py new file mode 100644 index 000000000..581837303 --- /dev/null +++ b/opennmt/tfa/seq2seq/beam_search_decoder.py @@ -0,0 +1,72 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A decoder that performs beam search.""" + +import collections + +from typing import Callable, Optional + +import numpy as np +import tensorflow as tf + +from typeguard import typechecked + +from opennmt.tfa.seq2seq import attention_wrapper, decoder +from opennmt.tfa.utils import keras_utils +from opennmt.tfa.utils.types import FloatTensorLike, Number, TensorLike + + +def _tile_batch(t, multiplier): + """Core single-tensor implementation of tile_batch.""" + t = tf.convert_to_tensor(t, name="t") + shape_t = tf.shape(t) + if t.shape.ndims is None or t.shape.ndims < 1: + raise ValueError("t must have statically known rank") + tiling = [1] * (t.shape.ndims + 1) + tiling[1] = multiplier + tiled_static_batch_size = ( + t.shape[0] * multiplier if t.shape[0] is not None else None + ) + tiled = tf.tile(tf.expand_dims(t, 1), tiling) + tiled = tf.reshape(tiled, tf.concat(([shape_t[0] * multiplier], shape_t[1:]), 0)) + tiled.set_shape(tf.TensorShape([tiled_static_batch_size]).concatenate(t.shape[1:])) + return tiled + + +def tile_batch(t: TensorLike, multiplier: int, name: Optional[str] = None) -> tf.Tensor: + """Tiles the batch dimension of a (possibly nested structure of) tensor(s). + + For each tensor t in a (possibly nested structure) of tensors, + this function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed + of minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a + shape `[batch_size * multiplier, s0, s1, ...]` composed of minibatch + entries `t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is + repeated `multiplier` times. + + Args: + t: `Tensor` shaped `[batch_size, ...]`. + multiplier: Python int. + name: Name scope for any created operations. + + Returns: + A (possibly nested structure of) `Tensor` shaped + `[batch_size * multiplier, ...]`. + + Raises: + ValueError: if tensor(s) `t` do not have a statically known rank or + the rank is < 1. + """ + with tf.name_scope(name or "tile_batch"): + return tf.nest.map_structure(lambda t_: _tile_batch(t_, multiplier), t) diff --git a/opennmt/tfa/seq2seq/decoder.py b/opennmt/tfa/seq2seq/decoder.py new file mode 100644 index 000000000..106d9fc23 --- /dev/null +++ b/opennmt/tfa/seq2seq/decoder.py @@ -0,0 +1,583 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base classes and functions for dynamic decoding.""" + +import abc + +from typing import Any, Optional, Tuple, Union + +import tensorflow as tf + +# TODO: Find public API alternatives to these +from tensorflow.python.ops import control_flow_util +from typeguard import typechecked + +from opennmt.tfa.utils.types import TensorLike + + +class Decoder(metaclass=abc.ABCMeta): + """An RNN Decoder abstract interface object. + + Concepts used by this interface: + - `inputs`: (structure of) tensors and TensorArrays that is passed as input + to the RNN cell composing the decoder, at each time step. + - `state`: (structure of) tensors and TensorArrays that is passed to the + RNN cell instance as the state. + - `finished`: boolean tensor telling whether each sequence in the batch is + finished. + - `training`: boolean whether it should behave in training mode or in + inference mode. + - `outputs`: instance of `tfa.seq2seq.BasicDecoderOutput`. Result of the decoding, at + each time step. + """ + + @property + def batch_size(self): + """The batch size of input values.""" + raise NotImplementedError + + @property + def output_size(self): + """A (possibly nested tuple of...) integer[s] or `TensorShape` + object[s].""" + raise NotImplementedError + + @property + def output_dtype(self): + """A (possibly nested tuple of...) dtype[s].""" + raise NotImplementedError + + @abc.abstractmethod + def initialize(self, name=None): + """Called before any decoding iterations. + + This methods must compute initial input values and initial state. + + Args: + name: Name scope for any created operations. + + Returns: + `(finished, initial_inputs, initial_state)`: initial values of + 'finished' flags, inputs and state. + """ + raise NotImplementedError + + @abc.abstractmethod + def step(self, time, inputs, state, training=None, name=None): + """Called per step of decoding (but only once for dynamic decoding). + + Args: + time: Scalar `int32` tensor. Current step number. + inputs: RNN cell input (possibly nested tuple of) tensor[s] for this + time step. + state: RNN cell state (possibly nested tuple of) tensor[s] from + previous time step. + training: Python boolean. Indicates whether the layer should behave + in training mode or in inference mode. Only relevant + when `dropout` or `recurrent_dropout` is used. + name: Name scope for any created operations. + + Returns: + `(outputs, next_state, next_inputs, finished)`: `outputs` is an + object containing the decoder output, `next_state` is a (structure + of) state tensors and TensorArrays, `next_inputs` is the tensor that + should be used as input for the next step, `finished` is a boolean + tensor telling whether the sequence is complete, for each sequence in + the batch. + """ + raise NotImplementedError + + def finalize(self, outputs, final_state, sequence_lengths): + raise NotImplementedError + + @property + def tracks_own_finished(self): + """Describes whether the Decoder keeps track of finished states. + + Most decoders will emit a true/false `finished` value independently + at each time step. In this case, the `tfa.seq2seq.dynamic_decode` function keeps + track of which batch entries are already finished, and performs a + logical OR to insert new batches to the finished set. + + Some decoders, however, shuffle batches / beams between time steps and + `tfa.seq2seq.dynamic_decode` will mix up the finished state across these entries + because it does not track the reshuffle across time steps. In this + case, it is up to the decoder to declare that it will keep track of its + own finished state by setting this property to `True`. + + Returns: + Python bool. + """ + return False + + +class BaseDecoder(tf.keras.layers.Layer): + """An RNN Decoder that is based on a Keras layer. + + Concepts used by this interface: + - `inputs`: (structure of) Tensors and TensorArrays that is passed as input + to the RNN cell composing the decoder, at each time step. + - `state`: (structure of) Tensors and TensorArrays that is passed to the + RNN cell instance as the state. + - `memory`: tensor that is usually the full output of the encoder, which + will be used for the attention wrapper for the RNN cell. + - `finished`: boolean tensor telling whether each sequence in the batch is + finished. + - `training`: boolean whether it should behave in training mode or in + inference mode. + - `outputs`: instance of `tfa.seq2seq.BasicDecoderOutput`. Result of the decoding, at + each time step. + """ + + @typechecked + def __init__( + self, + output_time_major: bool = False, + impute_finished: bool = False, + maximum_iterations: Optional[TensorLike] = None, + parallel_iterations: int = 32, + swap_memory: bool = False, + **kwargs, + ): + self.output_time_major = output_time_major + self.impute_finished = impute_finished + self.maximum_iterations = maximum_iterations + self.parallel_iterations = parallel_iterations + self.swap_memory = swap_memory + super().__init__(**kwargs) + + def call(self, inputs, initial_state=None, training=None, **kwargs): + init_kwargs = kwargs + init_kwargs["initial_state"] = initial_state + return dynamic_decode( + self, + output_time_major=self.output_time_major, + impute_finished=self.impute_finished, + maximum_iterations=self.maximum_iterations, + parallel_iterations=self.parallel_iterations, + swap_memory=self.swap_memory, + training=training, + decoder_init_input=inputs, + decoder_init_kwargs=init_kwargs, + ) + + @property + def batch_size(self): + """The batch size of input values.""" + raise NotImplementedError + + @property + def output_size(self): + """A (possibly nested tuple of...) integer[s] or `TensorShape` + object[s].""" + raise NotImplementedError + + @property + def output_dtype(self): + """A (possibly nested tuple of...) dtype[s].""" + raise NotImplementedError + + def initialize(self, inputs, initial_state=None, **kwargs): + """Called before any decoding iterations. + + This methods must compute initial input values and initial state. + + Args: + inputs: (structure of) tensors that contains the input for the + decoder. In the normal case, it's a tensor with shape + [batch, timestep, embedding]. + initial_state: (structure of) tensors that contains the initial state + for the RNN cell. + **kwargs: Other arguments that are passed in from layer.call() + method. It could contains item like input `sequence_length`, or + masking for input. + + Returns: + `(finished, initial_inputs, initial_state)`: initial values of + 'finished' flags, inputs and state. + """ + raise NotImplementedError + + def step(self, time, inputs, state, training): + """Called per step of decoding (but only once for dynamic decoding). + + Args: + time: Scalar `int32` tensor. Current step number. + inputs: RNN cell input (possibly nested tuple of) tensor[s] for this + time step. + state: RNN cell state (possibly nested tuple of) tensor[s] from + previous time step. + training: Python boolean. Indicates whether the layer should + behave in training mode or in inference mode. + + Returns: + `(outputs, next_state, next_inputs, finished)`: `outputs` is an + object containing the decoder output, `next_state` is a + (structure of) state tensors and TensorArrays, `next_inputs` is the + tensor that should be used as input for the next step, `finished` is + a boolean tensor telling whether the sequence is complete, for each + sequence in the batch. + """ + raise NotImplementedError + + def finalize(self, outputs, final_state, sequence_lengths): + raise NotImplementedError + + @property + def tracks_own_finished(self): + """Describes whether the Decoder keeps track of finished states. + + Most decoders will emit a true/false `finished` value independently + at each time step. In this case, the `tfa.seq2seq.dynamic_decode` function keeps + track of which batch entries are already finished, and performs a + logical OR to insert new batches to the finished set. + + Some decoders, however, shuffle batches / beams between time steps and + `tfa.seq2seq.dynamic_decode` will mix up the finished state across these entries + because it does not track the reshuffle across time steps. In this + case, it is up to the decoder to declare that it will keep track of its + own finished state by setting this property to `True`. + + Returns: + Python bool. + """ + return False + + # TODO(scottzhu): Add build/get_config/from_config and other layer methods. + + +@typechecked +def dynamic_decode( + decoder: Union[Decoder, BaseDecoder], + output_time_major: bool = False, + impute_finished: bool = False, + maximum_iterations: Optional[TensorLike] = None, + parallel_iterations: int = 32, + swap_memory: bool = False, + training: Optional[bool] = None, + scope: Optional[str] = None, + enable_tflite_convertible: bool = False, + **kwargs, +) -> Tuple[Any, Any, Any]: + """Runs dynamic decoding with a decoder. + + Calls `initialize()` once and `step()` repeatedly on the decoder object. + + Args: + decoder: A `tfa.seq2seq.Decoder` or `tfa.seq2seq.BaseDecoder` instance. + output_time_major: Python boolean. Default: `False` (batch major). If + `True`, outputs are returned as time major tensors (this mode is + faster). Otherwise, outputs are returned as batch major tensors (this + adds extra time to the computation). + impute_finished: Python boolean. If `True`, then states for batch + entries which are marked as finished get copied through and the + corresponding outputs get zeroed out. This causes some slowdown at + each time step, but ensures that the final state and outputs have + the correct values and that backprop ignores time steps that were + marked as finished. + maximum_iterations: A strictly positive `int32` scalar, the maximum + allowed number of decoding steps. Default is `None` (decode until the + decoder is fully done). + parallel_iterations: Argument passed to `tf.while_loop`. + swap_memory: Argument passed to `tf.while_loop`. + training: Python boolean. Indicates whether the layer should behave + in training mode or in inference mode. Only relevant + when `dropout` or `recurrent_dropout` is used. + scope: Optional name scope to use. + enable_tflite_convertible: Python boolean. If `True`, then the variables + of `TensorArray` become of 1-D static shape. Also zero pads in the + output tensor will be discarded. Default: `False`. + **kwargs: dict, other keyword arguments for dynamic_decode. It might + contain arguments for `BaseDecoder` to initialize, which takes all + tensor inputs during call(). + + Returns: + `(final_outputs, final_state, final_sequence_lengths)`. + + Raises: + ValueError: if `maximum_iterations` is provided but is not a scalar. + """ + with tf.name_scope(scope or "decoder"): + is_xla = ( + not tf.executing_eagerly() + and control_flow_util.GraphOrParentsInXlaContext( + tf.compat.v1.get_default_graph() + ) + ) + + if maximum_iterations is not None: + maximum_iterations = tf.convert_to_tensor( + maximum_iterations, dtype=tf.int32, name="maximum_iterations" + ) + if maximum_iterations.shape.ndims != 0: + raise ValueError("maximum_iterations must be a scalar") + tf.debugging.assert_greater( + maximum_iterations, + 0, + message="maximum_iterations should be greater than 0", + ) + elif is_xla: + raise ValueError("maximum_iterations is required for XLA compilation.") + + if isinstance(decoder, Decoder): + initial_finished, initial_inputs, initial_state = decoder.initialize() + else: + # For BaseDecoder that takes tensor inputs during call. + decoder_init_input = kwargs.pop("decoder_init_input", None) + decoder_init_kwargs = kwargs.pop("decoder_init_kwargs", {}) + initial_finished, initial_inputs, initial_state = decoder.initialize( + decoder_init_input, **decoder_init_kwargs + ) + + if enable_tflite_convertible: + # Assume the batch_size = 1 for inference. + # So we can change 2-D TensorArray into 1-D by reshaping it. + tf.debugging.assert_equal( + decoder.batch_size, + 1, + message="TFLite conversion requires a batch size of 1", + ) + zero_outputs = tf.nest.map_structure( + lambda shape, dtype: tf.reshape( + tf.zeros(_prepend_batch(decoder.batch_size, shape), dtype=dtype), + [-1], + ), + decoder.output_size, + decoder.output_dtype, + ) + else: + zero_outputs = tf.nest.map_structure( + lambda shape, dtype: tf.zeros( + _prepend_batch(decoder.batch_size, shape), dtype=dtype + ), + decoder.output_size, + decoder.output_dtype, + ) + + if maximum_iterations is not None: + initial_finished = tf.logical_or(initial_finished, 0 >= maximum_iterations) + initial_sequence_lengths = tf.zeros_like(initial_finished, dtype=tf.int32) + initial_time = tf.constant(0, dtype=tf.int32) + + def _shape(batch_size, from_shape): + if not isinstance(from_shape, tf.TensorShape) or from_shape.ndims == 0: + return None + else: + batch_size = tf.get_static_value( + tf.convert_to_tensor(batch_size, name="batch_size") + ) + return tf.TensorShape([batch_size]).concatenate(from_shape) + + dynamic_size = maximum_iterations is None or not is_xla + # The dynamic shape `TensorArray` is not allowed in TFLite yet. + dynamic_size = dynamic_size and (not enable_tflite_convertible) + + def _create_ta(s, d): + if enable_tflite_convertible: + # TFLite requires 1D element_shape. + if isinstance(s, tf.TensorShape) and s.ndims == 0: + s = (1,) + element_shape = s + else: + element_shape = _shape(decoder.batch_size, s) + return tf.TensorArray( + dtype=d, + size=0 if dynamic_size else maximum_iterations, + dynamic_size=dynamic_size, + element_shape=element_shape, + ) + + initial_outputs_ta = tf.nest.map_structure( + _create_ta, decoder.output_size, decoder.output_dtype + ) + + def condition( + unused_time, + unused_outputs_ta, + unused_state, + unused_inputs, + finished, + unused_sequence_lengths, + ): + return tf.logical_not(tf.reduce_all(finished)) + + def body(time, outputs_ta, state, inputs, finished, sequence_lengths): + """Internal while_loop body. + + Args: + time: scalar int32 tensor. + outputs_ta: structure of TensorArray. + state: (structure of) state tensors and TensorArrays. + inputs: (structure of) input tensors. + finished: bool tensor (keeping track of what's finished). + sequence_lengths: int32 tensor (keeping track of time of finish). + + Returns: + `(time + 1, outputs_ta, next_state, next_inputs, next_finished, + next_sequence_lengths)`. + ``` + """ + (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step( + time, inputs, state, training + ) + decoder_state_sequence_lengths = False + if decoder.tracks_own_finished: + next_finished = decoder_finished + lengths = getattr(decoder_state, "lengths", None) + if lengths is not None: + # sequence lengths are provided by decoder_state.lengths; + # overwrite our sequence lengths. + decoder_state_sequence_lengths = True + sequence_lengths = tf.cast(lengths, tf.int32) + else: + next_finished = tf.logical_or(decoder_finished, finished) + + if decoder_state_sequence_lengths: + # Just pass something through the loop; at the next iteration + # we'll pull the sequence lengths from the decoder_state again. + next_sequence_lengths = sequence_lengths + else: + next_sequence_lengths = tf.where( + tf.logical_not(finished), + tf.fill(tf.shape(sequence_lengths), time + 1), + sequence_lengths, + ) + + tf.nest.assert_same_structure(state, decoder_state) + tf.nest.assert_same_structure(outputs_ta, next_outputs) + tf.nest.assert_same_structure(inputs, next_inputs) + + # Zero out output values past finish + if impute_finished: + + def zero_out_finished(out, zero): + if finished.shape.rank < zero.shape.rank: + broadcast_finished = tf.broadcast_to( + tf.expand_dims(finished, axis=-1), zero.shape + ) + return tf.where(broadcast_finished, zero, out) + else: + return tf.where(finished, zero, out) + + emit = tf.nest.map_structure( + zero_out_finished, next_outputs, zero_outputs + ) + else: + emit = next_outputs + + # Copy through states past finish + def _maybe_copy_state(new, cur): + # TensorArrays and scalar states get passed through. + if isinstance(cur, tf.TensorArray): + pass_through = True + else: + new.set_shape(cur.shape) + pass_through = new.shape.ndims == 0 + if not pass_through: + broadcast_finished = tf.broadcast_to( + tf.expand_dims(finished, axis=-1), new.shape + ) + return tf.where(broadcast_finished, cur, new) + else: + return new + + if impute_finished: + next_state = tf.nest.map_structure( + _maybe_copy_state, decoder_state, state + ) + else: + next_state = decoder_state + + if enable_tflite_convertible: + # Reshape to 1-D. + emit = tf.nest.map_structure(lambda x: tf.reshape(x, [-1]), emit) + + outputs_ta = tf.nest.map_structure( + lambda ta, out: ta.write(time, out), outputs_ta, emit + ) + return ( + time + 1, + outputs_ta, + next_state, + next_inputs, + next_finished, + next_sequence_lengths, + ) + + res = tf.while_loop( + condition, + body, + loop_vars=( + initial_time, + initial_outputs_ta, + initial_state, + initial_inputs, + initial_finished, + initial_sequence_lengths, + ), + parallel_iterations=parallel_iterations, + maximum_iterations=maximum_iterations, + swap_memory=swap_memory, + ) + + final_outputs_ta = res[1] + final_state = res[2] + final_sequence_lengths = res[5] + + final_outputs = tf.nest.map_structure(lambda ta: ta.stack(), final_outputs_ta) + + try: + final_outputs, final_state = decoder.finalize( + final_outputs, final_state, final_sequence_lengths + ) + except NotImplementedError: + pass + + if not output_time_major: + if enable_tflite_convertible: + # Reshape the output to the original shape. + def _restore_batch(x): + return tf.expand_dims(x, [1]) + + final_outputs = tf.nest.map_structure(_restore_batch, final_outputs) + + final_outputs = tf.nest.map_structure(_transpose_batch_time, final_outputs) + + return final_outputs, final_state, final_sequence_lengths + + +def _prepend_batch(batch_size, shape): + """Prepends the batch dimension to the shape. + + If the batch_size value is known statically, this function returns a + TensorShape, otherwise a Tensor. + """ + if isinstance(batch_size, tf.Tensor): + static_batch_size = tf.get_static_value(batch_size) + else: + static_batch_size = batch_size + if static_batch_size is None: + return tf.concat(([batch_size], shape), axis=0) + return [static_batch_size] + shape + + +def _transpose_batch_time(tensor): + """Transposes the batch and time dimension of tensor if its rank is at + least 2.""" + shape = tensor.shape + if shape.rank is not None and shape.rank < 2: + return tensor + perm = tf.concat(([1, 0], tf.range(2, tf.rank(tensor))), axis=0) + return tf.transpose(tensor, perm) diff --git a/opennmt/tfa/seq2seq/loss.py b/opennmt/tfa/seq2seq/loss.py new file mode 100644 index 000000000..790b965a9 --- /dev/null +++ b/opennmt/tfa/seq2seq/loss.py @@ -0,0 +1,211 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Loss functions for sequence models.""" + +from typing import Callable, Optional + +import tensorflow as tf + +from typeguard import typechecked + +from opennmt.tfa.utils.types import TensorLike + + +def sequence_loss( + logits: TensorLike, + targets: TensorLike, + weights: TensorLike, + average_across_timesteps: bool = True, + average_across_batch: bool = True, + sum_over_timesteps: bool = False, + sum_over_batch: bool = False, + softmax_loss_function: Optional[Callable] = None, + name: Optional[str] = None, +) -> tf.Tensor: + """Computes the weighted cross-entropy loss for a sequence of logits. + + Depending on the values of `average_across_timesteps` / + `sum_over_timesteps` and `average_across_batch` / `sum_over_batch`, the + return Tensor will have rank 0, 1, or 2 as these arguments reduce the + cross-entropy at each target, which has shape + `[batch_size, sequence_length]`, over their respective dimensions. For + example, if `average_across_timesteps` is `True` and `average_across_batch` + is `False`, then the return Tensor will have shape `[batch_size]`. + + Note that `average_across_timesteps` and `sum_over_timesteps` cannot be + True at same time. Same for `average_across_batch` and `sum_over_batch`. + + The recommended loss reduction in tf 2.0 has been changed to sum_over, + instead of weighted average. User are recommend to use `sum_over_timesteps` + and `sum_over_batch` for reduction. + + Args: + logits: A Tensor of shape + `[batch_size, sequence_length, num_decoder_symbols]` and dtype float. + The logits correspond to the prediction across all classes at each + timestep. + targets: A Tensor of shape `[batch_size, sequence_length]` and dtype + int. The target represents the true class at each timestep. + weights: A Tensor of shape `[batch_size, sequence_length]` and dtype + float. `weights` constitutes the weighting of each prediction in the + sequence. When using `weights` as masking, set all valid timesteps to 1 + and all padded timesteps to 0, e.g. a mask returned by + `tf.sequence_mask`. + average_across_timesteps: If set, sum the cost across the sequence + dimension and divide the cost by the total label weight across + timesteps. + average_across_batch: If set, sum the cost across the batch dimension and + divide the returned cost by the batch size. + sum_over_timesteps: If set, sum the cost across the sequence dimension + and divide the size of the sequence. Note that any element with 0 + weights will be excluded from size calculation. + sum_over_batch: if set, sum the cost across the batch dimension and + divide the total cost by the batch size. Not that any element with 0 + weights will be excluded from size calculation. + softmax_loss_function: Function (labels, logits) -> loss-batch + to be used instead of the standard softmax (the default if this is + None). **Note that to avoid confusion, it is required for the function + to accept named arguments.** + name: Optional name for this operation, defaults to "sequence_loss". + + Returns: + A float Tensor of rank 0, 1, or 2 depending on the + `average_across_timesteps` and `average_across_batch` arguments. By + default, it has rank 0 (scalar) and is the weighted average cross-entropy + (log-perplexity) per symbol. + + Raises: + ValueError: logits does not have 3 dimensions or targets does not have 2 + dimensions or weights does not have 2 dimensions. + """ + if len(logits.shape) != 3: + raise ValueError( + "Logits must be a [batch_size x sequence_length x logits] tensor" + ) + + targets_rank = len(targets.shape) + if targets_rank != 2 and targets_rank != 3: + raise ValueError( + "Targets must be either a [batch_size x sequence_length] tensor " + + "where each element contains the labels' index" + + "or a [batch_size x sequence_length x num_classes] tensor " + + "where the third axis is a one-hot representation of the labels" + ) + + if len(weights.shape) != 2: + raise ValueError("Weights must be a [batch_size x sequence_length] tensor") + + if average_across_timesteps and sum_over_timesteps: + raise ValueError( + "average_across_timesteps and sum_over_timesteps cannot " + "be set to True at same time." + ) + if average_across_batch and sum_over_batch: + raise ValueError( + "average_across_batch and sum_over_batch cannot be set " + "to True at same time." + ) + if average_across_batch and sum_over_timesteps: + raise ValueError( + "average_across_batch and sum_over_timesteps cannot be set " + "to True at same time because of ambiguous order." + ) + if sum_over_batch and average_across_timesteps: + raise ValueError( + "sum_over_batch and average_across_timesteps cannot be set " + "to True at same time because of ambiguous order." + ) + with tf.name_scope(name or "sequence_loss"): + num_classes = tf.shape(input=logits)[2] + logits_flat = tf.reshape(logits, [-1, num_classes]) + if softmax_loss_function is None: + if targets_rank == 2: + targets = tf.reshape(targets, [-1]) + crossent = tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=targets, logits=logits_flat + ) + else: + targets = tf.reshape(targets, [-1, num_classes]) + crossent = tf.nn.softmax_cross_entropy_with_logits( + labels=targets, logits=logits_flat + ) + else: + targets = tf.reshape(targets, [-1]) + crossent = softmax_loss_function(labels=targets, logits=logits_flat) + crossent *= tf.reshape(weights, [-1]) + if average_across_timesteps and average_across_batch: + crossent = tf.reduce_sum(input_tensor=crossent) + total_size = tf.reduce_sum(input_tensor=weights) + crossent = tf.math.divide_no_nan(crossent, total_size) + elif sum_over_timesteps and sum_over_batch: + crossent = tf.reduce_sum(input_tensor=crossent) + total_count = tf.cast(tf.math.count_nonzero(weights), crossent.dtype) + crossent = tf.math.divide_no_nan(crossent, total_count) + else: + crossent = tf.reshape(crossent, tf.shape(input=logits)[0:2]) + if average_across_timesteps or average_across_batch: + reduce_axis = [0] if average_across_batch else [1] + crossent = tf.reduce_sum(input_tensor=crossent, axis=reduce_axis) + total_size = tf.reduce_sum(input_tensor=weights, axis=reduce_axis) + crossent = tf.math.divide_no_nan(crossent, total_size) + elif sum_over_timesteps or sum_over_batch: + reduce_axis = [0] if sum_over_batch else [1] + crossent = tf.reduce_sum(input_tensor=crossent, axis=reduce_axis) + total_count = tf.cast( + tf.math.count_nonzero(weights, axis=reduce_axis), + dtype=crossent.dtype, + ) + crossent = tf.math.divide_no_nan(crossent, total_count) + return crossent + + +class SequenceLoss(tf.keras.losses.Loss): + """Weighted cross-entropy loss for a sequence of logits.""" + + @typechecked + def __init__( + self, + average_across_timesteps: bool = False, + average_across_batch: bool = False, + sum_over_timesteps: bool = True, + sum_over_batch: bool = True, + softmax_loss_function: Optional[Callable] = None, + name: Optional[str] = None, + ): + super().__init__(reduction=tf.keras.losses.Reduction.NONE, name=name) + self.average_across_timesteps = average_across_timesteps + self.average_across_batch = average_across_batch + self.sum_over_timesteps = sum_over_timesteps + self.sum_over_batch = sum_over_batch + self.softmax_loss_function = softmax_loss_function + + def __call__(self, y_true, y_pred, sample_weight=None): + """Override the parent __call__ to have a customized reduce + behavior.""" + return sequence_loss( + y_pred, + y_true, + sample_weight, + average_across_timesteps=self.average_across_timesteps, + average_across_batch=self.average_across_batch, + sum_over_timesteps=self.sum_over_timesteps, + sum_over_batch=self.sum_over_batch, + softmax_loss_function=self.softmax_loss_function, + name=self.name, + ) + + def call(self, y_true, y_pred): + # Skip this method since the __call__ contains real implementation. + pass diff --git a/opennmt/tfa/seq2seq/sampler.py b/opennmt/tfa/seq2seq/sampler.py new file mode 100644 index 000000000..06fc56070 --- /dev/null +++ b/opennmt/tfa/seq2seq/sampler.py @@ -0,0 +1,96 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Objects sampling from the decoder output distribution and producing the next input.""" + +import abc + +from typing import Callable, Optional + +import tensorflow as tf + +from typeguard import typechecked + +from opennmt.tfa.seq2seq import decoder +from opennmt.tfa.utils import types +from opennmt.tfa.utils.types import Initializer, TensorLike + +_transpose_batch_time = decoder._transpose_batch_time + + +class Sampler(metaclass=abc.ABCMeta): + """Interface for implementing sampling in seq2seq decoders. + + Sampler classes implement the logic of sampling from the decoder output distribution + and producing the inputs for the next decoding step. In most cases, they should not be + used directly but passed to a `tfa.seq2seq.BasicDecoder` instance that will manage the + sampling. + + """ + + @abc.abstractmethod + def initialize(self, inputs, **kwargs): + """initialize the sampler with the input tensors. + + This method must be invoked exactly once before calling other + methods of the Sampler. + + Args: + inputs: A (structure of) input tensors, it could be a nested tuple or + a single tensor. + **kwargs: Other kwargs for initialization. It could contain tensors + like mask for inputs, or non tensor parameter. + + Returns: + `(initial_finished, initial_inputs)`. + """ + pass + + @abc.abstractmethod + def sample(self, time, outputs, state): + """Returns `sample_ids`.""" + pass + + @abc.abstractmethod + def next_inputs(self, time, outputs, state, sample_ids): + """Returns `(finished, next_inputs, next_state)`.""" + pass + + @abc.abstractproperty + def batch_size(self): + """Batch size of tensor returned by `sample`. + + Returns a scalar int32 tensor. The return value might not + available before the invocation of initialize(), in this case, + ValueError is raised. + """ + raise NotImplementedError("batch_size has not been implemented") + + @abc.abstractproperty + def sample_ids_shape(self): + """Shape of tensor returned by `sample`, excluding the batch dimension. + + Returns a `TensorShape`. The return value might not available + before the invocation of initialize(). + """ + raise NotImplementedError("sample_ids_shape has not been implemented") + + @abc.abstractproperty + def sample_ids_dtype(self): + """DType of tensor returned by `sample`. + + Returns a DType. The return value might not available before the + invocation of initialize(). + """ + raise NotImplementedError("sample_ids_dtype has not been implemented") diff --git a/opennmt/tfa/seq2seq/tests/attention_wrapper_test.py b/opennmt/tfa/seq2seq/tests/attention_wrapper_test.py new file mode 100644 index 000000000..604789fe2 --- /dev/null +++ b/opennmt/tfa/seq2seq/tests/attention_wrapper_test.py @@ -0,0 +1,365 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tfa.seq2seq.attention_wrapper.""" + +import collections + +import numpy as np +import pytest +import tensorflow as tf + +from packaging.version import Version + +from opennmt.tfa.seq2seq import attention_wrapper as wrapper +from opennmt.tfa.seq2seq import sampler as sampler_py + + +class DummyData: + def __init__(self): + self.batch = 10 + self.timestep = 5 + self.memory_size = 6 + self.units = 8 + + self.memory = np.random.randn( + self.batch, self.timestep, self.memory_size + ).astype(np.float32) + self.memory_length = np.random.randint( + low=1, high=self.timestep + 1, size=(self.batch,) + ) + self.query = np.random.randn(self.batch, self.units).astype(np.float32) + self.state = np.random.randn(self.batch, self.timestep).astype(np.float32) + + +attention_classes = [ + wrapper.LuongAttention, +] + + +@pytest.mark.parametrize("attention_cls", attention_classes) +def test_attention_shape_inference(attention_cls): + dummy_data = DummyData() + attention = attention_cls(dummy_data.units, dummy_data.memory) + attention_score = attention([dummy_data.query, dummy_data.state]) + assert len(attention_score) == 2 + assert attention_score[0].shape == (dummy_data.batch, dummy_data.timestep) + assert attention_score[1].shape == (dummy_data.batch, dummy_data.timestep) + + +@pytest.mark.parametrize("attention_cls", attention_classes) +def test_get_config(attention_cls): + dummy_data = DummyData() + attention = attention_cls(dummy_data.units, dummy_data.memory) + config = attention.get_config() + + attention_from_config = attention_cls.from_config(config) + config_from_clone = attention_from_config.get_config() + + assert config == config_from_clone + + +@pytest.mark.parametrize("attention_cls", attention_classes) +def test_layer_output(attention_cls): + dummy_data = DummyData() + attention = attention_cls(dummy_data.units, dummy_data.memory) + score = attention([dummy_data.query, dummy_data.state]) + + assert len(score) == 2 + assert score[0].shape == (dummy_data.batch, dummy_data.timestep) + assert score[1].shape == (dummy_data.batch, dummy_data.timestep) + + +@pytest.mark.parametrize("attention_cls", attention_classes) +def test_passing_memory_from_call(attention_cls): + dummy_data = DummyData() + attention = attention_cls(dummy_data.units, dummy_data.memory) + weights_before_query = attention.get_weights() + ref_score = attention([dummy_data.query, dummy_data.state]) + + all_weights = attention.get_weights() + config = attention.get_config() + # Simulate the twice invocation of calls here. + attention_from_config = attention_cls.from_config(config) + attention_from_config.build(dummy_data.memory.shape) + attention_from_config.set_weights(weights_before_query) + attention_from_config(dummy_data.memory, setup_memory=True) + attention_from_config.build([dummy_data.query.shape, dummy_data.state.shape]) + attention_from_config.set_weights(all_weights) + score = attention_from_config([dummy_data.query, dummy_data.state]) + + np.testing.assert_allclose(ref_score, score) + + +@pytest.mark.parametrize("attention_cls", attention_classes) +def test_save_load_layer(attention_cls): + dummy_data = DummyData() + vocab = 20 + embedding_dim = 6 + inputs = tf.keras.Input(shape=[dummy_data.timestep]) + encoder_input = tf.keras.layers.Embedding(vocab, embedding_dim, mask_zero=True)( + inputs + ) + encoder_output = tf.keras.layers.LSTM( + dummy_data.memory_size, return_sequences=True + )(encoder_input) + + attention = attention_cls(dummy_data.units, encoder_output) + query = tf.keras.Input(shape=[dummy_data.units]) + state = tf.keras.Input(shape=[dummy_data.timestep]) + + score = attention([query, state]) + + x_test = np.random.randint(vocab, size=(dummy_data.batch, dummy_data.timestep)) + model = tf.keras.Model([inputs, query, state], score) + # Fall back to v1 style Keras training loop until issue with + # using outputs of a layer in another layer's constructor. + model.compile("rmsprop", "mse") + y_ref = model.predict_on_batch([x_test, dummy_data.query, dummy_data.state]) + + if Version(tf.__version__) >= Version("2.13"): + model.use_legacy_config = True + + config = model.get_config() + weights = model.get_weights() + loaded_model = tf.keras.Model.from_config( + config, custom_objects={attention_cls.__name__: attention_cls} + ) + loaded_model.set_weights(weights) + + # Fall back to v1 style Keras training loop until issue with + # using outputs of a layer in another layer's constructor. + loaded_model.compile("rmsprop", "mse") + + y = loaded_model.predict_on_batch([x_test, dummy_data.query, dummy_data.state]) + + np.testing.assert_allclose(y_ref, y) + + +@pytest.mark.parametrize("attention_cls", attention_classes) +def test_manual_memory_reset(attention_cls): + dummy_data = DummyData() + attention = attention_cls(dummy_data.units) + + def _compute_score(batch_size=None): + if batch_size is None: + batch_size = dummy_data.batch + memory = dummy_data.memory[:batch_size] + attention.setup_memory( + memory, memory_sequence_length=dummy_data.memory_length[:batch_size] + ) + assert attention.values.shape.as_list() == list(memory.shape) + assert attention.keys.shape.as_list() == list(memory.shape)[:-1] + [ + dummy_data.units + ] + return attention([dummy_data.query[:batch_size], dummy_data.state[:batch_size]]) + + _compute_score(batch_size=dummy_data.batch) + variables = list(attention.variables) + _compute_score(batch_size=dummy_data.batch - 1) + + # No new variables were created. + for var_1, var_2 in zip(variables, list(attention.variables)): + assert var_1 is var_2 + + +def test_masking(): + memory = tf.ones([4, 4, 5], dtype=tf.float32) + memory_sequence_length = tf.constant([1, 2, 3, 4], dtype=tf.int32) + query = tf.ones([4, 5], dtype=tf.float32) + state = None + attention = wrapper.LuongAttention(5, memory, memory_sequence_length) + alignment, _ = attention([query, state]) + assert np.sum(np.triu(alignment, k=1)) == 0 + + +@pytest.mark.parametrize("attention_cls", attention_classes) +def test_memory_re_setup(attention_cls): + class MyModel(tf.keras.models.Model): + def __init__(self, vocab, embedding_dim, memory_size, units): + super().__init__() + self.emb = tf.keras.layers.Embedding(vocab, embedding_dim, mask_zero=True) + self.encoder = tf.keras.layers.LSTM(memory_size, return_sequences=True) + self.attn_mch = attention_cls(units) + + def call(self, inputs): + enc_input, query, state = inputs + mask = self.emb.compute_mask(enc_input) + enc_input = self.emb(enc_input) + enc_output = self.encoder(enc_input, mask=mask) + # To ensure manual resetting also works in the graph mode, + # we call the attention mechanism twice. + self.attn_mch(enc_output, mask=mask, setup_memory=True) + self.attn_mch(enc_output, mask=mask, setup_memory=True) + score = self.attn_mch([query, state]) + return score + + vocab = 20 + embedding_dim = 6 + num_batches = 5 + + dummy_data = DummyData() + model = MyModel(vocab, embedding_dim, dummy_data.memory_size, dummy_data.units) + model.compile("rmsprop", "mse") + + x = np.random.randint( + vocab, size=(num_batches * dummy_data.batch, dummy_data.timestep) + ) + x_test = np.random.randint( + vocab, size=(num_batches * dummy_data.batch, dummy_data.timestep) + ) + y = np.random.randn(num_batches * dummy_data.batch, dummy_data.timestep) + + query = np.tile(dummy_data.query, [num_batches, 1]) + state = np.tile(dummy_data.state, [num_batches, 1]) + + model.fit([x, query, state], (y, y), batch_size=dummy_data.batch) + model.predict_on_batch([x_test, query, state]) + + +class ResultSummary( + collections.namedtuple("ResultSummary", ("shape", "dtype", "mean")) +): + pass + + +def get_result_summary(x): + if isinstance(x, np.ndarray): + return ResultSummary(x.shape, x.dtype, x.mean()) + return x + + +def assert_allclose_or_equal(x, y, **kwargs): + if isinstance(x, np.ndarray) or isinstance(x, float): + np.testing.assert_allclose(x, y, atol=1e-3, **kwargs) + else: + assert x == y + + +class DummyData2: + def __init__(self): + self.batch = 64 + self.units = 128 + self.encoder_timestep = 10 + self.encoder_dim = 256 + self.decoder_timestep = 12 + self.encoder_outputs = np.random.randn( + self.batch, self.encoder_timestep, self.encoder_dim + ) + self.encoder_sequence_length = np.random.randint( + 1, high=self.encoder_timestep, size=(self.batch,) + ).astype(np.int32) + self.decoder_inputs = np.random.randn( + self.batch, self.decoder_timestep, self.units + ) + self.decoder_sequence_length = np.random.randint( + self.decoder_timestep, size=(self.batch,) + ).astype(np.int32) + + +def test_custom_attention_layer(): + dummy_data = DummyData2() + attention_mechanism = wrapper.LuongAttention(dummy_data.units) + cell = tf.keras.layers.LSTMCell(dummy_data.units) + attention_layer = tf.keras.layers.Dense( + dummy_data.units * 2, use_bias=False, activation=tf.math.tanh + ) + attention_wrapper = wrapper.AttentionWrapper( + cell, attention_mechanism, attention_layer=attention_layer + ) + with pytest.raises(ValueError): + # Should fail because the attention mechanism has not been + # initialized. + attention_wrapper.get_initial_state( + batch_size=dummy_data.batch, dtype=tf.float32 + ) + attention_mechanism.setup_memory( + dummy_data.encoder_outputs.astype(np.float32), + memory_sequence_length=dummy_data.encoder_sequence_length, + ) + initial_state = attention_wrapper.get_initial_state( + batch_size=dummy_data.batch, dtype=tf.float32 + ) + assert initial_state.attention.shape[-1] == dummy_data.units * 2 + first_input = dummy_data.decoder_inputs[:, 0].astype(np.float32) + output, _ = attention_wrapper(first_input, initial_state) + assert output.shape[-1] == dummy_data.units * 2 + + +def set_random_state_for_tf_and_np(): + """Since the results of the tests have been hardcoded, we need to make sure, + when we refactor code that the random state is the same. Meaning that all + random functions should be called in the same order. + """ + tf.random.set_seed(87654321) + np.random.seed(87654321) + DummyData2() + + +def test_attention_state_with_keras_rnn(): + # See https://github.com/tensorflow/addons/issues/1095. + cell = tf.keras.layers.LSTMCell(8) + + mechanism = wrapper.LuongAttention(units=8, memory=tf.ones((2, 4, 8))) + + cell = wrapper.AttentionWrapper(cell=cell, attention_mechanism=mechanism) + + layer = tf.keras.layers.RNN(cell) + _ = layer(inputs=tf.ones((2, 4, 8))) + + # Make sure the explicit initial_state also works. + initial_state = cell.get_initial_state(batch_size=2, dtype=tf.float32) + _ = layer(inputs=tf.ones((2, 4, 8)), initial_state=initial_state) + + +def test_attention_state_with_variable_length_input(): + cell = tf.keras.layers.LSTMCell(3) + mechanism = wrapper.LuongAttention(units=3) + cell = wrapper.AttentionWrapper(cell, mechanism) + + var_len = tf.random.uniform(shape=(), minval=2, maxval=10, dtype=tf.int32) + lengths = tf.random.uniform( + shape=(var_len,), minval=1, maxval=var_len + 1, dtype=tf.int32 + ) + data = tf.ones(shape=(var_len, var_len, 3)) + mask = tf.sequence_mask(lengths, maxlen=var_len) + + mechanism.setup_memory(data) + layer = tf.keras.layers.RNN(cell) + + _ = layer(data, mask=mask) + + +def test_attention_wrapper_with_gru_cell(): + mechanism = wrapper.LuongAttention(units=3) + cell = tf.keras.layers.GRUCell(3) + cell = wrapper.AttentionWrapper(cell, mechanism) + memory = tf.ones([2, 5, 3]) + inputs = tf.ones([2, 3]) + mechanism.setup_memory(memory) + initial_state = cell.get_initial_state(inputs=inputs) + _, state = cell(inputs, initial_state) + tf.nest.assert_same_structure(initial_state, state) + + +def test_attention_wrapper_with_multiple_attention_mechanisms(): + cell = tf.keras.layers.LSTMCell(5) + mechanisms = [wrapper.LuongAttention(units=3), wrapper.LuongAttention(units=3)] + # We simply test that the wrapper creation makes no error. + wrapper.AttentionWrapper(cell, mechanisms, attention_layer_size=[4, 5]) + wrapper.AttentionWrapper( + cell, + mechanisms, + attention_layer=[tf.keras.layers.Dense(4), tf.keras.layers.Dense(5)], + ) diff --git a/opennmt/tfa/seq2seq/tests/loss_test.py b/opennmt/tfa/seq2seq/tests/loss_test.py new file mode 100644 index 000000000..3f38dcb62 --- /dev/null +++ b/opennmt/tfa/seq2seq/tests/loss_test.py @@ -0,0 +1,314 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.addons.seq2seq.python.loss_ops.""" + +import numpy as np +import pytest +import tensorflow as tf + +from opennmt.tfa.seq2seq import loss + + +def get_test_data(): + batch_size = 2 + sequence_length = 3 + number_of_classes = 5 + logits = [ + tf.constant(i + 0.5, shape=[batch_size, number_of_classes]) + for i in range(sequence_length) + ] + logits = tf.stack(logits, axis=1) + targets = [ + tf.constant(i, tf.int32, shape=[batch_size]) for i in range(sequence_length) + ] + targets = tf.stack(targets, axis=1) + + weights = [tf.constant(1.0, shape=[batch_size]) for _ in range(sequence_length)] + weights = tf.stack(weights, axis=1) + # expected_loss = sparse_softmax_cross_entropy_with_logits(targets, + # logits) where targets = [0, 1, 2], + # and logits = [[0.5] * 5, [1.5] * 5, [2.5] * 5] + expected_loss = 1.60944 + return ( + batch_size, + sequence_length, + number_of_classes, + logits, + targets, + weights, + expected_loss, + ) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +@pytest.mark.parametrize("average_across_timesteps", [True, False]) +@pytest.mark.parametrize("average_across_batch", [True, False]) +@pytest.mark.parametrize("zero_weights", [True, False]) +def test_sequence_loss(average_across_timesteps, average_across_batch, zero_weights): + ( + batch_size, + sequence_length, + _, + logits, + targets, + weights, + expected_loss, + ) = get_test_data() + + if zero_weights: + weights = [tf.constant(0.0, shape=[batch_size]) for _ in range(sequence_length)] + weights = tf.stack(weights, axis=1) + + computed = loss.sequence_loss( + logits, + targets, + weights, + average_across_timesteps=average_across_timesteps, + average_across_batch=average_across_batch, + ) + computed = computed.numpy() + if average_across_timesteps and average_across_batch and zero_weights: + expected = 0.0 + elif not average_across_timesteps and average_across_batch and zero_weights: + expected = np.zeros(sequence_length) + elif average_across_timesteps and not average_across_batch and zero_weights: + expected = np.zeros(batch_size) + elif not average_across_timesteps and not average_across_batch and zero_weights: + expected = np.zeros((batch_size, sequence_length)) + elif average_across_timesteps and average_across_batch and not zero_weights: + expected = expected_loss + elif not average_across_timesteps and average_across_batch and not zero_weights: + expected = np.full(sequence_length, expected_loss) + elif average_across_timesteps and not average_across_batch and not zero_weights: + expected = np.full(batch_size, expected_loss) + else: + expected = np.full((batch_size, sequence_length), expected_loss) + + np.testing.assert_allclose(computed, expected, rtol=1e-6, atol=1e-6) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +@pytest.mark.parametrize("average_across_timesteps", [True, False]) +@pytest.mark.parametrize("average_across_batch", [True, False]) +def test_sequence_loss_class(average_across_timesteps, average_across_batch): + + ( + batch_size, + sequence_length, + _, + logits, + targets, + weights, + expected_loss, + ) = get_test_data() + seq_loss = loss.SequenceLoss( + average_across_timesteps=average_across_timesteps, + average_across_batch=average_across_batch, + sum_over_timesteps=False, + sum_over_batch=False, + ) + average_loss_per_example = seq_loss(targets, logits, weights) + res = average_loss_per_example.numpy() + if average_across_timesteps and average_across_batch: + expected = expected_loss + elif not average_across_timesteps and average_across_batch: + expected = np.full(sequence_length, expected_loss) + elif average_across_timesteps and not average_across_batch: + expected = np.full(batch_size, expected_loss) + elif not average_across_timesteps and not average_across_batch: + expected = np.full((batch_size, sequence_length), expected_loss) + + np.testing.assert_allclose(res, expected, atol=1e-6, rtol=1e-6) + + +def test_sum_reduction(): + ( + batch_size, + sequence_length, + _, + logits, + targets, + weights, + expected_loss, + ) = get_test_data() + seq_loss = loss.SequenceLoss( + average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=True, + sum_over_batch=True, + ) + average_loss_per_example = seq_loss(targets, logits, weights) + res = average_loss_per_example.numpy() + np.testing.assert_allclose(expected_loss, res, atol=1e-6, rtol=1e-6) + + seq_loss = loss.SequenceLoss( + average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=False, + sum_over_batch=True, + ) + average_loss_per_sequence = seq_loss(targets, logits, weights) + res = average_loss_per_sequence.numpy() + compare_per_sequence = np.full((sequence_length), expected_loss) + np.testing.assert_allclose(compare_per_sequence, res, atol=1e-6, rtol=1e-6) + + seq_loss = loss.SequenceLoss( + average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=True, + sum_over_batch=False, + ) + average_loss_per_batch = seq_loss(targets, logits, weights) + res = average_loss_per_batch.numpy() + compare_per_batch = np.full((batch_size), expected_loss) + np.testing.assert_allclose(compare_per_batch, res, atol=1e-6, rtol=1e-6) + + seq_loss = loss.SequenceLoss( + average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=False, + sum_over_batch=False, + ) + total_loss = seq_loss(targets, logits, weights) + res = total_loss.numpy() + compare_total = np.full((batch_size, sequence_length), expected_loss) + np.testing.assert_allclose(compare_total, res, atol=1e-6, rtol=1e-6) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_weighted_sum_reduction(): + ( + batch_size, + sequence_length, + _, + logits, + targets, + _, + expected_loss, + ) = get_test_data() + weights = [tf.constant(1.0, shape=[batch_size]) for _ in range(sequence_length)] + # Make the last element in the sequence to have zero weights. + weights[-1] = tf.constant(0.0, shape=[batch_size]) + weights = tf.stack(weights, axis=1) + seq_loss = loss.SequenceLoss( + average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=True, + sum_over_batch=True, + ) + average_loss_per_example = seq_loss(targets, logits, weights) + res = average_loss_per_example.numpy() + np.testing.assert_allclose(expected_loss, res, rtol=1e-6, atol=1e-6) + + seq_loss = loss.SequenceLoss( + average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=False, + sum_over_batch=True, + ) + average_loss_per_sequence = seq_loss(targets, logits, weights) + res = average_loss_per_sequence.numpy() + compare_per_sequence = np.full(sequence_length, expected_loss) + # The last element in every sequence are zeros, which will be + # filtered. + compare_per_sequence[-1] = 0.0 + np.testing.assert_allclose(compare_per_sequence, res, rtol=1e-6, atol=1e-6) + + seq_loss = loss.SequenceLoss( + average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=True, + sum_over_batch=False, + ) + average_loss_per_batch = seq_loss(targets, logits, weights) + res = average_loss_per_batch.numpy() + compare_per_batch = np.full(batch_size, expected_loss) + np.testing.assert_allclose(compare_per_batch, res, rtol=1e-6, atol=1e-6) + + seq_loss = loss.SequenceLoss( + average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=False, + sum_over_batch=False, + ) + total_loss = seq_loss(targets, logits, weights) + res = total_loss.numpy() + compare_total = np.full((batch_size, sequence_length), expected_loss) + # The last element in every sequence are zeros, which will be + # filtered. + compare_total[:, -1] = 0 + np.testing.assert_allclose(compare_total, res, rtol=1e-6, atol=1e-6) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_ambiguous_order(): + with pytest.raises(ValueError, match="because of ambiguous order"): + _, _, _, logits, targets, weights, _ = get_test_data() + seq_loss = loss.SequenceLoss( + average_across_timesteps=False, + average_across_batch=True, + sum_over_timesteps=True, + sum_over_batch=False, + ) + seq_loss(targets, logits, weights).numpy() + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_keras_compatibility(): + """To test the compatibility of SequenceLoss with Keras's built-in + training loops, we create a fake model which always outputs a pre- + defined set of logits. + + Then we check the calculated loss to be equal to the expected + loss. Note that since the fake model doesn't have any trainable + parameters, no matter how many steps we train it, it always + outputs the same loss value. + """ + ( + batch_size, + sequence_length, + number_of_classes, + logits, + targets, + weights, + expected_loss, + ) = get_test_data() + targets = tf.one_hot(targets, depth=number_of_classes) + + def return_logits(x): + logits_single_row = logits[0, :, :] + logits_batch = tf.tile( + tf.expand_dims(logits_single_row, 0), [tf.shape(x)[0], 1, 1] + ) + return logits_batch + + inp = tf.keras.layers.Input(shape=(sequence_length,)) + out = tf.keras.layers.Lambda( + return_logits, output_shape=(sequence_length, number_of_classes) + )(inp) + model = tf.keras.models.Model(inp, out) + + loss_obj = loss.SequenceLoss() + model.compile(optimizer="adam", loss=loss_obj, sample_weight_mode="temporal") + + # This is a fake input. + x = tf.ones(shape=(batch_size, sequence_length)) + + h = model.fit( + x, targets, sample_weight=weights, batch_size=batch_size, steps_per_epoch=1 + ) + + calculated_loss = h.history["loss"][0] + np.testing.assert_allclose(calculated_loss, expected_loss, rtol=1e-6, atol=1e-6) diff --git a/opennmt/tfa/text/__init__.py b/opennmt/tfa/text/__init__.py new file mode 100644 index 000000000..7bc260563 --- /dev/null +++ b/opennmt/tfa/text/__init__.py @@ -0,0 +1 @@ +from opennmt.tfa.text.crf import crf_decode, crf_log_likelihood diff --git a/opennmt/tfa/text/crf.py b/opennmt/tfa/text/crf.py new file mode 100644 index 000000000..e6aee72c7 --- /dev/null +++ b/opennmt/tfa/text/crf.py @@ -0,0 +1,567 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import warnings + +from typing import Optional, Tuple + +import numpy as np +import tensorflow as tf + +from typeguard import typechecked + +from opennmt.tfa.rnn.abstract_rnn_cell import AbstractRNNCell +from opennmt.tfa.utils.types import TensorLike + +# TODO: Wrap functions in @tf.function once +# https://github.com/tensorflow/tensorflow/issues/29075 is resolved + + +def crf_filtered_inputs(inputs: TensorLike, tag_bitmap: TensorLike) -> tf.Tensor: + """Constrains the inputs to filter out certain tags at each time step. + + tag_bitmap limits the allowed tags at each input time step. + This is useful when an observed output at a given time step needs to be + constrained to a selected set of tags. + + Args: + inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials + to use as input to the CRF layer. + tag_bitmap: A [batch_size, max_seq_len, num_tags] boolean tensor + representing all active tags at each index for which to calculate the + unnormalized score. + Returns: + filtered_inputs: A [batch_size] vector of unnormalized sequence scores. + """ + + # set scores of filtered out inputs to be -inf. + filtered_inputs = tf.where( + tag_bitmap, + inputs, + tf.fill(tf.shape(inputs), tf.cast(float("-inf"), inputs.dtype)), + ) + return filtered_inputs + + +def crf_sequence_score( + inputs: TensorLike, + tag_indices: TensorLike, + sequence_lengths: TensorLike, + transition_params: TensorLike, +) -> tf.Tensor: + """Computes the unnormalized score for a tag sequence. + + Args: + inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials + to use as input to the CRF layer. + tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which + we compute the unnormalized score. + sequence_lengths: A [batch_size] vector of true sequence lengths. + transition_params: A [num_tags, num_tags] transition matrix. + Returns: + sequence_scores: A [batch_size] vector of unnormalized sequence scores. + """ + tag_indices = tf.cast(tag_indices, dtype=tf.int32) + sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32) + + # If max_seq_len is 1, we skip the score calculation and simply gather the + # unary potentials of the single tag. + def _single_seq_fn(): + batch_size = tf.shape(inputs, out_type=tf.int32)[0] + batch_inds = tf.reshape(tf.range(batch_size), [-1, 1]) + indices = tf.concat([batch_inds, tf.zeros_like(batch_inds)], axis=1) + + tag_inds = tf.gather_nd(tag_indices, indices) + tag_inds = tf.reshape(tag_inds, [-1, 1]) + indices = tf.concat([indices, tag_inds], axis=1) + + sequence_scores = tf.gather_nd(inputs, indices) + + sequence_scores = tf.where( + tf.less_equal(sequence_lengths, 0), + tf.zeros_like(sequence_scores), + sequence_scores, + ) + return sequence_scores + + def _multi_seq_fn(): + # Compute the scores of the given tag sequence. + unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs) + binary_scores = crf_binary_score( + tag_indices, sequence_lengths, transition_params + ) + sequence_scores = unary_scores + binary_scores + return sequence_scores + + return tf.cond(tf.equal(tf.shape(inputs)[1], 1), _single_seq_fn, _multi_seq_fn) + + +def crf_multitag_sequence_score( + inputs: TensorLike, + tag_bitmap: TensorLike, + sequence_lengths: TensorLike, + transition_params: TensorLike, +) -> tf.Tensor: + """Computes the unnormalized score of all tag sequences matching + tag_bitmap. + + tag_bitmap enables more than one tag to be considered correct at each time + step. This is useful when an observed output at a given time step is + consistent with more than one tag, and thus the log likelihood of that + observation must take into account all possible consistent tags. + + Using one-hot vectors in tag_bitmap gives results identical to + crf_sequence_score. + + Args: + inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials + to use as input to the CRF layer. + tag_bitmap: A [batch_size, max_seq_len, num_tags] boolean tensor + representing all active tags at each index for which to calculate the + unnormalized score. + sequence_lengths: A [batch_size] vector of true sequence lengths. + transition_params: A [num_tags, num_tags] transition matrix. + Returns: + sequence_scores: A [batch_size] vector of unnormalized sequence scores. + """ + tag_bitmap = tf.cast(tag_bitmap, dtype=tf.bool) + sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32) + filtered_inputs = crf_filtered_inputs(inputs, tag_bitmap) + + # If max_seq_len is 1, we skip the score calculation and simply gather the + # unary potentials of all active tags. + def _single_seq_fn(): + return tf.reduce_logsumexp(filtered_inputs, axis=[1, 2], keepdims=False) + + def _multi_seq_fn(): + # Compute the logsumexp of all scores of sequences + # matching the given tags. + return crf_log_norm( + inputs=filtered_inputs, + sequence_lengths=sequence_lengths, + transition_params=transition_params, + ) + + return tf.cond(tf.equal(tf.shape(inputs)[1], 1), _single_seq_fn, _multi_seq_fn) + + +def crf_log_norm( + inputs: TensorLike, sequence_lengths: TensorLike, transition_params: TensorLike +) -> tf.Tensor: + """Computes the normalization for a CRF. + + Args: + inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials + to use as input to the CRF layer. + sequence_lengths: A [batch_size] vector of true sequence lengths. + transition_params: A [num_tags, num_tags] transition matrix. + Returns: + log_norm: A [batch_size] vector of normalizers for a CRF. + """ + sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32) + # Split up the first and rest of the inputs in preparation for the forward + # algorithm. + first_input = tf.slice(inputs, [0, 0, 0], [-1, 1, -1]) + first_input = tf.squeeze(first_input, [1]) + + # If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp + # over the "initial state" (the unary potentials). + def _single_seq_fn(): + log_norm = tf.reduce_logsumexp(first_input, [1]) + # Mask `log_norm` of the sequences with length <= zero. + log_norm = tf.where( + tf.less_equal(sequence_lengths, 0), tf.zeros_like(log_norm), log_norm + ) + return log_norm + + def _multi_seq_fn(): + """Forward computation of alpha values.""" + rest_of_input = tf.slice(inputs, [0, 1, 0], [-1, -1, -1]) + # Compute the alpha values in the forward algorithm in order to get the + # partition function. + + alphas = crf_forward( + rest_of_input, first_input, transition_params, sequence_lengths + ) + log_norm = tf.reduce_logsumexp(alphas, [1]) + # Mask `log_norm` of the sequences with length <= zero. + log_norm = tf.where( + tf.less_equal(sequence_lengths, 0), tf.zeros_like(log_norm), log_norm + ) + return log_norm + + return tf.cond(tf.equal(tf.shape(inputs)[1], 1), _single_seq_fn, _multi_seq_fn) + + +def crf_log_likelihood( + inputs: TensorLike, + tag_indices: TensorLike, + sequence_lengths: TensorLike, + transition_params: Optional[TensorLike] = None, +) -> Tuple[tf.Tensor, tf.Tensor]: + """Computes the log-likelihood of tag sequences in a CRF. + + Args: + inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials + to use as input to the CRF layer. + tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which + we compute the log-likelihood. + sequence_lengths: A [batch_size] vector of true sequence lengths. + transition_params: A [num_tags, num_tags] transition matrix, + if available. + Returns: + log_likelihood: A [batch_size] `Tensor` containing the log-likelihood of + each example, given the sequence of tag indices. + transition_params: A [num_tags, num_tags] transition matrix. This is + either provided by the caller or created in this function. + """ + inputs = tf.convert_to_tensor(inputs) + + num_tags = inputs.shape[2] + + # cast type to handle different types + tag_indices = tf.cast(tag_indices, dtype=tf.int32) + sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32) + + # TODO(windqaq): re-evaluate if `transition_params` can be `None`. + if transition_params is None: + initializer = tf.keras.initializers.GlorotUniform() + transition_params = tf.Variable( + initializer([num_tags, num_tags]), "transitions" + ) + transition_params = tf.cast(transition_params, inputs.dtype) + sequence_scores = crf_sequence_score( + inputs, tag_indices, sequence_lengths, transition_params + ) + log_norm = crf_log_norm(inputs, sequence_lengths, transition_params) + + # Normalize the scores to get the log-likelihood per example. + log_likelihood = sequence_scores - log_norm + return log_likelihood, transition_params + + +def crf_unary_score( + tag_indices: TensorLike, sequence_lengths: TensorLike, inputs: TensorLike +) -> tf.Tensor: + """Computes the unary scores of tag sequences. + + Args: + tag_indices: A [batch_size, max_seq_len] matrix of tag indices. + sequence_lengths: A [batch_size] vector of true sequence lengths. + inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials. + Returns: + unary_scores: A [batch_size] vector of unary scores. + """ + tag_indices = tf.cast(tag_indices, dtype=tf.int32) + sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32) + + batch_size = tf.shape(inputs)[0] + max_seq_len = tf.shape(inputs)[1] + num_tags = tf.shape(inputs)[2] + + flattened_inputs = tf.reshape(inputs, [-1]) + + offsets = tf.expand_dims(tf.range(batch_size) * max_seq_len * num_tags, 1) + offsets += tf.expand_dims(tf.range(max_seq_len) * num_tags, 0) + # Use int32 or int64 based on tag_indices' dtype. + if tag_indices.dtype == tf.int64: + offsets = tf.cast(offsets, tf.int64) + flattened_tag_indices = tf.reshape(offsets + tag_indices, [-1]) + + unary_scores = tf.reshape( + tf.gather(flattened_inputs, flattened_tag_indices), [batch_size, max_seq_len] + ) + + masks = tf.sequence_mask( + sequence_lengths, maxlen=tf.shape(tag_indices)[1], dtype=unary_scores.dtype + ) + + unary_scores = tf.reduce_sum(unary_scores * masks, 1) + return unary_scores + + +def crf_binary_score( + tag_indices: TensorLike, sequence_lengths: TensorLike, transition_params: TensorLike +) -> tf.Tensor: + """Computes the binary scores of tag sequences. + + Args: + tag_indices: A [batch_size, max_seq_len] matrix of tag indices. + sequence_lengths: A [batch_size] vector of true sequence lengths. + transition_params: A [num_tags, num_tags] matrix of binary potentials. + Returns: + binary_scores: A [batch_size] vector of binary scores. + """ + tag_indices = tf.cast(tag_indices, dtype=tf.int32) + sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32) + + num_tags = tf.shape(transition_params)[0] + num_transitions = tf.shape(tag_indices)[1] - 1 + + # Truncate by one on each side of the sequence to get the start and end + # indices of each transition. + start_tag_indices = tf.slice(tag_indices, [0, 0], [-1, num_transitions]) + end_tag_indices = tf.slice(tag_indices, [0, 1], [-1, num_transitions]) + + # Encode the indices in a flattened representation. + flattened_transition_indices = start_tag_indices * num_tags + end_tag_indices + flattened_transition_params = tf.reshape(transition_params, [-1]) + + # Get the binary scores based on the flattened representation. + binary_scores = tf.gather(flattened_transition_params, flattened_transition_indices) + + masks = tf.sequence_mask( + sequence_lengths, maxlen=tf.shape(tag_indices)[1], dtype=binary_scores.dtype + ) + truncated_masks = tf.slice(masks, [0, 1], [-1, -1]) + binary_scores = tf.reduce_sum(binary_scores * truncated_masks, 1) + return binary_scores + + +def crf_forward( + inputs: TensorLike, + state: TensorLike, + transition_params: TensorLike, + sequence_lengths: TensorLike, +) -> tf.Tensor: + """Computes the alpha values in a linear-chain CRF. + + See http://www.cs.columbia.edu/~mcollins/fb.pdf for reference. + + Args: + inputs: A [batch_size, num_tags] matrix of unary potentials. + state: A [batch_size, num_tags] matrix containing the previous alpha + values. + transition_params: A [num_tags, num_tags] matrix of binary potentials. + This matrix is expanded into a [1, num_tags, num_tags] in preparation + for the broadcast summation occurring within the cell. + sequence_lengths: A [batch_size] vector of true sequence lengths. + + Returns: + new_alphas: A [batch_size, num_tags] matrix containing the + new alpha values. + """ + sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32) + + last_index = tf.maximum( + tf.constant(0, dtype=sequence_lengths.dtype), sequence_lengths - 1 + ) + inputs = tf.transpose(inputs, [1, 0, 2]) + transition_params = tf.expand_dims(transition_params, 0) + + def _scan_fn(_state, _inputs): + _state = tf.expand_dims(_state, 2) + transition_scores = _state + transition_params + new_alphas = _inputs + tf.reduce_logsumexp(transition_scores, [1]) + return new_alphas + + all_alphas = tf.transpose(tf.scan(_scan_fn, inputs, state), [1, 0, 2]) + # add first state for sequences of length 1 + all_alphas = tf.concat([tf.expand_dims(state, 1), all_alphas], 1) + + idxs = tf.stack([tf.range(tf.shape(last_index)[0]), last_index], axis=1) + return tf.gather_nd(all_alphas, idxs) + + +class CrfDecodeForwardRnnCell(AbstractRNNCell): + """Computes the forward decoding in a linear-chain CRF.""" + + @typechecked + def __init__(self, transition_params: TensorLike, **kwargs): + """Initialize the CrfDecodeForwardRnnCell. + + Args: + transition_params: A [num_tags, num_tags] matrix of binary + potentials. This matrix is expanded into a + [1, num_tags, num_tags] in preparation for the broadcast + summation occurring within the cell. + """ + super().__init__(**kwargs) + self._transition_params = tf.expand_dims(transition_params, 0) + self._num_tags = transition_params.shape[0] + + @property + def state_size(self): + return self._num_tags + + @property + def output_size(self): + return self._num_tags + + def build(self, input_shape): + super().build(input_shape) + + def call(self, inputs, state): + """Build the CrfDecodeForwardRnnCell. + + Args: + inputs: A [batch_size, num_tags] matrix of unary potentials. + state: A [batch_size, num_tags] matrix containing the previous step's + score values. + + Returns: + backpointers: A [batch_size, num_tags] matrix of backpointers. + new_state: A [batch_size, num_tags] matrix of new score values. + """ + state = tf.expand_dims(state[0], 2) + transition_scores = state + tf.cast( + self._transition_params, self._compute_dtype + ) + new_state = inputs + tf.reduce_max(transition_scores, [1]) + backpointers = tf.argmax(transition_scores, 1) + backpointers = tf.cast(backpointers, dtype=tf.int32) + return backpointers, new_state + + def get_config(self) -> dict: + config = { + "transition_params": tf.squeeze(self._transition_params, 0).numpy().tolist() + } + base_config = super(CrfDecodeForwardRnnCell, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config: dict) -> "CrfDecodeForwardRnnCell": + config["transition_params"] = np.array( + config["transition_params"], dtype=np.float32 + ) + return cls(**config) + + +def crf_decode_forward( + inputs: TensorLike, + state: TensorLike, + transition_params: TensorLike, + sequence_lengths: TensorLike, +) -> tf.Tensor: + """Computes forward decoding in a linear-chain CRF. + + Args: + inputs: A [batch_size, num_tags] matrix of unary potentials. + state: A [batch_size, num_tags] matrix containing the previous step's + score values. + transition_params: A [num_tags, num_tags] matrix of binary potentials. + sequence_lengths: A [batch_size] vector of true sequence lengths. + + Returns: + backpointers: A [batch_size, num_tags] matrix of backpointers. + new_state: A [batch_size, num_tags] matrix of new score values. + """ + sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32) + mask = tf.sequence_mask(sequence_lengths, tf.shape(inputs)[1]) + crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params, dtype=inputs.dtype) + crf_fwd_layer = tf.keras.layers.RNN( + crf_fwd_cell, + return_sequences=True, + return_state=True, + dtype=inputs.dtype, + zero_output_for_mask=True, # See: https://github.com/tensorflow/addons/issues/2639 + ) + return crf_fwd_layer(inputs, state, mask=mask) + + +def crf_decode_backward(inputs: TensorLike, state: TensorLike) -> tf.Tensor: + """Computes backward decoding in a linear-chain CRF. + + Args: + inputs: A [batch_size, num_tags] matrix of + backpointer of next step (in time order). + state: A [batch_size, 1] matrix of tag index of next step. + + Returns: + new_tags: A [batch_size, num_tags] + tensor containing the new tag indices. + """ + inputs = tf.transpose(inputs, [1, 0, 2]) + + def _scan_fn(state, inputs): + state = tf.squeeze(state, axis=[1]) + idxs = tf.stack([tf.range(tf.shape(inputs)[0]), state], axis=1) + new_tags = tf.expand_dims(tf.gather_nd(inputs, idxs), axis=-1) + return new_tags + + return tf.transpose(tf.scan(_scan_fn, inputs, state), [1, 0, 2]) + + +def crf_decode( + potentials: TensorLike, transition_params: TensorLike, sequence_length: TensorLike +) -> tf.Tensor: + """Decode the highest scoring sequence of tags. + + Args: + potentials: A [batch_size, max_seq_len, num_tags] tensor of + unary potentials. + transition_params: A [num_tags, num_tags] matrix of + binary potentials. + sequence_length: A [batch_size] vector of true sequence lengths. + + Returns: + decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`. + Contains the highest scoring tag indices. + best_score: A [batch_size] vector, containing the score of `decode_tags`. + """ + if tf.__version__[:3] == "2.4": + warnings.warn( + "CRF Decoding does not work with KerasTensors in TF2.4. " + "The bug has since been fixed in tensorflow/tensorflow##45534" + ) + + sequence_length = tf.cast(sequence_length, dtype=tf.int32) + + # If max_seq_len is 1, we skip the algorithm and simply return the + # argmax tag and the max activation. + def _single_seq_fn(): + decode_tags = tf.cast(tf.argmax(potentials, axis=2), dtype=tf.int32) + best_score = tf.reshape(tf.reduce_max(potentials, axis=2), shape=[-1]) + return decode_tags, best_score + + def _multi_seq_fn(): + # Computes forward decoding. Get last score and backpointers. + initial_state = tf.slice(potentials, [0, 0, 0], [-1, 1, -1]) + initial_state = tf.squeeze(initial_state, axis=[1]) + inputs = tf.slice(potentials, [0, 1, 0], [-1, -1, -1]) + + sequence_length_less_one = tf.maximum( + tf.constant(0, dtype=tf.int32), sequence_length - 1 + ) + + backpointers, last_score = crf_decode_forward( + inputs, initial_state, transition_params, sequence_length_less_one + ) + + backpointers = tf.reverse_sequence( + backpointers, sequence_length_less_one, seq_axis=1 + ) + + initial_state = tf.cast(tf.argmax(last_score, axis=1), dtype=tf.int32) + initial_state = tf.expand_dims(initial_state, axis=-1) + + decode_tags = crf_decode_backward(backpointers, initial_state) + decode_tags = tf.squeeze(decode_tags, axis=[2]) + decode_tags = tf.concat([initial_state, decode_tags], axis=1) + decode_tags = tf.reverse_sequence(decode_tags, sequence_length, seq_axis=1) + + best_score = tf.reduce_max(last_score, axis=1) + return decode_tags, best_score + + if potentials.shape[1] is not None: + # shape is statically know, so we just execute + # the appropriate code path + if potentials.shape[1] == 1: + return _single_seq_fn() + else: + return _multi_seq_fn() + else: + return tf.cond( + tf.equal(tf.shape(potentials)[1], 1), _single_seq_fn, _multi_seq_fn + ) diff --git a/opennmt/tfa/text/tests/crf_test.py b/opennmt/tfa/text/tests/crf_test.py new file mode 100644 index 000000000..ea726e6a5 --- /dev/null +++ b/opennmt/tfa/text/tests/crf_test.py @@ -0,0 +1,404 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for CRF.""" + +import itertools + +import numpy as np +import pytest +import tensorflow as tf + +from numpy.testing import assert_array_equal +from packaging.version import Version + +from opennmt.tfa import text +from opennmt.tfa.text.crf import ( + crf_binary_score, + crf_decode, + crf_decode_forward, + crf_filtered_inputs, + crf_log_likelihood, + crf_log_norm, + crf_multitag_sequence_score, + crf_sequence_score, + crf_unary_score, +) +from opennmt.tfa.utils import test_utils + + +def calculate_sequence_score(inputs, transition_params, tag_indices, sequence_lengths): + expected_unary_score = sum( + inputs[i][tag_indices[i]] for i in range(sequence_lengths) + ) + expected_binary_score = sum( + transition_params[tag_indices[i], tag_indices[i + 1]] + for i in range(sequence_lengths - 1) + ) + return expected_unary_score + expected_binary_score + + +def brute_force_decode(sequence_lengths, inputs, transition_params): + num_words = inputs.shape[0] + num_tags = inputs.shape[1] + + all_sequence_scores = [] + all_sequences = [] + + tag_indices_iterator = itertools.product(range(num_tags), repeat=sequence_lengths) + inputs = tf.expand_dims(inputs, 0) + sequence_lengths = tf.expand_dims(sequence_lengths, 0) + transition_params = tf.constant(transition_params) + + # Compare the dynamic program with brute force computation. + for tag_indices in tag_indices_iterator: + tag_indices = list(tag_indices) + tag_indices.extend([0] * (num_words - sequence_lengths)) + all_sequences.append(tag_indices) + sequence_score = crf_sequence_score( + inputs=inputs, + tag_indices=tf.expand_dims(tag_indices, 0), + sequence_lengths=sequence_lengths, + transition_params=transition_params, + ) + sequence_score = tf.squeeze(sequence_score, [0]) + all_sequence_scores.append(sequence_score) + + expected_max_sequence_index = np.argmax(all_sequence_scores) + expected_max_sequence = all_sequences[expected_max_sequence_index] + expected_max_score = all_sequence_scores[expected_max_sequence_index] + return expected_max_sequence, expected_max_score + + +@pytest.mark.parametrize("dtype", [np.float16, np.float32]) +def test_crf_filtered_inputs(dtype): + # Test both the length-1 and regular cases. + sequence_lengths_list = [np.array(3, dtype=np.int32), np.array(1, dtype=np.int32)] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=dtype), + np.array([[4, 5, -3]], dtype=dtype), + ] + tag_bitmap_list = [ + np.array( + [ + [True, False, False], + [False, True, True], + [False, True, True], + [False, True, True], + ], + dtype=bool, + ), + np.array([[False, True, True]], dtype=bool), + ] + neg_inf = float("-inf") + expected_filtered_inputs_list = [ + np.array( + [[4, neg_inf, neg_inf], [neg_inf, -1, 3], [neg_inf, 2, 1], [neg_inf, 0, 0]], + dtype=dtype, + ), + np.array([[neg_inf, 5, -3]], dtype=dtype), + ] + for sequence_lengths, inputs, tag_bitmap, expected_filtered_inputs in zip( + sequence_lengths_list, + inputs_list, + tag_bitmap_list, + expected_filtered_inputs_list, + ): + filtered_inputs = crf_filtered_inputs( + inputs=tf.expand_dims(inputs, 0), tag_bitmap=tf.expand_dims(tag_bitmap, 0) + ) + filtered_inputs = tf.squeeze(filtered_inputs, [0]) + + test_utils.assert_allclose_according_to_type( + filtered_inputs, expected_filtered_inputs + ) + + +@pytest.mark.parametrize("dtype", [np.float16, np.float32]) +def test_crf_sequence_score(dtype): + transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=dtype) + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int32), + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=dtype), + np.array([[4, 5, -3]], dtype=dtype), + ] + tag_indices_list = [ + np.array([1, 2, 1, 0], dtype=np.int32), + np.array([1], dtype=np.int32), + ] + for sequence_lengths, inputs, tag_indices in zip( + sequence_lengths_list, inputs_list, tag_indices_list + ): + sequence_score = crf_sequence_score( + inputs=tf.expand_dims(inputs, 0), + tag_indices=tf.expand_dims(tag_indices, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + transition_params=tf.constant(transition_params), + ) + sequence_score = tf.squeeze(sequence_score, [0]) + + expected_sequence_score = calculate_sequence_score( + inputs, transition_params, tag_indices, sequence_lengths + ) + test_utils.assert_allclose_according_to_type( + sequence_score, expected_sequence_score + ) + + +@pytest.mark.parametrize("dtype", [np.float16, np.float32]) +def test_crf_multi_tag_sequence_score(dtype): + transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=dtype) + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int32), + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=dtype), + np.array([[4, 5, -3]], dtype=dtype), + ] + tag_bitmap_list = [ + np.array( + [ + [True, True, False], + [True, False, True], + [False, True, True], + [True, False, True], + ], + dtype=bool, + ), + np.array([[True, True, False]], dtype=bool), + ] + for sequence_lengths, inputs, tag_bitmap in zip( + sequence_lengths_list, inputs_list, tag_bitmap_list + ): + sequence_score = crf_multitag_sequence_score( + inputs=tf.expand_dims(inputs, 0), + tag_bitmap=tf.expand_dims(tag_bitmap, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + transition_params=tf.constant(transition_params), + ) + sequence_score = tf.squeeze(sequence_score, [0]) + all_indices_list = [ + single_index_bitmap.nonzero()[0] + for single_index_bitmap in tag_bitmap[:sequence_lengths] + ] + expected_sequence_scores = [ + calculate_sequence_score( + inputs, transition_params, indices, sequence_lengths + ) + for indices in itertools.product(*all_indices_list) + ] + expected_log_sum_exp_sequence_scores = np.logaddexp.reduce( + expected_sequence_scores + ) + test_utils.assert_allclose_according_to_type( + sequence_score, expected_log_sum_exp_sequence_scores + ) + + +@pytest.mark.parametrize("dtype", [np.float16, np.float32]) +def test_crf_unary_score(dtype): + inputs = np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=dtype) + for dtype in (np.int32, np.int64): + tag_indices = np.array([1, 2, 1, 0], dtype=dtype) + sequence_lengths = np.array(3, dtype=np.int32) + unary_score = crf_unary_score( + tag_indices=tf.expand_dims(tag_indices, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + inputs=tf.expand_dims(inputs, 0), + ) + unary_score = tf.squeeze(unary_score, [0]) + expected_unary_score = sum( + inputs[i][tag_indices[i]] for i in range(sequence_lengths) + ) + test_utils.assert_allclose_according_to_type(unary_score, expected_unary_score) + + +@pytest.mark.parametrize("dtype", [np.float16, np.float32]) +def test_crf_binary_score(dtype): + tag_indices = np.array([1, 2, 1, 0], dtype=np.int32) + transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=dtype) + sequence_lengths = np.array(3, dtype=np.int32) + binary_score = crf_binary_score( + tag_indices=tf.expand_dims(tag_indices, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + transition_params=tf.constant(transition_params), + ) + binary_score = tf.squeeze(binary_score, [0]) + expected_binary_score = sum( + transition_params[tag_indices[i], tag_indices[i + 1]] + for i in range(sequence_lengths - 1) + ) + test_utils.assert_allclose_according_to_type(binary_score, expected_binary_score) + + +@pytest.mark.parametrize("dtype", [np.float16, np.float32]) +def test_crf_log_norm(dtype): + transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=dtype) + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int64), + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=dtype), + np.array([[3, -1, 3]], dtype=dtype), + ] + tag_indices_list = [ + np.array([1, 2, 1, 0], dtype=np.int32), + np.array([2], dtype=np.int32), + ] + + for sequence_lengths, inputs, tag_indices in zip( + sequence_lengths_list, inputs_list, tag_indices_list + ): + num_words = inputs.shape[0] + num_tags = inputs.shape[1] + all_sequence_scores = [] + + # Compare the dynamic program with brute force computation. + for tag_indices in itertools.product(range(num_tags), repeat=sequence_lengths): + tag_indices = list(tag_indices) + tag_indices.extend([0] * (num_words - sequence_lengths)) + all_sequence_scores.append( + crf_sequence_score( + inputs=tf.expand_dims(inputs, 0), + tag_indices=tf.expand_dims(tag_indices, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + transition_params=tf.constant(transition_params), + ) + ) + + brute_force_log_norm = tf.reduce_logsumexp(all_sequence_scores) + log_norm = crf_log_norm( + inputs=tf.expand_dims(inputs, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + transition_params=tf.constant(transition_params), + ) + log_norm = tf.squeeze(log_norm, [0]) + + test_utils.assert_allclose_according_to_type(log_norm, brute_force_log_norm) + + +@pytest.mark.parametrize("dtype", [np.float16, np.float32]) +def test_crf_log_norm_zero_seq_length(dtype): + """Test `crf_log_norm` when `sequence_lengths` contains one or more + zeros.""" + inputs = tf.constant(np.ones([2, 10, 5], dtype=dtype)) + transition_params = tf.constant(np.ones([5, 5], dtype=dtype)) + sequence_lengths = tf.constant(np.zeros([2], dtype=np.int32)) + expected_log_norm = np.zeros([2], dtype=dtype) + log_norm = crf_log_norm(inputs, sequence_lengths, transition_params) + test_utils.assert_allclose_according_to_type(log_norm, expected_log_norm) + + +@pytest.mark.parametrize("dtype", [np.float16, np.float32]) +def test_crf_log_likelihood(dtype): + inputs = np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=dtype) + transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=dtype) + sequence_lengths = np.array(3, dtype=np.int32) + + num_words = inputs.shape[0] + num_tags = inputs.shape[1] + all_sequence_log_likelihoods = [] + + # Make sure all probabilities sum to 1. + for tag_indices in itertools.product(range(num_tags), repeat=sequence_lengths): + tag_indices = list(tag_indices) + tag_indices.extend([0] * (num_words - sequence_lengths)) + sequence_log_likelihood, _ = crf_log_likelihood( + inputs=tf.expand_dims(inputs, 0), + tag_indices=tf.expand_dims(tag_indices, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + transition_params=tf.constant(transition_params), + ) + all_sequence_log_likelihoods.append(sequence_log_likelihood) + total_log_likelihood = tf.reduce_logsumexp(all_sequence_log_likelihoods) + test_utils.assert_allclose_according_to_type( + total_log_likelihood, 0.0, rtol=1e-6, atol=1e-6, half_rtol=2e-3, half_atol=2e-3 + ) + + # check if `transition_params = None` raises an error + crf_log_likelihood( + inputs=tf.expand_dims(inputs, 0), + tag_indices=tf.expand_dims(tag_indices, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + ) + + +@pytest.mark.parametrize("dtype", [np.float16, np.float32]) +def test_crf_decode(dtype): + transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=dtype) + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int64), + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=dtype), + np.array([[-1, 2, 1]], dtype=dtype), + ] + tag_indices_list = [ + np.array([1, 2, 1, 0], dtype=np.int32), + np.array([2], dtype=np.int32), + ] + + for sequence_lengths, inputs, tag_indices in zip( + sequence_lengths_list, inputs_list, tag_indices_list + ): + expected_max_sequence, expected_max_score = brute_force_decode( + sequence_lengths, inputs, transition_params + ) + + actual_max_sequence, actual_max_score = crf_decode( + tf.expand_dims(inputs, 0), + tf.constant(transition_params), + tf.expand_dims(sequence_lengths, 0), + ) + actual_max_sequence = tf.squeeze(actual_max_sequence, [0]) + actual_max_score = tf.squeeze(actual_max_score, [0]) + + test_utils.assert_allclose_according_to_type( + actual_max_score, expected_max_score, 1e-6, 1e-6 + ) + assert ( + list(actual_max_sequence[:sequence_lengths]) + == expected_max_sequence[:sequence_lengths] + ) + + +def test_crf_decode_zero_seq_length(): + """Test that crf_decode works when sequence_length contains one or more + zeros.""" + inputs = tf.constant(np.ones([2, 10, 5], dtype=np.float32)) + transition_params = tf.constant(np.ones([5, 5], dtype=np.float32)) + sequence_lengths = tf.constant(np.zeros([2], dtype=np.int32)) + tags, scores = crf_decode(inputs, transition_params, sequence_lengths) + assert len(tags.shape) == 2 + assert len(scores.shape) == 1 + + +def test_different_dtype(): + inputs = np.ones([16, 20, 5], dtype=np.float32) + tags = tf.convert_to_tensor(np.ones([16, 20], dtype=np.int64)) + seq_lens = np.ones([16], dtype=np.int64) * 20 + + loss, _ = crf_log_likelihood( + inputs=inputs, tag_indices=tags, sequence_lengths=seq_lens + ) diff --git a/opennmt/tfa/utils/__init__.py b/opennmt/tfa/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/opennmt/tfa/utils/keras_utils.py b/opennmt/tfa/utils/keras_utils.py new file mode 100644 index 000000000..58042c929 --- /dev/null +++ b/opennmt/tfa/utils/keras_utils.py @@ -0,0 +1,66 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for tf.keras.""" + +import tensorflow as tf + + +def assert_like_rnncell(cell_name, cell): + """Raises a TypeError if cell is not like a + tf.keras.layers.AbstractRNNCell. + + Args: + cell_name: A string to give a meaningful error referencing to the name + of the function argument. + cell: The object which should behave like a + tf.keras.layers.AbstractRNNCell. + + Raises: + TypeError: A human-friendly exception. + """ + conditions = [ + _hasattr(cell, "output_size"), + _hasattr(cell, "state_size"), + _hasattr(cell, "get_initial_state"), + callable(cell), + ] + + errors = [ + "'output_size' property is missing", + "'state_size' property is missing", + "'get_initial_state' method is required", + "is not callable", + ] + + if not all(conditions): + errors = [error for error, cond in zip(errors, conditions) if not cond] + raise TypeError( + "The argument {!r} ({}) is not an RNNCell: {}.".format( + cell_name, cell, ", ".join(errors) + ) + ) + + +def _hasattr(obj, attr_name): + # If possible, avoid retrieving the attribute as the object might run some + # lazy computation in it. + if attr_name in dir(obj): + return True + try: + getattr(obj, attr_name) + except AttributeError: + return False + else: + return True diff --git a/opennmt/tfa/utils/test_utils.py b/opennmt/tfa/utils/test_utils.py new file mode 100644 index 000000000..f87280a2d --- /dev/null +++ b/opennmt/tfa/utils/test_utils.py @@ -0,0 +1,275 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for testing Addons.""" + +import inspect +import os +import random + +import numpy as np +import pytest +import tensorflow as tf + +# from opennmt.tfa.utils.tf_test_utils import layer_test # noqa + +NUMBER_OF_WORKERS = int(os.environ.get("PYTEST_XDIST_WORKER_COUNT", "1")) +WORKER_ID = int(os.environ.get("PYTEST_XDIST_WORKER", "gw0")[2]) +NUMBER_OF_GPUS = len(tf.config.list_physical_devices("GPU")) + + +def is_gpu_available(): + return NUMBER_OF_GPUS >= 1 + + +# Some configuration before starting the tests. + +# we only need one core per worker. +# This avoids context switching for speed, but it also prevents TensorFlow to go +# crazy on systems with many cores (kokoro has 30+ cores). +tf.config.threading.set_intra_op_parallelism_threads(1) +tf.config.threading.set_inter_op_parallelism_threads(1) + +if is_gpu_available(): + # We split each of the physical GPUs to 2 logical GPUs, and use only the + # first gpu at the moment. That's enough for most use cases. + # split the first gpu into chunks of 100MB per virtual device. + # It's the user's job to limit the amount of pytest workers depending + # on the available memory. + # In practice, each process takes a bit more memory. + # There must be some kind of overhead but it's not very big (~200MB more) + # Each worker has two virtual devices. + # When running on gpu, only the first device is used. The other one is used + # in distributed strategies. + physical_gpus = tf.config.list_physical_devices("GPU") + tf.config.set_visible_devices(physical_gpus[0], "GPU") + for physical_gpu in physical_gpus: + virtual_gpus = [ + tf.config.LogicalDeviceConfiguration(memory_limit=100) for _ in range(2) + ] + tf.config.set_logical_device_configuration(physical_gpu, virtual_gpus) + + +def finalizer(): + tf.config.run_functions_eagerly(False) + + +def pytest_make_parametrize_id(config, val, argname): + if isinstance(val, tf.DType): + return val.name + if val is False: + return "no_" + argname + if val is True: + return argname + + +@pytest.fixture(scope="function", params=["eager_mode", "tf_function"]) +def maybe_run_functions_eagerly(request): + if request.param == "eager_mode": + tf.config.run_functions_eagerly(True) + elif request.param == "tf_function": + tf.config.run_functions_eagerly(False) + + request.addfinalizer(finalizer) + + +@pytest.fixture(scope="function") +def only_run_functions_eagerly(request): + tf.config.run_functions_eagerly(True) + request.addfinalizer(finalizer) + + +@pytest.fixture(scope="function", params=["float32", "mixed_float16"]) +def run_with_mixed_precision_policy(request): + tf.keras.mixed_precision.set_global_policy(request.param) + yield + tf.keras.mixed_precision.set_global_policy("float32") + + +@pytest.fixture(scope="function", params=["channels_first", "channels_last"]) +def data_format(request): + return request.param + + +@pytest.fixture(scope="function", autouse=True) +def set_seeds(): + random.seed(0) + np.random.seed(0) + tf.random.set_seed(0) + + +def pytest_addoption(parser): + parser.addoption( + "--skip-custom-ops", + action="store_true", + help="When a custom op is being loaded in a test, skip this test.", + ) + + +def gpus_for_testing(): + """For the moment it's very simple, but it might change in the future, + with multiple physical gpus for example. So it's better if this function + is called rather than hardcoding the gpu devices in the tests. + """ + if not is_gpu_available(): + raise SystemError( + "You are trying to get some gpus for testing but no gpu is available on " + "your system. \nDid you forget to use `@pytest.mark.needs_gpu` on your test" + " so that it's skipped automatically when no gpu is available?" + ) + return ["gpu:0", "gpu:1"] + + +@pytest.fixture(scope="session", autouse=True) +def set_global_variables(request): + pass + + +def pytest_configure(config): + config.addinivalue_line( + "markers", "with_device(devices): mark test to run on specific devices." + ) + config.addinivalue_line("markers", "needs_gpu: mark test that needs a gpu.") + + +@pytest.fixture(autouse=True, scope="function") +def device(request): + try: + requested_device = request.param + except Exception: + # workaround for DocTestItem + # https://github.com/pytest-dev/pytest/issues/5070 + requested_device = "no_device" + if requested_device == "no_device": + yield requested_device + elif requested_device == tf.distribute.MirroredStrategy: + strategy = requested_device(gpus_for_testing()) + with strategy.scope(): + yield strategy + elif isinstance(requested_device, str): + if requested_device == "gpu": + # we use GPU:0 because the virtual device we created is the + # only one in the first GPU (so first in the list of virtual devices). + requested_device += ":0" + elif requested_device == "cpu": + requested_device = "cpu" + else: + raise KeyError("Invalid device: " + requested_device) + with tf.device(requested_device): + yield requested_device + + +def get_marks(device_name): + if device_name == "gpu" or device_name == tf.distribute.MirroredStrategy: + return [pytest.mark.needs_gpu] + else: + return [] + + +def pytest_generate_tests(metafunc): + marker = metafunc.definition.get_closest_marker("with_device") + if marker is None: + # tests which don't have the "with_device" mark are executed on CPU + # to ensure reproducibility. We can't let TensorFlow decide + # where to place the ops. + devices = ["cpu"] + else: + devices = marker.args[0] + + parameters = [pytest.param(x, marks=get_marks(x)) for x in devices] + metafunc.parametrize("device", parameters, indirect=True) + + +def pytest_collection_modifyitems(items): + for item in items: + if item.get_closest_marker("needs_gpu") is not None: + if not is_gpu_available(): + item.add_marker(pytest.mark.skip("The gpu is not available.")) + + +def assert_not_allclose(a, b, **kwargs): + """Assert that two numpy arrays, do not have near values. + + Args: + a: the first value to compare. + b: the second value to compare. + **kwargs: additional keyword arguments to be passed to the underlying + `np.testing.assert_allclose` call. + + Raises: + AssertionError: If `a` and `b` are unexpectedly close at all elements. + """ + try: + np.testing.assert_allclose(a, b, **kwargs) + except AssertionError: + return + raise AssertionError("The two values are close at all elements") + + +def assert_allclose_according_to_type( + a, + b, + rtol=1e-6, + atol=1e-6, + float_rtol=1e-6, + float_atol=1e-6, + half_rtol=1e-3, + half_atol=1e-3, + bfloat16_rtol=1e-2, + bfloat16_atol=1e-2, +): + """ + Similar to tf.test.TestCase.assertAllCloseAccordingToType() + but this doesn't need a subclassing to run. + """ + a = np.array(a) + b = np.array(b) + # types with lower tol are put later to overwrite previous ones. + if ( + a.dtype == np.float32 + or b.dtype == np.float32 + or a.dtype == np.complex64 + or b.dtype == np.complex64 + ): + rtol = max(rtol, float_rtol) + atol = max(atol, float_atol) + if a.dtype == np.float16 or b.dtype == np.float16: + rtol = max(rtol, half_rtol) + atol = max(atol, half_atol) + if a.dtype == tf.bfloat16.as_numpy_dtype or b.dtype == tf.bfloat16.as_numpy_dtype: + rtol = max(rtol, bfloat16_rtol) + atol = max(atol, bfloat16_atol) + + np.testing.assert_allclose(a, b, rtol=rtol, atol=atol) + + +def discover_classes(module, parent, class_exceptions): + """ + Args: + module: a module in which to search for classes that inherit from the parent class + parent: the parent class that identifies classes in the module that should be tested + class_exceptions: a list of specific classes that should be excluded when + discovering classes in a module + + Returns: + a list of classes for testing using pytest for parameterized tests + """ + + classes = [ + class_info[1] + for class_info in inspect.getmembers(module, inspect.isclass) + if issubclass(class_info[1], parent) and not class_info[0] in class_exceptions + ] + + return classes diff --git a/opennmt/tfa/utils/types.py b/opennmt/tfa/utils/types.py new file mode 100644 index 000000000..300471c22 --- /dev/null +++ b/opennmt/tfa/utils/types.py @@ -0,0 +1,82 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Types for typing functions signatures.""" + +import importlib + +from typing import Callable, List, Union + +import numpy as np +import tensorflow as tf + +from packaging.version import Version + +# Find KerasTensor. +if Version(tf.__version__).release >= Version("2.16").release: + # Determine if loading keras 2 or 3. + if ( + hasattr(tf.keras, "version") + and Version(tf.keras.version()).release >= Version("3.0").release + ): + from keras import KerasTensor + else: + from tf_keras.src.engine.keras_tensor import KerasTensor +elif Version(tf.__version__).release >= Version("2.13").release: + from keras.src.engine.keras_tensor import KerasTensor +elif Version(tf.__version__).release >= Version("2.5").release: + from keras.engine.keras_tensor import KerasTensor +else: + from tensorflow.python.keras.engine.keras_tensor import KerasTensor + + +Number = Union[ + float, + int, + np.float16, + np.float32, + np.float64, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, +] + +Initializer = Union[None, dict, str, Callable, tf.keras.initializers.Initializer] +Regularizer = Union[None, dict, str, Callable, tf.keras.regularizers.Regularizer] +Constraint = Union[None, dict, str, Callable, tf.keras.constraints.Constraint] +Activation = Union[None, str, Callable] +if importlib.util.find_spec("tensorflow.keras.optimizers.legacy") is not None: + Optimizer = Union[ + tf.keras.optimizers.Optimizer, tf.keras.optimizers.legacy.Optimizer, str + ] +else: + Optimizer = Union[tf.keras.optimizers.Optimizer, str] + +TensorLike = Union[ + List[Union[Number, list]], + tuple, + Number, + np.ndarray, + tf.Tensor, + tf.SparseTensor, + tf.Variable, + KerasTensor, +] +FloatTensorLike = Union[tf.Tensor, float, np.float16, np.float32, np.float64] +AcceptableDTypes = Union[tf.DType, np.dtype, type, int, str, None] diff --git a/opennmt/utils/decoding.py b/opennmt/utils/decoding.py index 3c69a98f2..4078ff090 100644 --- a/opennmt/utils/decoding.py +++ b/opennmt/utils/decoding.py @@ -4,7 +4,8 @@ import collections import tensorflow as tf -import tensorflow_addons as tfa + +import opennmt.tfa as tfa from opennmt import constants from opennmt.utils import misc diff --git a/setup.py b/setup.py index 1a1100897..c87720dd5 100644 --- a/setup.py +++ b/setup.py @@ -75,7 +75,7 @@ def get_project_version(): "pyyaml>=5.3,<7", "rouge>=1.0,<2", "sacrebleu>=1.5.0,<3", - "tensorflow-addons>=0.16,<0.22", + "typeguard>=2.7,<3.0.0", ], extras_require={ "tensorflow": [