Skip to content

Commit

Permalink
(chore): migrate to only checking cs{r,c}_matrix instead of spmatrix
Browse files Browse the repository at this point in the history
  • Loading branch information
ilan-gold committed Nov 14, 2024
1 parent af6480e commit f2af154
Show file tree
Hide file tree
Showing 15 changed files with 64 additions and 66 deletions.
5 changes: 2 additions & 3 deletions src/anndata/_core/aligned_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@

import numpy as np
import pandas as pd
from scipy.sparse import spmatrix

from .._warnings import ExperimentalFeatureWarning, ImplicitModificationWarning
from ..compat import AwkArray
from ..compat import AwkArray, SpMatrix
from ..utils import (
axis_len,
convert_to_dict,
Expand All @@ -36,7 +35,7 @@
OneDIdx = Sequence[int] | Sequence[bool] | slice
TwoDIdx = tuple[OneDIdx, OneDIdx]
# TODO: pd.DataFrame only allowed in AxisArrays?
Value = pd.DataFrame | spmatrix | np.ndarray
Value = pd.DataFrame | SpMatrix | np.ndarray

P = TypeVar("P", bound="AlignedMappingBase")
"""Parent mapping an AlignedView is based on."""
Expand Down
8 changes: 4 additions & 4 deletions src/anndata/_core/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,13 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):

