Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dpnp.isdtype implementation #2274

Merged
merged 4 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions .github/workflows/array-api-skips.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@
# no 'uint8' dtype
array_api_tests/test_array_object.py::test_getitem_masking

# no 'isdtype' function
array_api_tests/test_data_type_functions.py::test_isdtype
array_api_tests/test_has_names.py::test_has_names[data_type-isdtype]
array_api_tests/test_signatures.py::test_func_signature[isdtype]

# missing unique-like functions
array_api_tests/test_has_names.py::test_has_names[set-unique_all]
array_api_tests/test_has_names.py::test_has_names[set-unique_counts]
Expand Down
1 change: 0 additions & 1 deletion doc/reference/dtype.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ Data type routines
dpnp.min_scalar_type
dpnp.result_type
dpnp.common_type
dpnp.obj2sctype

Creating data types
-------------------
Expand Down
56 changes: 56 additions & 0 deletions dpnp/dpnp_iface_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
"integer",
"intc",
"intp",
"isdtype",
"issubdtype",
"is_type_supported",
"nan",
Expand Down Expand Up @@ -194,11 +195,66 @@ def iinfo(dtype):
smallest representable number.

"""

if isinstance(dtype, dpnp_array):
dtype = dtype.dtype
return dpt.iinfo(dtype)


def isdtype(dtype, kind):
"""
Returns a boolean indicating whether a provided `dtype` is
of a specified data type `kind`.

Parameters
----------
dtype : dtype
The input dtype.
kind : {dtype, str, tuple of dtypes or strs}
The input dtype or dtype kind. Allowed dtype kinds are:

* ``'bool'`` : boolean kind
* ``'signed integer'`` : signed integer data types
* ``'unsigned integer'`` : unsigned integer data types
* ``'integral'`` : integer data types
* ``'real floating'`` : real-valued floating-point data types
* ``'complex floating'`` : complex floating-point data types
* ``'numeric'`` : numeric data types

Returns
-------
out : bool
A boolean indicating whether a provided `dtype` is of a specified data
type `kind`.

See Also
--------
:obj:`dpnp.issubdtype` : Test if the first argument is a type code
lower/equal in type hierarchy.

Examples
--------
>>> import dpnp as np
>>> np.isdtype(np.float32, np.float64)
False
>>> np.isdtype(np.float32, "real floating")
True
>>> np.isdtype(np.complex128, ("real floating", "complex floating"))
True

"""

if isinstance(dtype, type):
dtype = dpt.dtype(dtype)

if isinstance(kind, type):
kind = dpt.dtype(kind)
elif isinstance(kind, tuple):
kind = tuple(dpt.dtype(k) if isinstance(k, type) else k for k in kind)

return dpt.isdtype(dtype, kind)


def issubdtype(arg1, arg2):
"""
Returns ``True`` if the first argument is a type code lower/equal
Expand Down
65 changes: 65 additions & 0 deletions dpnp/tests/test_dtype_routines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy
import pytest
from numpy.testing import assert_raises_regex

import dpnp

from .helper import numpy_version

if numpy_version() >= "2.0.0":
from numpy._core.numerictypes import sctypes
else:
from numpy.core.numerictypes import sctypes


class TestIsDType:
dtype_group = {
"signed integer": sctypes["int"],
"unsigned integer": sctypes["uint"],
"integral": sctypes["int"] + sctypes["uint"],
"real floating": sctypes["float"],
"complex floating": sctypes["complex"],
"numeric": (
sctypes["int"]
+ sctypes["uint"]
+ sctypes["float"]
+ sctypes["complex"]
),
}

@pytest.mark.parametrize(
"dt, close_dt",
[
# TODO: replace with (dpnp.uint64, dpnp.uint32) once available
(dpnp.int64, dpnp.int32),
(numpy.uint64, numpy.uint32),
(dpnp.float64, dpnp.float32),
(dpnp.complex128, dpnp.complex64),
],
)
@pytest.mark.parametrize("dt_group", [None] + list(dtype_group.keys()))
def test_basic(self, dt, close_dt, dt_group):
# First check if same dtypes return "True" and different ones
# give "False" (even if they're close in the dtype hierarchy).
if dt_group is None:
assert dpnp.isdtype(dt, dt)
assert not dpnp.isdtype(dt, close_dt)
assert dpnp.isdtype(dt, (dt, close_dt))

# Check that dtype and a dtype group that it belongs to return "True",
# and "False" otherwise.
elif dt in self.dtype_group[dt_group]:
assert dpnp.isdtype(dt, dt_group)
assert dpnp.isdtype(dt, (close_dt, dt_group))
else:
assert not dpnp.isdtype(dt, dt_group)

def test_invalid_args(self):
with assert_raises_regex(TypeError, r"Expected instance of.*"):
dpnp.isdtype("int64", dpnp.int64)

with assert_raises_regex(TypeError, r"Unsupported data type kind:.*"):
dpnp.isdtype(dpnp.int64, 1)

with assert_raises_regex(ValueError, r"Unrecognized data type kind:.*"):
dpnp.isdtype(dpnp.int64, "int64")
Loading