From 2e208a34a0affe8e89a0e5c44984b318622914f4 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Mon, 30 Sep 2024 15:53:46 +0200 Subject: [PATCH] Split up PCA tests (#3268) --- tests/test_pca.py | 44 ++++++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/tests/test_pca.py b/tests/test_pca.py index bebb75299..7e49c49cf 100644 --- a/tests/test_pca.py +++ b/tests/test_pca.py @@ -1,6 +1,7 @@ from __future__ import annotations import warnings +from contextlib import nullcontext from functools import wraps from typing import TYPE_CHECKING @@ -190,19 +191,27 @@ def test_pca_warnings_sparse(): def test_pca_transform(array_type): - A = array_type(A_list).astype("float32") + adata = AnnData(array_type(A_list).astype("float32")) A_pca_abs = np.abs(A_pca) - A_svd_abs = np.abs(A_svd) - - adata = AnnData(A) - with warnings.catch_warnings(record=True) as record: - sc.pp.pca(adata, n_comps=4, zero_center=True, dtype="float64") - assert len(record) == 0, record + warnings.filterwarnings("error") + sc.pp.pca(adata, n_comps=4, zero_center=True, dtype="float64") assert np.linalg.norm(A_pca_abs[:, :4] - np.abs(adata.obsm["X_pca"])) < 2e-05 - with warnings.catch_warnings(record=True) as record: + +def test_pca_transform_randomized(array_type): + adata = AnnData(array_type(A_list).astype("float32")) + A_pca_abs = np.abs(A_pca) + + warnings.filterwarnings("error") + with ( + pytest.warns( + UserWarning, match="svd_solver 'randomized' does not work with sparse input" + ) + if sparse.issparse(adata.X) + else nullcontext() + ): sc.pp.pca( adata, n_comps=5, @@ -211,21 +220,16 @@ def test_pca_transform(array_type): dtype="float64", random_state=14, ) - if sparse.issparse(A): - assert any( - isinstance(r.message, UserWarning) - and "svd_solver 'randomized' does not work with sparse input" - in str(r.message) - for r in record - ) - else: - assert len(record) == 0 assert np.linalg.norm(A_pca_abs - np.abs(adata.obsm["X_pca"])) < 2e-05 - with warnings.catch_warnings(record=True) as record: - sc.pp.pca(adata, n_comps=4, zero_center=False, dtype="float64", random_state=14) - assert len(record) == 0 + +def test_pca_transform_no_zero_center(array_type): + adata = AnnData(array_type(A_list).astype("float32")) + A_svd_abs = np.abs(A_svd) + + warnings.filterwarnings("error") + sc.pp.pca(adata, n_comps=4, zero_center=False, dtype="float64", random_state=14) assert np.linalg.norm(A_svd_abs[:, :4] - np.abs(adata.obsm["X_pca"])) < 2e-05