Skip to content

Commit

Permalink
AdaProp Optimizer in Optax/KitchenSink
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 572989782
  • Loading branch information
init2winit Team authored and copybara-github committed Oct 12, 2023
1 parent 9601a98 commit ee3bf7a
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 8 deletions.
4 changes: 4 additions & 0 deletions init2winit/optimizer_lib/kitchen_sink/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""Kitchen Sink: decomposing optimizers in JAX."""

from init2winit.optimizer_lib.kitchen_sink._src.alias import adapropw
from init2winit.optimizer_lib.kitchen_sink._src.alias import nadamw
from init2winit.optimizer_lib.kitchen_sink._src.core import kitchen_sink
from init2winit.optimizer_lib.kitchen_sink._src.transform import add_decayed_weights
Expand All @@ -37,6 +38,7 @@
from init2winit.optimizer_lib.kitchen_sink._src.transform import PreconditionBySecondMomentCoordinateWiseState
from init2winit.optimizer_lib.kitchen_sink._src.transform import sanitize_values
from init2winit.optimizer_lib.kitchen_sink._src.transform import scale_by_adam
from init2winit.optimizer_lib.kitchen_sink._src.transform import scale_by_adaprop
from init2winit.optimizer_lib.kitchen_sink._src.transform import scale_by_amsgrad
from init2winit.optimizer_lib.kitchen_sink._src.transform import scale_by_learning_rate
from init2winit.optimizer_lib.kitchen_sink._src.transform import scale_by_nadam
Expand All @@ -49,6 +51,7 @@

__all__ = (
'nadamw',
'adapropw',
'kitchen_sink',
'bias_correction',
'BiasCorrectionState',
Expand All @@ -71,6 +74,7 @@
'sanitize_values',
'scale_by_adam',
'scale_by_amsgrad',
'scale_by_adaprop',
'scale_by_learning_rate',
'scale_by_nadam',
'ScaleByAdamState',
Expand Down
54 changes: 54 additions & 0 deletions init2winit/optimizer_lib/kitchen_sink/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Aliases for optimizers not found in optax."""
from typing import Any, Callable, Optional, Union
from init2winit.optimizer_lib.kitchen_sink._src import transform
import jax.numpy as jnp
import optax


Expand Down Expand Up @@ -67,3 +68,56 @@ def nadamw(
transform.scale_by_nadam(b1, b2, eps, eps_root, debias),
optax.add_decayed_weights(weight_decay, weight_decay_mask),
transform.scale_by_learning_rate(learning_rate))


def adapropw(
learning_rate: optax.ScalarOrSchedule,
alpha: float = 1.0,
b1: float = 0.9,
b3: float = 1.0,
b4: float = 0.999,
eps: float = 1e-8,
use_nesterov: bool = False,
quantized_dtype: str = 'float32',
weight_decay: float = 0.0,
weight_decay_mask: Optional[Union[Any, Callable[[optax.Params],
Any]]] = None,
) -> optax.GradientTransformation:
"""Rescale updates according to the AdaProp algorithm.
Args:
learning_rate: this is a fixed global scaling factor.
alpha: upper bound on bet.
b1: decay rate for the exponentially weighted average of grads.
b3: decay rate for the exponentially weighted average of max grads.
b4: decay rate for the exponentially weighted average of reward.
eps: term added to the denominator to improve numerical stability.
use_nesterov: Whether to use Nesterov-style update.
quantized_dtype: type of the quantized input. Allowed options are
'bfloat16' and 'float32'. If floating-point type is specified,
accumulators are stored as such type, instead of quantized integers.
weight_decay: strength of the weight decay regularization. Note that this
weight decay is multiplied with the learning rate. This is consistent with
other frameworks such as PyTorch, but different from (Loshchilov et al,
2019) where the weight decay is only multiplied with the "schedule
multiplier", but not the base learning rate.
weight_decay_mask: a tree with same structure as (or a prefix of) the params
PyTree, or a Callable that returns such a pytree given the params/updates.
The leaves should be booleans, `True` for leaves/subtrees you want to
apply the weight decay to, and `False` for those you want to skip. Note
that the Nadam gradient transformations are applied to all parameters.
Returns:
An (init_fn, update_fn) tuple.
"""
if quantized_dtype == 'float32':
q_dtype = jnp.float32
else:
q_dtype = jnp.bfloat16
return optax.chain(
transform.scale_by_adaprop(alpha=alpha, b1=b1, b3=b3, b4=b4,
eps=eps, use_nesterov=use_nesterov,
quantized_dtype=q_dtype),
optax.add_decayed_weights(weight_decay, weight_decay_mask),
transform.scale_by_learning_rate(learning_rate, flip_sign=True),
)
114 changes: 106 additions & 8 deletions init2winit/optimizer_lib/kitchen_sink/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,18 @@ def _update_moment(updates, moments, decay, order):


