From 825e4e06a84d15abc523c3d61c9af94fdc48de67 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sun, 8 Dec 2024 03:36:31 +0100 Subject: [PATCH] Fixed where a nonbatchable check was being called. --- diffrax/_solver/runge_kutta.py | 2 +- test/test_global_interpolation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/diffrax/_solver/runge_kutta.py b/diffrax/_solver/runge_kutta.py index d1e69467..11a9f6c8 100644 --- a/diffrax/_solver/runge_kutta.py +++ b/diffrax/_solver/runge_kutta.py @@ -964,7 +964,7 @@ def eval_k_jac(): assert implicit_tableau.a_diagonal[0] == 0 # pyright: ignore assert len(set(implicit_tableau.a_diagonal[1:])) == 1 # pyright: ignore jac_stage_index = 1 - stage_index = eqxi.nonbatchable(stage_index) + stage_index = eqxi.nonbatchable(stage_index) # These `stop_gradients` are needed to work around the lack of # symbolic zeros in `custom_vjp`s. if eval_fs: diff --git a/test/test_global_interpolation.py b/test/test_global_interpolation.py index 97f1938b..9c1c236c 100644 --- a/test/test_global_interpolation.py +++ b/test/test_global_interpolation.py @@ -340,7 +340,7 @@ def _test_dense_interpolation(solver, key, t1): @pytest.mark.parametrize("solver", all_ode_solvers + all_split_solvers) -def test_dense_interpolation(solver, getkey): +def test_dense_interpolation(solver): solver = implicit_tol(solver) key = jr.PRNGKey(5678) vals, true_vals, derivs, true_derivs = _test_dense_interpolation(solver, key, 1)