Skip to content

Commit

Permalink
Internal
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 588089542
  • Loading branch information
sourabh2k15 authored and copybara-github committed Dec 5, 2023
1 parent 48201f6 commit 9b7bb9f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions init2winit/optimizer_lib/kitchen_sink/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,7 +1259,7 @@ def scale_by_adaprop(
b3: float = 1.0,
b4: float = 0.9,
eps: float = 1e-8,
use_nesterov: bool = False,
use_nesterov: str = 'False',
quantized_dtype: jnp.dtype = jnp.float32,
) -> optax.GradientTransformation:
"""Rescale updates according to the AdaProp algorithm.
Expand Down Expand Up @@ -1307,7 +1307,7 @@ def update_fn(updates, state, params):
new_count = optax.safe_int32_increment(state.count)
b2 = 1.0 - (1.0 - b1)/alpha
mu = jax.tree_map(lambda g, t: (1-b1)*g + b1*t, updates, state.mu)
if use_nesterov:
if use_nesterov == 'True':
mu2 = jax.tree_map(lambda g, t: (1-b1)*g + b1*t, updates, mu)
mu_hat = _bias_correction(mu2, b1, new_count)
else:
Expand Down

0 comments on commit 9b7bb9f

Please sign in to comment.