def _update_first_moment_variance_preserved(updates, moments, decay):
"""Applies variance preserved EMA.
Multiplies incoming gradient by sqrt{1-beta^2} as opposed to 1-beta.
"""Applies variance preserved EMA.
Multiplies incoming gradient by sqrt{1-beta^2} as opposed to 1-beta.
Introduces bias.
Args:
updates: updates.
moments: moments,
decay: the decay parameter.
Returns:
Variance Preserved EMA.
Variance Preserved EMA.
"""
return jax.tree_map(
lambda g, t: ((1 - decay**2) ** 0.5) * g + decay * t,
Expand Down Expand Up @@ -571,7 +571,7 @@ def scale_by_adaptive_gd(
Args:
init_r_squared: initial guess for r^2.
Returns:
An (init_fn, update_fn) tuple.
"""
Expand Down Expand Up @@ -634,7 +634,7 @@ def scale_by_layerwise_adaptive_gd(
init_r_squared: float = 1.0,
) -> optax.GradientTransformation:
"""Rescale updates according to LAYER-WISE Adaptive GD.
Args:
init_r_squared: initial guess for r^2.
Expand Down Expand Up @@ -891,7 +891,7 @@ def scale_by_coordinate_wise_adaptive_gd_simple(
eps: float = 1e-8,
) -> optax.GradientTransformation:
"""Rescale updates according to simpler COORDINATE-WISE Adaptive GD.
Args:
init_r_squared: Initial guess for r^2.
eps: Initial value for mu_sum.
Expand Down Expand Up @@ -1249,6 +1249,104 @@ def update_fn(updates, state, params=None):
return optax.GradientTransformation(init_fn, update_fn)


class ScaleByAdapropState(NamedTuple):
"""State for the AdaProp algorithm."""
count: chex.Array # shape=(), dtype=jnp.int32.
pp: optax.Updates
mu: optax.Updates
nu: optax.Updates
gain: optax.Updates


def scale_by_adaprop(
alpha: float = 1.0,
b1: float = 0.9,
b3: float = 1.0,
b4: float = 0.9,
eps: float = 1e-8,
use_nesterov: bool = False,
quantized_dtype: jnp.dtype = jnp.float32,
) -> optax.GradientTransformation:
"""Rescale updates according to the AdaProp algorithm.
Args:
alpha: upper bound on bet.
b1: decay rate for the exponentially weighted average of grads.
# b2: decay rate for the exponentially weighted average of absolute grads
# is omitted because it is calculated from alpha and b1.
b3: decay rate for the exponentially weighted average of max grads.
b4: decay rate for the exponentially weighted average of reward.
eps: term added to the denominator to improve numerical stability.
use_nesterov: Whether to use Nesterov-style update.
quantized_dtype: type of the quantized input. Allowed options are
jnp.bfloat16 and jnp.float32. If floating-point type is specified,
accumulators are stored as such type, instead of quantized integers.
Returns:
An (init_fn, update_fn) tuple.
"""

def init_fn(params):
prev_params = jax.tree_map(
lambda p: jnp.zeros_like(p, dtype=quantized_dtype), params
)
mu = jax.tree_map(
lambda p: jnp.zeros_like(p, dtype=quantized_dtype), params
)
nu = jax.tree_map(
lambda p: jnp.zeros_like(p, dtype=quantized_dtype), params
)
gain = jax.tree_map(
lambda p: jnp.ones_like(p, dtype=quantized_dtype), params
)

return ScaleByAdapropState(
count=jnp.zeros([], jnp.int32),
pp=prev_params,
mu=mu,
nu=nu,
gain=gain,
)

def update_fn(updates, state, params):
new_count = optax.safe_int32_increment(state.count)
b2 = 1.0 - (1.0 - b1)/alpha
mu = jax.tree_map(lambda g, t: (1-b1)*g + b1*t, updates, state.mu)
if use_nesterov:
mu2 = jax.tree_map(lambda g, t: (1-b1)*g + b1*t, updates, mu)
mu_hat = _bias_correction(mu2, b1, new_count)
else:
mu_hat = _bias_correction(mu, b1, new_count)
nu = jax.tree_map(lambda g, t: (1-b2)*jnp.abs(g) + b2*t, updates, state.nu)
nu_hat = _bias_correction(nu, b2, new_count)
pp = jax.tree_map(lambda p, t: (1-b4)*p + b4*t, params, state.pp)
pp_hat = _bias_correction(pp, b4, new_count)
param_change = jax.tree_map(lambda p, i: p - i, params, pp_hat)
g_max = jax.tree_map(lambda g, n: jnp.maximum(jnp.abs(g), n),
updates, nu_hat)
gain = jax.tree_map(
lambda r, p, g, x: jnp.maximum(b3*r - p*g/(x + eps), 0.0),
state.gain, param_change, updates, g_max)
wealth = jax.tree_map(lambda g: 1.0 + g, gain)

bet_factor = jax.tree_map(
lambda m, n: m / (n + eps),
mu_hat,
nu_hat,
)
new_updates = jax.tree_map(lambda b, w: b * w,
bet_factor, wealth)
return new_updates, ScaleByAdapropState(
count=new_count,
pp=pp,
mu=mu,
nu=nu,
gain=gain,
)

return optax.GradientTransformation(init_fn, update_fn)


class PreconditionByRssState(NamedTuple):
"""State holding the sum of gradient squares to date."""

Expand Down

0 comments on commit ee3bf7a

Please sign in to comment.