Skip to content

Commit

Permalink
Add dpnp.isdtype implementation (#2274)
Browse files Browse the repository at this point in the history
The PR proposes to implement `dpnp.isdtype` function.

The function is mandated according to python array API. The
corresponding muted tests are enabled in python array API compliance
scope.
  • Loading branch information
antonwolfy authored Jan 23, 2025
1 parent 451c2b3 commit d522480
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 6 deletions.
5 changes: 0 additions & 5 deletions .github/workflows/array-api-skips.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@
# no 'uint8' dtype
array_api_tests/test_array_object.py::test_getitem_masking

# no 'isdtype' function
array_api_tests/test_data_type_functions.py::test_isdtype
array_api_tests/test_has_names.py::test_has_names[data_type-isdtype]
array_api_tests/test_signatures.py::test_func_signature[isdtype]

# missing unique-like functions
array_api_tests/test_has_names.py::test_has_names[set-unique_all]
array_api_tests/test_has_names.py::test_has_names[set-unique_counts]
Expand Down
1 change: 0 additions & 1 deletion doc/reference/dtype.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ Data type routines
dpnp.min_scalar_type
dpnp.result_type
dpnp.common_type
dpnp.obj2sctype

Creating data types
-------------------
Expand Down
56 changes: 56 additions & 0 deletions dpnp/dpnp_iface_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
"integer",
"intc",
"intp",
"isdtype",
"issubdtype",
"is_type_supported",
"nan",
Expand Down Expand Up @@ -194,11 +195,66 @@ def iinfo(dtype):
smallest representable number.
"""

if isinstance(dtype, dpnp_array):
dtype = dtype.dtype
return dpt.iinfo(dtype)


def isdtype(dtype, kind):
"""
Returns a boolean indicating whether a provided `dtype` is
of a specified data type `kind`.
Parameters
----------
dtype : dtype
The input dtype.
kind : {dtype, str, tuple of dtypes or strs}
The input dtype or dtype kind. Allowed dtype kinds are:
* ``'bool'`` : boolean kind
* ``'signed integer'`` : signed integer data types
* ``'unsigned integer'`` : unsigned integer data types
* ``'integral'`` : integer data types
* ``'real floating'`` : real-valued floating-point data types
* ``'complex floating'`` : complex floating-point data types
* ``'numeric'`` : numeric data types
Returns
-------
out : bool
A boolean indicating whether a provided `dtype` is of a specified data
type `kind`.
See Also
--------
:obj:`dpnp.issubdtype` : Test if the first argument is a type code
lower/equal in type hierarchy.
Examples
--------
>>> import dpnp as np
>>> np.isdtype(np.float32, np.float64)
False
>>> np.isdtype(np.float32, "real floating")
True
>>> np.isdtype(np.complex128, ("real floating", "complex floating"))
True
"""

if isinstance(dtype, type):
dtype = dpt.dtype(dtype)

if isinstance(kind, type):
kind = dpt.dtype(kind)
elif isinstance(kind, tuple):
kind = tuple(dpt.dtype(k) if isinstance(k, type) else k for k in kind)

return dpt.isdtype(dtype, kind)


def issubdtype(arg1, arg2):
"""
Returns ``True`` if the first argument is a type code lower/equal
Expand Down
65 changes: 65 additions & 0 deletions dpnp/tests/test_dtype_routines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy
import pytest
from numpy.testing import assert_raises_regex

import dpnp

from .helper import numpy_version

if numpy_version() >= "2.0.0":
from numpy._core.numerictypes import sctypes
else:
from numpy.core.numerictypes import sctypes


class TestIsDType:
dtype_group = {
"signed integer": sctypes["int"],
"unsigned integer": sctypes["uint"],
"integral": sctypes["int"] + sctypes["uint"],
"real floating": sctypes["float"],
"complex floating": sctypes["complex"],
"numeric": (
sctypes["int"]
+ sctypes["uint"]
+ sctypes["float"]
+ sctypes["complex"]
),
}

@pytest.mark.parametrize(
"dt, close_dt",
[
# TODO: replace with (dpnp.uint64, dpnp.uint32) once available
(dpnp.int64, dpnp.int32),
(numpy.uint64, numpy.uint32),
(dpnp.float64, dpnp.float32),
(dpnp.complex128, dpnp.complex64),
],
)
@pytest.mark.parametrize("dt_group", [None] + list(dtype_group.keys()))
def test_basic(self, dt, close_dt, dt_group):
# First check if same dtypes return "True" and different ones
# give "False" (even if they're close in the dtype hierarchy).
if dt_group is None:
assert dpnp.isdtype(dt, dt)
assert not dpnp.isdtype(dt, close_dt)
assert dpnp.isdtype(dt, (dt, close_dt))

# Check that dtype and a dtype group that it belongs to return "True",
# and "False" otherwise.
elif dt in self.dtype_group[dt_group]:
assert dpnp.isdtype(dt, dt_group)
assert dpnp.isdtype(dt, (close_dt, dt_group))
else:
assert not dpnp.isdtype(dt, dt_group)

def test_invalid_args(self):
with assert_raises_regex(TypeError, r"Expected instance of.*"):
dpnp.isdtype("int64", dpnp.int64)

with assert_raises_regex(TypeError, r"Unsupported data type kind:.*"):
dpnp.isdtype(dpnp.int64, 1)

with assert_raises_regex(ValueError, r"Unrecognized data type kind:.*"):
dpnp.isdtype(dpnp.int64, "int64")

0 comments on commit d522480

Please sign in to comment.