Skip to content

Commit

Permalink
Move on_epoch_end and epoch increment to after run_evaluation loop. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
justinxzhao authored Oct 9, 2023
1 parent c0946be commit f84ee5c
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions ludwig/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,7 +1098,8 @@ def _train_loop(
"""Completes up to one epoch through the data."""
self.distributed.zero_grad(self.optimizer)
batch_idx = 0
while not batcher.last_batch() and progress_tracker.steps < self.total_steps:
should_break = False
while not batcher.last_batch() and progress_tracker.steps < self.total_steps and not should_break:
progress_tracker.learning_rate = self.optimizer.param_groups[0]["lr"]
self.callback(lambda c: c.on_batch_start(self, progress_tracker, save_path))

Expand Down Expand Up @@ -1149,13 +1150,6 @@ def _train_loop(
# batch duration measurements when using timer callbacks.
self.callback(lambda c: c.on_batch_end(self, progress_tracker, save_path, sync_step=should_step))

if batcher.last_batch():
# We have completed an epoch, so we need to increment the epoch counter. It's important to do this here
# instead of outside of the train loop since it's possible the train loop will exit early due to
# early stopping, or step-based training.
progress_tracker.epoch += 1
self.callback(lambda c: c.on_epoch_end(self, progress_tracker, save_path))

if progress_tracker.steps % final_steps_per_checkpoint == 0:
if not self.skip_all_evaluation:
# Publishes metrics to MLFLow if there are any MLFlow callbacks.
Expand Down Expand Up @@ -1189,10 +1183,12 @@ def _train_loop(
if self.is_coordinator():
progress_tracker.save(os.path.join(save_path, TRAINING_PROGRESS_TRACKER_FILE_NAME))

if should_break:
return should_break
# If this was the last batch, then increment the epoch counter and invoke the `on_epoch_end` callback.
if batcher.last_batch():
progress_tracker.epoch += 1
self.callback(lambda c: c.on_epoch_end(self, progress_tracker, save_path))

return False
return should_break

def train_online(self, dataset):
self.dist_model.train() # Sets model training mode.
Expand Down

0 comments on commit f84ee5c

Please sign in to comment.