def __init__(
self,
X: np.ndarray | sparse.spmatrix | pd.DataFrame | None = None,
X: ArrayDataStructureType | pd.DataFrame | None = None,
obs: pd.DataFrame | Mapping[str, Iterable[Any]] | None = None,
var: pd.DataFrame | Mapping[str, Iterable[Any]] | None = None,
uns: Mapping[str, Any] | None = None,
obsm: np.ndarray | Mapping[str, Sequence[Any]] | None = None,
varm: np.ndarray | Mapping[str, Sequence[Any]] | None = None,
layers: Mapping[str, np.ndarray | sparse.spmatrix] | None = None,
layers: Mapping[str, ArrayDataStructureType] | None = None,
raw: Mapping[str, Any] | None = None,
dtype: np.dtype | type | str | None = None,
shape: tuple[int, int] | None = None,
Expand Down Expand Up @@ -573,7 +573,7 @@ def X(self) -> ArrayDataStructureType | None:
# return X

@X.setter
def X(self, value: np.ndarray | sparse.spmatrix | SpArray | None):
def X(self, value: ArrayDataStructureType | None):
if value is None:
if self.isbacked:
raise NotImplementedError(
Expand Down Expand Up @@ -1169,7 +1169,7 @@ def _inplace_subset_obs(self, index: Index1D):
self._init_as_actual(adata_subset)

# TODO: Update, possibly remove
def __setitem__(self, index: Index, val: float | np.ndarray | sparse.spmatrix):
def __setitem__(self, index: Index, val: ArrayDataStructureType):
if self.is_view:
raise ValueError("Object is view and cannot be accessed with `[]`.")
obs, var = self._normalize_indices(index)
Expand Down
12 changes: 6 additions & 6 deletions src/anndata/_core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import h5py
import numpy as np
import pandas as pd
from scipy.sparse import issparse, spmatrix
from scipy.sparse import issparse

from ..compat import AwkArray, DaskArray, SpArray
from ..compat import AwkArray, DaskArray, SpArray, SpMatrix

if TYPE_CHECKING:
from ..compat import Index, Index1D
Expand Down Expand Up @@ -69,13 +69,13 @@ def name_idx(i):
elif isinstance(indexer, str):
return index.get_loc(indexer) # int
elif isinstance(
indexer, Sequence | np.ndarray | pd.Index | spmatrix | np.matrix | SpArray
indexer, Sequence | np.ndarray | pd.Index | SpMatrix | np.matrix | SpArray
):
if hasattr(indexer, "shape") and (
(indexer.shape == (index.shape[0], 1))
or (indexer.shape == (1, index.shape[0]))
):
if isinstance(indexer, spmatrix | SpArray):
if isinstance(indexer, SpMatrix | SpArray):
indexer = indexer.toarray()
indexer = np.ravel(indexer)
if not isinstance(indexer, np.ndarray | pd.Index):
Expand Down Expand Up @@ -167,9 +167,9 @@ def _subset_dask(a: DaskArray, subset_idx: Index):
return a[subset_idx]


@_subset.register(spmatrix)
@_subset.register(SpMatrix)
@_subset.register(SpArray)
def _subset_sparse(a: spmatrix | SpArray, subset_idx: Index):
def _subset_sparse(a: SpMatrix | SpArray, subset_idx: Index):
# Correcting for indexing behaviour of sparse.spmatrix
if len(subset_idx) > 1 and all(isinstance(x, Iterable) for x in subset_idx):
first_idx = subset_idx[0]
Expand Down
25 changes: 11 additions & 14 deletions src/anndata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import pandas as pd
from natsort import natsorted
from scipy import sparse
from scipy.sparse import spmatrix

from anndata._warnings import ExperimentalFeatureWarning

Expand All @@ -29,6 +28,7 @@
CupySparseMatrix,
DaskArray,
SpArray,
SpMatrix,
_map_cat_to_str,
)
from ..utils import asarray, axis_len, warn_once
Expand Down Expand Up @@ -135,7 +135,7 @@ def equal_dask_array(a, b) -> bool:
if isinstance(b, DaskArray):
if tokenize(a) == tokenize(b):
return True
if isinstance(a._meta, spmatrix):
if isinstance(a._meta, SpMatrix):
# TODO: Maybe also do this in the other case?
return da.map_blocks(equal, a, b, drop_axis=(0, 1)).all()
else:
Expand Down Expand Up @@ -165,7 +165,7 @@ def equal_series(a, b) -> bool:
return a.equals(b)


@equal.register(sparse.spmatrix)
@equal.register(SpMatrix)
@equal.register(SpArray)
@equal.register(CupySparseMatrix)
def equal_sparse(a, b) -> bool:
Expand All @@ -174,7 +174,7 @@ def equal_sparse(a, b) -> bool:

xp = array_api_compat.array_namespace(a.data)

if isinstance(b, CupySparseMatrix | sparse.spmatrix | SpArray):
if isinstance(b, CupySparseMatrix | SpMatrix | SpArray):
if isinstance(a, CupySparseMatrix):
# Comparison broken for CSC matrices
# https://github.com/cupy/cupy/issues/7757
Expand Down Expand Up @@ -206,7 +206,7 @@ def equal_awkward(a, b) -> bool:


def as_sparse(x, use_sparse_array=False):
if not isinstance(x, sparse.spmatrix | SpArray):
if not isinstance(x, SpMatrix | SpArray):
if CAN_USE_SPARSE_ARRAY and use_sparse_array:
return sparse.csr_array(x)
return sparse.csr_matrix(x)
Expand Down Expand Up @@ -536,7 +536,7 @@ def apply(self, el, *, axis, fill_value=None):
return el
if isinstance(el, pd.DataFrame):
return self._apply_to_df(el, axis=axis, fill_value=fill_value)
elif isinstance(el, sparse.spmatrix | SpArray | CupySparseMatrix):
elif isinstance(el, SpMatrix | SpArray | CupySparseMatrix):
return self._apply_to_sparse(el, axis=axis, fill_value=fill_value)
elif isinstance(el, AwkArray):
return self._apply_to_awkward(el, axis=axis, fill_value=fill_value)
Expand Down Expand Up @@ -614,8 +614,8 @@ def _apply_to_array(self, el, *, axis, fill_value=None):
)

