Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cautious option to RAdamScheduleFree #54

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints
# *.ipynb

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

.DS_Store
89 changes: 54 additions & 35 deletions schedulefree/radam_schedulefree.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,10 @@
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple, Union, Optional, Iterable, Dict, Callable, Any
from typing_extensions import TypeAlias
from typing import Tuple, Union, Optional, Callable
import torch
import torch.optim
try:
from torch.optim.optimizer import ParamsT
except ImportError:
ParamsT : TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]
import math
from torch.optim.optimizer import ParamsT

class RAdamScheduleFree(torch.optim.Optimizer):
r"""
Expand All @@ -27,7 +22,7 @@ class RAdamScheduleFree(torch.optim.Optimizer):
Iterable of parameters to optimize or dicts defining
parameter groups.
lr (float):
Learning rate parameter (default 0.0025)
Learning rate parameter (default: 0.0025)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999)).
eps (float):
Expand All @@ -36,19 +31,29 @@ class RAdamScheduleFree(torch.optim.Optimizer):
weight_decay (float):
Weight decay, i.e. a L2 penalty (default: 0).
r (float): Use polynomial weighting in the average
with power r (default 0).
with power r (default: 0).
weight_lr_power (float): During warmup, the weights in the average will
be equal to lr raised to this power. Set to 0 for no weighting
(default 2.0).
(default: 2.0).
foreach (bool): Use a foreach-backed implementation of the optimizer.
Should be significantly faster, but will have higher peak memory
usage (default True if supported in your PyTorch version).
usage (default: True, if supported in your PyTorch version).
silent_sgd_phase (bool): If True, the optimizer will not use the first SGD phase of RAdam.
This means that the optimizer will not update model parameters during the early training
steps (e.g., < 5 when β_2 = 0.999), but just update the momentum values of the optimizer.
This helps stabilize training by ensuring smoother warmup behavior and more reliable
calculation of the moving average coefficient (`ckp1`). Recommended to set to True
(default True).
(default: True).
cautious (bool, experimental): If True, applies a cautious update strategy as proposed in
https://arxiv.org/abs/2411.16085 and implemented in https://github.com/kyleliang919/C-Optim.
While the original cautious optimizer aligns momentum updates with gradient directions
for faster convergence, our implementation differs in its combination with Schedule-Free
optimization. Since the z-update in Schedule-Free doesn't contain momentum terms, directly
applying cautious mask to z-update is meaningless. Instead, we apply the cautious operations
to the y-update (after implicit x contraction), as y represents the training parameters
where cautious update is more appropriate. Our preliminary experiments suggest this adaptation
can lead to slightly faster convergence, though the theoretical implications of combining
these approaches remain to be fully understood (default: False).
"""

def __init__(self,
Expand All @@ -60,7 +65,8 @@ def __init__(self,
r: float = 0.0,
weight_lr_power: float = 2.0,
foreach: Optional[bool] = hasattr(torch, "_foreach_mul_"),
silent_sgd_phase: bool = True
silent_sgd_phase: bool = True,
cautious: bool = False,
):

defaults = dict(lr=lr,
Expand All @@ -75,7 +81,8 @@ def __init__(self,
weight_lr_power=weight_lr_power,
weight_decay=weight_decay,
foreach=foreach,
silent_sgd_phase=silent_sgd_phase)
silent_sgd_phase=silent_sgd_phase,
cautious=cautious)
super().__init__(params, defaults)

@torch.no_grad()
Expand Down Expand Up @@ -129,6 +136,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
beta1, beta2 = group["betas"]
decay = group["weight_decay"]
silent_sgd_phase = group["silent_sgd_phase"]
cautious = group["cautious"]
k = group["k"] # current steps
step = k + 1
r = group['r']
Expand All @@ -155,12 +163,9 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
weight = (step**r) * (lr_max**weight_lr_power)
weight_sum = group["weight_sum"] = group["weight_sum"] + weight

try:
ckp1 = weight / weight_sum
except ZeroDivisionError:
ckp1 = 0
ckp1 = weight / weight_sum if weight_sum > 0 else 0

adaptive_y_lr = lr * (beta1 * (1 - ckp1) - 1)
adaptive_y_lr = lr * (1 - beta1 * (1 - ckp1))
active_p = [p for p in group["params"] if p.grad is not None]

for p in active_p:
Expand Down Expand Up @@ -189,11 +194,21 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
# Weight decay calculated at y
if decay != 0:
torch._foreach_add_(grad, y, alpha=decay)

# These operations update y in-place,
# without computing x explicitly.
torch._foreach_lerp_(y, z, weight=ckp1)
torch._foreach_add_(y, grad, alpha=adaptive_y_lr)

if cautious:
u = torch._foreach_sub(y, z)
torch._foreach_mul_(u, ckp1)
torch._foreach_add_(u, grad, alpha=adaptive_y_lr)
mask = torch._foreach_mul(u, grad)
mask = [(m > 0).to(g.dtype) for m, g in zip(mask, grad)]
torch._foreach_mul_(mask, [m.numel() / (m.sum() + 1) for m in mask])
torch._foreach_mul_(u, mask)
torch._foreach_sub_(y, u)
else:
# These operations update y in-place,
# without computing x explicitly.
torch._foreach_lerp_(y, z, weight=ckp1)
torch._foreach_sub_(y, grad, alpha=adaptive_y_lr)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @nhamanasu, I might be missing something, but is the subtraction correct here (it also appears in the non-foreach and closure versions)? I'm wondering if this is an error that might have been introduced unintentionally.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the comment!

You're exactly right. In my test branch, I reversed the sign of adaptive_y_lr and used sub functions, but I somehow forgot to reflect these changes to this c-radam branch. This might have led to completely opposite results. Thank you for catching this critical issue!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

..sorry, after considering the combination with cautious update, I'm not sure which sign is correct for this part. Let me re-think this block!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the end, I concluded your concern was right. Thank you again for the valuable comments!


# z step
torch._foreach_sub_(z, grad, alpha=lr)
Expand All @@ -214,22 +229,26 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
denom = exp_avg_sq.div(bias_correction2).sqrt_().add_(eps)

# Reuse grad buffer for memory efficiency
grad_normalized = grad.div_(denom)
else:
# Fall back to SGD (or nothing)
grad_normalized = grad
grad.div_(denom)

# Weight decay calculated at y
if decay != 0:
grad_normalized.add_(y, alpha=decay)

# These operations update y in-place,
# without computing x explicitly.
y.lerp_(end=z, weight=ckp1)
y.add_(grad_normalized, alpha=adaptive_y_lr)
grad.add_(y, alpha=decay)

if cautious:
u = (y - z).mul_(ckp1).add_(grad, alpha=adaptive_y_lr)
mask = (u * grad > 0).to(grad.dtype)
mask.mul_(mask.numel() / (mask.sum() + 1))
u.mul_(mask)
y.sub_(u)
else:
# These operations update y in-place,
# without computing x explicitly.
y.lerp_(end=z, weight=ckp1)
y.sub_(grad, alpha=adaptive_y_lr)

# z step
z.sub_(grad_normalized, alpha=lr)
z.sub_(grad, alpha=lr)

group["k"] = k + 1
return loss
Loading