diff --git a/dpnp/tests/test_ndarray.py b/dpnp/tests/test_ndarray.py index 0a394c29158..dd790c8e56a 100644 --- a/dpnp/tests/test_ndarray.py +++ b/dpnp/tests/test_ndarray.py @@ -460,3 +460,29 @@ def test_clip(): expected = numpy.clip(numpy_array, 3, 7) assert_array_equal(expected, result) + + +def test_rmatmul_dpnp_array(): + a = dpnp.ones(10) + b = dpnp.ones(10) + + class Dummy(dpnp.ndarray): + def __init__(self, x): + self._array_obj = x.get_array() + + def __matmul__(self, other): + return NotImplemented + + d = Dummy(a) + + result = d @ b + expected = a @ b + assert (result == expected).all() + + +def test_rmatmul_numpy_array(): + a = dpnp.ones(10) + b = numpy.ones(10) + + with pytest.raises(TypeError): + b @ a diff --git a/dpnp/tests/test_utils.py b/dpnp/tests/test_utils.py new file mode 100644 index 00000000000..89e97b75d5e --- /dev/null +++ b/dpnp/tests/test_utils.py @@ -0,0 +1,69 @@ +import dpctl +import dpctl.tensor as dpt +import numpy +import pytest + +import dpnp + + +class TestIsSupportedArrayOrScalar: + @pytest.mark.parametrize( + "array", + [ + dpnp.array([1, 2, 3]), + dpnp.array(1), + dpt.asarray([1, 2, 3]), + ], + ) + def test_valid_arrays(self, array): + assert dpnp.is_supported_array_or_scalar(array) is True + + @pytest.mark.parametrize( + "value", + [ + 42, + True, + "1", + ], + ) + def test_valid_scalars(self, value): + assert dpnp.is_supported_array_or_scalar(value) is True + + @pytest.mark.parametrize( + "array", + [ + [1, 2, 3], + (1, 2, 3), + None, + numpy.array([1, 2, 3]), + ], + ) + def test_invalid_arrays(self, array): + assert not dpnp.is_supported_array_or_scalar(array) is True + + +class TestSynchronizeArrayData: + @pytest.mark.parametrize( + "array", + [ + dpnp.array([1, 2, 3]), + dpt.asarray([1, 2, 3]), + ], + ) + def test_synchronize_array_data(self, array): + a_copy = dpnp.copy(array, sycl_queue=array.sycl_queue) + try: + dpnp.synchronize_array_data(a_copy) + except Exception as e: + pytest.fail(f"synchronize_array_data failed: {e}") + + @pytest.mark.parametrize( + "input", + [ + [1, 2, 3], + numpy.array([1, 2, 3]), + ], + ) + def test_unsupported_type(self, input): + with pytest.raises(TypeError): + dpnp.synchronize_array_data(input) diff --git a/setup.py b/setup.py index 37c85fd2647..be32243d34a 100644 --- a/setup.py +++ b/setup.py @@ -43,6 +43,7 @@ def _get_cmdclass(): Topic :: Software Development Topic :: Scientific/Engineering Operating System :: Microsoft :: Windows +Operating System :: POSIX :: Linux Operating System :: POSIX Operating System :: Unix """ @@ -82,4 +83,6 @@ def _get_cmdclass(): ] }, include_package_data=False, + python_requires=">=3.9,<3.14", + install_requires=["dpctl >= 0.19.0dev0", "numpy"], )