-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement python array API Inspection namespace (#2275)
The PR proposes to implement `dpnp.__array_namespace_info__`, what is a python array API Inspection namespace. It is required to achieve the compliance with python array API. The implementation leverages on appropriate namespace exposed by dpctl.tensor. In addition, this PR makes an `__array_api_version__` attribute available in dpnp. It also borrowed from dpctl. The PR adds a dedication documentation page describing `Array API standard compatibility`, including reference on new `dpnp.__array_namespace_info__`.
- Loading branch information
1 parent
b72f953
commit 83fef36
Showing
8 changed files
with
249 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
.. _array-api-standard-compatibility: | ||
|
||
.. https://numpy.org/doc/stable/reference/array_api.html | ||
******************************** | ||
Array API standard compatibility | ||
******************************** | ||
|
||
DPNP's main namespace as well as the :mod:`dpnp.fft` and :mod:`dpnp.linalg` | ||
namespaces are compatible with the | ||
`2023.12 version <https://data-apis.org/array-api/2023.12/index.html>`__ | ||
of the Python array API standard. | ||
|
||
Inspection | ||
========== | ||
|
||
DPNP implements the `array API inspection utilities | ||
<https://data-apis.org/array-api/latest/API_specification/inspection.html>`__. | ||
These functions can be accessed via the ``__array_namespace_info__()`` | ||
function, which returns a namespace containing the inspection utilities. | ||
|
||
.. autosummary:: | ||
:toctree: generated/ | ||
:nosignatures: | ||
|
||
dpnp.__array_namespace_info__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
.. _routines.fft: | ||
|
||
.. py:module:: dpnp.fft | ||
Discrete Fourier Transform | ||
========================== | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,3 +33,4 @@ API reference of the Data Parallel Extension for NumPy* | |
dtypes_table | ||
comparison | ||
misc | ||
array_api |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
.. _routines.linalg: | ||
|
||
.. py:module:: dpnp.linalg | ||
Linear algebra | ||
============== | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# -*- coding: utf-8 -*- | ||
# ***************************************************************************** | ||
# Copyright (c) 2025, Intel Corporation | ||
# All rights reserved. | ||
# | ||
# Redistribution and use in source and binary forms, with or without | ||
# modification, are permitted provided that the following conditions are met: | ||
# - Redistributions of source code must retain the above copyright notice, | ||
# this list of conditions and the following disclaimer. | ||
# - Redistributions in binary form must reproduce the above copyright notice, | ||
# this list of conditions and the following disclaimer in the documentation | ||
# and/or other materials provided with the distribution. | ||
# | ||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE | ||
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | ||
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | ||
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | ||
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF | ||
# THE POSSIBILITY OF SUCH DAMAGE. | ||
# ***************************************************************************** | ||
|
||
""" | ||
Array API Inspection namespace | ||
This is the namespace for inspection functions as defined by the array API | ||
standard. See | ||
https://data-apis.org/array-api/latest/API_specification/inspection.html for | ||
more details. | ||
""" | ||
|
||
import dpctl.tensor as dpt | ||
|
||
__all__ = ["__array_namespace_info__"] | ||
|
||
|
||
def __array_namespace_info__(): | ||
""" | ||
Returns a namespace with Array API namespace inspection utilities. | ||
The array API inspection namespace defines the following functions: | ||
- capabilities() | ||
- default_device() | ||
- default_dtypes() | ||
- dtypes() | ||
- devices() | ||
Returns | ||
------- | ||
info : ModuleType | ||
The array API inspection namespace for DPNP. | ||
Examples | ||
-------- | ||
>>> import dpnp as np | ||
>>> info = np.__array_namespace_info__() | ||
>>> info.default_dtypes() # may vary and depends on default device | ||
{'real floating': dtype('float64'), | ||
'complex floating': dtype('complex128'), | ||
'integral': dtype('int64'), | ||
'indexing': dtype('int64')} | ||
""" | ||
|
||
return dpt.__array_namespace_info__() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import numpy | ||
import pytest | ||
from dpctl import SyclDeviceCreationError, get_devices, select_default_device | ||
from dpctl.tensor._tensor_impl import default_device_complex_type | ||
|
||
import dpnp | ||
from dpnp.tests.helper import ( | ||
has_support_aspect64, | ||
is_win_platform, | ||
numpy_version, | ||
) | ||
|
||
info = dpnp.__array_namespace_info__() | ||
default_device = select_default_device() | ||
|
||
|
||
def test_capabilities(): | ||
caps = info.capabilities() | ||
assert caps["boolean indexing"] is True | ||
assert caps["data-dependent shapes"] is True | ||
assert caps["max dimensions"] == 64 | ||
|
||
|
||
def test_default_device(): | ||
assert info.default_device() == default_device | ||
|
||
|
||
def test_default_dtypes(): | ||
dtypes = info.default_dtypes() | ||
assert ( | ||
dtypes["real floating"] | ||
== dpnp.default_float_type() | ||
== dpnp.asarray(0.0).dtype | ||
) | ||
# TODO: add dpnp.default_complex_type() function | ||
assert ( | ||
dtypes["complex floating"] | ||
== default_device_complex_type(default_device) | ||
== dpnp.asarray(0.0j).dtype | ||
) | ||
if not is_win_platform() or numpy_version() >= "2.0.0": | ||
# numpy changed default integer on Windows since 2.0 | ||
assert dtypes["integral"] == dpnp.intp == dpnp.asarray(0).dtype | ||
assert ( | ||
dtypes["indexing"] == dpnp.intp == dpnp.argmax(dpnp.zeros(10)).dtype | ||
) | ||
|
||
with pytest.raises( | ||
TypeError, match="Unsupported type for device argument:" | ||
): | ||
info.default_dtypes(device=1) | ||
|
||
|
||
def test_dtypes_all(): | ||
dtypes = info.dtypes() | ||
assert dtypes == ( | ||
{ | ||
"bool": dpnp.bool_, | ||
"int8": numpy.int8, # TODO: replace with dpnp.int8 | ||
"int16": numpy.int16, # TODO: replace with dpnp.int16 | ||
"int32": dpnp.int32, | ||
"int64": dpnp.int64, | ||
"uint8": numpy.uint8, # TODO: replace with dpnp.uint8 | ||
"uint16": numpy.uint16, # TODO: replace with dpnp.uint16 | ||
"uint32": numpy.uint32, # TODO: replace with dpnp.uint32 | ||
"uint64": numpy.uint64, # TODO: replace with dpnp.uint64 | ||
"float32": dpnp.float32, | ||
} | ||
| ({"float64": dpnp.float64} if has_support_aspect64() else {}) | ||
| {"complex64": dpnp.complex64} | ||
| ({"complex128": dpnp.complex128} if has_support_aspect64() else {}) | ||
) | ||
|
||
|
||
dtype_categories = { | ||
"bool": {"bool": dpnp.bool_}, | ||
"signed integer": { | ||
"int8": numpy.int8, # TODO: replace with dpnp.int8 | ||
"int16": numpy.int16, # TODO: replace with dpnp.int16 | ||
"int32": dpnp.int32, | ||
"int64": dpnp.int64, | ||
}, | ||
"unsigned integer": { # TODO: replace with dpnp dtypes once available | ||
"uint8": numpy.uint8, | ||
"uint16": numpy.uint16, | ||
"uint32": numpy.uint32, | ||
"uint64": numpy.uint64, | ||
}, | ||
"integral": ("signed integer", "unsigned integer"), | ||
"real floating": {"float32": dpnp.float32} | ||
| ({"float64": dpnp.float64} if has_support_aspect64() else {}), | ||
"complex floating": {"complex64": dpnp.complex64} | ||
| ({"complex128": dpnp.complex128} if has_support_aspect64() else {}), | ||
"numeric": ("integral", "real floating", "complex floating"), | ||
} | ||
|
||
|
||
@pytest.mark.parametrize("kind", dtype_categories) | ||
def test_dtypes_kind(kind): | ||
expected = dtype_categories[kind] | ||
if isinstance(expected, tuple): | ||
assert info.dtypes(kind=kind) == info.dtypes(kind=expected) | ||
else: | ||
assert info.dtypes(kind=kind) == expected | ||
|
||
|
||
def test_dtypes_tuple(): | ||
dtypes = info.dtypes(kind=("bool", "integral")) | ||
assert dtypes == { | ||
"bool": dpnp.bool_, | ||
"int8": numpy.int8, # TODO: replace with dpnp.int8 | ||
"int16": numpy.int16, # TODO: replace with dpnp.int16 | ||
"int32": dpnp.int32, | ||
"int64": dpnp.int64, | ||
"uint8": numpy.uint8, # TODO: replace with dpnp.uint8 | ||
"uint16": numpy.uint16, # TODO: replace with dpnp.uint16 | ||
"uint32": numpy.uint32, # TODO: replace with dpnp.uint32 | ||
"uint64": numpy.uint64, # TODO: replace with dpnp.uint64 | ||
} | ||
|
||
|
||
def test_dtypes_invalid_kind(): | ||
with pytest.raises(ValueError, match="Unrecognized data type kind"): | ||
info.dtypes(kind="invalid") | ||
|
||
|
||
def test_dtypes_invalid_device(): | ||
with pytest.raises(SyclDeviceCreationError, match="Could not create"): | ||
info.dtypes(device="str") | ||
|
||
|
||
def test_devices(): | ||
assert info.devices() == get_devices() |