Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Follow recommendation on the interaction with numpy.ndarray in binary ops #2266

Merged
merged 8 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ def __and__(self, other):
# '__array_prepare__',
# '__array_priority__',
# '__array_struct__',
# '__array_ufunc__',

__array_ufunc__ = None

# '__array_wrap__',

def __array_namespace__(self, /, *, api_version=None):
Expand Down
14 changes: 14 additions & 0 deletions dpnp/tests/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,20 @@ def get_all_dtypes(
return dtypes


def get_array(xp, a):
"""
Cast input array `a` to a type supported by `xp` interface.

Implicit conversion of either DPNP or DPCTL array to a NumPy array is not
allowed. Input array has to be explicitly casted with `asnumpy` function.

"""

if xp is numpy and dpnp.is_supported_array_type(a):
return dpnp.asnumpy(a)
return a


def generate_random_numpy_array(
shape,
dtype=None,
Expand Down
3 changes: 2 additions & 1 deletion dpnp/tests/test_arraycreation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .helper import (
assert_dtype_allclose,
get_all_dtypes,
get_array,
)
from .third_party.cupy import testing

Expand Down Expand Up @@ -768,7 +769,7 @@ def test_space_numpy_dtype(func, start_dtype, stop_dtype):
],
)
def test_linspace_arrays(start, stop):
func = lambda xp: xp.linspace(start, stop, 10)
func = lambda xp: xp.linspace(get_array(xp, start), get_array(xp, stop), 10)
assert func(numpy).shape == func(dpnp).shape


Expand Down
15 changes: 7 additions & 8 deletions dpnp/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1935,7 +1935,7 @@ def test_matrix_rank(self, data, dtype):

np_rank = numpy.linalg.matrix_rank(a)
dp_rank = dpnp.linalg.matrix_rank(a_dp)
assert np_rank == dp_rank
assert dp_rank.asnumpy() == np_rank

@pytest.mark.parametrize("dtype", get_all_dtypes())
@pytest.mark.parametrize(
Expand All @@ -1953,7 +1953,7 @@ def test_matrix_rank_hermitian(self, data, dtype):

np_rank = numpy.linalg.matrix_rank(a, hermitian=True)
dp_rank = dpnp.linalg.matrix_rank(a_dp, hermitian=True)
assert np_rank == dp_rank
assert dp_rank.asnumpy() == np_rank

@pytest.mark.parametrize(
"high_tol, low_tol",
Expand Down Expand Up @@ -1986,15 +1986,15 @@ def test_matrix_rank_tolerance(self, high_tol, low_tol):
dp_rank_high_tol = dpnp.linalg.matrix_rank(
a_dp, hermitian=True, tol=dp_high_tol
)
assert np_rank_high_tol == dp_rank_high_tol
assert dp_rank_high_tol.asnumpy() == np_rank_high_tol

np_rank_low_tol = numpy.linalg.matrix_rank(
a, hermitian=True, tol=low_tol
)
dp_rank_low_tol = dpnp.linalg.matrix_rank(
a_dp, hermitian=True, tol=dp_low_tol
)
assert np_rank_low_tol == dp_rank_low_tol
assert dp_rank_low_tol.asnumpy() == np_rank_low_tol

# rtol kwarg was added in numpy 2.0
@testing.with_requires("numpy>=2.0")
Expand Down Expand Up @@ -2807,15 +2807,14 @@ def check_decomposition(
for i in range(min(dp_a.shape[-2], dp_a.shape[-1])):
dpnp_diag_s[..., i, i] = dp_s[..., i]
reconstructed = dpnp.dot(dp_u, dpnp.dot(dpnp_diag_s, dp_vt))
# TODO: use assert dpnp.allclose() inside check_decomposition()
# when it will support complex dtypes
assert_allclose(dp_a, reconstructed, rtol=tol, atol=1e-4)

assert dpnp.allclose(dp_a, reconstructed, rtol=tol, atol=1e-4)

assert_allclose(dp_s, np_s, rtol=tol, atol=1e-03)

if compute_vt:
for i in range(min(dp_a.shape[-2], dp_a.shape[-1])):
if np_u[..., 0, i] * dp_u[..., 0, i] < 0:
if np_u[..., 0, i] * dpnp.asnumpy(dp_u[..., 0, i]) < 0:
np_u[..., :, i] = -np_u[..., :, i]
np_vt[..., i, :] = -np_vt[..., i, :]
for i in range(numpy.count_nonzero(np_s > tol)):
Expand Down
6 changes: 5 additions & 1 deletion dpnp/tests/test_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .helper import (
assert_dtype_allclose,
get_all_dtypes,
get_array,
get_complex_dtypes,
get_float_complex_dtypes,
get_float_dtypes,
Expand Down Expand Up @@ -1232,7 +1233,10 @@ def test_axes(self):
def test_axes_type(self, axes):
a = numpy.ones((50, 40, 3))
ia = dpnp.array(a)
assert_equal(dpnp.rot90(ia, axes=axes), numpy.rot90(a, axes=axes))
assert_equal(
dpnp.rot90(ia, axes=axes),
numpy.rot90(a, axes=get_array(numpy, axes)),
)

def test_rotation_axes(self):
a = numpy.arange(8).reshape((2, 2, 2))
Expand Down
21 changes: 21 additions & 0 deletions dpnp/tests/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,27 @@ def test_wrong_api_version(self, api_version):
)


class TestArrayUfunc:
def test_add(self):
a = numpy.ones(10)
b = dpnp.ones(10)
msg = "An array must be any of supported type"

with assert_raises_regex(TypeError, msg):
a + b

with assert_raises_regex(TypeError, msg):
b + a

def test_add_inplace(self):
a = numpy.ones(10)
b = dpnp.ones(10)
with assert_raises_regex(
TypeError, "operand 'dpnp_array' does not support ufuncs"
):
a += b


class TestItem:
@pytest.mark.parametrize("args", [2, 7, (1, 2), (2, 0)])
def test_basic(self, args):
Expand Down
Loading