From cc1799404b1cc2987679980d2b46e7780360a608 Mon Sep 17 00:00:00 2001 From: Zachary Garrett Date: Fri, 6 Sep 2024 05:44:54 -0700 Subject: [PATCH] Add support for Tensor learning rates and gradients with mixed types. PiperOrigin-RevId: 671726026 --- RELEASE.md | 6 ++-- .../python/learning/optimizers/adagrad.py | 27 ++++++++------- .../learning/optimizers/adagrad_test.py | 20 +++++++++-- .../python/learning/optimizers/adam.py | 34 +++++++++---------- .../python/learning/optimizers/adam_test.py | 25 +++++++++++--- .../python/learning/optimizers/sgdm.py | 15 ++++---- .../python/learning/optimizers/sgdm_test.py | 20 +++++++++-- 7 files changed, 96 insertions(+), 51 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index d275652d3a..fd98838442 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -26,8 +26,10 @@ and this project adheres to ### Fixed -* A bug where `tff.learning.optimizers.build_adafactor(...)` would update its - step counter twice upon every invocation of `.next()`. +* A bug where `tff.learning.optimizers.build_adafactor` would update its step + counter twice upon every invocation of `.next()`. +* A bug where tensor learning rates for `tff.learning.optimizers.build_sgdm` + would fail with mixed dtype gradients. ### Removed diff --git a/tensorflow_federated/python/learning/optimizers/adagrad.py b/tensorflow_federated/python/learning/optimizers/adagrad.py index 256e129931..52fd57874e 100644 --- a/tensorflow_federated/python/learning/optimizers/adagrad.py +++ b/tensorflow_federated/python/learning/optimizers/adagrad.py @@ -27,7 +27,7 @@ _HPARAMS_KEYS = [optimizer.LEARNING_RATE_KEY, _EPSILON_KEY] State = TypeVar('State', bound=collections.OrderedDict[str, Any]) -Hparams = TypeVar('Hparams', bound=collections.OrderedDict[str, float]) +Hparams = TypeVar('Hparams', bound=collections.OrderedDict[str, Any]) class _Adagrad(optimizer.Optimizer[State, optimizer.Weights, Hparams]): @@ -40,16 +40,19 @@ def __init__( epsilon: optimizer.Float = 1e-7, ): """Initializes SGD optimizer.""" - if learning_rate < 0.0: + if not tf.is_symbolic_tensor(learning_rate) and learning_rate < 0.0: raise ValueError( f'Adagrad `learning_rate` must be nonnegative, found {learning_rate}.' ) - if initial_preconditioner_value < 0.0: + if ( + not tf.is_symbolic_tensor(initial_preconditioner_value) + and initial_preconditioner_value < 0.0 + ): raise ValueError( 'Adagrad `initial_preconditioner_value` must be nonnegative, found ' f'{initial_preconditioner_value}.' ) - if epsilon < 0.0: + if not tf.is_symbolic_tensor(epsilon) and epsilon < 0.0: raise ValueError(f'Adagrad epsilon must be nonnegative, found {epsilon}.') self._lr = learning_rate self._initial_precond = initial_preconditioner_value @@ -57,14 +60,15 @@ def __init__( def initialize(self, specs: Any) -> State: initial_preconditioner = tf.nest.map_structure( - lambda s: tf.ones(s.shape, s.dtype) * self._initial_precond, specs + lambda s: tf.ones(s.shape, s.dtype) + * tf.cast(self._initial_precond, s.dtype), + specs, ) - state = collections.OrderedDict([ + return collections.OrderedDict([ (optimizer.LEARNING_RATE_KEY, self._lr), (_EPSILON_KEY, self._epsilon), (_PRECONDITIONER_KEY, initial_preconditioner), ]) - return state def next( self, state: State, weights: optimizer.Weights, gradients: Any @@ -82,7 +86,9 @@ def _adagrad_update(w, p, g): if g is None: return w, p p = p + tf.math.square(g) - w = w - lr * g / tf.math.sqrt(p + epsilon) + w = w - tf.cast(lr, g.dtype) * g / tf.math.sqrt( + p + tf.cast(epsilon, p.dtype) + ) return w, p updated_weights, updated_preconditioner = nest_utils.map_at_leaves( @@ -99,11 +105,6 @@ def get_hparams(self, state: State) -> Hparams: return collections.OrderedDict([(k, state[k]) for k in _HPARAMS_KEYS]) def set_hparams(self, state: State, hparams: Hparams) -> State: - # TODO: b/245962555 - Find an alternative to `update_struct` if it - # interferes with typing guarantees. - # We use `tff.structure.update_struct` (rather than something like - # `copy.deepcopy`) to ensure that this can be called within a - # `tff.Computation`. return structure.update_struct(state, **hparams) diff --git a/tensorflow_federated/python/learning/optimizers/adagrad_test.py b/tensorflow_federated/python/learning/optimizers/adagrad_test.py index 4501f132e7..d2f30cf3ef 100644 --- a/tensorflow_federated/python/learning/optimizers/adagrad_test.py +++ b/tensorflow_federated/python/learning/optimizers/adagrad_test.py @@ -145,8 +145,8 @@ def random_vector(): genarator.normal(shape=s.shape, dtype=s.dtype) for s in weight_spec ] - intial_weight = random_vector() - model_variables_fn = lambda: [tf.Variable(v) for v in intial_weight] + initial_weight = random_vector() + model_variables_fn = lambda: [tf.Variable(v) for v in initial_weight] gradients = [random_vector() for _ in range(steps)] tff_optimizer_fn = lambda: adagrad.build_adagrad(0.01) keras_optimizer_fn = lambda: tf.keras.optimizers.Adagrad(0.01) @@ -227,6 +227,22 @@ def test_set_get_hparams_is_no_op(self, spec): updated_state = optimizer.set_hparams(state, hparams) self.assertEqual(state, updated_state) + def test_lr_with_different_weight_dtypes(self): + weights = ( + tf.constant([0.1], dtype=tf.float32), + tf.constant(1.0, dtype=tf.float64), + tf.constant([10.0, 10.0], dtype=tf.bfloat16), + ) + adagrad_optimizer = adagrad.build_adagrad( + learning_rate=tf.constant(0.1, dtype=tf.float32), + initial_preconditioner_value=tf.constant(0.1, dtype=tf.float32), + epsilon=tf.constant(0.1, dtype=tf.float64), + ) + state = adagrad_optimizer.initialize(weights) + adagrad_optimizer.next( + state, weights, tf.nest.map_structure(tf.zeros_like, weights) + ) + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_federated/python/learning/optimizers/adam.py b/tensorflow_federated/python/learning/optimizers/adam.py index e45a09283e..928655e5ec 100644 --- a/tensorflow_federated/python/learning/optimizers/adam.py +++ b/tensorflow_federated/python/learning/optimizers/adam.py @@ -36,7 +36,7 @@ ] State = TypeVar('State', bound=collections.OrderedDict[str, Any]) -Hparams = TypeVar('Hparams', bound=collections.OrderedDict[str, float]) +Hparams = TypeVar('Hparams', bound=collections.OrderedDict[str, Any]) class _Adam(optimizer.Optimizer[State, optimizer.Weights, Hparams]): @@ -50,19 +50,19 @@ def __init__( epsilon: optimizer.Float = 1e-7, ): """Initializes Adam optimizer.""" - if learning_rate < 0.0: + if not tf.is_symbolic_tensor(learning_rate) and learning_rate < 0.0: raise ValueError( f'Adam `learning_rate` must be nonnegative, found {learning_rate}.' ) - if beta_1 < 0.0 or beta_1 > 1.0: + if not tf.is_symbolic_tensor(beta_1) and (beta_1 < 0.0 or beta_1 > 1.0): raise ValueError( f'Adam `beta_1` must be in the range [0.0, 1.0], found {beta_1}.' ) - if beta_2 < 0.0 or beta_2 > 1.0: + if not tf.is_symbolic_tensor(beta_2) and (beta_2 < 0.0 or beta_2 > 1.0): raise ValueError( f'Adam `beta_2` must be in the range [0.0, 1.0], found {beta_2}.' ) - if epsilon < 0.0: + if not tf.is_symbolic_tensor(epsilon) and epsilon < 0.0: raise ValueError(f'Adam `epsilon` must be nonnegative, found {epsilon}.') self._lr = learning_rate self._beta_1 = beta_1 @@ -76,7 +76,7 @@ def initialize(self, specs: Any) -> State: initial_preconditioner = tf.nest.map_structure( lambda s: tf.zeros(s.shape, s.dtype), specs ) - state = collections.OrderedDict([ + return collections.OrderedDict([ (optimizer.LEARNING_RATE_KEY, self._lr), (_BETA_1_KEY, self._beta_1), (_BETA_2_KEY, self._beta_2), @@ -85,7 +85,6 @@ def initialize(self, specs: Any) -> State: (_ACCUMULATOR_KEY, initial_accumulator), (_PRECONDITIONER_KEY, initial_preconditioner), ]) - return state def next( self, state: State, weights: optimizer.Weights, gradients: Any @@ -103,18 +102,24 @@ def next( optimizer.check_weights_state_match( weights, preconditioner, 'preconditioner' ) + if tf.is_tensor(beta_1): + casted_step = tf.cast(step, beta_1.dtype) + else: + casted_step = step normalized_lr = ( lr - * tf.math.sqrt((1 - tf.math.pow(beta_2, tf.cast(step, tf.float32)))) - / (1 - tf.math.pow(beta_1, tf.cast(step, tf.float32))) + * tf.math.sqrt((1.0 - tf.math.pow(beta_2, casted_step))) + / (1.0 - tf.math.pow(beta_1, casted_step)) ) def _adam_update(w, a, p, g): if g is None: return w, a, p - a = a + (g - a) * (1 - beta_1) - p = p + (tf.math.square(g) - p) * (1 - beta_2) - w = w - normalized_lr * a / (tf.math.sqrt(p) + epsilon) + a = a + (g - a) * (1 - tf.cast(beta_1, a.dtype)) + p = p + (tf.math.square(g) - p) * (1 - tf.cast(beta_2, p.dtype)) + w = w - tf.cast(normalized_lr, a.dtype) * a / ( + tf.math.sqrt(p) + tf.cast(epsilon, p.dtype) + ) return w, a, p updated_weights, updated_accumulator, updated_preconditioner = ( @@ -142,11 +147,6 @@ def get_hparams(self, state: State) -> Hparams: return collections.OrderedDict([(k, state[k]) for k in _HPARAMS_KEYS]) def set_hparams(self, state: State, hparams: Hparams) -> State: - # TODO: b/245962555 - Find an alternative to `update_struct` if it - # interferes with typing guarantees. - # We use `tff.structure.update_struct` (rather than something like - # `copy.deepcopy`) to ensure that this can be called within a - # `tff.Computation`. return structure.update_struct(state, **hparams) diff --git a/tensorflow_federated/python/learning/optimizers/adam_test.py b/tensorflow_federated/python/learning/optimizers/adam_test.py index 3fcac74d61..1d6b0cc358 100644 --- a/tensorflow_federated/python/learning/optimizers/adam_test.py +++ b/tensorflow_federated/python/learning/optimizers/adam_test.py @@ -55,9 +55,7 @@ def test_math(self): for _ in range(4): state, weights = optimizer.next(state, weights, gradients) history.append(weights) - self.assertAllClose( - [[1.0], [0.9000007], [0.8000017], [0.700002], [0.600003]], history - ) + self.assertAllClose([[1.0], [0.9], [0.8], [0.7], [0.6]], history) @parameterized.named_parameters( ('scalar_spec', _SCALAR_SPEC), @@ -142,8 +140,8 @@ def random_vector(): genarator.normal(shape=s.shape, dtype=s.dtype) for s in weight_spec ] - intial_weight = random_vector() - model_variables_fn = lambda: [tf.Variable(v) for v in intial_weight] + initial_weight = random_vector() + model_variables_fn = lambda: [tf.Variable(v) for v in initial_weight] gradients = [random_vector() for _ in range(steps)] tff_optimizer_fn = lambda: adam.build_adam(0.01, 0.9, 0.999) keras_optimizer_fn = lambda: tf.keras.optimizers.Adam(0.01, 0.9, 0.999) @@ -225,6 +223,23 @@ def test_set_get_hparams_is_no_op(self, spec): updated_state = optimizer.set_hparams(state, hparams) self.assertEqual(state, updated_state) + def test_lr_with_different_weight_dtypes(self): + weights = ( + tf.constant([0.1], dtype=tf.float32), + tf.constant(1.0, dtype=tf.float64), + tf.constant([10.0, 10.0], dtype=tf.bfloat16), + ) + adam_optimizer = adam.build_adam( + learning_rate=tf.constant(0.1, dtype=tf.float32), + beta_1=tf.constant(0.1, dtype=tf.float32), + beta_2=tf.constant(0.1, dtype=tf.float32), + epsilon=tf.constant(0.1, dtype=tf.float64), + ) + state = adam_optimizer.initialize(weights) + adam_optimizer.next( + state, weights, tf.nest.map_structure(tf.zeros_like, weights) + ) + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_federated/python/learning/optimizers/sgdm.py b/tensorflow_federated/python/learning/optimizers/sgdm.py index d4fcc9872b..78dfe27e27 100644 --- a/tensorflow_federated/python/learning/optimizers/sgdm.py +++ b/tensorflow_federated/python/learning/optimizers/sgdm.py @@ -26,7 +26,7 @@ _ACCUMULATOR_KEY = 'accumulator' State = TypeVar('State', bound=collections.OrderedDict[str, Any]) -Hparams = TypeVar('Hparams', bound=collections.OrderedDict[str, float]) +Hparams = TypeVar('Hparams', bound=collections.OrderedDict[str, Any]) class _SGD(optimizer.Optimizer[State, optimizer.Weights, Hparams]): @@ -38,14 +38,16 @@ def __init__( momentum: Optional[optimizer.Float] = None, ): """Initializes SGD optimizer.""" - if learning_rate < 0.0: + if not tf.is_symbolic_tensor(learning_rate) and learning_rate < 0.0: raise ValueError( f'SGD `learning_rate` must be nonnegative, found {learning_rate}.' ) if momentum: # We should only track momentum as a hparam in the case that it is both # specified and nonzero. - if momentum < 0.0 or momentum > 1.0: + if not tf.is_symbolic_tensor(momentum) and ( + momentum < 0.0 or momentum > 1.0 + ): raise ValueError( 'SGD `momentum` must be `None` or in the range [0, 1], found ' f'{momentum}.' @@ -77,7 +79,7 @@ def next( def _sgd_update(w, g): if g is None: return w - return w - lr * g + return w - tf.cast(lr, dtype=g.dtype) * g updated_weights = nest_utils.map_at_leaves( _sgd_update, weights, gradients @@ -111,11 +113,6 @@ def get_hparams(self, state: State) -> Hparams: return collections.OrderedDict([(k, state[k]) for k in self._hparams_keys]) def set_hparams(self, state: State, hparams: Hparams) -> State: - # TODO: b/245962555 - Find an alternative to `update_struct` if it - # interferes with typing guarantees. - # We use `tff.structure.update_struct` (rather than something like - # `copy.deepcopy`) to ensure that this can be called within a - # `tff.Computation`. return structure.update_struct(state, **hparams) diff --git a/tensorflow_federated/python/learning/optimizers/sgdm_test.py b/tensorflow_federated/python/learning/optimizers/sgdm_test.py index 8b75ba654c..fd7613fb1b 100644 --- a/tensorflow_federated/python/learning/optimizers/sgdm_test.py +++ b/tensorflow_federated/python/learning/optimizers/sgdm_test.py @@ -43,7 +43,7 @@ def test_get_hparams_momentum(self, momentum_value): optimizer = sgdm.build_sgdm(0.01, momentum=momentum_value) state = optimizer.initialize(_SCALAR_SPEC) hparams = optimizer.get_hparams(state) - # Whether we specify None momentum or momentum 0.0, we shouldnt track the + # Whether we specify None momentum or momentum 0.0, we shouldn't track the # extra accumulator state. The implementation of next checks for the # presence or absence of momentum key--it should not be there in either # case. @@ -177,8 +177,8 @@ def random_vector(): genarator.normal(shape=s.shape, dtype=s.dtype) for s in weight_spec ] - intial_weight = random_vector() - model_variables_fn = lambda: [tf.Variable(v) for v in intial_weight] + initial_weight = random_vector() + model_variables_fn = lambda: [tf.Variable(v) for v in initial_weight] gradients = [random_vector() for _ in range(steps)] tff_optimizer_fn = lambda: sgdm.build_sgdm(learning_rate, momentum) @@ -306,6 +306,20 @@ def test_set_get_hparams_is_no_op_with_momentum(self, spec): updated_state = optimizer.set_hparams(state, hparams) self.assertEqual(state, updated_state) + def test_lr_with_different_weight_dtypes(self): + weights = ( + tf.constant([0.1], dtype=tf.float32), + tf.constant(1.0, dtype=tf.float64), + tf.constant([10.0, 10.0], dtype=tf.bfloat16), + ) + sgdm_optimizer = sgdm.build_sgdm( + learning_rate=tf.constant(0.1, dtype=tf.float32) + ) + state = sgdm_optimizer.initialize(weights) + sgdm_optimizer.next( + state, weights, tf.nest.map_structure(tf.zeros_like, weights) + ) + if __name__ == '__main__': tf.test.main()