Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PAL optimizer. #626

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions init2winit/optimizer_lib/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from init2winit.optimizer_lib import gradient_accumulator
from init2winit.optimizer_lib import kitchen_sink
from init2winit.optimizer_lib import online_newton_step
from init2winit.optimizer_lib import parabolic_approximation_line_search
from init2winit.optimizer_lib import pax_adafactor
from init2winit.optimizer_lib import samuel
from init2winit.optimizer_lib import sharpness_aware_minimization
Expand All @@ -36,6 +37,7 @@




def sgd(learning_rate, weight_decay, momentum=None, nesterov=False):
r"""A customizable gradient descent optimizer.

Expand Down
161 changes: 161 additions & 0 deletions init2winit/optimizer_lib/parabolic_approximation_line_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# coding=utf-8
# Copyright 2024 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Implementation of Parabolic Approximation Line Search (PAL).

Paper: https://arxiv.org/abs/1903.11991
Code: https://github.com/cogsys-tuebingen/PAL
"""

from typing import NamedTuple

from init2winit.model_lib import model_utils

import jax
from jax import lax
import jax.numpy as jnp
import optax


class ParabolicApproximationLineSearchState(NamedTuple):
step: jnp.ndarray # shape=(), dtype=jnp.int32.
base_state: NamedTuple # The state of the base optimizer.
hyperparams: dict[str, jnp.ndarray] # The base optimizer's hyperparams.


def parabolic_approximation_line_search(
mu: float,
alpha: float,
s_max: float,
start_step: int,
stop_step: int,
batch_axis_name: str,
base_opt_init_fn,
base_opt_update_fn,
) -> optax.GradientTransformation:
"""Implementation of Parabolic Approximation Line Search (PAL).

Paper: https://arxiv.org/abs/1903.11991
Code: https://github.com/cogsys-tuebingen/PAL

References:
Mutschler and Zell, 2021: https://arxiv.org/abs/1903.11991
Args:
mu: The measuring step size to use when computing the loss as the projected
point.
alpha: The update step adaptation used when computing the update.
s_max: The upper bound for the maximum step size that we can take.
start_step: The step to start using PAL at.
stop_step: The step to stop using PAL at.
batch_axis_name: the name of the axis to pmap over. Used to run a pmean
before applying the optimizer update.
base_opt_init_fn: The initialization function for the base optimizer used to
generate updates given the total gradient.
base_opt_update_fn: The update function for the base optimizer used to
generate updates given the total gradient.

Returns:
The corresponding `GradientTransformation`.
"""

def init_fn(params):
base_state = base_opt_init_fn(params)
return ParabolicApproximationLineSearchState(
step=jnp.zeros([], dtype=jnp.int32),
base_state=base_state,
hyperparams=base_state.hyperparams)

def update_fn(updates, state, cost_fn_params_tuple):
(cost_fn, params) = cost_fn_params_tuple

def pal_update(updates, state, params):

def loss_fn(params):
loss, _ = cost_fn(params)
return lax.pmean(loss, axis_name=batch_axis_name)

loss = loss_fn(params)

grad = updates
updates, state = base_opt_update_fn(updates, state, params)

updates_norm = jnp.sqrt(model_utils.l2_regularization(updates, 0))
updates = jax.tree_util.tree_map(lambda u: u / updates_norm, updates)
new_params = optax.apply_updates(
params, jax.tree_util.tree_map(lambda u: mu * u, updates))
new_loss = loss_fn(new_params)

b = jax.tree_util.tree_reduce(
lambda a, b: a + b,
jax.tree_util.tree_map(lambda g, u: jnp.sum(g * u), grad, updates))
a = (new_loss - loss - b * mu) / (mu**2)

def line_search_update(mu, alpha, a, b):
del mu
return (-1.0 * alpha * b) / (2.0 * a)

def mu_update(mu, alpha, a, b):
del alpha, a, b
return mu

def noop_update(mu, alpha, a, b):
del mu, alpha, a, b
return 0.0

s_upd_1 = lax.cond(
jnp.logical_and(jnp.greater(a, 0), jnp.less(b, 0)),
line_search_update, noop_update, mu, alpha, a, b)

s_upd_2 = lax.cond(
jnp.logical_and(jnp.less_equal(a, 0), jnp.less(b, 0)), mu_update,
noop_update, mu, alpha, a, b)

s_upd = jnp.maximum(s_upd_1, s_upd_2)
s_upd = lax.cond(jnp.greater(s_upd, s_max), lambda: s_max, lambda: s_upd)
state.hyperparams['learning_rate'] = s_upd

def scale_update(updates, lr):
return jax.tree_util.tree_map(lambda u: u * lr, updates)

def scale_by_zeros_update(updates, lr):
del lr
return jax.tree_util.tree_map(jnp.zeros_like, updates)

updates = lax.cond(
jnp.greater(s_upd, 0.0), scale_update, scale_by_zeros_update, updates,
s_upd)

return updates, state

def base_optimizer_update(updates, state, params):
return base_opt_update_fn(updates, state, params)

updates, base_state = lax.cond(
jnp.logical_and(
jnp.greater_equal(state.step, start_step),
jnp.less_equal(state.step, stop_step)),
pal_update,
base_optimizer_update,
updates,
state.base_state,
params)

step = state.step + jnp.ones([], dtype=jnp.int32)
state = ParabolicApproximationLineSearchState(
step=step, base_state=base_state, hyperparams=base_state.hyperparams)

return updates, state

return optax.GradientTransformation(init_fn, update_fn)