diff --git a/tests/test_core.py b/tests/test_core.py index 5063954..2f04431 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -992,3 +992,19 @@ def test_no_unsafe_cumulative_sum_cast(): ): a = ndx.asarray([1, 2, 3], ndx.int32) ndx.cumulative_sum(a, dtype=ndx.uint64) + + +@pytest.mark.parametrize( + "func, x", + [ + (np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int32)), + (np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.float32)), + (np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.float32)), + (np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.float32)), + ], +) +def test_argmaxmin(func, x): + np_result = func(x) + ndx_result = getattr(ndx, func.__name__)(ndx.asarray(x)).to_numpy() + breakpoint() + np.testing.assert_equal(np_result, ndx_result) diff --git a/xfails.txt b/xfails.txt index 69431ee..9665406 100644 --- a/xfails.txt +++ b/xfails.txt @@ -90,8 +90,6 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_signbit array_api_tests/test_operators_and_elementwise_functions.py::test_sinh array_api_tests/test_operators_and_elementwise_functions.py::test_sqrt array_api_tests/test_operators_and_elementwise_functions.py::test_tan -array_api_tests/test_searching_functions.py::test_argmax -array_api_tests/test_searching_functions.py::test_argmin array_api_tests/test_searching_functions.py::test_nonzero_zerodim_error array_api_tests/test_searching_functions.py::test_searchsorted array_api_tests/test_searching_functions.py::test_where