diff --git a/tests/tools/sliced_test.py b/tests/tools/sliced_test.py index 60d6361a3..d04b5901d 100644 --- a/tests/tools/sliced_test.py +++ b/tests/tools/sliced_test.py @@ -61,7 +61,7 @@ def test_random_projs( n, m, dim, n_proj = 12, 17, 5, 13 rng1, rng2 = jax.random.split(rng, 2) a, x, b, y = gen_data(rng1, n, m, dim) - weights = jax.random.uniform(rng2, n_proj) + weights = jax.random.uniform(rng2, (n_proj,)) # Test non-negative and returns output as needed. cost, out = sliced.sliced_wasserstein(