Skip to content

Commit

Permalink
Merge branch 'main' into pca-dask-sparse
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep committed Sep 30, 2024
2 parents 18e7bab + 2e208a3 commit 90a196f
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 36 deletions.
20 changes: 20 additions & 0 deletions src/testing/scanpy/_helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from __future__ import annotations

import warnings
from contextlib import AbstractContextManager
from dataclasses import dataclass
from itertools import permutations
from typing import TYPE_CHECKING

Expand All @@ -14,6 +16,8 @@
import scanpy as sc

if TYPE_CHECKING:
from collections.abc import MutableSequence

from scanpy._compat import DaskArray

# TODO: Report more context on the fields being compared on error
Expand Down Expand Up @@ -138,3 +142,19 @@ def as_sparse_dask_array(*args, **kwargs) -> DaskArray:
from anndata.tests.helpers import as_sparse_dask_array

return as_sparse_dask_array(*args, **kwargs)


@dataclass(init=False)
class MultiContext(AbstractContextManager):
contexts: MutableSequence[AbstractContextManager]

def __init__(self, *contexts: AbstractContextManager):
self.contexts = list(contexts)

def __enter__(self):
for ctx in self.contexts:
ctx.__enter__()

def __exit__(self, exc_type, exc_value, traceback):
for ctx in reversed(self.contexts):
ctx.__exit__(exc_type, exc_value, traceback)
76 changes: 40 additions & 36 deletions tests/test_pca.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import warnings
from contextlib import nullcontext
from functools import wraps
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -187,21 +188,41 @@ def test_pca_warnings_sparse():


def test_pca_transform(array_type):
A = array_type(A_list).astype("float32")
adata = AnnData(array_type(A_list).astype("float32"))
A_pca_abs = np.abs(A_pca)
A_svd_abs = np.abs(A_svd)

adata = AnnData(A)

with warnings.catch_warnings(record=True) as record:
sc.pp.pca(adata, n_comps=4, zero_center=True, dtype="float64")
assert len(record) == 0, record
warnings.filterwarnings("error")
sc.pp.pca(adata, n_comps=4, zero_center=True, dtype="float64")

if isinstance(adata.obsm["X_pca"], DaskArray):
adata.obsm["X_pca"] = adata.obsm["X_pca"].compute()
assert np.linalg.norm(A_pca_abs[:, :4] - np.abs(adata.obsm["X_pca"])) < 2e-05

with warnings.catch_warnings(record=True) as record:

def test_pca_transform_randomized(array_type):
adata = AnnData(array_type(A_list).astype("float32"))
A_pca_abs = np.abs(A_pca)

warnings.filterwarnings("error")
if isinstance(adata.X, DaskArray) and issparse(adata.X._meta):
ctx = _helpers.MultiContext(
pytest.warns(
UserWarning,
match=r"random_state is ignored when using a sparse dask array",
),
pytest.warns(
UserWarning,
match=r"svd_solver is ignored when using a sparse dask array",
),
)
elif sparse.issparse(adata.X):
ctx = pytest.warns(
UserWarning, match="svd_solver 'randomized' does not work with sparse input"
)
else:
ctx = nullcontext()

with ctx:
sc.pp.pca(
adata,
n_comps=5,
Expand All @@ -210,38 +231,21 @@ def test_pca_transform(array_type):
dtype="float64",
random_state=14,
)
if sparse.issparse(A):
assert any(
isinstance(r.message, UserWarning)
and "svd_solver 'randomized' does not work with sparse input"
in str(r.message)
for r in record
)
elif isinstance(A, DaskArray) and issparse(A._meta):
assert any(
isinstance(r.message, UserWarning)
and str(r.message)
== "random_state is ignored when using a sparse dask array"
for r in record
)
assert any(
isinstance(r.message, UserWarning)
and str(r.message) == "svd_solver is ignored when using a sparse dask array"
for r in record
)
else:
assert len(record) == 0, [r.message for r in record]

assert np.linalg.norm(A_pca_abs - np.abs(adata.obsm["X_pca"])) < 2e-05

if not (isinstance(A, DaskArray) and issparse(A._meta)):
with warnings.catch_warnings(record=True) as record:
sc.pp.pca(
adata, n_comps=4, zero_center=False, dtype="float64", random_state=14
)
assert len(record) == 0, [r.message for r in record]

assert np.linalg.norm(A_svd_abs[:, :4] - np.abs(adata.obsm["X_pca"])) < 2e-05
def test_pca_transform_no_zero_center(request: pytest.FixtureRequest, array_type):
adata = AnnData(array_type(A_list).astype("float32"))
A_svd_abs = np.abs(A_svd)
if isinstance(adata.X, DaskArray) and issparse(adata.X._meta):
reason = "TruncatedSVD is not supported for sparse Dask yet"
request.applymarker(pytest.mark.xfail(reason=reason))

warnings.filterwarnings("error")
sc.pp.pca(adata, n_comps=4, zero_center=False, dtype="float64", random_state=14)

assert np.linalg.norm(A_svd_abs[:, :4] - np.abs(adata.obsm["X_pca"])) < 2e-05


def test_pca_shapes():
Expand Down

0 comments on commit 90a196f

Please sign in to comment.