diff --git a/tests/batching_test.py b/tests/batching_test.py index e1fa533b..52ed82ea 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -18,8 +18,8 @@ from absl.testing import absltest from jax import jit +from jax import config from jax import random -from jax.config import config import jax.numpy as jnp from jax.tree_util import tree_map import neural_tangents as nt diff --git a/tests/elementwise_numerical_test.py b/tests/elementwise_numerical_test.py index 0a166a30..bb141a66 100644 --- a/tests/elementwise_numerical_test.py +++ b/tests/elementwise_numerical_test.py @@ -16,7 +16,7 @@ from absl.testing import absltest -from jax.config import config +from jax import config from examples import elementwise_numerical from tests import test_utils diff --git a/tests/elementwise_test.py b/tests/elementwise_test.py index 82424b41..9856fdf1 100644 --- a/tests/elementwise_test.py +++ b/tests/elementwise_test.py @@ -16,7 +16,7 @@ from absl.testing import absltest -from jax.config import config +from jax import config from examples import elementwise from tests import test_utils diff --git a/tests/empirical_ntk_test.py b/tests/empirical_ntk_test.py index 81973cc1..19af638c 100644 --- a/tests/empirical_ntk_test.py +++ b/tests/empirical_ntk_test.py @@ -16,7 +16,7 @@ from absl.testing import absltest -from jax.config import config +from jax import config from examples import empirical_ntk from tests import test_utils diff --git a/tests/empirical_test.py b/tests/empirical_test.py index f610bb11..f4475156 100644 --- a/tests/empirical_test.py +++ b/tests/empirical_test.py @@ -28,7 +28,7 @@ from jax import random from jax import remat from jax import tree_map -from jax.config import config +from jax import config import jax.numpy as jnp from jax.tree_util import tree_reduce import neural_tangents as nt diff --git a/tests/function_space_test.py b/tests/function_space_test.py index f60b0920..b951c841 100644 --- a/tests/function_space_test.py +++ b/tests/function_space_test.py @@ -16,7 +16,7 @@ from absl.testing import absltest -from jax.config import config +from jax import config from examples import function_space from tests import test_utils diff --git a/tests/imdb_test.py b/tests/imdb_test.py index b93eea14..e19016b9 100644 --- a/tests/imdb_test.py +++ b/tests/imdb_test.py @@ -16,7 +16,7 @@ """Tests for `examples/imdb.py`.""" from absl.testing import absltest -from jax.config import config +from jax import config from examples import imdb from tests import test_utils diff --git a/tests/infinite_fcn_test.py b/tests/infinite_fcn_test.py index 1d5f6f12..c6a70376 100644 --- a/tests/infinite_fcn_test.py +++ b/tests/infinite_fcn_test.py @@ -16,7 +16,7 @@ from absl.testing import absltest -from jax.config import config +from jax import config from examples import infinite_fcn from tests import test_utils diff --git a/tests/monte_carlo_test.py b/tests/monte_carlo_test.py index 5208c23c..3ba3385c 100644 --- a/tests/monte_carlo_test.py +++ b/tests/monte_carlo_test.py @@ -16,8 +16,8 @@ from absl.testing import absltest import jax +from jax import config from jax import random -from jax.config import config import jax.numpy as jnp import neural_tangents as nt from neural_tangents import stax diff --git a/tests/predict_test.py b/tests/predict_test.py index 5269da8e..950d16ad 100644 --- a/tests/predict_test.py +++ b/tests/predict_test.py @@ -21,7 +21,7 @@ from jax import jit from jax import random from jax import vmap -from jax.config import config +from jax import config from jax.example_libraries import optimizers from jax.flatten_util import ravel_pytree import jax.numpy as jnp diff --git a/tests/rules_test.py b/tests/rules_test.py index 72b28067..8192c207 100644 --- a/tests/rules_test.py +++ b/tests/rules_test.py @@ -23,7 +23,7 @@ from absl.testing import absltest import jax from jax import lax -from jax.config import config +from jax import config from jax.core import Primitive from jax.core import ShapedArray from jax.interpreters import ad diff --git a/tests/weight_space_test.py b/tests/weight_space_test.py index a2dee130..f4778733 100644 --- a/tests/weight_space_test.py +++ b/tests/weight_space_test.py @@ -16,7 +16,7 @@ from absl.testing import absltest -from jax.config import config +from jax import config from examples import weight_space from tests import test_utils