Skip to content

Commit

Permalink
Backport PR scverse#3268: Split up PCA tests
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep authored and meeseeksmachine committed Sep 30, 2024
1 parent ae926b8 commit c8efab2
Showing 1 changed file with 24 additions and 20 deletions.
44 changes: 24 additions & 20 deletions tests/test_pca.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import warnings
from contextlib import nullcontext
from functools import wraps
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -190,19 +191,27 @@ def test_pca_warnings_sparse():


def test_pca_transform(array_type):
A = array_type(A_list).astype("float32")
adata = AnnData(array_type(A_list).astype("float32"))
A_pca_abs = np.abs(A_pca)
A_svd_abs = np.abs(A_svd)

adata = AnnData(A)

with warnings.catch_warnings(record=True) as record:
sc.pp.pca(adata, n_comps=4, zero_center=True, dtype="float64")
assert len(record) == 0, record
warnings.filterwarnings("error")
sc.pp.pca(adata, n_comps=4, zero_center=True, dtype="float64")

assert np.linalg.norm(A_pca_abs[:, :4] - np.abs(adata.obsm["X_pca"])) < 2e-05

with warnings.catch_warnings(record=True) as record:

def test_pca_transform_randomized(array_type):
adata = AnnData(array_type(A_list).astype("float32"))
A_pca_abs = np.abs(A_pca)

warnings.filterwarnings("error")
with (
pytest.warns(
UserWarning, match="svd_solver 'randomized' does not work with sparse input"
)
if sparse.issparse(adata.X)
else nullcontext()
):
sc.pp.pca(
adata,
n_comps=5,
Expand All @@ -211,21 +220,16 @@ def test_pca_transform(array_type):
dtype="float64",
random_state=14,
)
if sparse.issparse(A):
assert any(
isinstance(r.message, UserWarning)
and "svd_solver 'randomized' does not work with sparse input"
in str(r.message)
for r in record
)
else:
assert len(record) == 0

assert np.linalg.norm(A_pca_abs - np.abs(adata.obsm["X_pca"])) < 2e-05

with warnings.catch_warnings(record=True) as record:
sc.pp.pca(adata, n_comps=4, zero_center=False, dtype="float64", random_state=14)
assert len(record) == 0

def test_pca_transform_no_zero_center(array_type):
adata = AnnData(array_type(A_list).astype("float32"))
A_svd_abs = np.abs(A_svd)

warnings.filterwarnings("error")
sc.pp.pca(adata, n_comps=4, zero_center=False, dtype="float64", random_state=14)

assert np.linalg.norm(A_svd_abs[:, :4] - np.abs(adata.obsm["X_pca"])) < 2e-05

Expand Down

0 comments on commit c8efab2

Please sign in to comment.