Skip to content

Commit

Permalink
Updated jax.config import
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 574906153
  • Loading branch information
superbobry authored and romanngg committed Nov 21, 2023
1 parent d816c8f commit dce935b
Show file tree
Hide file tree
Showing 9 changed files with 9 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
"from jax.example_libraries import optimizers\n",
"from jax import grad, jit, vmap\n",
"from jax import lax\n",
"from jax.config import config\n",
"from jax import config\n",
"config.update('jax_enable_x64', True)\n",
"\n",
"from functools import partial\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/myrtle_kernel_with_neural_tangents.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
},
"outputs": [],
"source": [
"from jax.config import config\n",
"from jax import config\n",
"# Enable float64 for JAX\n",
"config.update(\"jax_enable_x64\", True)\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/phase_diagram.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
"from jax.example_libraries import optimizers\n",
"from jax import grad, jit, vmap\n",
"from jax import lax\n",
"from jax.config import config\n",
"from jax import config\n",
"config.update('jax_enable_x64', True)\n",
"\n",
"from functools import partial\n",
Expand Down
2 changes: 1 addition & 1 deletion tests/stax/branching_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import random as prandom

from absl.testing import absltest
from jax import config
from jax import default_backend
from jax import random
from jax.config import config
import jax.numpy as jnp
import neural_tangents as nt
from neural_tangents import stax
Expand Down
2 changes: 1 addition & 1 deletion tests/stax/combinators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import random as prandom

from absl.testing import absltest
from jax import config
from jax import random
from jax.config import config
import jax.numpy as jnp
from neural_tangents import stax
from tests import test_utils
Expand Down
2 changes: 1 addition & 1 deletion tests/stax/elementwise_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
import random as prandom

from absl.testing import absltest
from jax import config
from jax import default_backend
from jax import grad, jacfwd, jacrev, jit, jvp, value_and_grad
from jax import random
from jax.config import config
import jax.numpy as jnp
import neural_tangents as nt
from neural_tangents import stax
Expand Down
2 changes: 1 addition & 1 deletion tests/stax/linear_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from jax import lax
from jax import random
from jax import vjp
from jax.config import config
from jax import config
import jax.numpy as jnp
import more_itertools
import neural_tangents as nt
Expand Down
2 changes: 1 addition & 1 deletion tests/stax/requirements_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
import random as prandom

from absl.testing import absltest
from jax import config
from jax import default_backend
from jax import jit
from jax import random
from jax.config import config
import jax.numpy as jnp
import neural_tangents as nt
from neural_tangents import stax
Expand Down
2 changes: 1 addition & 1 deletion tests/stax/stax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
import random as prandom

from absl.testing import absltest
from jax import config
from jax import default_backend
from jax import jit
from jax import random
from jax.config import config
from jax.example_libraries import stax as ostax
import jax.numpy as jnp
import neural_tangents as nt
Expand Down

0 comments on commit dce935b

Please sign in to comment.