From d2b859d72fb224d890f214d691735fa326ef0b5f Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 22 Oct 2024 15:27:06 +0200 Subject: [PATCH] (feat): support ellipsis indexing --- src/anndata/_core/index.py | 9 ++++++--- src/anndata/compat/__init__.py | 3 ++- tests/test_views.py | 8 ++++++++ 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/anndata/_core/index.py b/src/anndata/_core/index.py index 78f446d18..37d6987f1 100644 --- a/src/anndata/_core/index.py +++ b/src/anndata/_core/index.py @@ -3,6 +3,7 @@ from collections.abc import Iterable, Sequence from functools import singledispatch from itertools import repeat +from types import EllipsisType from typing import TYPE_CHECKING import h5py @@ -47,7 +48,8 @@ def _normalize_index( | str | Sequence[bool | int | np.integer] | np.ndarray - | pd.Index, + | pd.Index + | EllipsisType, index: pd.Index, ) -> slice | int | np.ndarray: # ndarray of int or bool if not isinstance(index, pd.RangeIndex): @@ -72,6 +74,8 @@ def name_idx(i): return slice(start, stop, step) elif isinstance(indexer, np.integer | int): return indexer + elif isinstance(indexer, EllipsisType): + return slice(None) elif isinstance(indexer, str): return index.get_loc(indexer) # int elif isinstance( @@ -107,8 +111,7 @@ def name_idx(i): "are not valid obs/ var names or indices." ) return positions # np.ndarray[int] - else: - raise IndexError(f"Unknown indexer {indexer!r} of type {type(indexer)}") + raise IndexError(f"Unknown indexer {indexer!r} of type {type(indexer)}") def _fix_slice_bounds(s: slice, length: int) -> slice: diff --git a/src/anndata/compat/__init__.py b/src/anndata/compat/__init__.py index 93c86141a..1f6d06b5a 100644 --- a/src/anndata/compat/__init__.py +++ b/src/anndata/compat/__init__.py @@ -10,6 +10,7 @@ from importlib.util import find_spec from inspect import Parameter, signature from pathlib import Path +from types import EllipsisType from typing import TYPE_CHECKING, TypeVar from warnings import warn @@ -46,7 +47,7 @@ class Empty: pass -Index1D = slice | int | str | np.int64 | np.ndarray +Index1D = slice | int | str | np.int64 | np.ndarray | EllipsisType Index = Index1D | tuple[Index1D, Index1D] | scipy.sparse.spmatrix | SpArray H5Group = h5py.Group H5Array = h5py.Dataset diff --git a/tests/test_views.py b/tests/test_views.py index 4e4f4ab75..73a9cf585 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -786,6 +786,14 @@ def test_dataframe_view_index_setting(): assert a2.obs.index.values.tolist() == ["a", "b"] +def test_ellipsis_index(adata, subset_func, matrix_type): + adata = gen_adata((10, 10), X_type=matrix_type, **GEN_ADATA_DASK_ARGS) + subset_obs_names = subset_func(adata.obs_names) + subset_ellipsis = adata[subset_obs_names, ...] + subset = adata[subset_obs_names, :] + assert_equal(subset_ellipsis, subset) + + # @pytest.mark.parametrize("dim", ["obs", "var"]) # @pytest.mark.parametrize( # ("idx", "pat"),