Skip to content

Commit

Permalink
Merge branch 'master' into harlowjo/debug/wheels-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
antonwolfy authored Jan 28, 2025
2 parents ba98594 + c844b26 commit 5422dc9
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 0 deletions.
26 changes: 26 additions & 0 deletions dpnp/tests/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,3 +460,29 @@ def test_clip():
expected = numpy.clip(numpy_array, 3, 7)

assert_array_equal(expected, result)


def test_rmatmul_dpnp_array():
a = dpnp.ones(10)
b = dpnp.ones(10)

class Dummy(dpnp.ndarray):
def __init__(self, x):
self._array_obj = x.get_array()

def __matmul__(self, other):
return NotImplemented

d = Dummy(a)

result = d @ b
expected = a @ b
assert (result == expected).all()


def test_rmatmul_numpy_array():
a = dpnp.ones(10)
b = numpy.ones(10)

with pytest.raises(TypeError):
b @ a
69 changes: 69 additions & 0 deletions dpnp/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import dpctl
import dpctl.tensor as dpt
import numpy
import pytest

import dpnp


class TestIsSupportedArrayOrScalar:
@pytest.mark.parametrize(
"array",
[
dpnp.array([1, 2, 3]),
dpnp.array(1),
dpt.asarray([1, 2, 3]),
],
)
def test_valid_arrays(self, array):
assert dpnp.is_supported_array_or_scalar(array) is True

@pytest.mark.parametrize(
"value",
[
42,
True,
"1",
],
)
def test_valid_scalars(self, value):
assert dpnp.is_supported_array_or_scalar(value) is True

@pytest.mark.parametrize(
"array",
[
[1, 2, 3],
(1, 2, 3),
None,
numpy.array([1, 2, 3]),
],
)
def test_invalid_arrays(self, array):
assert not dpnp.is_supported_array_or_scalar(array) is True


class TestSynchronizeArrayData:
@pytest.mark.parametrize(
"array",
[
dpnp.array([1, 2, 3]),
dpt.asarray([1, 2, 3]),
],
)
def test_synchronize_array_data(self, array):
a_copy = dpnp.copy(array, sycl_queue=array.sycl_queue)
try:
dpnp.synchronize_array_data(a_copy)
except Exception as e:
pytest.fail(f"synchronize_array_data failed: {e}")

@pytest.mark.parametrize(
"input",
[
[1, 2, 3],
numpy.array([1, 2, 3]),
],
)
def test_unsupported_type(self, input):
with pytest.raises(TypeError):
dpnp.synchronize_array_data(input)
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def _get_cmdclass():
Topic :: Software Development
Topic :: Scientific/Engineering
Operating System :: Microsoft :: Windows
Operating System :: POSIX :: Linux
Operating System :: POSIX
Operating System :: Unix
"""
Expand Down Expand Up @@ -82,4 +83,6 @@ def _get_cmdclass():
]
},
include_package_data=False,
python_requires=">=3.9,<3.14",
install_requires=["dpctl >= 0.19.0dev0", "numpy"],
)

0 comments on commit 5422dc9

Please sign in to comment.