diff --git a/anndata/_core/index.py b/anndata/_core/index.py index fe3645b04..f815b565e 100644 --- a/anndata/_core/index.py +++ b/anndata/_core/index.py @@ -123,12 +123,14 @@ def _subset(a: np.ndarray, subset_idx: Index): subset_idx = np.ix_(*subset_idx) return a[subset_idx] + @_subset.register(ZarrArray) def _subset_zarr(a: ZarrArray, subset_idx: Index): if all(isinstance(x, cabc.Iterable) for x in subset_idx): subset_idx = np.ix_(*subset_idx) return a.oindex[subset_idx] + @_subset.register(DaskArray) def _subset_dask(a: DaskArray, subset_idx: Index): if isinstance(subset_idx, slice): diff --git a/anndata/experimental/read_backed/lazy_arrays.py b/anndata/experimental/read_backed/lazy_arrays.py index 931f99933..fe210b7ac 100644 --- a/anndata/experimental/read_backed/lazy_arrays.py +++ b/anndata/experimental/read_backed/lazy_arrays.py @@ -38,7 +38,7 @@ class LazyCategoricalArray(MaskedArrayMixIn): "_categories", "_categories_cache", "group", - "_drop_unused_cats" + "_drop_unused_cats", ) def __init__(self, codes, categories, attrs, _drop_unused_cats, *args, **kwargs): @@ -54,7 +54,7 @@ def __init__(self, codes, categories, attrs, _drop_unused_cats, *args, **kwargs) self._categories = categories self._categories_cache = None self.attrs = dict(attrs) - self._drop_unused_cats = _drop_unused_cats # obsm/varm do not drop, but obs and var do. TODO: Should fix in normal AnnData? + self._drop_unused_cats = _drop_unused_cats # obsm/varm do not drop, but obs and var do. TODO: Should fix in normal AnnData? @property def categories(self): # __slots__ and cached_property are incompatible diff --git a/anndata/experimental/read_backed/read_backed.py b/anndata/experimental/read_backed/read_backed.py index 4aeae17eb..78edab99d 100644 --- a/anndata/experimental/read_backed/read_backed.py +++ b/anndata/experimental/read_backed/read_backed.py @@ -223,7 +223,6 @@ def _normalize_indices(self, index: Optional[Index]) -> Tuple[slice, slice]: ) def to_memory(self, exclude=[]): - # nullable and categoricals need special handling because xarray will convert them to numpy arrays first with dtype object def get_nullable_and_categorical_cols(ds): cols = [] @@ -252,7 +251,7 @@ def to_df(ds, exclude_vars=[]): if len(exclude_vars) == 0: df = df[list(ds.keys())] return df - + # handling for AxisArrays def backed_dict_to_memory(d, prefix): res = {} @@ -534,8 +533,12 @@ def callback(func, elem_name: str, elem, iospec): d_with_xr[k] = v return Dataset2D(d_with_xr) elif iospec.encoding_type == "categorical": - drop_unused_cats = not (elem_name.startswith('/obsm') or elem_name.startswith('/varm')) - return LazyCategoricalArray(elem["codes"], elem["categories"], elem.attrs, drop_unused_cats) + drop_unused_cats = not ( + elem_name.startswith("/obsm") or elem_name.startswith("/varm") + ) + return LazyCategoricalArray( + elem["codes"], elem["categories"], elem.attrs, drop_unused_cats + ) elif "nullable" in iospec.encoding_type: return LazyMaskedArray( elem["values"], diff --git a/anndata/experimental/read_backed/xarray.py b/anndata/experimental/read_backed/xarray.py index 0f869759f..7fe936af3 100644 --- a/anndata/experimental/read_backed/xarray.py +++ b/anndata/experimental/read_backed/xarray.py @@ -2,23 +2,32 @@ from anndata._core.index import Index, _subset from anndata._core.views import as_view + def get_index_dim(ds): - assert len(ds.dims) == 1, f"xarray Dataset should not have more than 1 dims, found {len(ds)}" + assert ( + len(ds.dims) == 1 + ), f"xarray Dataset should not have more than 1 dims, found {len(ds)}" return list(ds.dims.keys())[0] -class Dataset2D(xr.Dataset): +class Dataset2D(xr.Dataset): @property - def shape(self): # aligned mapping classes look for this for DataFrames so this ensures usability with e.g., obsm + def shape( + self, + ): # aligned mapping classes look for this for DataFrames so this ensures usability with e.g., obsm return [self.dims[get_index_dim(self)], len(self)] - + + @_subset.register(Dataset2D) def _(a: xr.DataArray, subset_idx: Index): key = get_index_dim(a) - if isinstance(subset_idx, tuple) and len(subset_idx) == 1: # xarray seems to have some code looking for a second entry in tuples - return a.isel(**{ key:subset_idx[0] }) - return a.isel(**{ key:subset_idx }) + if ( + isinstance(subset_idx, tuple) and len(subset_idx) == 1 + ): # xarray seems to have some code looking for a second entry in tuples + return a.isel(**{key: subset_idx[0]}) + return a.isel(**{key: subset_idx}) + @as_view.register(Dataset2D) def _(a: Dataset2D, view_args): - return a \ No newline at end of file + return a diff --git a/anndata/tests/test_read_backed_experimental.py b/anndata/tests/test_read_backed_experimental.py index 65859e910..39bcdc754 100644 --- a/anndata/tests/test_read_backed_experimental.py +++ b/anndata/tests/test_read_backed_experimental.py @@ -398,8 +398,8 @@ def test_nullable_boolean_array_subset_subset(nullable_boolean_lazy_arr): def test_nullable_boolean_array_no_mask_equality(nullable_boolean_lazy_arr_no_mask): - assert nullable_boolean_lazy_arr_no_mask[0] == True - assert (nullable_boolean_lazy_arr_no_mask[3:5] == False).all() + assert nullable_boolean_lazy_arr_no_mask[0] is True + assert (nullable_boolean_lazy_arr_no_mask[3:5] is False).all() assert (nullable_boolean_lazy_arr_no_mask[5:7] == np.array([True, False])).all()