Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add __array_namespace__ method #2252

Merged
merged 6 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 31 additions & 13 deletions dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
offset=offset,
order=order,
buffer_ctor_kwargs={"queue": sycl_queue_normalized},
array_namespace=dpnp,
)

@property
Expand Down Expand Up @@ -201,6 +202,31 @@ def __and__(self, other):
# '__array_ufunc__',
# '__array_wrap__',

def __array_namespace__(self, /, *, api_version=None):
"""
Returns array namespace, member functions of which implement data API.

Parameters
----------
api_version : str, optional
Request namespace compliant with given version of array API. If
``None``, namespace for the most recent supported version is
returned.
Default: ``None``.

Returns
-------
out : any
An object representing the array API namespace. It should have
every top-level function defined in the specification as
an attribute. It may contain other public names as well, but it is
recommended to only include those names that are part of the
specification.

"""

return self._array_obj.__array_namespace__(api_version=api_version)

def __bool__(self):
"""``True`` if self else ``False``."""
return self._array_obj.__bool__()
Expand Down Expand Up @@ -327,15 +353,7 @@ def __getitem__(self, key):
key = _get_unwrapped_index_key(key)

item = self._array_obj.__getitem__(key)
if not isinstance(item, dpt.usm_ndarray):
raise RuntimeError(
"Expected dpctl.tensor.usm_ndarray, got {}"
"".format(type(item))
)

res = self.__new__(dpnp_array)
res._array_obj = item
return res
return dpnp_array._create_from_usm_ndarray(item)

# '__getstate__',

Expand Down Expand Up @@ -606,6 +624,7 @@ def _create_from_usm_ndarray(usm_ary: dpt.usm_ndarray):
)
res = dpnp_array.__new__(dpnp_array)
res._array_obj = usm_ary
res._array_obj._set_namespace(dpnp)
return res

def all(self, axis=None, out=None, keepdims=False, *, where=True):
Expand Down Expand Up @@ -1749,17 +1768,16 @@ def transpose(self, *axes):
if axes_len == 1 and isinstance(axes[0], (tuple, list)):
axes = axes[0]

res = self.__new__(dpnp_array)
if ndim == 2 and axes_len == 0:
res._array_obj = self._array_obj.T
usm_res = self._array_obj.T
else:
if len(axes) == 0 or axes[0] is None:
# self.transpose().shape == self.shape[::-1]
# self.transpose(None).shape == self.shape[::-1]
axes = tuple((ndim - x - 1) for x in range(ndim))

res._array_obj = dpt.permute_dims(self._array_obj, axes)
return res
usm_res = dpt.permute_dims(self._array_obj, axes)
return dpnp_array._create_from_usm_ndarray(usm_res)

def var(
self,
Expand Down
10 changes: 2 additions & 8 deletions dpnp/dpnp_iface_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,14 +622,8 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
out_strides = a_straides[:-2] + (1,)
out_offset = a_element_offset

return dpnp_array._create_from_usm_ndarray(
dpt.usm_ndarray(
out_shape,
dtype=a.dtype,
buffer=a.get_array(),
strides=out_strides,
offset=out_offset,
)
return dpnp_array(
out_shape, buffer=a, strides=out_strides, offset=out_offset
)


Expand Down
48 changes: 47 additions & 1 deletion dpnp/tests/test_ndarray.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import dpctl.tensor as dpt
import numpy
import pytest
from numpy.testing import assert_allclose, assert_array_equal
from numpy.testing import (
assert_allclose,
assert_array_equal,
assert_raises_regex,
)

import dpnp

Expand Down Expand Up @@ -104,6 +108,48 @@ def test_flags_writable():
assert not a.imag.flags.writable


class TestArrayNamespace:
def test_basic(self):
a = dpnp.arange(2)
xp = a.__array_namespace__()
assert xp is dpnp

@pytest.mark.parametrize("api_version", [None, "2023.12"])
def test_api_version(self, api_version):
a = dpnp.arange(2)
xp = a.__array_namespace__(api_version=api_version)
assert xp is dpnp

@pytest.mark.parametrize("api_version", ["2021.12", "2022.12", "2024.12"])
def test_unsupported_api_version(self, api_version):
a = dpnp.arange(2)
assert_raises_regex(
ValueError,
"Only 2023.12 is supported",
a.__array_namespace__,
api_version=api_version,
)

@pytest.mark.parametrize(
"api_version",
[
2023,
(2022,),
[
2021,
],
],
)
def test_wrong_api_version(self, api_version):
a = dpnp.arange(2)
assert_raises_regex(
TypeError,
"Expected type str",
a.__array_namespace__,
api_version=api_version,
)


class TestItem:
@pytest.mark.parametrize("args", [2, 7, (1, 2), (2, 0)])
def test_basic(self, args):
Expand Down
Loading