diff --git a/ludwig/trainers/trainer.py b/ludwig/trainers/trainer.py index bf56a1c86f5..7f41c4376be 100644 --- a/ludwig/trainers/trainer.py +++ b/ludwig/trainers/trainer.py @@ -1098,7 +1098,8 @@ def _train_loop( """Completes up to one epoch through the data.""" self.distributed.zero_grad(self.optimizer) batch_idx = 0 - while not batcher.last_batch() and progress_tracker.steps < self.total_steps: + should_break = False + while not batcher.last_batch() and progress_tracker.steps < self.total_steps and not should_break: progress_tracker.learning_rate = self.optimizer.param_groups[0]["lr"] self.callback(lambda c: c.on_batch_start(self, progress_tracker, save_path)) @@ -1149,13 +1150,6 @@ def _train_loop( # batch duration measurements when using timer callbacks. self.callback(lambda c: c.on_batch_end(self, progress_tracker, save_path, sync_step=should_step)) - if batcher.last_batch(): - # We have completed an epoch, so we need to increment the epoch counter. It's important to do this here - # instead of outside of the train loop since it's possible the train loop will exit early due to - # early stopping, or step-based training. - progress_tracker.epoch += 1 - self.callback(lambda c: c.on_epoch_end(self, progress_tracker, save_path)) - if progress_tracker.steps % final_steps_per_checkpoint == 0: if not self.skip_all_evaluation: # Publishes metrics to MLFLow if there are any MLFlow callbacks. @@ -1189,10 +1183,12 @@ def _train_loop( if self.is_coordinator(): progress_tracker.save(os.path.join(save_path, TRAINING_PROGRESS_TRACKER_FILE_NAME)) - if should_break: - return should_break + # If this was the last batch, then increment the epoch counter and invoke the `on_epoch_end` callback. + if batcher.last_batch(): + progress_tracker.epoch += 1 + self.callback(lambda c: c.on_epoch_end(self, progress_tracker, save_path)) - return False + return should_break def train_online(self, dataset): self.dist_model.train() # Sets model training mode.