Skip to content

Commit

Permalink
Roll-forward with fixes: Fix interaction between scheduler.step() and…
Browse files Browse the repository at this point in the history
… gradient accumulation steps, refactor schedulers to use `LambdaLR`, and add cosine annealing LR scheduler as a decay method. (#3555)
  • Loading branch information
justinxzhao authored Aug 29, 2023
1 parent f34c272 commit fed9b82
Show file tree
Hide file tree
Showing 5 changed files with 241 additions and 51 deletions.
109 changes: 78 additions & 31 deletions ludwig/modules/lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import logging
import math
from typing import Any, Dict
from typing import Any, Callable, Dict

from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LambdaLR, ReduceLROnPlateau, SequentialLR

from ludwig.constants import MINIMIZE, TRAINING, VALIDATION
from ludwig.modules.metric_registry import get_metric_objective
from ludwig.schema.lr_scheduler import LRSchedulerConfig
from ludwig.utils.metric_utils import TrainerMetric
from ludwig.utils.trainer_utils import ProgressTracker

logger = logging.getLogger(__name__)


class ReduceLROnPLateauCappedDecreases(ReduceLROnPlateau):
def __init__(self, optimizer: Optimizer, mode: str, reduce_limit: int, factor: float, patience: int):
Expand All @@ -29,11 +31,12 @@ def step(self, metrics):
def num_reduce_lr(self) -> int:
return self._num_reduce_lr

def _reduce_lr(self, epoch):
def _reduce_lr(self, epoch=None):
"""Overrides the base ReduceLROnPlateau implementation."""
self._num_reduce_lr += 1
self.apply_lr(epoch)
self.apply_lr()

def apply_lr(self, epoch=None):
def apply_lr(self):
if self._num_reduce_lr == 0:
return

Expand All @@ -43,24 +46,23 @@ def apply_lr(self, epoch=None):
if old_lr - new_lr > self.eps:
param_group["lr"] = new_lr
if self.verbose:
epoch_str = ("%.2f" if isinstance(epoch, float) else "%.5d") % epoch
print("Epoch {}: reducing learning rate" " of group {} to {:.4e}.".format(epoch_str, i, new_lr))
logger.info(f"From ReduceLROnPLateauCappedDecreases, reducing learning rate to {new_lr}")


class LRScheduler:
def __init__(
self,
config: LRSchedulerConfig,
optimizer: Optimizer,
steps_per_checkpoint: int = 1000,
total_steps: int = 10000,
steps_per_checkpoint: int,
total_steps: int,
):
self.config = config
self.optimizer = optimizer

# Scheduler updated each training step
self.step_info = StepInfo(steps_per_checkpoint, total_steps, self.config)
self._train_scheduler = get_schedule_with_warmup(self.config, self.optimizer, self.step_info)
self._train_scheduler = get_schedule_with_warmup_and_decay(self.config, self.optimizer, self.step_info)

# Scheduler updated each eval step
self._eval_scheduler = None
Expand All @@ -74,13 +76,8 @@ def __init__(
patience=self.config.reduce_on_plateau_patience,
)

self.reset(steps_per_checkpoint, total_steps)

def reset(self, steps_per_checkpoint: int, total_steps: int):
# Retain state but update number of steps for training
self.step_info.reset(steps_per_checkpoint, total_steps)

def step(self):
"""Called every step of training."""
self._train_scheduler.step()

if self._eval_scheduler is not None:
Expand All @@ -90,6 +87,7 @@ def step(self):
self._eval_scheduler.apply_lr()

def eval_step(self, progress_tracker: ProgressTracker, validation_field: str):
"""Called every checkpoint evaluation step."""
if self._eval_scheduler is None:
# No reduce on plateau
return
Expand Down Expand Up @@ -140,14 +138,11 @@ class StepInfo:

def __init__(self, steps_per_checkpoint: int, total_steps: int, config: LRSchedulerConfig):
self.config = config
self.reset(steps_per_checkpoint, total_steps)

def reset(self, steps_per_checkpoint: int, total_steps: int):
self.steps_per_checkpoint = steps_per_checkpoint
self.num_training_steps = total_steps

if self.config.warmup_fraction > 0 and self.config.warmup_evaluations > 0:
logging.info(
logger.info(
"Both `learning_rate_scheduler.warmup_fraction` and `learning_rate_scheduler.warmup_evaluations` "
"provided. The larger of the two (as a function of the total training steps) will be used."
)
Expand All @@ -160,28 +155,46 @@ def reset(self, steps_per_checkpoint: int, total_steps: int):
self.num_warmup_steps = num_warmup_steps


def get_schedule_with_warmup(
def get_schedule_with_warmup_and_decay(
config: LRSchedulerConfig,
optimizer: Optimizer,
step_info: StepInfo,
) -> LambdaLR:
"""Creates a learning rate scheduler that updates each training step."""
decay_fn = decay_registry[config.decay]
schedulers = []

def lr_lambda(current_step: int):
if current_step < step_info.num_warmup_steps:
return float(current_step) / float(max(1, step_info.num_warmup_steps))
return decay_fn(current_step, step_info.num_training_steps, step_info.num_warmup_steps, config)
# Warmup scheduler.
if step_info.num_warmup_steps > 0:
warmup_scheduler = LambdaLR(
optimizer,
lambda current_step: float(current_step) / float(max(1, step_info.num_warmup_steps)),
)
schedulers.append(warmup_scheduler)

return LambdaLR(optimizer, lr_lambda, last_epoch=-1)
# Decay scheduler.
decay = config.decay
decay_scheduler = decay_registry[decay](config, optimizer, step_info)
schedulers.append(decay_scheduler)

if len(schedulers) == 1:
# Only one scheduler, so no need to wrap in a SequentialLR.
return schedulers[0]

# Return a SequentialLR that applies the warmup and decay schedulers in order
# with the warmup scheduler only applied for the first num_warmup_steps steps.
return SequentialLR(optimizer, schedulers=schedulers, milestones=[step_info.num_warmup_steps])


def no_decay(current_step: int, num_training_steps: int, num_warmup_steps: int, config: LRSchedulerConfig):
return 1.0


def linear_decay(current_step: int, num_training_steps: int, num_warmup_steps: int, config: LRSchedulerConfig):
return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
return max(
0.0,
float(num_training_steps - num_warmup_steps - current_step)
/ float(max(1, num_training_steps - num_warmup_steps)),
)


def exponential_decay(current_step: int, num_training_steps: int, num_warmup_steps: int, config: LRSchedulerConfig):
Expand All @@ -194,8 +207,42 @@ def exponential_decay(current_step: int, num_training_steps: int, num_warmup_ste
return math.pow(decay_rate, exponent)


def wrap_decay_fn(decay_fn: Callable) -> Callable:
def init_fn(config: LRSchedulerConfig, optimizer: Optimizer, step_info: StepInfo) -> LambdaLR:
return LambdaLR(
optimizer,
lambda current_step: decay_fn(
current_step, step_info.num_training_steps, step_info.num_warmup_steps, config
),
)

return init_fn


def init_cosine_decay(
config: LRSchedulerConfig,
optimizer: Optimizer,
step_info: StepInfo,
) -> CosineAnnealingWarmRestarts:
t_0 = config.t_0
if not t_0:
t_0 = step_info.steps_per_checkpoint
if not t_0:
# A scheduler may be initialized with dummy values like at the start of training.
# Ensure that t_0 != 0, as this causes an error to be raised.
t_0 = 1

return CosineAnnealingWarmRestarts(
optimizer,
T_0=t_0,
T_mult=config.t_mult or 1,
eta_min=config.eta_min or 0,
)


decay_registry = {
None: no_decay,
"linear": linear_decay,
"exponential": exponential_decay,
None: wrap_decay_fn(no_decay),
"linear": wrap_decay_fn(linear_decay),
"exponential": wrap_decay_fn(exponential_decay),
"cosine": init_cosine_decay,
}
28 changes: 27 additions & 1 deletion ludwig/schema/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class LRSchedulerConfig(schema_utils.BaseMarshmallowConfig, ABC):
"""Configuration for learning rate scheduler parameters."""

decay: str = schema_utils.StringOptions(
options=["linear", "exponential"],
options=["linear", "exponential", "cosine"],
default=None,
allow_none=True,
description="Turn on decay of the learning rate.",
Expand Down Expand Up @@ -99,6 +99,32 @@ class LRSchedulerConfig(schema_utils.BaseMarshmallowConfig, ABC):
parameter_metadata=TRAINER_METADATA[MODEL_ECD]["learning_rate_scheduler"]["reduce_eval_split"],
)

# Parameters for CosineAnnealingWarmRestarts scheduler

t_0: int = schema_utils.PositiveInteger(
default=None,
allow_none=True,
description="Number of steps before the first restart for cosine annealing decay. If not specified, it"
" will be set to `steps_per_checkpoint`.",
parameter_metadata=TRAINER_METADATA[MODEL_ECD]["learning_rate_scheduler"]["t_0"],
)

t_mult: int = schema_utils.PositiveInteger(
default=1,
description="Period multiplier after each restart for cosine annealing decay. Defaults to 1, i.e.,"
" restart every `t_0` steps. If set to a larger value, the period between restarts increases by that"
" multiplier. For e.g., if t_mult is 2, then the periods would be: t_0, 2*t_0, 2^2*t_0, 2^3*t_0, etc.",
parameter_metadata=TRAINER_METADATA[MODEL_ECD]["learning_rate_scheduler"]["t_mult"],
)

eta_min: float = schema_utils.FloatRange(
default=0,
min=0,
max=1,
description="Minimum learning rate allowed for cosine annealing decay. Default: 0.",
parameter_metadata=TRAINER_METADATA[MODEL_ECD]["learning_rate_scheduler"]["eta_min"],
)


# TODO(travis): too much boilerplate here, we should find a way to abstract all this and only require specifying the
# minimal amount needed for the new config object.
Expand Down
14 changes: 13 additions & 1 deletion ludwig/schema/metadata/configs/trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,10 @@ ecd:
suggested_values_reasoning:
Starting with exponential decay is a safe place to start, as it is a "softer" decrease in the learning
rate over time, as compared with linear, which is more steep after the initial drop. Linear decay is
most useful when the risk of catastrophic forgetting is very high (e.g, for fine-tuning pretrained models).
most useful when the risk of catastrophic forgetting is very high (e.g, for fine-tuning pretrained
models). Cosine annealing is a type of learning rate schedule that has the effect of starting with a
large learning rate that is relatively rapidly decreased to a minimum value before being increased
rapidly again. The resetting of the learning rate acts like a simulated restart of the learning process.
If you observe your loss curves shooting up (even on the training set) in later epochs, increasing the
decay rate may help mitigate this effect.
ui_display_name: Decay
Expand Down Expand Up @@ -600,6 +603,15 @@ ecd:
reduce_eval_split:
expected_impact: 1
ui_display_name: Reduce Eval Split
t_0:
expected_impact: 1
ui_display_name: T_0
t_mult:
expected_impact: 1
ui_display_name: T_mult
eta_min:
expected_impact: 1
ui_display_name: Eta Min
gbm:
learning_rate:
commonly_used: true
Expand Down
20 changes: 13 additions & 7 deletions ludwig/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,15 +208,17 @@ def prepare(self):
base_learning_rate = self.config.learning_rate
if self.distributed:
lr_scale_fn = learning_rate_scale_fns[self.config.learning_rate_scaling]
base_learning_rate *= lr_scale_fn(self.distributed.size() * self.gradient_accumulation_steps)
base_learning_rate *= lr_scale_fn(self.distributed.size())
self.base_learning_rate = base_learning_rate

self.dist_model, self.optimizer = self.distributed.prepare(
self.compiled_model,
self.config,
self.base_learning_rate,
)
self.scheduler = LRScheduler(self.config.learning_rate_scheduler, self.optimizer)

# NOTE: This is a partially configured LRScheduler. It will be updated in the first call to train_step.
self.scheduler = LRScheduler(self.config.learning_rate_scheduler, self.optimizer, 0, 0)

def train_step(
self, inputs: Dict[str, torch.Tensor], targets: Dict[str, torch.Tensor], should_step: bool = True
Expand Down Expand Up @@ -762,8 +764,13 @@ def train(
final_steps_per_checkpoint = min(final_steps_per_checkpoint, self.total_steps)
early_stopping_steps = final_steps_per_checkpoint * self.early_stop

# Update learning rate scheduler which depends on number of steps
self.scheduler.reset(final_steps_per_checkpoint, self.total_steps)
# Initialize the learning rate scheduler.
self.scheduler = LRScheduler(
self.config.learning_rate_scheduler,
self.optimizer,
steps_per_checkpoint=final_steps_per_checkpoint,
total_steps=self.total_steps,
)

if self.is_coordinator():
logger.info(
Expand Down Expand Up @@ -944,9 +951,8 @@ def _train_loop(
loss, all_losses = self.train_step(inputs, targets, should_step=should_step)
logger.info(f"Train loss for step {progress_tracker.steps}: {loss:.3f}")

if should_step:
# Update LR schduler here instead of train loop to avoid updating during batch size tuning, etc.
self.scheduler.step()
# Update LR schduler here instead of train loop to avoid updating during batch size tuning, etc.
self.scheduler.step()

if self.is_coordinator() and not self.skip_save_log:
self.write_step_summary(
Expand Down
Loading

0 comments on commit fed9b82

Please sign in to comment.