Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep committed Sep 30, 2024
1 parent 1674951 commit 18e7bab
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
8 changes: 2 additions & 6 deletions src/scanpy/preprocessing/_pca/_dask_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,10 @@ def fit(self, x: DaskArray) -> PCASparseFit:
self.explained_variance_, self.components_ = scipy.linalg.eigh(

Check warning on line 55 in src/scanpy/preprocessing/_pca/_dask_sparse.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/preprocessing/_pca/_dask_sparse.py#L51-L55

Added lines #L51 - L55 were not covered by tests
covariance, lower=False
)
# NOTE: We reverse the eigen vector and eigen values here
# because cupy provides them in ascending order. Make a copy otherwise
# it is not C_CONTIGUOUS anymore and would error when converting to
# CumlArray
self.explained_variance_ = self.explained_variance_[::-1]

# Arrange eigenvectors and eigenvalues in descending order
self.explained_variance_ = self.explained_variance_[::-1]
self.components_ = np.flip(self.components_, axis=1)

self.components_ = self.components_.T[: self.n_components_, :]

Check warning on line 62 in src/scanpy/preprocessing/_pca/_dask_sparse.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/preprocessing/_pca/_dask_sparse.py#L60-L62

Added lines #L60 - L62 were not covered by tests

self.explained_variance_ratio_ = self.explained_variance_ / np.sum(

Check warning on line 64 in src/scanpy/preprocessing/_pca/_dask_sparse.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/preprocessing/_pca/_dask_sparse.py#L64

Added line #L64 was not covered by tests
Expand Down
29 changes: 23 additions & 6 deletions tests/test_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from scipy.sparse import issparse

import scanpy as sc
from scanpy._compat import DaskArray
from scanpy.preprocessing._pca._dask_sparse import _cov_sparse_dask
from testing.scanpy import _helpers
from testing.scanpy._helpers.data import pbmc3k_normalized
Expand All @@ -27,7 +28,6 @@
from collections.abc import Callable
from typing import Literal

from scanpy._compat import DaskArray

A_list = np.array(
[
Expand Down Expand Up @@ -197,6 +197,8 @@ def test_pca_transform(array_type):
sc.pp.pca(adata, n_comps=4, zero_center=True, dtype="float64")
assert len(record) == 0, record

if isinstance(adata.obsm["X_pca"], DaskArray):
adata.obsm["X_pca"] = adata.obsm["X_pca"].compute()
assert np.linalg.norm(A_pca_abs[:, :4] - np.abs(adata.obsm["X_pca"])) < 2e-05

with warnings.catch_warnings(record=True) as record:
Expand All @@ -215,16 +217,31 @@ def test_pca_transform(array_type):
in str(r.message)
for r in record
)
elif isinstance(A, DaskArray) and issparse(A._meta):
assert any(
isinstance(r.message, UserWarning)
and str(r.message)
== "random_state is ignored when using a sparse dask array"
for r in record
)
assert any(
isinstance(r.message, UserWarning)
and str(r.message) == "svd_solver is ignored when using a sparse dask array"
for r in record
)
else:
assert len(record) == 0
assert len(record) == 0, [r.message for r in record]

assert np.linalg.norm(A_pca_abs - np.abs(adata.obsm["X_pca"])) < 2e-05

with warnings.catch_warnings(record=True) as record:
sc.pp.pca(adata, n_comps=4, zero_center=False, dtype="float64", random_state=14)
assert len(record) == 0
if not (isinstance(A, DaskArray) and issparse(A._meta)):
with warnings.catch_warnings(record=True) as record:
sc.pp.pca(
adata, n_comps=4, zero_center=False, dtype="float64", random_state=14
)
assert len(record) == 0, [r.message for r in record]

assert np.linalg.norm(A_svd_abs[:, :4] - np.abs(adata.obsm["X_pca"])) < 2e-05
assert np.linalg.norm(A_svd_abs[:, :4] - np.abs(adata.obsm["X_pca"])) < 2e-05


def test_pca_shapes():
Expand Down

0 comments on commit 18e7bab

Please sign in to comment.