From 150d8aab2e798ffa64a7843738c0e51cec7068c2 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 11 Nov 2024 10:18:10 +0100 Subject: [PATCH] (fix): test for float32 and float64 --- src/anndata/_core/index.py | 4 +++- tests/test_views.py | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/anndata/_core/index.py b/src/anndata/_core/index.py index 9186a19db..53434186a 100644 --- a/src/anndata/_core/index.py +++ b/src/anndata/_core/index.py @@ -82,7 +82,9 @@ def name_idx(i): indexer = np.array(indexer) if len(indexer) == 0: indexer = indexer.astype(int) - if isinstance(indexer, np.ndarray) and indexer.dtype == float: + if isinstance(indexer, np.ndarray) and np.issubdtype( + indexer.dtype, np.floating + ): indexer_int = indexer.astype(int) if np.all((indexer - indexer_int) != 0): raise IndexError(f"Indexer {indexer!r} has floating point values.") diff --git a/tests/test_views.py b/tests/test_views.py index fe6b66bea..fb6794dfd 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -818,8 +818,11 @@ def test_index_3d_errors(index: tuple[int | EllipsisType, ...], expected_error: "index", [ pytest.param(sparse.csr_matrix(np.random.random((1, 10))), id="sparse"), - pytest.param(np.array([1.2, 2.3]), id="ndarray"), pytest.param([1.2, 3.4], id="list"), + *( + pytest.param(np.array([1.2, 2.3], dtype=dtype), id=f"ndarray-{dtype}") + for dtype in [np.float32, np.float64] + ), ], ) def test_index_float_sequence_raises_error(index):