-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add cautious option to RAdamScheduleFree #54
base: main
Are you sure you want to change the base?
Conversation
Thanks for this pull request. I will look into merging it in the New Year. |
# These operations update y in-place, | ||
# without computing x explicitly. | ||
torch._foreach_lerp_(y, z, weight=ckp1) | ||
torch._foreach_sub_(y, grad, alpha=adaptive_y_lr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @nhamanasu, I might be missing something, but is the subtraction correct here (it also appears in the non-foreach and closure versions)? I'm wondering if this is an error that might have been introduced unintentionally.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the comment!
You're exactly right. In my test branch, I reversed the sign of adaptive_y_lr
and used sub
functions, but I somehow forgot to reflect these changes to this c-radam branch. This might have led to completely opposite results. Thank you for catching this critical issue!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
..sorry, after considering the combination with cautious update, I'm not sure which sign is correct for this part. Let me re-think this block!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the end, I concluded your concern was right. Thank you again for the valuable comments!
Based on @LoganBooker 's comment, I've fixed the gradient update part. With the default
As the difference is negligible, I set the default |
Grams: Gradient Descent with Adaptive Momentum Scaling |
Additional, if difine u = (y - z).mul_(ckp1).add_(grad, alpha=adaptive_y_lr) |
@gesen2egee
Anyway, I think we should separate the PRs if we really seek to implement these ideas (e.g., for AdEMAMix, we already have the related issue: #46) |
By the way, I apologize for raising this after opening the PR myself, but I've recently started to wonder if adding new experimental features to this repository is the best approach or not. On the one hand, we could continue expanding the schedule-free library with experimental features (including this cautious option), and leave the choice of using them up to the users. On the other hand, we could maintain this repository as one that only includes well-established optimizers like the three we've already implemented, those with theoretical guarantees and proven practical utility. |
What does this PR do?
I added cautious option to RAdamScheduleFree.
Cautious optimizer is proposed in https://arxiv.org/abs/2411.16085 and https://github.com/kyleliang919/C-Optim .
More details
As I wrote in the docstring, the combination of cautious and schedulefree is non-trivial.
In cautious optimizer, by aligning momentum update with the each gradient direction, it leads to faster convergence.
But in schedulefree, the gradient update term in
z
doesn't contain momentum, which means the cautious mask is meaningless.So I chose to apply cautious mask to
y
update (after contractingx
implicitly) instead, though I guess it's a bit tricky.But in some sense, I believe the training parameter
y
should be aligned in a cautious spirit.Experimental Results
The toy-experimental results show its faster and promising convergence ability.
Below are the single run result of example/mnist/main.py, which I think is superior to both "Default PyTorch Implementation" and AdamWScheduleFree results.