Skip to content

Commit

Permalink
Remove tensorflow.addons dependency and bring needed classes
Browse files Browse the repository at this point in the history
  • Loading branch information
jordimas committed Aug 22, 2024
1 parent 6f3b952 commit b5fd468
Show file tree
Hide file tree
Showing 38 changed files with 6,888 additions and 15 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ jobs:
- name: Run tests
run: |
pytest --cov=opennmt --cov-report xml opennmt/tests
pytest opennmt/tfa/
- name: Upload coverage report
if: matrix.tensorflow == '2.13'
Expand Down
1 change: 1 addition & 0 deletions docs/generate-apidoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def document_function(output_dir, function_path):
def module_is_public(module):
return (
module.__name__.startswith("opennmt")
and not module.__name__.startswith("opennmt.tfa")
and hasattr(module, "__file__")
and module.__file__.endswith("__init__.py")
)
Expand Down
3 changes: 2 additions & 1 deletion opennmt/decoders/rnn_decoder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Define RNN-based decoders."""

import tensorflow as tf
import tensorflow_addons as tfa

import opennmt.tfa as tfa

from opennmt.decoders import decoder
from opennmt.layers import bridge, common, rnn, transformer
Expand Down
3 changes: 2 additions & 1 deletion opennmt/encoders/rnn_encoder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Define RNN-based encoders."""

import tensorflow as tf
import tensorflow_addons as tfa

import opennmt.tfa as tfa

from opennmt.encoders.encoder import Encoder, SequentialEncoder
from opennmt.layers import common, rnn
Expand Down
3 changes: 2 additions & 1 deletion opennmt/models/catalog.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Catalog of predefined models."""

import tensorflow as tf
import tensorflow_addons as tfa

import opennmt.tfa as tfa

from opennmt import config as config_util
from opennmt import decoders, encoders, inputters, layers
Expand Down
3 changes: 2 additions & 1 deletion opennmt/models/sequence_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa

import opennmt.tfa as tfa

from opennmt import inputters
from opennmt.models.model import Model
Expand Down
3 changes: 2 additions & 1 deletion opennmt/models/sequence_to_sequence.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Standard sequence-to-sequence model."""

import tensorflow as tf
import tensorflow_addons as tfa

import opennmt.tfa as tfa

from opennmt import config as config_util
from opennmt import constants, inputters
Expand Down
7 changes: 3 additions & 4 deletions opennmt/optimizers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
import inspect

import tensorflow as tf
import tensorflow_addons as tfa

from packaging.version import Version
from tensorflow_addons.optimizers.weight_decay_optimizers import (
DecoupledWeightDecayExtension,
)

import opennmt.tfa as tfa

from opennmt.tfa.optimizers.weight_decay_optimizers import DecoupledWeightDecayExtension
from opennmt.utils import misc

if Version(tf.__version__) >= Version("2.11.0"):
Expand Down
6 changes: 2 additions & 4 deletions opennmt/tests/optimizer_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import tensorflow as tf
import tensorflow_addons as tfa

from tensorflow_addons.optimizers.weight_decay_optimizers import (
DecoupledWeightDecayExtension,
)
import opennmt.tfa as tfa

from opennmt.optimizers import utils
from opennmt.tests import test_util
from opennmt.tfa.optimizers.weight_decay_optimizers import DecoupledWeightDecayExtension


class OptimizerTest(tf.test.TestCase):
Expand Down
2 changes: 2 additions & 0 deletions opennmt/tfa/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from opennmt.tfa import optimizers, rnn, seq2seq, text
from opennmt.tfa.utils import types
20 changes: 20 additions & 0 deletions opennmt/tfa/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import numpy as np
import pytest
import tensorflow as tf

import opennmt.tfa as tfa

