Skip to content

Commit

Permalink
Better support for optimizers that require access to the gradient bef…
Browse files Browse the repository at this point in the history
…ore aggregation.

PiperOrigin-RevId: 575870781
  • Loading branch information
init2winit Team authored and copybara-github committed Oct 23, 2023
1 parent b7c5fc2 commit 0fe51a8
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 9 deletions.
2 changes: 2 additions & 0 deletions init2winit/optimizer_lib/gradient_accumulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"""
from typing import NamedTuple, Optional

from init2winit.optimizer_lib import utils as optimizer_utils
import jax
import jax.numpy as jnp
import optax
Expand Down Expand Up @@ -110,6 +111,7 @@ def init_fn(params):
num_per_step_batches=jnp.zeros([], jnp.int32),
accumulations=jax.tree_map(jnp.zeros_like, params))

@optimizer_utils.no_gradient_aggregation
def update_fn(updates, state, params=None):
zeros_params = jax.tree_map(jnp.zeros_like, state.accumulations)

Expand Down
27 changes: 22 additions & 5 deletions init2winit/optimizer_lib/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from absl.testing import absltest
import chex
from init2winit.optimizer_lib import optimizers
from init2winit.optimizer_lib.utils import extract_field
from init2winit.optimizer_lib import utils
import jax.numpy as jnp
from ml_collections.config_dict import ConfigDict

Expand All @@ -46,13 +46,30 @@ def test_adam(self):
del update_fn
optimizer_state = init_fn({'foo': jnp.ones(10)})
# Test that we can extract 'count'.
chex.assert_type(extract_field(optimizer_state, 'count'), int)
chex.assert_type(utils.extract_field(optimizer_state, 'count'), int)
# Test that we can extract 'nu'.
chex.assert_shape(extract_field(optimizer_state, 'nu')['foo'], (10,))
chex.assert_shape(utils.extract_field(optimizer_state, 'nu')['foo'], (10,))
# Test that we can extract 'mu'.
chex.assert_shape(extract_field(optimizer_state, 'mu')['foo'], (10,))
chex.assert_shape(utils.extract_field(optimizer_state, 'mu')['foo'], (10,))
# Test that attemptping to extract a nonexistent field "abc" returns None.
chex.assert_equal(extract_field(optimizer_state, 'abc'), None)
chex.assert_equal(utils.extract_field(optimizer_state, 'abc'), None)


class AggregationDecoratorTest(chex.TestCase):
"""Test the requires_gradient_aggregation decorator."""

def test_no_aggregation(self):
@utils.no_gradient_aggregation
def dummy_update_fn(updates, state, params):
del updates, state, params

self.assertFalse(utils.requires_gradient_aggregation(dummy_update_fn))

def test_with_aggregation(self):
def dummy_update_fn(updates, state, params):
del updates, state, params

self.assertTrue(utils.requires_gradient_aggregation(dummy_update_fn))


if __name__ == '__main__':
Expand Down
41 changes: 41 additions & 0 deletions init2winit/optimizer_lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,44 @@ def extract_field(state, field_name):

# If we didn't find anything, return None.
return None


def no_gradient_aggregation(
update_fn: optax.TransformUpdateFn,
) -> optax.TransformUpdateFn:
"""Decorator for signaling that the optimizer requires access to the gradient before aggregation.
Standard Optax optimizers assume that the `update` argument passed to the
`update_fn` function has already been aggregated over the batch dimension,
however, in some cases the full jacobian is useful (e.g., in order to
calculate batch statistics). This decorator is used to signal the trainer not
to perform any gradient aggregation before calling the optimizer's update
function. It will then be the responsibility of the optimizer to aggregate the
updates before returning it.
Args:
update_fn: An optax update transformation function.
Returns:
The same callable, with an additional attribute that signals that no
aggregation should be performed.
"""
setattr(update_fn, 'init2winit_requires_gradient_aggregation', False)
return update_fn


def requires_gradient_aggregation(
update_fn: optax.TransformUpdateFn,
) -> bool:
"""Returns whether the given update_fn requires the gradient to be aggregated.
See no_gradient_aggregation() above for additional details.
Args:
update_fn: An Optax update transformation function.
Returns:
True if the gradient should be aggregated before invoking the update
function, False otherwise.
"""
return getattr(update_fn, 'init2winit_requires_gradient_aggregation', True)
9 changes: 5 additions & 4 deletions init2winit/trainer_lib/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from absl import logging
from init2winit import utils
from init2winit.model_lib import model_utils
from init2winit.optimizer_lib import gradient_accumulator
from init2winit.optimizer_lib import utils as optimizer_utils
from init2winit.trainer_lib import base_trainer
from init2winit.trainer_lib import trainer_utils
import jax
Expand Down Expand Up @@ -104,10 +104,11 @@ def opt_cost(params):
# If we are accumulating gradients, we handle gradient synchronization inside
# the optimizer so that we can only sync when actually updating the model.
if axis_name is not None:
if isinstance(optimizer_state, gradient_accumulator.GradientAccumulatorState): # pylint:disable=g-line-too-long
cost_value = lax.pmean(cost_value, axis_name=axis_name)
if optimizer_utils.requires_gradient_aggregation(optimizer_update_fn):
cost_value, grad = lax.pmean((cost_value, grad), axis_name='batch')
else:
cost_value, grad = lax.pmean((cost_value, grad), axis_name=axis_name)
# Skip gradient aggration, only aggregate the cost value.
cost_value = lax.pmean(cost_value, axis_name='batch')

grad_norm = jnp.sqrt(model_utils.l2_regularization(grad, 0))
# TODO(znado): move to inside optax gradient clipping.
Expand Down

0 comments on commit 0fe51a8

Please sign in to comment.