From 4c7faab66552af294eac7648b4b85ca288c75f38 Mon Sep 17 00:00:00 2001 From: nhamanasu Date: Thu, 19 Dec 2024 22:25:17 +0900 Subject: [PATCH 1/6] add cautious option to RAdamScheduleFree --- .../mnist/__pycache__/main.cpython-312.pyc | Bin 0 -> 8960 bytes schedulefree/radam_schedulefree.py | 68 +++++++++++------ schedulefree/radam_schedulefree_closure.py | 71 ++++++++++++------ 3 files changed, 90 insertions(+), 49 deletions(-) create mode 100644 examples/mnist/__pycache__/main.cpython-312.pyc diff --git a/examples/mnist/__pycache__/main.cpython-312.pyc b/examples/mnist/__pycache__/main.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2a64ca7cebcd89e1860740b356e5b68fa718455 GIT binary patch literal 8960 zcmb_ieQZ-#mVeKFwx7S^IF50EgqVuI)!-h#i*f?o~JjG~2rm%U^ zOcE+$lwhrdcK)Vn;ZZ^_5zdKB`jYX*(o7gmqzwNs=KZRcXRcGUN>!Rzn4&qte86a#C}F zaH&Q01S|N+idM9c#_pbSQ2KY|3Lz*&LGFM^5{?T)KtNSe7`1+PYM+^>*TWXm(koU%?f%`#`JtYctI_i(O1+m$+hMPmtkTH2fpQmh{MTlorGu#eh?m>C?&o=5g=ffx8|X6 z3+2RR7Bzf4#&R4ET%nLw-^3@%Tbsp7(fooVRoSyYb#jJj?OD++1{;4fJ?> z_~}46kcjsVv;0*d8tdh~5#~ztdM{r(xgI|p-Cp*(HyjJGd~Y}s;Dz3>HxTKG&2d&} zG+}~c*(1-v3?=;g5dCSM*rd(3d`p8_diMssXVdOjxx9Qib8WS4)w^cR+4rTXTL!t5 zTOC-vwzfNG@0LoNc6@5MZCD-5HFZJWyjjZc|JWjBZT6MJ%ZD=;SNG>^d!dBhER|eX zZOhp@($r7rr!WMSfxk-xup;+34@%+ZzU(R zr3h7-r;_94+p6GHM4_@x5Y5u8QY7@N9O6oziaFqZNFi85@RA_x-)jx`vnu@GW&tC! zRW$_J%C>+5Zi?zq$x*Qq+p77nEk?DjRz|N%ssJOYi7i|l4p#**Sj_~&E)$1PAd({@ z#j&9{hrNo1u=lzr7LA4m7|wwO8X?O0rbV4UfNw*rL_f%J#FS%VaP5mVpPN%W&xuN9#1%mR~n>E z9?!M7H&moZV8HFe6)=}z$H#_39*>LUT5$F#ZYQNBl5kIo5KVbJqkpNa-{_&&8pn5)VDsVTx2Oj z(hyXnyv$c_8~V1C%jHxBT??{UFPE3U30QlEM9qjNZ`~3Ud1h2KC0CQw3MfiaD9FM*QI$S@}k4JePNJn}O`cwZ|9K+?4 zE(<^hRu%PZBpzltufU2bAu8$${2o0hhGE#zYfh z3+R?CH{%V78cEL-wWv&b1yLPhBciUPXo^%oV8f8*<~&?H;xdYqmz&~64M-##615CF z6Y#O3Ng5Dz=4Dt;)JJ0iUh$CBusA(Yf5j{Krab}Xx=3NmqE>)rSb=*52TI`}d01^k zG&vN-b^r=Fj5^`RKY<9)jM9I2?t^nT&KD^2hi`oF#s+2I?C4ob<~ojjOPDm<;Py}J zmtR{N&U6+WJ6C3xXEQS&y}zdYnf?p?z0O?M;j}+Z6=-Xo-kGI$ZqO}{tlO4^k9MZD zo0j^;#6lv|x!RetxRA-TXkIYivHtHR$O&N{l+efPTa10&gikzCJcp}uixE?e&eYBINHFk3%(kNJ7v z9|Qk%^{GZ>J4HStROVswiI%8uTuk0fKB~8WvmIt5A(O9C9l(W>Kz?JU#^pq~81vny-iiQmkNWB1% z03@aB7*_(q^8hIUV3nc(G;4qonsn=0wWXr#iQJSMzWUKY+00Cg!!P#}xZ1@R)R3P~!ddD#{jI;p7!$mannf~IWA$!mhp zw}Y_`Hq?&f$N8+06!@f zaF0Vs5p|IgWDY`3!cFcVrd~s$7C;KrFHyx_6Lm8IcGlzNxpPQ9hS711j$m{YB4~aH z)593O4v}aSV5ABtmNEdxI$xCISf9X+BI_B9iZ~d9$@6&o${;8!cobIV{sJhk!jJC- z#iodZXFoW5V;pgC;)96|s{JK^>vFa3oSQWzag}m*U3Wj-fszw{&jL; zJUzVBS=ib7>7m<)Ru6uBVr}GSV_%Hj^XIycFLkDeAK7*q(LJdrG2TXa6iD*<&M2~s;jX9A~& zVy6Z1s;&wQxvpqpDZKG(=xU5wQ1nQI7)HCUleGm&OQ=b{=D4>KW@x)M;RBOe14o-zE$9pPhw6;%s^e&g~NCj+fxvSux9IriIz{5+iP{ zsN;l{+5HlA+N2TfsTvVLT6K^i9V}ups5ofKTqf|Lal1%DP;B5@`fjZ&MQK?DE!S`< zNiBDMujLxHYAuL``^9K9mRQbzv<2;k} zx?7fJu>W3EqY?OfWUqG0P_ZMZpI3z-UZ+(dh!<;BNb3D$fJkcYkbgs_jQ>w<8l*P0 zQky2>u-v7*qCvc^t3r}ePLNr^dtC}ibMSY6cGm$ck1-6$aYV+!326e?+w_YoAUHVRe# zvoc5d%XL(&?|8XfPQh&7kkjR_JX@wu9wn-CV8`}PBio1oJi~oM8J>blqgQ5IKlpON_J!*GUe=ISid(O z0wwR@pvy&xlrR@#!455!in=f>cxS)^O--|*Si=dHK0alEzf6$en-0l!J^={)&M}MDSPKc`U zA0%2s;2w`)Cup8m@b~rf^$i?wCG>80Bi_MR4^BXGl=~mc*3WY%61twAUi3?VqZPpn zUOJ+)5%gw=dKhi-oeeAz$r}*{1~}rX5SK7XZxbG01fC`YSC>eGy%6(qJS)}>aZ})% zi3o3EmJ>~0hVcMZDihVPa-zl$9SAT=1!i^7QM zi_wd0gpY#zW&#~SArw}0JBS)dxVwox#ry^8vpt zw9j)9sX8p+6b!v6L&Y^jO9sw8>^i!3!I4B^8YKY+=3xRp!BsD6rIq2hlgLcRVUeWG z*I^14p=e{3JWuG|f(0yETyh54J;e-%yQVy!;!{@fA$5R5h~_R}RJPtuLzYMXWZ4cY zri$)%3Lm4mWuPO9@}qz>^TcDbJ#TK$n%f`tpV=^v&5sl)L!NSEDMzMm{lfjW2fjkf z*(Yj3)AWoWHRI&hGWlWCzJK0xU-KR^4~*?jxaZ2Qq%`|nxkrZ7a>o%^82r+OeR0q_39?*qpU}LHoEw z##~zab)$2&{qx?>dLP;b=En=R*44)R?gQD~2Oc&YoIm$SXDM{<_cDPs;XYH?{mL^nS>HW>=CQ?5Dvnna+L~65E+5U? z_GE2)^0uC=t!M3e);4;-8>0E~%?1aqSiYex+t8M8aAzCbYtz|=k^B1~nm_m0WVv}- zdX#w^o_6L>Kd#@s+L3P?$hHmS+6M18{)^)mj*a>uFvr17dM(rb>Au_h*2vqvIqNIm zqM4o6z^iMTSvs9LxO(i~_WP$d=+`&FHovqqmYG?7_t*6PCmLXTLKC#B4<3b$nM-K;bMsQJ#Iw2Bi=>e@L z65$#!Ym>7*^8c6s-?pKgE;_}H64y+C2hBu+w;~LhnyHSsTq;-eATI^WU90LOs)HSS}`qFKaWgy@NyB+1_r9ls$gzadP&CG5W=S|1Xv zzazT;pe4wjXIcljE#0>0T5vrhAo=Z}1)_E;{nse} J)DEfF{{;wRU4Z}q literal 0 HcmV?d00001 diff --git a/schedulefree/radam_schedulefree.py b/schedulefree/radam_schedulefree.py index 396766d..fd74cb9 100644 --- a/schedulefree/radam_schedulefree.py +++ b/schedulefree/radam_schedulefree.py @@ -3,15 +3,10 @@ # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple, Union, Optional, Iterable, Dict, Callable, Any -from typing_extensions import TypeAlias +from typing import Tuple, Union, Optional, Callable import torch import torch.optim -try: - from torch.optim.optimizer import ParamsT -except ImportError: - ParamsT : TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]] -import math +from torch.optim.optimizer import ParamsT class RAdamScheduleFree(torch.optim.Optimizer): r""" @@ -49,6 +44,12 @@ class RAdamScheduleFree(torch.optim.Optimizer): This helps stabilize training by ensuring smoother warmup behavior and more reliable calculation of the moving average coefficient (`ckp1`). Recommended to set to True (default True). + cautious (bool, experimental): If True, use a cautious update strategy, which is proposed + in https://arxiv.org/abs/2411.16085 and https://github.com/kyleliang919/C-Optim. If we + directly apply cautious operation to z update, it's meaningless since z update doesn't + contain momentum elements, so we apply cautious operations to y update, which means the + combination of C-Optim and Schedule-Free is not obvious. In a few hands-on experiments, + we found that this option can lead to slightly faster convergence. """ def __init__(self, @@ -60,7 +61,8 @@ def __init__(self, r: float = 0.0, weight_lr_power: float = 2.0, foreach: Optional[bool] = hasattr(torch, "_foreach_mul_"), - silent_sgd_phase: bool = True + silent_sgd_phase: bool = True, + cautious: bool = True, ): defaults = dict(lr=lr, @@ -75,7 +77,8 @@ def __init__(self, weight_lr_power=weight_lr_power, weight_decay=weight_decay, foreach=foreach, - silent_sgd_phase=silent_sgd_phase) + silent_sgd_phase=silent_sgd_phase, + cautious=cautious) super().__init__(params, defaults) @torch.no_grad() @@ -129,6 +132,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] beta1, beta2 = group["betas"] decay = group["weight_decay"] silent_sgd_phase = group["silent_sgd_phase"] + cautious = group["cautious"] k = group["k"] # current steps step = k + 1 r = group['r'] @@ -189,11 +193,21 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] # Weight decay calculated at y if decay != 0: torch._foreach_add_(grad, y, alpha=decay) - - # These operations update y in-place, - # without computing x explicitly. - torch._foreach_lerp_(y, z, weight=ckp1) - torch._foreach_add_(y, grad, alpha=adaptive_y_lr) + + if cautious: + u = torch._foreach_sub(y, z) + torch._foreach_mul_(u, ckp1) + torch._foreach_add_(u, grad, alpha=adaptive_y_lr) + mask = torch._foreach_mul(u, grad) + mask = [(m > 0).to(g.dtype) for m, g in zip(mask, grad)] + torch._foreach_mul_(mask, [m.numel() / (m.sum() + 1) for m in mask]) + torch._foreach_mul_(u, mask) + torch._foreach_sub_(y, u) + else: + # These operations update y in-place, + # without computing x explicitly. + torch._foreach_lerp_(y, z, weight=ckp1) + torch._foreach_sub_(y, grad, alpha=adaptive_y_lr) # z step torch._foreach_sub_(z, grad, alpha=lr) @@ -214,22 +228,26 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] denom = exp_avg_sq.div(bias_correction2).sqrt_().add_(eps) # Reuse grad buffer for memory efficiency - grad_normalized = grad.div_(denom) - else: - # Fall back to SGD (or nothing) - grad_normalized = grad + grad.div_(denom) # Weight decay calculated at y if decay != 0: - grad_normalized.add_(y, alpha=decay) - - # These operations update y in-place, - # without computing x explicitly. - y.lerp_(end=z, weight=ckp1) - y.add_(grad_normalized, alpha=adaptive_y_lr) + grad.add_(y, alpha=decay) + + if cautious: + u = (y - z).mul_(ckp1).add_(grad, alpha=adaptive_y_lr) + mask = (u * grad > 0).to(grad.dtype) + mask.mul_(mask.numel() / (mask.sum() + 1)) + u.mul_(mask) + y.sub_(u) + else: + # These operations update y in-place, + # without computing x explicitly. + y.lerp_(end=z, weight=ckp1) + y.sub_(grad, alpha=adaptive_y_lr) # z step - z.sub_(grad_normalized, alpha=lr) + z.sub_(grad, alpha=lr) group["k"] = k + 1 return loss diff --git a/schedulefree/radam_schedulefree_closure.py b/schedulefree/radam_schedulefree_closure.py index c0d6a99..2ea10e6 100644 --- a/schedulefree/radam_schedulefree_closure.py +++ b/schedulefree/radam_schedulefree_closure.py @@ -3,14 +3,10 @@ # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple, Union, Optional, Iterable, Dict, Callable, Any -from typing_extensions import TypeAlias +from typing import Tuple, Union, Optional, Callable import torch.optim -try: - from torch.optim.optimizer import ParamsT -except ImportError: - ParamsT : TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]] -import math + +from torch.optim.optimizer import ParamsT class RAdamScheduleFreeClosure(torch.optim.Optimizer): r""" @@ -47,6 +43,12 @@ class RAdamScheduleFreeClosure(torch.optim.Optimizer): This helps stabilize training by ensuring smoother warmup behavior and more reliable calculation of the moving average coefficient (`ckp1`). Recommended to set to True (default True). + cautious (bool, experimental): If True, use a cautious update strategy, which is proposed + in https://arxiv.org/abs/2411.16085 and https://github.com/kyleliang919/C-Optim. If we + directly apply cautious operation to z update, it's meaningless since z update doesn't + contain momentum elements, so we apply cautious operations to y update, which means the + combination of C-Optim and Schedule-Free is not obvious. In a few hands-on experiments, + we found that this option can lead to slightly faster convergence. """ def __init__(self, params: ParamsT, @@ -57,7 +59,8 @@ def __init__(self, r: float = 0, weight_lr_power: float = 2.0, foreach: Optional[bool] = hasattr(torch, "_foreach_mul_"), - silent_sgd_phase: bool = True + silent_sgd_phase: bool = True, + cautious: bool = True, ): defaults = dict(lr=lr, betas=betas, @@ -70,7 +73,8 @@ def __init__(self, weight_lr_power=weight_lr_power, weight_decay=weight_decay, foreach=foreach, - silent_sgd_phase=silent_sgd_phase) + silent_sgd_phase=silent_sgd_phase, + cautious=cautious) super().__init__(params, defaults) @@ -117,6 +121,7 @@ def step(self, closure: Callable[[], float]) -> Optional[float]: step = k + 1 r = group['r'] silent_sgd_phase = group["silent_sgd_phase"] + cautious = group["cautious"] decay = group['weight_decay'] beta1, beta2 = group['betas'] weight_lr_power = group['weight_lr_power'] @@ -174,13 +179,25 @@ def step(self, closure: Callable[[], float]) -> Optional[float]: if decay != 0: torch._foreach_add_(grad, y, alpha=decay) - # Unextrapolate - torch._foreach_lerp_(y, z, weight=1-1/beta1) + if cautious: + u = torch._foreach_sub(y, z) + torch._foreach_mul_(u, ckp1) + torch._foreach_add_(u, grad, alpha=adaptive_y_lr) + mask = torch._foreach_mul(u, grad) + mask = [(m > 0).to(g.dtype) for m, g in zip(mask, grad)] + torch._foreach_mul_(mask, [m.numel() / (m.sum() + 1) for m in mask]) + torch._foreach_mul_(u, mask) + torch._foreach_sub_(y, u) + else: + # These operations update y in-place, + # without computing x explicitly. + torch._foreach_lerp_(y, z, weight=ckp1) + torch._foreach_sub_(y, grad, alpha=adaptive_y_lr) torch._foreach_sub_(z, grad, alpha=lr) - ### Take step - torch._foreach_lerp_(y, z, weight=ckp1) + # Unextrapolate to x + torch._foreach_lerp_(y, z, weight=1-1/beta1) else: for p in active_p: grad = p.grad @@ -198,22 +215,28 @@ def step(self, closure: Callable[[], float]) -> Optional[float]: # Adam step denom = exp_avg_sq.div(bias_correction2).sqrt_().add_(eps) - grad_normalized = grad.div_(denom) - else: - # Fall back to SGD (or nothing) - grad_normalized = grad + grad.div_(denom) # Weight decay calculated at y if decay != 0: - grad_normalized.add_(y, alpha=decay) - - # Unextrapolate - x = y.lerp_(end=z, weight=1-1/beta1) + grad.add_(y, alpha=decay) + + if cautious: + u = (y - z).mul_(ckp1).add_(grad, alpha=adaptive_y_lr) + mask = (u * grad > 0).to(grad.dtype) + mask.mul_(mask.numel() / (mask.sum() + 1)) + u.mul_(mask) + y.sub_(u) + else: + # These operations update y in-place, + # without computing x explicitly. + y.lerp_(end=z, weight=ckp1) + y.sub_(grad, alpha=adaptive_y_lr) - z.sub_(grad_normalized, alpha=lr) + z.sub_(grad, alpha=lr) - ### Take step - x.lerp_(end=z, weight=ckp1) + # Unextrapolate to x + y.lerp_(end=z, weight=1-1/beta1) group['k'] = k+1 return loss From b6a3dd0c7ff74c5089ff4899eb0c6c34753e3fab Mon Sep 17 00:00:00 2001 From: nhamanasu Date: Thu, 19 Dec 2024 22:53:49 +0900 Subject: [PATCH 2/6] remove _pycahce__ and add .gitignore --- .gitignore | 156 ++++++++++++++++++ .../mnist/__pycache__/main.cpython-312.pyc | Bin 8960 -> 0 bytes 2 files changed, 156 insertions(+) create mode 100644 .gitignore delete mode 100644 examples/mnist/__pycache__/main.cpython-312.pyc diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..08296fc --- /dev/null +++ b/.gitignore @@ -0,0 +1,156 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints +# *.ipynb + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +.DS_Store diff --git a/examples/mnist/__pycache__/main.cpython-312.pyc b/examples/mnist/__pycache__/main.cpython-312.pyc deleted file mode 100644 index c2a64ca7cebcd89e1860740b356e5b68fa718455..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8960 zcmb_ieQZ-#mVeKFwx7S^IF50EgqVuI)!-h#i*f?o~JjG~2rm%U^ zOcE+$lwhrdcK)Vn;ZZ^_5zdKB`jYX*(o7gmqzwNs=KZRcXRcGUN>!Rzn4&qte86a#C}F zaH&Q01S|N+idM9c#_pbSQ2KY|3Lz*&LGFM^5{?T)KtNSe7`1+PYM+^>*TWXm(koU%?f%`#`JtYctI_i(O1+m$+hMPmtkTH2fpQmh{MTlorGu#eh?m>C?&o=5g=ffx8|X6 z3+2RR7Bzf4#&R4ET%nLw-^3@%Tbsp7(fooVRoSyYb#jJj?OD++1{;4fJ?> z_~}46kcjsVv;0*d8tdh~5#~ztdM{r(xgI|p-Cp*(HyjJGd~Y}s;Dz3>HxTKG&2d&} zG+}~c*(1-v3?=;g5dCSM*rd(3d`p8_diMssXVdOjxx9Qib8WS4)w^cR+4rTXTL!t5 zTOC-vwzfNG@0LoNc6@5MZCD-5HFZJWyjjZc|JWjBZT6MJ%ZD=;SNG>^d!dBhER|eX zZOhp@($r7rr!WMSfxk-xup;+34@%+ZzU(R zr3h7-r;_94+p6GHM4_@x5Y5u8QY7@N9O6oziaFqZNFi85@RA_x-)jx`vnu@GW&tC! zRW$_J%C>+5Zi?zq$x*Qq+p77nEk?DjRz|N%ssJOYi7i|l4p#**Sj_~&E)$1PAd({@ z#j&9{hrNo1u=lzr7LA4m7|wwO8X?O0rbV4UfNw*rL_f%J#FS%VaP5mVpPN%W&xuN9#1%mR~n>E z9?!M7H&moZV8HFe6)=}z$H#_39*>LUT5$F#ZYQNBl5kIo5KVbJqkpNa-{_&&8pn5)VDsVTx2Oj z(hyXnyv$c_8~V1C%jHxBT??{UFPE3U30QlEM9qjNZ`~3Ud1h2KC0CQw3MfiaD9FM*QI$S@}k4JePNJn}O`cwZ|9K+?4 zE(<^hRu%PZBpzltufU2bAu8$${2o0hhGE#zYfh z3+R?CH{%V78cEL-wWv&b1yLPhBciUPXo^%oV8f8*<~&?H;xdYqmz&~64M-##615CF z6Y#O3Ng5Dz=4Dt;)JJ0iUh$CBusA(Yf5j{Krab}Xx=3NmqE>)rSb=*52TI`}d01^k zG&vN-b^r=Fj5^`RKY<9)jM9I2?t^nT&KD^2hi`oF#s+2I?C4ob<~ojjOPDm<;Py}J zmtR{N&U6+WJ6C3xXEQS&y}zdYnf?p?z0O?M;j}+Z6=-Xo-kGI$ZqO}{tlO4^k9MZD zo0j^;#6lv|x!RetxRA-TXkIYivHtHR$O&N{l+efPTa10&gikzCJcp}uixE?e&eYBINHFk3%(kNJ7v z9|Qk%^{GZ>J4HStROVswiI%8uTuk0fKB~8WvmIt5A(O9C9l(W>Kz?JU#^pq~81vny-iiQmkNWB1% z03@aB7*_(q^8hIUV3nc(G;4qonsn=0wWXr#iQJSMzWUKY+00Cg!!P#}xZ1@R)R3P~!ddD#{jI;p7!$mannf~IWA$!mhp zw}Y_`Hq?&f$N8+06!@f zaF0Vs5p|IgWDY`3!cFcVrd~s$7C;KrFHyx_6Lm8IcGlzNxpPQ9hS711j$m{YB4~aH z)593O4v}aSV5ABtmNEdxI$xCISf9X+BI_B9iZ~d9$@6&o${;8!cobIV{sJhk!jJC- z#iodZXFoW5V;pgC;)96|s{JK^>vFa3oSQWzag}m*U3Wj-fszw{&jL; zJUzVBS=ib7>7m<)Ru6uBVr}GSV_%Hj^XIycFLkDeAK7*q(LJdrG2TXa6iD*<&M2~s;jX9A~& zVy6Z1s;&wQxvpqpDZKG(=xU5wQ1nQI7)HCUleGm&OQ=b{=D4>KW@x)M;RBOe14o-zE$9pPhw6;%s^e&g~NCj+fxvSux9IriIz{5+iP{ zsN;l{+5HlA+N2TfsTvVLT6K^i9V}ups5ofKTqf|Lal1%DP;B5@`fjZ&MQK?DE!S`< zNiBDMujLxHYAuL``^9K9mRQbzv<2;k} zx?7fJu>W3EqY?OfWUqG0P_ZMZpI3z-UZ+(dh!<;BNb3D$fJkcYkbgs_jQ>w<8l*P0 zQky2>u-v7*qCvc^t3r}ePLNr^dtC}ibMSY6cGm$ck1-6$aYV+!326e?+w_YoAUHVRe# zvoc5d%XL(&?|8XfPQh&7kkjR_JX@wu9wn-CV8`}PBio1oJi~oM8J>blqgQ5IKlpON_J!*GUe=ISid(O z0wwR@pvy&xlrR@#!455!in=f>cxS)^O--|*Si=dHK0alEzf6$en-0l!J^={)&M}MDSPKc`U zA0%2s;2w`)Cup8m@b~rf^$i?wCG>80Bi_MR4^BXGl=~mc*3WY%61twAUi3?VqZPpn zUOJ+)5%gw=dKhi-oeeAz$r}*{1~}rX5SK7XZxbG01fC`YSC>eGy%6(qJS)}>aZ})% zi3o3EmJ>~0hVcMZDihVPa-zl$9SAT=1!i^7QM zi_wd0gpY#zW&#~SArw}0JBS)dxVwox#ry^8vpt zw9j)9sX8p+6b!v6L&Y^jO9sw8>^i!3!I4B^8YKY+=3xRp!BsD6rIq2hlgLcRVUeWG z*I^14p=e{3JWuG|f(0yETyh54J;e-%yQVy!;!{@fA$5R5h~_R}RJPtuLzYMXWZ4cY zri$)%3Lm4mWuPO9@}qz>^TcDbJ#TK$n%f`tpV=^v&5sl)L!NSEDMzMm{lfjW2fjkf z*(Yj3)AWoWHRI&hGWlWCzJK0xU-KR^4~*?jxaZ2Qq%`|nxkrZ7a>o%^82r+OeR0q_39?*qpU}LHoEw z##~zab)$2&{qx?>dLP;b=En=R*44)R?gQD~2Oc&YoIm$SXDM{<_cDPs;XYH?{mL^nS>HW>=CQ?5Dvnna+L~65E+5U? z_GE2)^0uC=t!M3e);4;-8>0E~%?1aqSiYex+t8M8aAzCbYtz|=k^B1~nm_m0WVv}- zdX#w^o_6L>Kd#@s+L3P?$hHmS+6M18{)^)mj*a>uFvr17dM(rb>Au_h*2vqvIqNIm zqM4o6z^iMTSvs9LxO(i~_WP$d=+`&FHovqqmYG?7_t*6PCmLXTLKC#B4<3b$nM-K;bMsQJ#Iw2Bi=>e@L z65$#!Ym>7*^8c6s-?pKgE;_}H64y+C2hBu+w;~LhnyHSsTq;-eATI^WU90LOs)HSS}`qFKaWgy@NyB+1_r9ls$gzadP&CG5W=S|1Xv zzazT;pe4wjXIcljE#0>0T5vrhAo=Z}1)_E;{nse} J)DEfF{{;wRU4Z}q From 2fc7228588a7678cc1b50ca6dacb498012775d83 Mon Sep 17 00:00:00 2001 From: nhamanasu Date: Sun, 22 Dec 2024 17:55:19 +0900 Subject: [PATCH 3/6] refine docstring --- schedulefree/radam_schedulefree.py | 26 +++++++++++++--------- schedulefree/radam_schedulefree_closure.py | 16 ++++++++----- 2 files changed, 25 insertions(+), 17 deletions(-) diff --git a/schedulefree/radam_schedulefree.py b/schedulefree/radam_schedulefree.py index fd74cb9..a39e80d 100644 --- a/schedulefree/radam_schedulefree.py +++ b/schedulefree/radam_schedulefree.py @@ -22,7 +22,7 @@ class RAdamScheduleFree(torch.optim.Optimizer): Iterable of parameters to optimize or dicts defining parameter groups. lr (float): - Learning rate parameter (default 0.0025) + Learning rate parameter (default: 0.0025) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)). eps (float): @@ -31,25 +31,29 @@ class RAdamScheduleFree(torch.optim.Optimizer): weight_decay (float): Weight decay, i.e. a L2 penalty (default: 0). r (float): Use polynomial weighting in the average - with power r (default 0). + with power r (default: 0). weight_lr_power (float): During warmup, the weights in the average will be equal to lr raised to this power. Set to 0 for no weighting - (default 2.0). + (default: 2.0). foreach (bool): Use a foreach-backed implementation of the optimizer. Should be significantly faster, but will have higher peak memory - usage (default True if supported in your PyTorch version). + usage (default: True, if supported in your PyTorch version). silent_sgd_phase (bool): If True, the optimizer will not use the first SGD phase of RAdam. This means that the optimizer will not update model parameters during the early training steps (e.g., < 5 when β_2 = 0.999), but just update the momentum values of the optimizer. This helps stabilize training by ensuring smoother warmup behavior and more reliable calculation of the moving average coefficient (`ckp1`). Recommended to set to True - (default True). - cautious (bool, experimental): If True, use a cautious update strategy, which is proposed - in https://arxiv.org/abs/2411.16085 and https://github.com/kyleliang919/C-Optim. If we - directly apply cautious operation to z update, it's meaningless since z update doesn't - contain momentum elements, so we apply cautious operations to y update, which means the - combination of C-Optim and Schedule-Free is not obvious. In a few hands-on experiments, - we found that this option can lead to slightly faster convergence. + (default: True). + cautious (bool, experimental): If True, applies a cautious update strategy as proposed in + https://arxiv.org/abs/2411.16085 and implemented in https://github.com/kyleliang919/C-Optim. + While the original cautious optimizer aligns momentum updates with gradient directions + for faster convergence, our implementation differs in its combination with Schedule-Free + optimization. Since the z-update in Schedule-Free doesn't contain momentum terms, directly + applying cautious mask to z-update is meaningless. Instead, we apply the cautious operations + to the y-update (after implicit x contraction), as y represents the training parameters + where cautious update is more appropriate. Our preliminary experiments suggest this adaptation + can lead to slightly faster convergence, though the theoretical implications of combining + these approaches remain to be fully understood (default: True). """ def __init__(self, diff --git a/schedulefree/radam_schedulefree_closure.py b/schedulefree/radam_schedulefree_closure.py index 2ea10e6..5988493 100644 --- a/schedulefree/radam_schedulefree_closure.py +++ b/schedulefree/radam_schedulefree_closure.py @@ -43,12 +43,16 @@ class RAdamScheduleFreeClosure(torch.optim.Optimizer): This helps stabilize training by ensuring smoother warmup behavior and more reliable calculation of the moving average coefficient (`ckp1`). Recommended to set to True (default True). - cautious (bool, experimental): If True, use a cautious update strategy, which is proposed - in https://arxiv.org/abs/2411.16085 and https://github.com/kyleliang919/C-Optim. If we - directly apply cautious operation to z update, it's meaningless since z update doesn't - contain momentum elements, so we apply cautious operations to y update, which means the - combination of C-Optim and Schedule-Free is not obvious. In a few hands-on experiments, - we found that this option can lead to slightly faster convergence. + cautious (bool, experimental): If True, applies a cautious update strategy as proposed in + https://arxiv.org/abs/2411.16085 and implemented in https://github.com/kyleliang919/C-Optim. + While the original cautious optimizer aligns momentum updates with gradient directions + for faster convergence, our implementation differs in its combination with Schedule-Free + optimization. Since the z-update in Schedule-Free doesn't contain momentum terms, directly + applying cautious mask to z-update is meaningless. Instead, we apply the cautious operations + to the y-update (after implicit x contraction), as y represents the training parameters + where cautious update is more appropriate. Our preliminary experiments suggest this adaptation + can lead to slightly faster convergence, though the theoretical implications of combining + these approaches remain to be fully understood (default: True). """ def __init__(self, params: ParamsT, From 51ee0b99e406efe09cb1d8f38ea4c53da896c37f Mon Sep 17 00:00:00 2001 From: nhamanasu Date: Thu, 26 Dec 2024 20:12:36 +0900 Subject: [PATCH 4/6] set default cautious value of RAdamScheduleFree to False --- schedulefree/radam_schedulefree.py | 4 ++-- schedulefree/radam_schedulefree_closure.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/schedulefree/radam_schedulefree.py b/schedulefree/radam_schedulefree.py index a39e80d..0d4058a 100644 --- a/schedulefree/radam_schedulefree.py +++ b/schedulefree/radam_schedulefree.py @@ -53,7 +53,7 @@ class RAdamScheduleFree(torch.optim.Optimizer): to the y-update (after implicit x contraction), as y represents the training parameters where cautious update is more appropriate. Our preliminary experiments suggest this adaptation can lead to slightly faster convergence, though the theoretical implications of combining - these approaches remain to be fully understood (default: True). + these approaches remain to be fully understood (default: False). """ def __init__(self, @@ -66,7 +66,7 @@ def __init__(self, weight_lr_power: float = 2.0, foreach: Optional[bool] = hasattr(torch, "_foreach_mul_"), silent_sgd_phase: bool = True, - cautious: bool = True, + cautious: bool = False, ): defaults = dict(lr=lr, diff --git a/schedulefree/radam_schedulefree_closure.py b/schedulefree/radam_schedulefree_closure.py index 5988493..ccf8247 100644 --- a/schedulefree/radam_schedulefree_closure.py +++ b/schedulefree/radam_schedulefree_closure.py @@ -52,7 +52,7 @@ class RAdamScheduleFreeClosure(torch.optim.Optimizer): to the y-update (after implicit x contraction), as y represents the training parameters where cautious update is more appropriate. Our preliminary experiments suggest this adaptation can lead to slightly faster convergence, though the theoretical implications of combining - these approaches remain to be fully understood (default: True). + these approaches remain to be fully understood (default: False). """ def __init__(self, params: ParamsT, @@ -64,7 +64,7 @@ def __init__(self, weight_lr_power: float = 2.0, foreach: Optional[bool] = hasattr(torch, "_foreach_mul_"), silent_sgd_phase: bool = True, - cautious: bool = True, + cautious: bool = False, ): defaults = dict(lr=lr, betas=betas, From 7b060d9288f7df5d1aa69d9cc8d4e4957ec83871 Mon Sep 17 00:00:00 2001 From: nhamanasu Date: Fri, 27 Dec 2024 10:37:36 +0900 Subject: [PATCH 5/6] update adaptive_y_lr --- schedulefree/radam_schedulefree.py | 2 +- schedulefree/radam_schedulefree_closure.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/schedulefree/radam_schedulefree.py b/schedulefree/radam_schedulefree.py index 0d4058a..0081328 100644 --- a/schedulefree/radam_schedulefree.py +++ b/schedulefree/radam_schedulefree.py @@ -168,7 +168,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] except ZeroDivisionError: ckp1 = 0 - adaptive_y_lr = lr * (beta1 * (1 - ckp1) - 1) + adaptive_y_lr = lr * (1 - beta1 * (1 - ckp1)) active_p = [p for p in group["params"] if p.grad is not None] for p in active_p: diff --git a/schedulefree/radam_schedulefree_closure.py b/schedulefree/radam_schedulefree_closure.py index ccf8247..b9dae5f 100644 --- a/schedulefree/radam_schedulefree_closure.py +++ b/schedulefree/radam_schedulefree_closure.py @@ -156,7 +156,7 @@ def step(self, closure: Callable[[], float]) -> Optional[float]: except ZeroDivisionError: ckp1 = 0 - adaptive_y_lr = lr * (beta1 * (1 - ckp1) - 1) + adaptive_y_lr = lr * (1 - beta1 * (1 - ckp1)) active_p = [p for p in group['params'] if p.grad is not None] if group['foreach'] and len(active_p) > 0: From c832e9e9e35ba440f634844c915154dc3e59f948 Mon Sep 17 00:00:00 2001 From: nhamanasu Date: Fri, 27 Dec 2024 10:45:44 +0900 Subject: [PATCH 6/6] update ckp1 calculation w/o try-catch block --- schedulefree/radam_schedulefree.py | 5 +---- schedulefree/radam_schedulefree_closure.py | 5 +---- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/schedulefree/radam_schedulefree.py b/schedulefree/radam_schedulefree.py index 0081328..c102963 100644 --- a/schedulefree/radam_schedulefree.py +++ b/schedulefree/radam_schedulefree.py @@ -163,10 +163,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] weight = (step**r) * (lr_max**weight_lr_power) weight_sum = group["weight_sum"] = group["weight_sum"] + weight - try: - ckp1 = weight / weight_sum - except ZeroDivisionError: - ckp1 = 0 + ckp1 = weight / weight_sum if weight_sum > 0 else 0 adaptive_y_lr = lr * (1 - beta1 * (1 - ckp1)) active_p = [p for p in group["params"] if p.grad is not None] diff --git a/schedulefree/radam_schedulefree_closure.py b/schedulefree/radam_schedulefree_closure.py index b9dae5f..5362973 100644 --- a/schedulefree/radam_schedulefree_closure.py +++ b/schedulefree/radam_schedulefree_closure.py @@ -151,10 +151,7 @@ def step(self, closure: Callable[[], float]) -> Optional[float]: weight = (step**r) * (lr_max**weight_lr_power) weight_sum = group['weight_sum'] = group['weight_sum'] + weight - try: - ckp1 = weight/weight_sum - except ZeroDivisionError: - ckp1 = 0 + ckp1 = weight / weight_sum if weight_sum > 0 else 0 adaptive_y_lr = lr * (1 - beta1 * (1 - ckp1)) active_p = [p for p in group['params'] if p.grad is not None]