from opennmt.tfa.utils.test_utils import ( # noqa: F401
data_format,
device,
maybe_run_functions_eagerly,
only_run_functions_eagerly,
pytest_addoption,
pytest_collection_modifyitems,
pytest_configure,
pytest_generate_tests,
pytest_make_parametrize_id,
run_with_mixed_precision_policy,
set_global_variables,
set_seeds,
)
6 changes: 6 additions & 0 deletions opennmt/tfa/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from opennmt.tfa.optimizers.lazy_adam import LazyAdam
from opennmt.tfa.optimizers.weight_decay_optimizers import (
AdamW,
DecoupledWeightDecayExtension,
extend_with_decoupled_weight_decay,
)
156 changes: 156 additions & 0 deletions opennmt/tfa/optimizers/lazy_adam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Variant of the Adam optimizer that handles sparse updates more efficiently.
Compared with the original Adam optimizer, the one in this file can
provide a large improvement in model training throughput for some
applications. However, it provides slightly different semantics than the
original Adam algorithm, and may lead to different empirical results.
"""

import importlib

from typing import Callable, Union

import tensorflow as tf

from typeguard import typechecked

from opennmt.tfa.utils.types import FloatTensorLike

if importlib.util.find_spec("tensorflow.keras.optimizers.legacy") is not None:
adam_optimizer_class = tf.keras.optimizers.legacy.Adam
else:
adam_optimizer_class = tf.keras.optimizers.Adam


@tf.keras.utils.register_keras_serializable(package="Addons")
class LazyAdam(adam_optimizer_class):
"""Variant of the Adam optimizer that handles sparse updates more
efficiently.
The original Adam algorithm maintains two moving-average accumulators for
each trainable variable; the accumulators are updated at every step.
This class provides lazier handling of gradient updates for sparse
variables. It only updates moving-average accumulators for sparse variable
indices that appear in the current batch, rather than updating the
accumulators for all indices. Compared with the original Adam optimizer,
it can provide large improvements in model training throughput for some
applications. However, it provides slightly different semantics than the
original Adam algorithm, and may lead to different empirical results.
Note, amsgrad is currently not supported and the argument can only be
False.
"""

@typechecked
def __init__(
self,
learning_rate: Union[FloatTensorLike, Callable] = 0.001,
beta_1: FloatTensorLike = 0.9,
beta_2: FloatTensorLike = 0.999,
epsilon: FloatTensorLike = 1e-7,
amsgrad: bool = False,
name: str = "LazyAdam",
**kwargs,
):
"""Constructs a new LazyAdam optimizer.
Args:
learning_rate: A `Tensor` or a floating point value. or a schedule
that is a `tf.keras.optimizers.schedules.LearningRateSchedule`
The learning rate.
beta_1: A `float` value or a constant `float` tensor.
The exponential decay rate for the 1st moment estimates.
beta_2: A `float` value or a constant `float` tensor.
The exponential decay rate for the 2nd moment estimates.
epsilon: A small constant for numerical stability.
This epsilon is "epsilon hat" in
[Adam: A Method for Stochastic Optimization. Kingma et al., 2014]
(http://arxiv.org/abs/1412.6980) (in the formula just
before Section 2.1), not the epsilon in Algorithm 1 of the paper.
amsgrad: `boolean`. Whether to apply AMSGrad variant of this
algorithm from the paper "On the Convergence of Adam and beyond".
Note that this argument is currently not supported and the
argument can only be `False`.
name: Optional name for the operations created when applying
gradients. Defaults to "LazyAdam".
**kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`,
`lr`, `decay`}. `clipnorm` is clip gradients by norm; `clipvalue`
is clip gradients by value, `decay` is included for backward
compatibility to allow time inverse decay of learning rate. `lr`
is included for backward compatibility, recommended to use
`learning_rate` instead.
"""
super().__init__(
learning_rate=learning_rate,
beta_1=beta_1,
beta_2=beta_2,
epsilon=epsilon,
amsgrad=amsgrad,
name=name,
**kwargs,
)

def _resource_apply_sparse(self, grad, var, indices):
var_dtype = var.dtype.base_dtype
lr_t = self._decayed_lr(var_dtype)
beta_1_t = self._get_hyper("beta_1", var_dtype)
beta_2_t = self._get_hyper("beta_2", var_dtype)
local_step = tf.cast(self.iterations + 1, var_dtype)
beta_1_power = tf.math.pow(beta_1_t, local_step)
beta_2_power = tf.math.pow(beta_2_t, local_step)
epsilon_t = tf.convert_to_tensor(self.epsilon, var_dtype)
lr = lr_t * tf.math.sqrt(1 - beta_2_power) / (1 - beta_1_power)

# \\(m := beta1 * m + (1 - beta1) * g_t\\)
m = self.get_slot(var, "m")
m_t_slice = beta_1_t * tf.gather(m, indices) + (1 - beta_1_t) * grad
m_update_op = self._resource_scatter_update(m, indices, m_t_slice)

# \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\)
v = self.get_slot(var, "v")
v_t_slice = beta_2_t * tf.gather(v, indices) + (1 - beta_2_t) * tf.math.square(
grad
)
v_update_op = self._resource_scatter_update(v, indices, v_t_slice)

# \\(variable += -learning_rate * m_t / (epsilon_t + sqrt(v_t))\\)
var_slice = lr * m_t_slice / (tf.math.sqrt(v_t_slice) + epsilon_t)
var_update_op = self._resource_scatter_sub(var, indices, var_slice)

return tf.group(*[var_update_op, m_update_op, v_update_op])

def _resource_scatter_update(self, resource, indices, update):
return self._resource_scatter_operate(
resource, indices, update, tf.raw_ops.ResourceScatterUpdate
)

def _resource_scatter_sub(self, resource, indices, update):
return self._resource_scatter_operate(
resource, indices, update, tf.raw_ops.ResourceScatterSub
)

def _resource_scatter_operate(self, resource, indices, update, resource_scatter_op):
resource_update_kwargs = {
"resource": resource.handle,
"indices": indices,
"updates": update,
}

return resource_scatter_op(**resource_update_kwargs)

def get_config(self):
return super().get_config()
Loading

0 comments on commit b5fd468

Please sign in to comment.