From 6dc6ecd542a2bb817829f75415c50b2446be6d97 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 5 Jun 2023 16:36:29 -0600 Subject: [PATCH 1/6] Fix a test helper It should have been using xp.real(x) instead of x.real. I'm not sure if float(x) or complex(x) would be better instead. --- array_api_tests/pytest_helpers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 461532e7..c51b14a6 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -464,8 +464,8 @@ def assert_array_elements( f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} " f"{f_func}" ) - _assert_float_element(at_out.real, at_expected.real, msg) - _assert_float_element(at_out.imag, at_expected.imag, msg) + _assert_float_element(xp.real(at_out), xp.real(at_expected), msg) + _assert_float_element(xp.imag(at_out), xp.imag(at_expected), msg) else: assert xp.all( out == expected From 0bceae1ae0e80443d66859b4e65df63c2a8526f9 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 5 Jun 2023 17:32:16 -0600 Subject: [PATCH 2/6] Update test_sum and test_prod to support complex dtypes --- array_api_tests/dtype_helpers.py | 5 ++--- array_api_tests/test_statistical_functions.py | 22 ++++++++++++++++--- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 56a89f63..2847a8dc 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -157,7 +157,7 @@ def is_int_dtype(dtype): return dtype in all_int_dtypes -def is_float_dtype(dtype): +def is_float_dtype(dtype, real=False): # None equals NumPy's xp.float64 object, so we specifically check it here. # xp.float64 is in fact an alias of np.dtype('float64'), and its equality # with None is meant to be deprecated at some point. @@ -165,11 +165,10 @@ def is_float_dtype(dtype): if dtype is None: return False valid_dtypes = real_float_dtypes - if api_version > "2021.12": + if api_version > "2021.12" and not real: valid_dtypes += complex_dtypes return dtype in valid_dtypes - def get_scalar_type(dtype: DataType) -> ScalarType: if dtype in all_int_dtypes: return int diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 990ae5c7..a7306119 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -11,7 +11,7 @@ from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import shape_helpers as sh -from . import xps +from . import xps, api_version from ._array_module import _UndefinedStub from .typing import DataType @@ -145,11 +145,19 @@ def test_prod(x, data): _dtype = x.dtype else: _dtype = default_dtype - else: + elif dh.is_float_dtype(x.dtype, real=True): if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]: _dtype = x.dtype else: _dtype = dh.default_float + elif api_version > "2021.12": + # Complex dtype + if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_complex]: + _dtype = x.dtype + else: + _dtype = dh.default_complex + else: + raise RuntimeError("Unexpected dtype. This indicates a bug in the test suite.") else: _dtype = dtype if _dtype is None: @@ -253,11 +261,19 @@ def test_sum(x, data): _dtype = x.dtype else: _dtype = default_dtype - else: + elif dh.is_float_dtype(x.dtype, real=True): if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]: _dtype = x.dtype else: _dtype = dh.default_float + elif api_version > "2021.12": + # Complex dtype + if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_complex]: + _dtype = x.dtype + else: + _dtype = dh.default_complex + else: + raise RuntimeError("Unexpected dtype. This indicates a bug in the test suite.") else: _dtype = dtype if _dtype is None: From 6fb104a9808eff73f15385056cd4dac7b72d5ff3 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 5 Jun 2023 18:13:47 -0600 Subject: [PATCH 3/6] Use portable indexing in test_take --- array_api_tests/test_indexing_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_tests/test_indexing_functions.py b/array_api_tests/test_indexing_functions.py index ab837ca5..7f35e73a 100644 --- a/array_api_tests/test_indexing_functions.py +++ b/array_api_tests/test_indexing_functions.py @@ -49,7 +49,7 @@ def test_take(x, data): f_axis_idx = sh.fmt_idx("x", axis_idx) for i in _indices: f_take_idx = sh.fmt_idx(f_axis_idx, i) - indexed_x = x[axis_idx][i] + indexed_x = x[axis_idx][i, ...] for at_idx in sh.ndindex(indexed_x.shape): out_idx = next(out_indices) ph.assert_0d_equals( From 5738b59845019935a2cec27dae20e28909db346d Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 5 Jun 2023 18:16:00 -0600 Subject: [PATCH 4/6] Use cmath.isnan instead of math.isnan in test_set_functions.py --- array_api_tests/test_set_functions.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/array_api_tests/test_set_functions.py b/array_api_tests/test_set_functions.py index 92d39739..e31b63cd 100644 --- a/array_api_tests/test_set_functions.py +++ b/array_api_tests/test_set_functions.py @@ -112,7 +112,7 @@ def test_unique_all(x): if dh.is_float_dtype(out.values.dtype): assume(math.prod(x.shape) <= 128) # may not be representable - expected = sum(v for k, v in counts.items() if math.isnan(k)) + expected = sum(v for k, v in counts.items() if cmath.isnan(k)) assert nans == expected, f"{nans} NaNs in out, but should be {expected}" @@ -137,7 +137,7 @@ def test_unique_counts(x): for idx in sh.ndindex(out.values.shape): val = scalar_type(out.values[idx]) count = int(out.counts[idx]) - if math.isnan(val): + if cmath.isnan(val): nans += 1 assert count == 1, ( f"out.counts[{idx}]={count} for out.values[{idx}]={val}, " @@ -159,7 +159,7 @@ def test_unique_counts(x): vals_idx[val] = idx if dh.is_float_dtype(out.values.dtype): assume(math.prod(x.shape) <= 128) # may not be representable - expected = sum(v for k, v in counts.items() if math.isnan(k)) + expected = sum(v for k, v in counts.items() if cmath.isnan(k)) assert nans == expected, f"{nans} NaNs in out, but should be {expected}" @@ -188,7 +188,7 @@ def test_unique_inverse(x): nans = 0 for idx in sh.ndindex(out.values.shape): val = scalar_type(out.values[idx]) - if math.isnan(val): + if cmath.isnan(val): nans += 1 else: assert ( From 00c56c961b95ec690beab7eae9697937ca2b6993 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 5 Jun 2023 19:01:01 -0600 Subject: [PATCH 5/6] Add basic check that the dtype attribute exists in test_iinfo and test_finfo --- array_api_tests/test_data_type_functions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index 917b1f26..fa69bbcd 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -162,6 +162,7 @@ def test_finfo(dtype_name): assert isinstance( value, stype ), f"type(out.{attr})={type(value)!r}, but should be {stype.__name__} {f_func}" + assert hasattr(out, "dtype"), f"out has no attribute 'dtype' {f_func}" # TODO: test values @@ -179,6 +180,7 @@ def test_iinfo(dtype_name): assert isinstance( value, int ), f"type(out.{attr})={type(value)!r}, but should be int {f_func}" + assert hasattr(out, "dtype"), f"out has no attribute 'dtype' {f_func}" # TODO: test values From 9064d5d34fcd1c47f8655493c1965c4b1625279a Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 16 Jun 2023 16:21:29 -0600 Subject: [PATCH 6/6] Use better keyword argument name --- array_api_tests/dtype_helpers.py | 4 ++-- array_api_tests/test_statistical_functions.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 2847a8dc..6e37724c 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -157,7 +157,7 @@ def is_int_dtype(dtype): return dtype in all_int_dtypes -def is_float_dtype(dtype, real=False): +def is_float_dtype(dtype, *, include_complex=True): # None equals NumPy's xp.float64 object, so we specifically check it here. # xp.float64 is in fact an alias of np.dtype('float64'), and its equality # with None is meant to be deprecated at some point. @@ -165,7 +165,7 @@ def is_float_dtype(dtype, real=False): if dtype is None: return False valid_dtypes = real_float_dtypes - if api_version > "2021.12" and not real: + if api_version > "2021.12" and include_complex: valid_dtypes += complex_dtypes return dtype in valid_dtypes diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index a7306119..b00284c9 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -145,7 +145,7 @@ def test_prod(x, data): _dtype = x.dtype else: _dtype = default_dtype - elif dh.is_float_dtype(x.dtype, real=True): + elif dh.is_float_dtype(x.dtype, include_complex=False): if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]: _dtype = x.dtype else: @@ -261,7 +261,7 @@ def test_sum(x, data): _dtype = x.dtype else: _dtype = default_dtype - elif dh.is_float_dtype(x.dtype, real=True): + elif dh.is_float_dtype(x.dtype, include_complex=False): if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]: _dtype = x.dtype else: