diff --git a/lr_scheduler.py b/lr_scheduler.py index a2122e5d..80ac6ad0 100644 --- a/lr_scheduler.py +++ b/lr_scheduler.py @@ -24,7 +24,6 @@ def build_scheduler(config, optimizer, n_iter_per_epoch): lr_scheduler = CosineLRScheduler( optimizer, t_initial=(num_steps - warmup_steps) if config.TRAIN.LR_SCHEDULER.WARMUP_PREFIX else num_steps, - t_mul=1., lr_min=config.TRAIN.MIN_LR, warmup_lr_init=config.TRAIN.WARMUP_LR, warmup_t=warmup_steps, @@ -118,7 +117,7 @@ def get_update_values(self, num_updates: int): class MultiStepLRScheduler(Scheduler): def __init__(self, optimizer: torch.optim.Optimizer, milestones, gamma=0.1, warmup_t=0, warmup_lr_init=0, t_in_epochs=True) -> None: super().__init__(optimizer, param_group_field="lr") - + self.milestones = milestones self.gamma = gamma self.warmup_t = warmup_t @@ -129,9 +128,9 @@ def __init__(self, optimizer: torch.optim.Optimizer, milestones, gamma=0.1, warm super().update_groups(self.warmup_lr_init) else: self.warmup_steps = [1 for _ in self.base_values] - + assert self.warmup_t <= min(self.milestones) - + def _get_lr(self, t): if t < self.warmup_t: lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]