def _apply_to_sparse(
self, el: sparse.spmatrix | SpArray, *, axis, fill_value=None
) -> spmatrix:
self, el: SpMatrix | SpArray, *, axis, fill_value=None
) -> SpMatrix:
if isinstance(el, CupySparseMatrix):
from cupyx.scipy import sparse
else:
Expand Down Expand Up @@ -724,11 +724,8 @@ def default_fill_value(els):
This is largely due to backwards compat, and might not be the ideal solution.
"""
if any(
isinstance(el, sparse.spmatrix | SpArray)
or (
isinstance(el, DaskArray)
and isinstance(el._meta, sparse.spmatrix | SpArray)
)
isinstance(el, SpMatrix | SpArray)
or (isinstance(el, DaskArray) and isinstance(el._meta, SpMatrix | SpArray))
for el in els
):
return 0
Expand Down Expand Up @@ -828,7 +825,7 @@ def concat_arrays(arrays, reindexers, axis=0, index=None, fill_value=None):
],
axis=axis,
)
elif any(isinstance(a, sparse.spmatrix | SpArray) for a in arrays):
elif any(isinstance(a, SpMatrix | SpArray) for a in arrays):
sparse_stack = (sparse.vstack, sparse.hstack)[axis]
use_sparse_array = any(issubclass(type(a), SpArray) for a in arrays)
return sparse_stack(
Expand Down
7 changes: 3 additions & 4 deletions src/anndata/_core/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
from collections.abc import Mapping, Sequence
from typing import ClassVar

from scipy import sparse

from ..compat import SpMatrix
from .aligned_mapping import AxisArraysView
from .anndata import AnnData
from .sparse_dataset import BaseCompressedSparseDataset
Expand All @@ -31,7 +30,7 @@ class Raw:
def __init__(
self,
adata: AnnData,
X: np.ndarray | sparse.spmatrix | None = None,
X: np.ndarray | SpMatrix | None = None,
var: pd.DataFrame | Mapping[str, Sequence] | None = None,
varm: AxisArrays | Mapping[str, np.ndarray] | None = None,
):
Expand Down Expand Up @@ -66,7 +65,7 @@ def _get_X(self, layer=None):
return self.X

@property
def X(self) -> BaseCompressedSparseDataset | np.ndarray | sparse.spmatrix:
def X(self) -> BaseCompressedSparseDataset | np.ndarray | SpMatrix:
# TODO: Handle unsorted array of integer indices for h5py.Datasets
if not self._adata.isbacked:
return self._X
Expand Down
6 changes: 3 additions & 3 deletions src/anndata/_core/sparse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from .. import abc
from .._settings import settings
from ..compat import H5Group, SpArray, ZarrArray, ZarrGroup, _read_attr
from ..compat import H5Group, SpArray, SpMatrix, ZarrArray, ZarrGroup, _read_attr
from .index import _fix_slice_bounds, _subset, unpack_index

if TYPE_CHECKING:
Expand Down Expand Up @@ -312,7 +312,7 @@ def get_memory_class(
if format == fmt:
if use_sparray_in_io and issubclass(memory_class, SpArray):
return memory_class
elif not use_sparray_in_io and issubclass(memory_class, ss.spmatrix):
elif not use_sparray_in_io and issubclass(memory_class, SpMatrix):
return memory_class
raise ValueError(f"Format string {format} is not supported.")

Expand All @@ -324,7 +324,7 @@ def get_backed_class(
if format == fmt:
if use_sparray_in_io and issubclass(backed_class, SpArray):
return backed_class
elif not use_sparray_in_io and issubclass(backed_class, ss.spmatrix):
elif not use_sparray_in_io and issubclass(backed_class, SpMatrix):
return backed_class
raise ValueError(f"Format string {format} is not supported.")

Expand Down
4 changes: 2 additions & 2 deletions src/anndata/_core/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

import numpy as np
import pandas as pd
from scipy import sparse

from .._warnings import ImplicitModificationWarning
from ..compat import SpMatrix
from ..utils import (
ensure_df_homogeneous,
join_english,
Expand Down Expand Up @@ -39,7 +39,7 @@ def coerce_array(
warnings.warn(msg, ImplicitModificationWarning)
value = value.A
return value
elif isinstance(value, sparse.spmatrix):
elif isinstance(value, SpMatrix):
msg = (
f"AnnData previously had undefined behavior around matrices of type {type(value)}."
"In 0.12, passing in this type will throw an error. Please convert to a supported type."
Expand Down
13 changes: 7 additions & 6 deletions src/anndata/_io/h5ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .._core.file_backing import filename
from .._core.sparse_dataset import BaseCompressedSparseDataset
from ..compat import (
SpMatrix,
_clean_uns,
_decode_structured_array,
_from_fixed_length_strings,
Expand Down Expand Up @@ -82,14 +83,14 @@ def write_h5ad(
f.attrs.setdefault("encoding-version", "0.1.0")

if "X" in as_dense and isinstance(
adata.X, sparse.spmatrix | BaseCompressedSparseDataset
adata.X, SpMatrix | BaseCompressedSparseDataset
):
write_sparse_as_dense(f, "X", adata.X, dataset_kwargs=dataset_kwargs)
elif not (adata.isbacked and Path(adata.filename) == Path(filepath)):
# If adata.isbacked, X should already be up to date
write_elem(f, "X", adata.X, dataset_kwargs=dataset_kwargs)
if "raw/X" in as_dense and isinstance(
adata.raw.X, sparse.spmatrix | BaseCompressedSparseDataset
adata.raw.X, SpMatrix | BaseCompressedSparseDataset
):
write_sparse_as_dense(
f, "raw/X", adata.raw.X, dataset_kwargs=dataset_kwargs
Expand All @@ -115,7 +116,7 @@ def write_h5ad(
def write_sparse_as_dense(
f: h5py.Group,
key: str,
value: sparse.spmatrix | BaseCompressedSparseDataset,
value: SpMatrix | BaseCompressedSparseDataset,
*,
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
):
Expand Down Expand Up @@ -172,7 +173,7 @@ def read_h5ad(
backed: Literal["r", "r+"] | bool | None = None,
*,
as_sparse: Sequence[str] = (),
as_sparse_fmt: type[sparse.spmatrix] = sparse.csr_matrix,
as_sparse_fmt: type[SpMatrix] = sparse.csr_matrix,
chunk_size: int = 6000, # TODO, probably make this 2d chunks
) -> AnnData:
"""\
Expand Down Expand Up @@ -275,7 +276,7 @@ def callback(func, elem_name: str, elem, iospec):
def _read_raw(
f: h5py.File | AnnDataFileManager,
as_sparse: Collection[str] = (),
rdasp: Callable[[h5py.Dataset], sparse.spmatrix] | None = None,
rdasp: Callable[[h5py.Dataset], SpMatrix] | None = None,
*,
attrs: Collection[str] = ("X", "var", "varm"),
) -> dict:
Expand Down Expand Up @@ -348,7 +349,7 @@ def read_dataset(dataset: h5py.Dataset):

@report_read_key_on_error
def read_dense_as_sparse(
dataset: h5py.Dataset, sparse_format: sparse.spmatrix, axis_chunk: int
dataset: h5py.Dataset, sparse_format: SpMatrix, axis_chunk: int
):
if sparse_format == sparse.csr_matrix:
return read_dense_as_csr(dataset, axis_chunk)
Expand Down
15 changes: 8 additions & 7 deletions src/anndata/_io/specs/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@
from numpy.typing import NDArray

from anndata._types import ArrayStorageType, GroupStorageType
from anndata.compat import SpArray
from anndata.compat import (
SpArray,
SpMatrix,
)
from anndata.typing import AxisStorable, InMemoryArrayOrScalarType

from .registry import Reader, Writer
Expand Down Expand Up @@ -127,7 +130,7 @@ def wrapper(
@_REGISTRY.register_read(H5Array, IOSpec("", ""))
def read_basic(
elem: H5File | H5Group | H5Array, *, _reader: Reader
) -> dict[str, InMemoryArrayOrScalarType] | npt.NDArray | sparse.spmatrix | SpArray:
) -> dict[str, InMemoryArrayOrScalarType] | npt.NDArray | SpMatrix | SpArray:
from anndata._io import h5ad

warn(
Expand All @@ -149,7 +152,7 @@ def read_basic(
@_REGISTRY.register_read(ZarrArray, IOSpec("", ""))
def read_basic_zarr(
elem: ZarrGroup | ZarrArray, *, _reader: Reader
) -> dict[str, InMemoryArrayOrScalarType] | npt.NDArray | sparse.spmatrix | SpArray:
) -> dict[str, InMemoryArrayOrScalarType] | npt.NDArray | SpMatrix | SpArray:
from anndata._io import zarr

warn(
Expand Down Expand Up @@ -588,7 +591,7 @@ def write_recarray_zarr(
def write_sparse_compressed(
f: GroupStorageType,
key: str,
value: sparse.spmatrix | SpArray,
value: SpMatrix | SpArray,
*,
_writer: Writer,
fmt: Literal["csr", "csc"],
Expand Down Expand Up @@ -755,9 +758,7 @@ def chunk_slice(start: int, stop: int) -> tuple[slice | None, slice | None]:
@_REGISTRY.register_read(H5Group, IOSpec("csr_matrix", "0.1.0"))
@_REGISTRY.register_read(ZarrGroup, IOSpec("csc_matrix", "0.1.0"))
@_REGISTRY.register_read(ZarrGroup, IOSpec("csr_matrix", "0.1.0"))
def read_sparse(
elem: GroupStorageType, *, _reader: Reader
) -> sparse.spmatrix | SpArray:
def read_sparse(elem: GroupStorageType, *, _reader: Reader) -> SpMatrix | SpArray:
return sparse_dataset(elem).to_memory()


Expand Down
Loading

0 comments on commit f2af154

Please sign in to comment.