Skip to content

Commit

Permalink
Fixed where a nonbatchable check was being called.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Dec 9, 2024
1 parent 1ae1d58 commit 825e4e0
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion diffrax/_solver/runge_kutta.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,7 @@ def eval_k_jac():
assert implicit_tableau.a_diagonal[0] == 0 # pyright: ignore
assert len(set(implicit_tableau.a_diagonal[1:])) == 1 # pyright: ignore
jac_stage_index = 1
stage_index = eqxi.nonbatchable(stage_index)
stage_index = eqxi.nonbatchable(stage_index)
# These `stop_gradients` are needed to work around the lack of
# symbolic zeros in `custom_vjp`s.
if eval_fs:
Expand Down
2 changes: 1 addition & 1 deletion test/test_global_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def _test_dense_interpolation(solver, key, t1):


@pytest.mark.parametrize("solver", all_ode_solvers + all_split_solvers)
def test_dense_interpolation(solver, getkey):
def test_dense_interpolation(solver):
solver = implicit_tol(solver)
key = jr.PRNGKey(5678)
vals, true_vals, derivs, true_derivs = _test_dense_interpolation(solver, key, 1)
Expand Down

0 comments on commit 825e4e0

Please sign in to comment.