diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 56a89f63..6e37724c 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -157,7 +157,7 @@ def is_int_dtype(dtype): return dtype in all_int_dtypes -def is_float_dtype(dtype): +def is_float_dtype(dtype, *, include_complex=True): # None equals NumPy's xp.float64 object, so we specifically check it here. # xp.float64 is in fact an alias of np.dtype('float64'), and its equality # with None is meant to be deprecated at some point. @@ -165,11 +165,10 @@ def is_float_dtype(dtype): if dtype is None: return False valid_dtypes = real_float_dtypes - if api_version > "2021.12": + if api_version > "2021.12" and include_complex: valid_dtypes += complex_dtypes return dtype in valid_dtypes - def get_scalar_type(dtype: DataType) -> ScalarType: if dtype in all_int_dtypes: return int diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 461532e7..c51b14a6 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -464,8 +464,8 @@ def assert_array_elements( f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} " f"{f_func}" ) - _assert_float_element(at_out.real, at_expected.real, msg) - _assert_float_element(at_out.imag, at_expected.imag, msg) + _assert_float_element(xp.real(at_out), xp.real(at_expected), msg) + _assert_float_element(xp.imag(at_out), xp.imag(at_expected), msg) else: assert xp.all( out == expected diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index 917b1f26..fa69bbcd 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -162,6 +162,7 @@ def test_finfo(dtype_name): assert isinstance( value, stype ), f"type(out.{attr})={type(value)!r}, but should be {stype.__name__} {f_func}" + assert hasattr(out, "dtype"), f"out has no attribute 'dtype' {f_func}" # TODO: test values @@ -179,6 +180,7 @@ def test_iinfo(dtype_name): assert isinstance( value, int ), f"type(out.{attr})={type(value)!r}, but should be int {f_func}" + assert hasattr(out, "dtype"), f"out has no attribute 'dtype' {f_func}" # TODO: test values diff --git a/array_api_tests/test_indexing_functions.py b/array_api_tests/test_indexing_functions.py index ab837ca5..7f35e73a 100644 --- a/array_api_tests/test_indexing_functions.py +++ b/array_api_tests/test_indexing_functions.py @@ -49,7 +49,7 @@ def test_take(x, data): f_axis_idx = sh.fmt_idx("x", axis_idx) for i in _indices: f_take_idx = sh.fmt_idx(f_axis_idx, i) - indexed_x = x[axis_idx][i] + indexed_x = x[axis_idx][i, ...] for at_idx in sh.ndindex(indexed_x.shape): out_idx = next(out_indices) ph.assert_0d_equals( diff --git a/array_api_tests/test_set_functions.py b/array_api_tests/test_set_functions.py index 92d39739..e31b63cd 100644 --- a/array_api_tests/test_set_functions.py +++ b/array_api_tests/test_set_functions.py @@ -112,7 +112,7 @@ def test_unique_all(x): if dh.is_float_dtype(out.values.dtype): assume(math.prod(x.shape) <= 128) # may not be representable - expected = sum(v for k, v in counts.items() if math.isnan(k)) + expected = sum(v for k, v in counts.items() if cmath.isnan(k)) assert nans == expected, f"{nans} NaNs in out, but should be {expected}" @@ -137,7 +137,7 @@ def test_unique_counts(x): for idx in sh.ndindex(out.values.shape): val = scalar_type(out.values[idx]) count = int(out.counts[idx]) - if math.isnan(val): + if cmath.isnan(val): nans += 1 assert count == 1, ( f"out.counts[{idx}]={count} for out.values[{idx}]={val}, " @@ -159,7 +159,7 @@ def test_unique_counts(x): vals_idx[val] = idx if dh.is_float_dtype(out.values.dtype): assume(math.prod(x.shape) <= 128) # may not be representable - expected = sum(v for k, v in counts.items() if math.isnan(k)) + expected = sum(v for k, v in counts.items() if cmath.isnan(k)) assert nans == expected, f"{nans} NaNs in out, but should be {expected}" @@ -188,7 +188,7 @@ def test_unique_inverse(x): nans = 0 for idx in sh.ndindex(out.values.shape): val = scalar_type(out.values[idx]) - if math.isnan(val): + if cmath.isnan(val): nans += 1 else: assert ( diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 990ae5c7..b00284c9 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -11,7 +11,7 @@ from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import shape_helpers as sh -from . import xps +from . import xps, api_version from ._array_module import _UndefinedStub from .typing import DataType @@ -145,11 +145,19 @@ def test_prod(x, data): _dtype = x.dtype else: _dtype = default_dtype - else: + elif dh.is_float_dtype(x.dtype, include_complex=False): if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]: _dtype = x.dtype else: _dtype = dh.default_float + elif api_version > "2021.12": + # Complex dtype + if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_complex]: + _dtype = x.dtype + else: + _dtype = dh.default_complex + else: + raise RuntimeError("Unexpected dtype. This indicates a bug in the test suite.") else: _dtype = dtype if _dtype is None: @@ -253,11 +261,19 @@ def test_sum(x, data): _dtype = x.dtype else: _dtype = default_dtype - else: + elif dh.is_float_dtype(x.dtype, include_complex=False): if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]: _dtype = x.dtype else: _dtype = dh.default_float + elif api_version > "2021.12": + # Complex dtype + if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_complex]: + _dtype = x.dtype + else: + _dtype = dh.default_complex + else: + raise RuntimeError("Unexpected dtype. This indicates a bug in the test suite.") else: _dtype = dtype if _dtype is None: