Skip to content

Commit

Permalink
Internal
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 575345109
  • Loading branch information
sourabh2k15 authored and copybara-github committed Oct 21, 2023
1 parent 328b99d commit 251406e
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions init2winit/trainer_lib/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from init2winit import checkpoint
from init2winit import schedules
from init2winit import utils
from init2winit.optimizer_lib import gradient_accumulator
from init2winit.optimizer_lib import optimizers
from init2winit.trainer_lib import trainer_utils
from init2winit.training_metrics_grabber import make_training_metrics
Expand Down Expand Up @@ -548,8 +549,10 @@ def _eval(
)

if self._eval_use_ema:
if isinstance(
self._optimizer_state.base_state.inner_state[0][0], optax.EmaState
if isinstance(self._optimizer_state, optax.InjectHyperparamsState):
eval_params = self._optimizer_state.inner_state[0][0].ema
elif isinstance(
self._optimizer_state, gradient_accumulator.GradientAccumulatorState
):
eval_params = self._optimizer_state.base_state.inner_state[0][0].ema
else:
Expand Down

0 comments on commit 251406e

Please sign in to comment.