-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
dpnp.isdtype
implementation (#2274)
The PR proposes to implement `dpnp.isdtype` function. The function is mandated according to python array API. The corresponding muted tests are enabled in python array API compliance scope.
- Loading branch information
1 parent
451c2b3
commit d522480
Showing
4 changed files
with
121 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import numpy | ||
import pytest | ||
from numpy.testing import assert_raises_regex | ||
|
||
import dpnp | ||
|
||
from .helper import numpy_version | ||
|
||
if numpy_version() >= "2.0.0": | ||
from numpy._core.numerictypes import sctypes | ||
else: | ||
from numpy.core.numerictypes import sctypes | ||
|
||
|
||
class TestIsDType: | ||
dtype_group = { | ||
"signed integer": sctypes["int"], | ||
"unsigned integer": sctypes["uint"], | ||
"integral": sctypes["int"] + sctypes["uint"], | ||
"real floating": sctypes["float"], | ||
"complex floating": sctypes["complex"], | ||
"numeric": ( | ||
sctypes["int"] | ||
+ sctypes["uint"] | ||
+ sctypes["float"] | ||
+ sctypes["complex"] | ||
), | ||
} | ||
|
||
@pytest.mark.parametrize( | ||
"dt, close_dt", | ||
[ | ||
# TODO: replace with (dpnp.uint64, dpnp.uint32) once available | ||
(dpnp.int64, dpnp.int32), | ||
(numpy.uint64, numpy.uint32), | ||
(dpnp.float64, dpnp.float32), | ||
(dpnp.complex128, dpnp.complex64), | ||
], | ||
) | ||
@pytest.mark.parametrize("dt_group", [None] + list(dtype_group.keys())) | ||
def test_basic(self, dt, close_dt, dt_group): | ||
# First check if same dtypes return "True" and different ones | ||
# give "False" (even if they're close in the dtype hierarchy). | ||
if dt_group is None: | ||
assert dpnp.isdtype(dt, dt) | ||
assert not dpnp.isdtype(dt, close_dt) | ||
assert dpnp.isdtype(dt, (dt, close_dt)) | ||
|
||
# Check that dtype and a dtype group that it belongs to return "True", | ||
# and "False" otherwise. | ||
elif dt in self.dtype_group[dt_group]: | ||
assert dpnp.isdtype(dt, dt_group) | ||
assert dpnp.isdtype(dt, (close_dt, dt_group)) | ||
else: | ||
assert not dpnp.isdtype(dt, dt_group) | ||
|
||
def test_invalid_args(self): | ||
with assert_raises_regex(TypeError, r"Expected instance of.*"): | ||
dpnp.isdtype("int64", dpnp.int64) | ||
|
||
with assert_raises_regex(TypeError, r"Unsupported data type kind:.*"): | ||
dpnp.isdtype(dpnp.int64, 1) | ||
|
||
with assert_raises_regex(ValueError, r"Unrecognized data type kind:.*"): | ||
dpnp.isdtype(dpnp.int64, "int64") |