-
Notifications
You must be signed in to change notification settings - Fork 51
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added LinearWarmupCosineAnnealingLR scheduler and linear warmup decay…
… function
- Loading branch information
valhassan
committed
Jan 8, 2025
1 parent
f8c2784
commit cc64d40
Showing
1 changed file
with
147 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
import math | ||
import warnings | ||
from typing import List | ||
|
||
from torch.optim import Optimizer | ||
from torch.optim.lr_scheduler import _LRScheduler | ||
|
||
|
||
# Adapted from https://github.com/Lightning-Universe/lightning-bolts/blob/master/src/pl_bolts/optimizers/lr_scheduler.py | ||
|
||
class LinearWarmupCosineAnnealingLR(_LRScheduler): | ||
"""Sets the learning rate of each parameter group to follow a linear warmup schedule between warmup_start_lr and | ||
base_lr followed by a cosine annealing schedule between base_lr and eta_min. | ||
.. warning:: | ||
It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR` | ||
after each iteration as calling it after each epoch will keep the starting lr at | ||
warmup_start_lr for the first epoch which is 0 in most cases. | ||
.. warning:: | ||
passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING. | ||
It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of | ||
:func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing | ||
epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling | ||
train and validation methods. | ||
Example: | ||
>>> import torch.nn as nn | ||
>>> from torch.optim import Adam | ||
>>> # | ||
>>> layer = nn.Linear(10, 1) | ||
>>> optimizer = Adam(layer.parameters(), lr=0.02) | ||
>>> scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40) | ||
>>> # the default case | ||
>>> for epoch in range(40): | ||
... # train(...) | ||
... # validate(...) | ||
... scheduler.step() | ||
>>> # passing epoch param case | ||
>>> for epoch in range(40): | ||
... scheduler.step(epoch) | ||
... # train(...) | ||
... # validate(...) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
optimizer: Optimizer, | ||
warmup_epochs: int, | ||
max_epochs: int, | ||
warmup_start_lr: float = 0.0, | ||
eta_min: float = 0.0, | ||
last_epoch: int = -1, | ||
) -> None: | ||
""" | ||
Args: | ||
optimizer (Optimizer): Wrapped optimizer. | ||
warmup_epochs (int): Maximum number of iterations for linear warmup | ||
max_epochs (int): Maximum number of iterations | ||
warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. | ||
eta_min (float): Minimum learning rate. Default: 0. | ||
last_epoch (int): The index of last epoch. Default: -1. | ||
""" | ||
self.warmup_epochs = warmup_epochs | ||
self.max_epochs = max_epochs | ||
self.warmup_start_lr = warmup_start_lr | ||
self.eta_min = eta_min | ||
|
||
super().__init__(optimizer, last_epoch) | ||
|
||
def get_lr(self) -> List[float]: | ||
"""Compute learning rate using chainable form of the scheduler.""" | ||
if not self._get_lr_called_within_step: | ||
warnings.warn( | ||
"To get the last learning rate computed by the scheduler; please use `get_last_lr()`.", | ||
UserWarning, | ||
) | ||
|
||
if self.last_epoch == 0: | ||
return [self.warmup_start_lr] * len(self.base_lrs) | ||
if self.last_epoch < self.warmup_epochs: | ||
return [ | ||
group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) | ||
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) | ||
] | ||
if self.last_epoch == self.warmup_epochs: | ||
return self.base_lrs | ||
if (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0: | ||
return [ | ||
group["lr"] | ||
+ (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2 | ||
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) | ||
] | ||
|
||
return [ | ||
(1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) | ||
/ ( | ||
1 | ||
+ math.cos( | ||
math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs) | ||
) | ||
) | ||
* (group["lr"] - self.eta_min) | ||
+ self.eta_min | ||
for group in self.optimizer.param_groups | ||
] | ||
|
||
def _get_closed_form_lr(self) -> List[float]: | ||
"""Called when epoch is passed as a param to the `step` function of the scheduler.""" | ||
if self.last_epoch < self.warmup_epochs: | ||
return [ | ||
self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) | ||
for base_lr in self.base_lrs | ||
] | ||
|
||
return [ | ||
self.eta_min | ||
+ 0.5 | ||
* (base_lr - self.eta_min) | ||
* (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) | ||
for base_lr in self.base_lrs | ||
] | ||
|
||
|
||
# warmup + decay as a function | ||
def linear_warmup_decay(warmup_steps, total_steps, cosine=True, linear=False): | ||
"""Linear warmup for warmup_steps, optionally with cosine annealing or linear decay to 0 at total_steps.""" | ||
assert not (linear and cosine) | ||
|
||
def fn(step): | ||
if step < warmup_steps: | ||
return float(step) / float(max(1, warmup_steps)) | ||
|
||
if not (cosine or linear): | ||
# no decay | ||
return 1.0 | ||
|
||
progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps)) | ||
if cosine: | ||
# cosine decay | ||
return 0.5 * (1.0 + math.cos(math.pi * progress)) | ||
|
||
# linear decay | ||
return 1.0 - progress | ||
|
||
return fn |