diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index b590ea432658..5bed858e163f 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -225,7 +225,7 @@ def __init__(self, self.num_experts = [] self.gate_modules = [] self.moe_layers = [] - self._step_applied = False + self._step_applied = True self._global_grad_norm = None self.use_ds_comm = False # False --> Use torch.dist, True --> Use ds.comm backend. @@ -2108,6 +2108,15 @@ def clip_fp32_gradients(self): clip_grad_norm_(parameters=self.module.parameters(), max_norm=self.gradient_clipping(), mpu=self.mpu) def _take_model_step(self, lr_kwargs, block_eigenvalue={}): + if self.lr_scheduler is not None and self._step_applied: + try: + self.lr_scheduler.step(**(lr_kwargs or {})) + except TypeError: + # XXX Hack to work with Megatron 2.0 and DeepSpeed pipelines. + # We don't currently have a way to specify lr_kwargs from + # pipe_engine.train_batch() + self.lr_scheduler.step(self.train_batch_size()) + if self.gradient_clipping() > 0.0: if not (self.fp16_enabled() or self.bfloat16_enabled() or self.amp_enabled() or self.zero_optimization()): self.clip_fp32_gradients() @@ -2157,14 +2166,6 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}): self.skipped_steps += 1 else: self.compression_scheduler.step() - if self.lr_scheduler is not None: - try: - self.lr_scheduler.step(**(lr_kwargs or {})) - except TypeError: - # XXX Hack to work with Megatron 2.0 and DeepSpeed pipelines. - # We don't currently have a way to specify lr_kwargs from - # pipe_engine.train_batch() - self.lr_scheduler.step(self.train_batch_size()) if report_progress and (self.global_steps + 1) % self.steps_per_print() == 0: self._report_progress(self.global_steps + 1) @@ -2191,8 +2192,6 @@ def step(self, lr_kwargs=None): report_progress = False - self._step_applied = False # assume False, will flip to True - # Update the model when we reach gradient accumulation boundaries if self.is_gradient_accumulation_boundary(): self.gas_boundary_ctr += 1