diff --git a/init2winit/trainer_lib/base_trainer.py b/init2winit/trainer_lib/base_trainer.py index c8e2a9ee..490ad100 100644 --- a/init2winit/trainer_lib/base_trainer.py +++ b/init2winit/trainer_lib/base_trainer.py @@ -26,6 +26,7 @@ from init2winit import checkpoint from init2winit import schedules from init2winit import utils +from init2winit.optimizer_lib import gradient_accumulator from init2winit.optimizer_lib import optimizers from init2winit.trainer_lib import trainer_utils from init2winit.training_metrics_grabber import make_training_metrics @@ -548,8 +549,10 @@ def _eval( ) if self._eval_use_ema: - if isinstance( - self._optimizer_state.base_state.inner_state[0][0], optax.EmaState + if isinstance(self._optimizer_state, optax.InjectHyperparamsState): + eval_params = self._optimizer_state.inner_state[0][0].ema + elif isinstance( + self._optimizer_state, gradient_accumulator.GradientAccumulatorState ): eval_params = self._optimizer_state.base_state.inner_state[0][0].ema else: