Skip to content

Commit

Permalink
(feat): support ellipsis indexing (#1729)
Browse files Browse the repository at this point in the history
* (feat): support ellipsis indexing

* (chore): release note

* (chore): add tests for all axis combinations

* (fix): index typing

* (feat): add support for 3d ellipsis/error on >1 ellipsis

* (fix): `unpack_index` type

* (fix): ignore ellipsis in docs

* (fix): import

* Update docs/conf.py

Co-authored-by: Philipp A. <[email protected]>

* (chore): `is` instead of `isinstance` checks

* (fix): try making `unpack_index` more general-purpose

* Apply suggestions from code review

Co-authored-by: Philipp A. <[email protected]>

* fix test_backed_ellipsis_indexing

* (refactor): simplify 3-elem index handling

* (refactor): simplify tests

* (fix): use correct base-comparison

* Update src/anndata/_core/index.py

Co-authored-by: Philipp A. <[email protected]>

* (fix): use `id` per param

* (chore): split ellipsis index

* fix types

---------

Co-authored-by: Philipp A. <[email protected]>
  • Loading branch information
ilan-gold and flying-sheep authored Oct 30, 2024
1 parent 437dbc8 commit 0024b82
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 16 deletions.
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@
("py:obj", "numpy._typing._array_like._ScalarType_co"),
# https://github.com/sphinx-doc/sphinx/issues/10974
("py:class", "numpy.int64"),
# https://github.com/tox-dev/sphinx-autodoc-typehints/issues/498
("py:class", "types.EllipsisType"),
]


Expand Down
1 change: 1 addition & 0 deletions docs/release-notes/1729.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for ellipsis indexing of the {class}`~anndata.AnnData` object {user}`ilan-gold`
3 changes: 2 additions & 1 deletion src/anndata/_core/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@
from os import PathLike
from typing import Any, Literal

from ..compat import Index1D
from ..typing import ArrayDataStructureType
from .aligned_mapping import AxisArraysView, LayersView, PairwiseArraysView
from .index import Index, Index1D
from .index import Index


# for backwards compat
Expand Down
36 changes: 22 additions & 14 deletions src/anndata/_core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,8 @@ def _normalize_indices(
if isinstance(index, pd.Series):
index: Index = index.values
if isinstance(index, tuple):
if len(index) > 2:
raise ValueError("AnnData can only be sliced in rows and columns.")
# deal with pd.Series
# TODO: The series should probably be aligned first
if isinstance(index[1], pd.Series):
index = index[0], index[1].values
if isinstance(index[0], pd.Series):
index = index[0].values, index[1]
index = tuple(i.values if isinstance(i, pd.Series) else i for i in index)
ax0, ax1 = unpack_index(index)
ax0 = _normalize_index(ax0, names0)
ax1 = _normalize_index(ax1, names1)
Expand Down Expand Up @@ -107,8 +101,7 @@ def name_idx(i):
"are not valid obs/ var names or indices."
)
return positions # np.ndarray[int]
else:
raise IndexError(f"Unknown indexer {indexer!r} of type {type(indexer)}")
raise IndexError(f"Unknown indexer {indexer!r} of type {type(indexer)}")


def _fix_slice_bounds(s: slice, length: int) -> slice:
Expand All @@ -132,13 +125,28 @@ def _fix_slice_bounds(s: slice, length: int) -> slice:

def unpack_index(index: Index) -> tuple[Index1D, Index1D]:
if not isinstance(index, tuple):
if index is Ellipsis:
index = slice(None)
return index, slice(None)
elif len(index) == 2:
num_ellipsis = sum(i is Ellipsis for i in index)
if num_ellipsis > 1:
raise IndexError("an index can only have a single ellipsis ('...')")
# If index has Ellipsis, filter it out (and if not, error)
if len(index) > 2:
if not num_ellipsis:
raise IndexError("Received a length 3 index without an ellipsis")
index = tuple(i for i in index if i is not Ellipsis)
return index
elif len(index) == 1:
return index[0], slice(None)
else:
raise IndexError("invalid number of indices")
# If index has Ellipsis, replace it with slice
if len(index) == 2:
index = tuple(slice(None) if i is Ellipsis else i for i in index)
return index
if len(index) == 1:
index = index[0]
if index is Ellipsis:
index = slice(None)
return index, slice(None)
raise IndexError("invalid number of indices")


@singledispatch
Expand Down
13 changes: 12 additions & 1 deletion src/anndata/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from importlib.util import find_spec
from inspect import Parameter, signature
from pathlib import Path
from types import EllipsisType
from typing import TYPE_CHECKING, TypeVar
from warnings import warn

Expand Down Expand Up @@ -47,7 +48,17 @@ class Empty:


Index1D = slice | int | str | np.int64 | np.ndarray
Index = Index1D | tuple[Index1D, Index1D] | scipy.sparse.spmatrix | SpArray
IndexRest = Index1D | EllipsisType
Index = (
IndexRest
| tuple[Index1D, IndexRest]
| tuple[IndexRest, Index1D]
| tuple[Index1D, Index1D, EllipsisType]
| tuple[EllipsisType, Index1D, Index1D]
| tuple[Index1D, EllipsisType, Index1D]
| scipy.sparse.spmatrix
| SpArray
)
H5Group = h5py.Group
H5Array = h5py.Dataset
H5File = h5py.File
Expand Down
52 changes: 52 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING

import joblib
import pytest
Expand All @@ -10,12 +11,63 @@
import anndata as ad
from anndata.tests.helpers import subset_func # noqa: F401

if TYPE_CHECKING:
from types import EllipsisType


@pytest.fixture
def backing_h5ad(tmp_path):
return tmp_path / "test.h5ad"


@pytest.fixture(
params=[
pytest.param((..., (slice(None), slice(None))), id="ellipsis"),
pytest.param(((...,), (slice(None), slice(None))), id="ellipsis_tuple"),
pytest.param(
((..., slice(0, 10)), (slice(None), slice(0, 10))), id="obs-ellipsis"
),
pytest.param(
((slice(0, 10), ...), (slice(0, 10), slice(None))), id="var-ellipsis"
),
pytest.param(
((slice(0, 10), slice(0, 10), ...), (slice(0, 10), slice(0, 10))),
id="obs-var-ellipsis",
),
pytest.param(
((..., slice(0, 10), slice(0, 10)), (slice(0, 10), slice(0, 10))),
id="ellipsis-obs-var",
),
pytest.param(
((slice(0, 10), ..., slice(0, 10)), (slice(0, 10), slice(0, 10))),
id="obs-ellipsis-var",
),
]
)
def ellipsis_index_with_equivalent(
request,
) -> tuple[tuple[EllipsisType | slice, ...] | EllipsisType, tuple[slice, slice]]:
return request.param


@pytest.fixture
def ellipsis_index(
ellipsis_index_with_equivalent: tuple[
tuple[EllipsisType | slice, ...] | EllipsisType, tuple[slice, slice]
],
) -> tuple[EllipsisType | slice, ...] | EllipsisType:
return ellipsis_index_with_equivalent[0]


@pytest.fixture
def equivalent_ellipsis_index(
ellipsis_index_with_equivalent: tuple[
tuple[EllipsisType | slice, ...] | EllipsisType, tuple[slice, slice]
],
) -> tuple[slice, slice]:
return ellipsis_index_with_equivalent[1]


#####################
# Dask tokenization #
#####################
Expand Down
12 changes: 12 additions & 0 deletions tests/test_backed_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
if TYPE_CHECKING:
from collections.abc import Callable, Generator, Sequence
from pathlib import Path
from types import EllipsisType

from _pytest.mark import ParameterSet
from numpy.typing import ArrayLike, NDArray
Expand Down Expand Up @@ -127,6 +128,17 @@ def test_backed_indexing(
assert_equal(csr_mem[:, var_idx].X, dense_disk[:, var_idx].X)


def test_backed_ellipsis_indexing(
ondisk_equivalent_adata: tuple[AnnData, AnnData, AnnData, AnnData],
ellipsis_index: tuple[EllipsisType | slice, ...] | EllipsisType,
equivalent_ellipsis_index: tuple[slice, slice],
):
csr_mem, csr_disk, csc_disk, _ = ondisk_equivalent_adata

assert_equal(csr_mem.X[equivalent_ellipsis_index], csr_disk.X[ellipsis_index])
assert_equal(csr_mem.X[equivalent_ellipsis_index], csc_disk.X[ellipsis_index])


def make_randomized_mask(size: int) -> np.ndarray:
randomized_mask = np.zeros(size, dtype=bool)
inds = np.random.choice(size, 20, replace=False)
Expand Down
28 changes: 28 additions & 0 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from contextlib import ExitStack
from copy import deepcopy
from operator import mul
from typing import TYPE_CHECKING

import joblib
import numpy as np
Expand Down Expand Up @@ -35,6 +36,9 @@
)
from anndata.utils import asarray

if TYPE_CHECKING:
from types import EllipsisType

IGNORE_SPARSE_EFFICIENCY_WARNING = pytest.mark.filterwarnings(
"ignore:Changing the sparsity structure:scipy.sparse.SparseEfficiencyWarning"
)
Expand Down Expand Up @@ -786,6 +790,30 @@ def test_dataframe_view_index_setting():
assert a2.obs.index.values.tolist() == ["a", "b"]


def test_ellipsis_index(
ellipsis_index: tuple[EllipsisType | slice, ...] | EllipsisType,
equivalent_ellipsis_index: tuple[slice, slice],
matrix_type,
):
adata = gen_adata((10, 10), X_type=matrix_type, **GEN_ADATA_DASK_ARGS)
subset_ellipsis = adata[ellipsis_index]
subset = adata[equivalent_ellipsis_index]
assert_equal(subset_ellipsis, subset)


@pytest.mark.parametrize(
("index", "expected_error"),
[
((..., 0, ...), r"only have a single ellipsis"),
((0, 0, 0), r"Received a length 3 index"),
],
ids=["ellipsis-int-ellipsis", "int-int-int"],
)
def test_index_3d_errors(index: tuple[int | EllipsisType, ...], expected_error: str):
with pytest.raises(IndexError, match=expected_error):
gen_adata((10, 10))[index]


# @pytest.mark.parametrize("dim", ["obs", "var"])
# @pytest.mark.parametrize(
# ("idx", "pat"),
Expand Down

0 comments on commit 0024b82

Please sign in to comment.