diff --git a/hessian/test_run_stats.py b/hessian/test_run_stats.py index d2a61a67..94fd5beb 100644 --- a/hessian/test_run_stats.py +++ b/hessian/test_run_stats.py @@ -34,7 +34,7 @@ from init2winit.init_lib import initializers from init2winit.model_lib import models from init2winit.trainer_lib import trainer -from jax.config import config as jax_config +from jax import config as jax_config from jax.flatten_util import ravel_pytree import jax.numpy as jnp import jax.random