Skip to content

Commit

Permalink
Fix argmax
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Oct 24, 2024
1 parent ed532b5 commit b139d18
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 5 deletions.
12 changes: 10 additions & 2 deletions ndonnx/_core/_numericimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,11 @@ def argmax(self, x, axis=None, keepdims=False):
opx.const([], dtype=dtypes.int64),
)
)
return _via_i64_f64(lambda x: opx.arg_max(x, axis=axis, keepdims=keepdims), [x])
return _via_i64_f64(
lambda x: opx.arg_max(x, axis=axis, keepdims=keepdims),
[x],
cast_return=False,
)

@validate_core
def argmin(self, x, axis=None, keepdims=False):
Expand All @@ -391,7 +395,11 @@ def argmin(self, x, axis=None, keepdims=False):
opx.const([], dtype=dtypes.int64),
)
)
return _via_i64_f64(lambda x: opx.arg_min(x, axis=axis, keepdims=keepdims), [x])
return _via_i64_f64(
lambda x: opx.arg_min(x, axis=axis, keepdims=keepdims),
[x],
cast_return=False,
)

@validate_core
def nonzero(self, x) -> tuple[Array, ...]:
Expand Down
2 changes: 1 addition & 1 deletion ndonnx/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ def argmax(x, axis=None, keepdims=False):

def argmin(x, axis=None, keepdims=False):
if (
out := x.dtype._ops.argmax(x, axis=axis, keepdims=keepdims)
out := x.dtype._ops.argmin(x, axis=axis, keepdims=keepdims)
) is not NotImplemented:
return out
raise UnsupportedOperationError(f"Unsupported operand type for argmin: '{x.dtype}'")
Expand Down
15 changes: 15 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,3 +992,18 @@ def test_no_unsafe_cumulative_sum_cast():
):
a = ndx.asarray([1, 2, 3], ndx.int32)
ndx.cumulative_sum(a, dtype=ndx.uint64)


@pytest.mark.parametrize(
"func, x",
[
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int32)),
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.float32)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.float32)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.float32)),
],
)
def test_argmaxmin(func, x):
np_result = func(x)
ndx_result = getattr(ndx, func.__name__)(ndx.asarray(x)).to_numpy()
assert_array_equal(np_result, ndx_result)
2 changes: 0 additions & 2 deletions xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_signbit
array_api_tests/test_operators_and_elementwise_functions.py::test_sinh
array_api_tests/test_operators_and_elementwise_functions.py::test_sqrt
array_api_tests/test_operators_and_elementwise_functions.py::test_tan
array_api_tests/test_searching_functions.py::test_argmax
array_api_tests/test_searching_functions.py::test_argmin
array_api_tests/test_searching_functions.py::test_nonzero_zerodim_error
array_api_tests/test_searching_functions.py::test_searchsorted
array_api_tests/test_searching_functions.py::test_where
Expand Down

0 comments on commit b139d18

Please sign in to comment.