From 251406e8c2f9d850dc1ee15f5798976cab47b6e8 Mon Sep 17 00:00:00 2001 From: Sourabh Medapati Date: Fri, 20 Oct 2023 16:28:18 -0700 Subject: [PATCH] Internal PiperOrigin-RevId: 575345109 --- init2winit/trainer_lib/base_trainer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/init2winit/trainer_lib/base_trainer.py b/init2winit/trainer_lib/base_trainer.py index c8e2a9ee..490ad100 100644 --- a/init2winit/trainer_lib/base_trainer.py +++ b/init2winit/trainer_lib/base_trainer.py @@ -26,6 +26,7 @@ from init2winit import checkpoint from init2winit import schedules from init2winit import utils +from init2winit.optimizer_lib import gradient_accumulator from init2winit.optimizer_lib import optimizers from init2winit.trainer_lib import trainer_utils from init2winit.training_metrics_grabber import make_training_metrics @@ -548,8 +549,10 @@ def _eval( ) if self._eval_use_ema: - if isinstance( - self._optimizer_state.base_state.inner_state[0][0], optax.EmaState + if isinstance(self._optimizer_state, optax.InjectHyperparamsState): + eval_params = self._optimizer_state.inner_state[0][0].ema + elif isinstance( + self._optimizer_state, gradient_accumulator.GradientAccumulatorState ): eval_params = self._optimizer_state.base_state.inner_state[0][0].ema else: