diff --git a/docs/conf.py b/docs/conf.py index fc8a414ea..f98fe5ba7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -112,6 +112,8 @@ ("py:obj", "numpy._typing._array_like._ScalarType_co"), # https://github.com/sphinx-doc/sphinx/issues/10974 ("py:class", "numpy.int64"), + # https://github.com/tox-dev/sphinx-autodoc-typehints/issues/498 + ("py:class", "types.EllipsisType"), ] diff --git a/docs/release-notes/1729.feature.md b/docs/release-notes/1729.feature.md new file mode 100644 index 000000000..a7f55361b --- /dev/null +++ b/docs/release-notes/1729.feature.md @@ -0,0 +1 @@ +Add support for ellipsis indexing of the {class}`~anndata.AnnData` object {user}`ilan-gold` diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index e972ba23c..7ef9f8ac4 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -52,9 +52,10 @@ from os import PathLike from typing import Any, Literal + from ..compat import Index1D from ..typing import ArrayDataStructureType from .aligned_mapping import AxisArraysView, LayersView, PairwiseArraysView - from .index import Index, Index1D + from .index import Index # for backwards compat diff --git a/src/anndata/_core/index.py b/src/anndata/_core/index.py index 78f446d18..f1d72ce0d 100644 --- a/src/anndata/_core/index.py +++ b/src/anndata/_core/index.py @@ -26,14 +26,8 @@ def _normalize_indices( if isinstance(index, pd.Series): index: Index = index.values if isinstance(index, tuple): - if len(index) > 2: - raise ValueError("AnnData can only be sliced in rows and columns.") - # deal with pd.Series # TODO: The series should probably be aligned first - if isinstance(index[1], pd.Series): - index = index[0], index[1].values - if isinstance(index[0], pd.Series): - index = index[0].values, index[1] + index = tuple(i.values if isinstance(i, pd.Series) else i for i in index) ax0, ax1 = unpack_index(index) ax0 = _normalize_index(ax0, names0) ax1 = _normalize_index(ax1, names1) @@ -107,8 +101,7 @@ def name_idx(i): "are not valid obs/ var names or indices." ) return positions # np.ndarray[int] - else: - raise IndexError(f"Unknown indexer {indexer!r} of type {type(indexer)}") + raise IndexError(f"Unknown indexer {indexer!r} of type {type(indexer)}") def _fix_slice_bounds(s: slice, length: int) -> slice: @@ -132,13 +125,28 @@ def _fix_slice_bounds(s: slice, length: int) -> slice: def unpack_index(index: Index) -> tuple[Index1D, Index1D]: if not isinstance(index, tuple): + if index is Ellipsis: + index = slice(None) return index, slice(None) - elif len(index) == 2: + num_ellipsis = sum(i is Ellipsis for i in index) + if num_ellipsis > 1: + raise IndexError("an index can only have a single ellipsis ('...')") + # If index has Ellipsis, filter it out (and if not, error) + if len(index) > 2: + if not num_ellipsis: + raise IndexError("Received a length 3 index without an ellipsis") + index = tuple(i for i in index if i is not Ellipsis) return index - elif len(index) == 1: - return index[0], slice(None) - else: - raise IndexError("invalid number of indices") + # If index has Ellipsis, replace it with slice + if len(index) == 2: + index = tuple(slice(None) if i is Ellipsis else i for i in index) + return index + if len(index) == 1: + index = index[0] + if index is Ellipsis: + index = slice(None) + return index, slice(None) + raise IndexError("invalid number of indices") @singledispatch diff --git a/src/anndata/compat/__init__.py b/src/anndata/compat/__init__.py index 93c86141a..b5ec3d415 100644 --- a/src/anndata/compat/__init__.py +++ b/src/anndata/compat/__init__.py @@ -10,6 +10,7 @@ from importlib.util import find_spec from inspect import Parameter, signature from pathlib import Path +from types import EllipsisType from typing import TYPE_CHECKING, TypeVar from warnings import warn @@ -47,7 +48,17 @@ class Empty: Index1D = slice | int | str | np.int64 | np.ndarray -Index = Index1D | tuple[Index1D, Index1D] | scipy.sparse.spmatrix | SpArray +IndexRest = Index1D | EllipsisType +Index = ( + IndexRest + | tuple[Index1D, IndexRest] + | tuple[IndexRest, Index1D] + | tuple[Index1D, Index1D, EllipsisType] + | tuple[EllipsisType, Index1D, Index1D] + | tuple[Index1D, EllipsisType, Index1D] + | scipy.sparse.spmatrix + | SpArray +) H5Group = h5py.Group H5Array = h5py.Dataset H5File = h5py.File diff --git a/tests/conftest.py b/tests/conftest.py index 65eff92b1..13fabdb93 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ from __future__ import annotations from functools import partial +from typing import TYPE_CHECKING import joblib import pytest @@ -10,12 +11,63 @@ import anndata as ad from anndata.tests.helpers import subset_func # noqa: F401 +if TYPE_CHECKING: + from types import EllipsisType + @pytest.fixture def backing_h5ad(tmp_path): return tmp_path / "test.h5ad" +@pytest.fixture( + params=[ + pytest.param((..., (slice(None), slice(None))), id="ellipsis"), + pytest.param(((...,), (slice(None), slice(None))), id="ellipsis_tuple"), + pytest.param( + ((..., slice(0, 10)), (slice(None), slice(0, 10))), id="obs-ellipsis" + ), + pytest.param( + ((slice(0, 10), ...), (slice(0, 10), slice(None))), id="var-ellipsis" + ), + pytest.param( + ((slice(0, 10), slice(0, 10), ...), (slice(0, 10), slice(0, 10))), + id="obs-var-ellipsis", + ), + pytest.param( + ((..., slice(0, 10), slice(0, 10)), (slice(0, 10), slice(0, 10))), + id="ellipsis-obs-var", + ), + pytest.param( + ((slice(0, 10), ..., slice(0, 10)), (slice(0, 10), slice(0, 10))), + id="obs-ellipsis-var", + ), + ] +) +def ellipsis_index_with_equivalent( + request, +) -> tuple[tuple[EllipsisType | slice, ...] | EllipsisType, tuple[slice, slice]]: + return request.param + + +@pytest.fixture +def ellipsis_index( + ellipsis_index_with_equivalent: tuple[ + tuple[EllipsisType | slice, ...] | EllipsisType, tuple[slice, slice] + ], +) -> tuple[EllipsisType | slice, ...] | EllipsisType: + return ellipsis_index_with_equivalent[0] + + +@pytest.fixture +def equivalent_ellipsis_index( + ellipsis_index_with_equivalent: tuple[ + tuple[EllipsisType | slice, ...] | EllipsisType, tuple[slice, slice] + ], +) -> tuple[slice, slice]: + return ellipsis_index_with_equivalent[1] + + ##################### # Dask tokenization # ##################### diff --git a/tests/test_backed_sparse.py b/tests/test_backed_sparse.py index 5f78ddb52..2778c76bb 100644 --- a/tests/test_backed_sparse.py +++ b/tests/test_backed_sparse.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Generator, Sequence from pathlib import Path + from types import EllipsisType from _pytest.mark import ParameterSet from numpy.typing import ArrayLike, NDArray @@ -127,6 +128,17 @@ def test_backed_indexing( assert_equal(csr_mem[:, var_idx].X, dense_disk[:, var_idx].X) +def test_backed_ellipsis_indexing( + ondisk_equivalent_adata: tuple[AnnData, AnnData, AnnData, AnnData], + ellipsis_index: tuple[EllipsisType | slice, ...] | EllipsisType, + equivalent_ellipsis_index: tuple[slice, slice], +): + csr_mem, csr_disk, csc_disk, _ = ondisk_equivalent_adata + + assert_equal(csr_mem.X[equivalent_ellipsis_index], csr_disk.X[ellipsis_index]) + assert_equal(csr_mem.X[equivalent_ellipsis_index], csc_disk.X[ellipsis_index]) + + def make_randomized_mask(size: int) -> np.ndarray: randomized_mask = np.zeros(size, dtype=bool) inds = np.random.choice(size, 20, replace=False) diff --git a/tests/test_views.py b/tests/test_views.py index 4e4f4ab75..6e57e08c7 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -3,6 +3,7 @@ from contextlib import ExitStack from copy import deepcopy from operator import mul +from typing import TYPE_CHECKING import joblib import numpy as np @@ -35,6 +36,9 @@ ) from anndata.utils import asarray +if TYPE_CHECKING: + from types import EllipsisType + IGNORE_SPARSE_EFFICIENCY_WARNING = pytest.mark.filterwarnings( "ignore:Changing the sparsity structure:scipy.sparse.SparseEfficiencyWarning" ) @@ -786,6 +790,30 @@ def test_dataframe_view_index_setting(): assert a2.obs.index.values.tolist() == ["a", "b"] +def test_ellipsis_index( + ellipsis_index: tuple[EllipsisType | slice, ...] | EllipsisType, + equivalent_ellipsis_index: tuple[slice, slice], + matrix_type, +): + adata = gen_adata((10, 10), X_type=matrix_type, **GEN_ADATA_DASK_ARGS) + subset_ellipsis = adata[ellipsis_index] + subset = adata[equivalent_ellipsis_index] + assert_equal(subset_ellipsis, subset) + + +@pytest.mark.parametrize( + ("index", "expected_error"), + [ + ((..., 0, ...), r"only have a single ellipsis"), + ((0, 0, 0), r"Received a length 3 index"), + ], + ids=["ellipsis-int-ellipsis", "int-int-int"], +) +def test_index_3d_errors(index: tuple[int | EllipsisType, ...], expected_error: str): + with pytest.raises(IndexError, match=expected_error): + gen_adata((10, 10))[index] + + # @pytest.mark.parametrize("dim", ["obs", "var"]) # @pytest.mark.parametrize( # ("idx", "pat"),