diff --git a/tests/utils_test.py b/tests/utils_test.py index 2486a4732..70bc4e23f 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -127,7 +127,7 @@ def fn(x: jnp.ndarray) -> jnp.ndarray: np.testing.assert_array_equal(fn(x), x.sum(1)) - @pytest.mark.limit_memory("10MB") + @pytest.mark.limit_memory("15MB") def test_vmap_max_memory(self, rng: jax.Array): n, m, d = 2 ** 16, 2 ** 11, 3 rng, rng_data = jax.random.split(rng, 2)