diff --git a/init2winit/trainer_lib/trainer.py b/init2winit/trainer_lib/trainer.py index c9680b04..ca3b9e3e 100644 --- a/init2winit/trainer_lib/trainer.py +++ b/init2winit/trainer_lib/trainer.py @@ -202,6 +202,9 @@ def train(self): self._time_at_prev_eval_end = start_time self._prev_eval_step = self._global_step + if self._global_step in self._checkpoint_steps and jax.process_index() == 0: + self._save(self._checkpoint_dir, max_to_keep=None) + for _ in range(start_step, self._num_train_steps): with jax.profiler.StepTraceAnnotation('train', step_num=self._global_step): @@ -210,9 +213,6 @@ def train(self): # directly in the top-level for loop). batch = next(train_iter) - if (self._global_step in self._checkpoint_steps - and jax.process_index() == 0): - self._save(self._checkpoint_dir, max_to_keep=None) lr = self._lr_fn(self._global_step) # It looks like we are reusing an rng key, but we aren't. # TODO(gdahl): Make it more obvious that passing rng is safe. @@ -225,13 +225,26 @@ def train(self): self._metrics_state, batch, self._global_step, lr, rng, self._local_device_indices, self._sum_train_cost) self._global_step += 1 + + if ( + self._global_step in self._checkpoint_steps + and jax.process_index() == 0 + ): + self._save(self._checkpoint_dir, max_to_keep=None) lr = self._optimizer_state.hyperparams['learning_rate'][0] # TODO(gdahl, gilmer): consider moving this test up. # NB: Since this test is after we increment self._global_step, having 0 # in eval_steps does nothing. if trainer_utils.should_eval( self._global_step, self._eval_frequency, self._eval_steps): - report = self._eval(lr, start_step, start_time) + try: + report = self._eval(lr, start_step, start_time) + except utils.TrainingDivergedError as e: + # In case of NaN durin evals, make sure to save the last checkpoint. + checkpoint.wait_for_checkpoint_save() + raise utils.TrainingDivergedError( + f'divergence at step {self._global_step}' + ) from e yield report if self._check_early_stopping(report): return