diff --git a/notebooks/Disentangling_Trainability_and_Generalization.ipynb b/notebooks/Disentangling_Trainability_and_Generalization.ipynb index 1dac80a3..7f25a24e 100644 --- a/notebooks/Disentangling_Trainability_and_Generalization.ipynb +++ b/notebooks/Disentangling_Trainability_and_Generalization.ipynb @@ -66,7 +66,7 @@ "from jax.example_libraries import optimizers\n", "from jax import grad, jit, vmap\n", "from jax import lax\n", - "from jax.config import config\n", + "from jax import config\n", "config.update('jax_enable_x64', True)\n", "\n", "from functools import partial\n", diff --git a/notebooks/myrtle_kernel_with_neural_tangents.ipynb b/notebooks/myrtle_kernel_with_neural_tangents.ipynb index a80d40fc..060a71a3 100644 --- a/notebooks/myrtle_kernel_with_neural_tangents.ipynb +++ b/notebooks/myrtle_kernel_with_neural_tangents.ipynb @@ -72,7 +72,7 @@ }, "outputs": [], "source": [ - "from jax.config import config\n", + "from jax import config\n", "# Enable float64 for JAX\n", "config.update(\"jax_enable_x64\", True)\n", "\n", diff --git a/notebooks/phase_diagram.ipynb b/notebooks/phase_diagram.ipynb index a1409eef..a76ce8ba 100644 --- a/notebooks/phase_diagram.ipynb +++ b/notebooks/phase_diagram.ipynb @@ -94,7 +94,7 @@ "from jax.example_libraries import optimizers\n", "from jax import grad, jit, vmap\n", "from jax import lax\n", - "from jax.config import config\n", + "from jax import config\n", "config.update('jax_enable_x64', True)\n", "\n", "from functools import partial\n", diff --git a/tests/stax/branching_test.py b/tests/stax/branching_test.py index 05560f83..5ba43df6 100644 --- a/tests/stax/branching_test.py +++ b/tests/stax/branching_test.py @@ -17,9 +17,9 @@ import random as prandom from absl.testing import absltest +from jax import config from jax import default_backend from jax import random -from jax.config import config import jax.numpy as jnp import neural_tangents as nt from neural_tangents import stax diff --git a/tests/stax/combinators_test.py b/tests/stax/combinators_test.py index f2a78759..7b5e035e 100644 --- a/tests/stax/combinators_test.py +++ b/tests/stax/combinators_test.py @@ -17,8 +17,8 @@ import random as prandom from absl.testing import absltest +from jax import config from jax import random -from jax.config import config import jax.numpy as jnp from neural_tangents import stax from tests import test_utils diff --git a/tests/stax/elementwise_test.py b/tests/stax/elementwise_test.py index 1d475a6e..41565be5 100644 --- a/tests/stax/elementwise_test.py +++ b/tests/stax/elementwise_test.py @@ -18,10 +18,10 @@ import random as prandom from absl.testing import absltest +from jax import config from jax import default_backend from jax import grad, jacfwd, jacrev, jit, jvp, value_and_grad from jax import random -from jax.config import config import jax.numpy as jnp import neural_tangents as nt from neural_tangents import stax diff --git a/tests/stax/linear_test.py b/tests/stax/linear_test.py index fcccc444..eab5a046 100644 --- a/tests/stax/linear_test.py +++ b/tests/stax/linear_test.py @@ -25,7 +25,7 @@ from jax import lax from jax import random from jax import vjp -from jax.config import config +from jax import config import jax.numpy as jnp import more_itertools import neural_tangents as nt diff --git a/tests/stax/requirements_test.py b/tests/stax/requirements_test.py index dff41abc..24e60bd2 100644 --- a/tests/stax/requirements_test.py +++ b/tests/stax/requirements_test.py @@ -19,10 +19,10 @@ import random as prandom from absl.testing import absltest +from jax import config from jax import default_backend from jax import jit from jax import random -from jax.config import config import jax.numpy as jnp import neural_tangents as nt from neural_tangents import stax diff --git a/tests/stax/stax_test.py b/tests/stax/stax_test.py index b6f3aaf3..fba80e0d 100644 --- a/tests/stax/stax_test.py +++ b/tests/stax/stax_test.py @@ -19,10 +19,10 @@ import random as prandom from absl.testing import absltest +from jax import config from jax import default_backend from jax import jit from jax import random -from jax.config import config from jax.example_libraries import stax as ostax import jax.numpy as jnp import neural_tangents as nt