diff --git a/docs/api/tools.md b/docs/api/tools.md index 5165777b9f..1d51559e5a 100644 --- a/docs/api/tools.md +++ b/docs/api/tools.md @@ -17,7 +17,7 @@ Any transformation of the data matrix that is not *preprocessing*. In contrast t :nosignatures: :toctree: ../generated/ - tl.pca + pp.pca tl.tsne tl.umap tl.draw_graph diff --git a/docs/conf.py b/docs/conf.py index 41dd3f0aa1..2e164b9c4c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,11 +1,11 @@ -import os +from __future__ import annotations + import sys -from pathlib import Path from datetime import datetime -from typing import Any +from pathlib import Path +from typing import TYPE_CHECKING import matplotlib # noqa -from sphinx.application import Sphinx from packaging.version import parse as parse_version # Don’t use tkinter agg when importing scanpy → … → matplotlib @@ -15,6 +15,9 @@ sys.path[:0] = [str(HERE.parent), str(HERE / "extensions")] import scanpy # noqa +if TYPE_CHECKING: + from sphinx.application import Sphinx + # -- General configuration ------------------------------------------------ diff --git a/docs/extensions/cite.py b/docs/extensions/cite.py index 616cbf4e91..5db46edc47 100644 --- a/docs/extensions/cite.py +++ b/docs/extensions/cite.py @@ -8,11 +8,11 @@ from docutils import nodes, utils if TYPE_CHECKING: - from typing import Any from collections.abc import Mapping, Sequence + from typing import Any - from sphinx.application import Sphinx from docutils.parsers.rst.states import Inliner + from sphinx.application import Sphinx def cite_role( diff --git a/docs/extensions/debug_docstrings.py b/docs/extensions/debug_docstrings.py index fba4dbaaba..87bc210cef 100644 --- a/docs/extensions/debug_docstrings.py +++ b/docs/extensions/debug_docstrings.py @@ -1,10 +1,14 @@ # Just do the following to see the rst of a function: -# rm -f _build/doctrees/api/scanpy..doctree; DEBUG=1 make html +# rm ./_build/doctrees/api/generated/scanpy..doctree; DEBUG=1 make html +from __future__ import annotations + import os +from typing import TYPE_CHECKING -from sphinx.application import Sphinx import sphinx.ext.napoleon +if TYPE_CHECKING: + from sphinx.application import Sphinx _pd_orig = sphinx.ext.napoleon._process_docstring diff --git a/docs/extensions/function_images.py b/docs/extensions/function_images.py index 688f26831f..7042daf1ee 100644 --- a/docs/extensions/function_images.py +++ b/docs/extensions/function_images.py @@ -1,13 +1,16 @@ """Images for plot functions""" +from __future__ import annotations + from pathlib import Path -from typing import List, Any +from typing import TYPE_CHECKING, Any -from sphinx.application import Sphinx -from sphinx.ext.autodoc import Options +if TYPE_CHECKING: + from sphinx.application import Sphinx + from sphinx.ext.autodoc import Options def insert_function_images( - app: Sphinx, what: str, name: str, obj: Any, options: Options, lines: List[str] + app: Sphinx, what: str, name: str, obj: Any, options: Options, lines: list[str] ): path = app.config.api_dir / f"{name}.png" if what != "function" or not path.is_file(): diff --git a/docs/extensions/git_ref.py b/docs/extensions/git_ref.py index 40a3644a7e..1f57de7a25 100644 --- a/docs/extensions/git_ref.py +++ b/docs/extensions/git_ref.py @@ -2,12 +2,14 @@ from __future__ import annotations -from functools import lru_cache - import re import subprocess -from sphinx.application import Sphinx -from sphinx.config import Config +from functools import lru_cache +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from sphinx.application import Sphinx + from sphinx.config import Config def git(*args: str) -> str: diff --git a/docs/extensions/has_attr_test.py b/docs/extensions/has_attr_test.py index f49b4bb887..d570498366 100644 --- a/docs/extensions/has_attr_test.py +++ b/docs/extensions/has_attr_test.py @@ -1,8 +1,13 @@ +from __future__ import annotations + from inspect import get_annotations +from typing import TYPE_CHECKING from jinja2.defaults import DEFAULT_NAMESPACE from jinja2.utils import import_string -from sphinx.application import Sphinx + +if TYPE_CHECKING: + from sphinx.application import Sphinx def has_member(obj_path: str, attr: str) -> bool: diff --git a/docs/extensions/param_police.py b/docs/extensions/param_police.py index 9b4b6b41f2..37942d3687 100644 --- a/docs/extensions/param_police.py +++ b/docs/extensions/param_police.py @@ -1,8 +1,12 @@ +from __future__ import annotations + import warnings +from typing import TYPE_CHECKING -from sphinx.application import Sphinx from sphinx.ext.napoleon import NumpyDocstring +if TYPE_CHECKING: + from sphinx.application import Sphinx _format_docutils_params_orig = NumpyDocstring._format_docutils_params param_warnings = {} diff --git a/docs/extensions/typed_returns.py b/docs/extensions/typed_returns.py index 7bcbe71b4c..ed93338a2d 100644 --- a/docs/extensions/typed_returns.py +++ b/docs/extensions/typed_returns.py @@ -1,8 +1,13 @@ +from __future__ import annotations + import re +from typing import TYPE_CHECKING -from sphinx.application import Sphinx from sphinx.ext.napoleon import NumpyDocstring +if TYPE_CHECKING: + from sphinx.application import Sphinx + def process_return(lines): for line in lines: diff --git a/docs/release-notes/1.4.5.md b/docs/release-notes/1.4.5.md index 31fc87f85d..a70ab739ee 100644 --- a/docs/release-notes/1.4.5.md +++ b/docs/release-notes/1.4.5.md @@ -20,7 +20,7 @@ Please install `scanpy==1.4.5.post3` instead of `scanpy==1.4.5`. - webpage overhaul, ecosystem page, release notes, tutorials overhaul {pr}`960` {pr}`966` {smaller}`A Wolf` ```{warning} -- changed default `solver` in {func}`~scanpy.tl.pca` from `auto` to `arpack` +- changed default `solver` in {func}`~scanpy.pp.pca` from `auto` to `arpack` - changed default `use_raw` in {func}`~scanpy.tl.score_genes` from `False` to `None` ``` diff --git a/pyproject.toml b/pyproject.toml index 02adcfb843..36a3fbbd60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,16 +106,16 @@ test-full = [ "scanpy[dask-ml]", ] doc = [ - "sphinx>=5", + "sphinx>=6", "sphinx-book-theme>=1.0.1", "scanpydoc>=0.9.5", - "sphinx-autodoc-typehints", - "myst-parser", - "myst-nb", + "sphinx-autodoc-typehints>=1.23.0", + "myst-parser>=2", + "myst-nb>=1", "sphinx-design", "sphinxext-opengraph", # for nice cards when sharing on social "sphinx-copybutton", - "nbsphinx", + "nbsphinx>=0.9", "ipython>=7.20", # for nbsphinx code highlighting "matplotlib!=3.6.1", # TODO: remove necessity for being able to import doc-linked classes @@ -177,14 +177,17 @@ source = [".", "**/site-packages"] [tool.ruff] select = [ - "F", # Pyflakes - "E", # Pycodestyle errors - "W", # Pycodestyle warnings + "E", # Error detected by Pycodestyle + "F", # Errors detected by Pyflakes + "W", # Warning detected by Pycodestyle + "UP", # pyupgrade + "I", # isort + "TCH", # manage type checking blocks "TID251", # Banned imports + "ICN", # Follow import conventions + "PTH", # Pathlib instead of os.path ] ignore = [ - # module imported but unused -> required for Scanpys API - "F401", # line too long -> we accept long comment lines; black gets rid of long code lines "E501", # module level import not at top of file -> required to circumvent circular imports for Scanpys API @@ -197,6 +200,9 @@ ignore = [ [tool.ruff.per-file-ignores] # Do not assign a lambda expression, use a def "scanpy/tools/_rank_genes_groups.py" = ["E731"] +[tool.ruff.isort] +known-first-party = ["scanpy"] +required-imports = ["from __future__ import annotations"] [tool.ruff.flake8-tidy-imports.banned-api] "pytest.importorskip".msg = "Use the “@needs” decorator/mark instead" "pandas.api.types.is_categorical_dtype".msg = "Use isinstance(s.dtype, CategoricalDtype) instead" diff --git a/scanpy/__init__.py b/scanpy/__init__.py index ac997b45b4..d6a2e46770 100644 --- a/scanpy/__init__.py +++ b/scanpy/__init__.py @@ -1,4 +1,5 @@ """Single-Cell Analysis in Python.""" +from __future__ import annotations try: # See https://github.com/maresb/hatch-vcs-footgun-example from setuptools_scm import get_version @@ -20,25 +21,26 @@ # the actual API # (start with settings as several tools are using it) -from ._settings import settings, Verbosity -from . import tools as tl -from . import preprocessing as pp -from . import plotting as pl -from . import datasets, logging, queries, external, get, metrics, experimental - -from anndata import AnnData, concat from anndata import ( - read_h5ad, + AnnData, + concat, read_csv, read_excel, + read_h5ad, read_hdf, read_loom, read_mtx, read_text, read_umi_tools, ) -from .readwrite import read, read_10x_h5, read_10x_mtx, write, read_visium + +from . import datasets, experimental, external, get, logging, metrics, queries +from . import plotting as pl +from . import preprocessing as pp +from . import tools as tl +from ._settings import Verbosity, settings from .neighbors import Neighbors +from .readwrite import read, read_10x_h5, read_10x_mtx, read_visium, write set_figure_params = settings.set_figure_params @@ -50,3 +52,36 @@ annotate_doc_types(sys.modules[__name__], "scanpy") del sys, annotate_doc_types + +__all__ = [ + "__version__", + "AnnData", + "concat", + "read_csv", + "read_excel", + "read_h5ad", + "read_hdf", + "read_loom", + "read_mtx", + "read_text", + "read_umi_tools", + "read", + "read_10x_h5", + "read_10x_mtx", + "read_visium", + "write", + "datasets", + "experimental", + "external", + "get", + "logging", + "metrics", + "queries", + "pl", + "pp", + "tl", + "Verbosity", + "settings", + "Neighbors", + "set_figure_params", +] diff --git a/scanpy/__main__.py b/scanpy/__main__.py index cdc3b29399..f0ec49f773 100644 --- a/scanpy/__main__.py +++ b/scanpy/__main__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from .cli import console_main console_main() diff --git a/scanpy/_compat.py b/scanpy/_compat.py index 2a5af79688..ab254e2b28 100644 --- a/scanpy/_compat.py +++ b/scanpy/_compat.py @@ -2,6 +2,7 @@ from dataclasses import dataclass, field from pathlib import Path + from packaging import version try: diff --git a/scanpy/_settings.py b/scanpy/_settings.py index 3e71e07668..9bcf30760a 100644 --- a/scanpy/_settings.py +++ b/scanpy/_settings.py @@ -4,14 +4,16 @@ import sys from contextlib import contextmanager from enum import IntEnum +from logging import getLevelName from pathlib import Path from time import time -from logging import getLevelName -from typing import Any, Union, Optional, Iterable, TextIO, Literal -from typing import Tuple, List, ContextManager +from typing import TYPE_CHECKING, Any, Literal, TextIO, Union from . import logging -from .logging import _set_log_level, _set_log_file, _RootLogger +from .logging import _RootLogger, _set_log_file, _set_log_level + +if TYPE_CHECKING: + from collections.abc import Generator, Iterable _VERBOSITY_TO_LOGLEVEL = { "error": "ERROR", @@ -55,7 +57,9 @@ def level(self) -> int: return getLevelName(_VERBOSITY_TO_LOGLEVEL[self.name]) @contextmanager - def override(self, verbosity: Verbosity | str | int) -> ContextManager[Verbosity]: + def override( + self, verbosity: Verbosity | str | int + ) -> Generator[Verbosity, None, None]: """\ Temporarily override verbosity """ @@ -68,7 +72,7 @@ def override(self, verbosity: Verbosity | str | int) -> ContextManager[Verbosity Verbosity.warn = Verbosity.warning -def _type_check(var: Any, varname: str, types: Union[type, Tuple[type, ...]]): +def _type_check(var: Any, varname: str, types: type | tuple[type, ...]): if isinstance(var, types): return if isinstance(types, type): @@ -98,14 +102,14 @@ def __init__( file_format_figs: str = "pdf", autosave: bool = False, autoshow: bool = True, - writedir: Union[str, Path] = "./write/", - cachedir: Union[str, Path] = "./cache/", - datasetdir: Union[str, Path] = "./data/", - figdir: Union[str, Path] = "./figures/", - cache_compression: Union[str, None] = "lzf", + writedir: str | Path = "./write/", + cachedir: str | Path = "./cache/", + datasetdir: str | Path = "./data/", + figdir: str | Path = "./figures/", + cache_compression: str | None = "lzf", max_memory=15, n_jobs=1, - logfile: Union[str, Path, None] = None, + logfile: str | Path | None = None, categories_to_ignore: Iterable[str] = ("N/A", "dontknow", "no_gate", "?"), _frameon: bool = True, _vector_friendly: bool = False, @@ -164,7 +168,7 @@ def verbosity(self) -> Verbosity: return self._verbosity @verbosity.setter - def verbosity(self, verbosity: Union[Verbosity, int, str]): + def verbosity(self, verbosity: Verbosity | int | str): verbosity_str_options = [ v for v in _VERBOSITY_TO_LOGLEVEL if isinstance(v, str) ] @@ -265,7 +269,7 @@ def writedir(self) -> Path: return self._writedir @writedir.setter - def writedir(self, writedir: Union[str, Path]): + def writedir(self, writedir: str | Path): _type_check(writedir, "writedir", (str, Path)) self._writedir = Path(writedir) @@ -277,7 +281,7 @@ def cachedir(self) -> Path: return self._cachedir @cachedir.setter - def cachedir(self, cachedir: Union[str, Path]): + def cachedir(self, cachedir: str | Path): _type_check(cachedir, "cachedir", (str, Path)) self._cachedir = Path(cachedir) @@ -289,7 +293,7 @@ def datasetdir(self) -> Path: return self._datasetdir @datasetdir.setter - def datasetdir(self, datasetdir: Union[str, Path]): + def datasetdir(self, datasetdir: str | Path): _type_check(datasetdir, "datasetdir", (str, Path)) self._datasetdir = Path(datasetdir).resolve() @@ -301,12 +305,12 @@ def figdir(self) -> Path: return self._figdir @figdir.setter - def figdir(self, figdir: Union[str, Path]): + def figdir(self, figdir: str | Path): _type_check(figdir, "figdir", (str, Path)) self._figdir = Path(figdir) @property - def cache_compression(self) -> Optional[str]: + def cache_compression(self) -> str | None: """\ Compression for `sc.read(..., cache=True)` (default `'lzf'`). @@ -315,7 +319,7 @@ def cache_compression(self) -> Optional[str]: return self._cache_compression @cache_compression.setter - def cache_compression(self, cache_compression: Optional[str]): + def cache_compression(self, cache_compression: str | None): if cache_compression not in {"lzf", "gzip", None}: raise ValueError( f"`cache_compression` ({cache_compression}) " @@ -324,7 +328,7 @@ def cache_compression(self, cache_compression: Optional[str]): self._cache_compression = cache_compression @property - def max_memory(self) -> Union[int, float]: + def max_memory(self) -> int | float: """\ Maximum memory usage in Gigabyte. @@ -333,7 +337,7 @@ def max_memory(self) -> Union[int, float]: return self._max_memory @max_memory.setter - def max_memory(self, max_memory: Union[int, float]): + def max_memory(self, max_memory: int | float): _type_check(max_memory, "max_memory", (int, float)) self._max_memory = max_memory @@ -354,14 +358,14 @@ def n_jobs(self, n_jobs: int): self._n_jobs = n_jobs @property - def logpath(self) -> Optional[Path]: + def logpath(self) -> Path | None: """\ The file path `logfile` was set to. """ return self._logpath @logpath.setter - def logpath(self, logpath: Union[str, Path, None]): + def logpath(self, logpath: str | Path | None): _type_check(logpath, "logfile", (str, Path)) # set via “file object” branch of logfile.setter self.logfile = Path(logpath).open("a") @@ -381,7 +385,7 @@ def logfile(self) -> TextIO: return self._logfile @logfile.setter - def logfile(self, logfile: Union[str, Path, TextIO, None]): + def logfile(self, logfile: str | Path | TextIO | None): if not hasattr(logfile, "write") and logfile: self.logpath = logfile else: # file object @@ -392,7 +396,7 @@ def logfile(self, logfile: Union[str, Path, TextIO, None]): _set_log_file(self) @property - def categories_to_ignore(self) -> List[str]: + def categories_to_ignore(self) -> list[str]: """\ Categories that are omitted in plotting etc. """ @@ -417,10 +421,10 @@ def set_figure_params( frameon: bool = True, vector_friendly: bool = True, fontsize: int = 14, - figsize: Optional[int] = None, - color_map: Optional[str] = None, + figsize: int | None = None, + color_map: str | None = None, format: _Format = "pdf", - facecolor: Optional[str] = None, + facecolor: str | None = None, transparent: bool = False, ipython_format: str = "png2x", ): diff --git a/scanpy/_utils/__init__.py b/scanpy/_utils/__init__.py index a3f9e771db..20fac13363 100644 --- a/scanpy/_utils/__init__.py +++ b/scanpy/_utils/__init__.py @@ -5,31 +5,35 @@ """ from __future__ import annotations -import sys +import importlib.util import inspect +import sys import warnings -import importlib.util -from enum import Enum -from pathlib import Path -from weakref import WeakSet from collections import namedtuple +from enum import Enum from functools import partial, singledispatch, wraps -from types import ModuleType, MethodType -from typing import Union, Callable, Optional, Mapping, Any, Dict, Tuple, Literal +from textwrap import dedent +from types import MethodType, ModuleType +from typing import TYPE_CHECKING, Any, Callable, Literal, Union +from weakref import WeakSet import numpy as np +from anndata import AnnData +from anndata import __version__ as anndata_version from numpy import random from numpy.typing import NDArray -from scipy import sparse -from anndata import AnnData, __version__ as anndata_version -from textwrap import dedent from packaging import version +from scipy import sparse -from .._settings import settings from .. import logging as logg from .._compat import DaskArray +from .._settings import settings from .compute.is_constant import is_constant # noqa: F401 +if TYPE_CHECKING: + from collections.abc import Mapping + from pathlib import Path + class Empty(Enum): token = 0 @@ -69,7 +73,7 @@ def check_versions(): ) -def getdoc(c_or_f: Union[Callable, type]) -> Optional[str]: +def getdoc(c_or_f: Callable | type) -> str | None: if getattr(c_or_f, "__doc__", None) is None: return None doc = inspect.getdoc(c_or_f) @@ -165,7 +169,7 @@ def annotate_doc_types(mod: ModuleType, root: str): def _doc_params(**kwds): """\ - Docstrings should start with "\" in the first line for proper formatting. + Docstrings should start with ``\\`` in the first line for proper formatting. """ def dec(obj): @@ -189,7 +193,7 @@ def _check_array_function_arguments(**kwargs): ) -def _check_use_raw(adata: AnnData, use_raw: Union[None, bool]) -> bool: +def _check_use_raw(adata: AnnData, use_raw: None | bool) -> bool: """ Normalize checking `use_raw`. @@ -262,7 +266,7 @@ def compute_association_matrix_of_groups( reference: str, normalization: Literal["prediction", "reference"] = "prediction", threshold: float = 0.01, - max_n_names: Optional[int] = 2, + max_n_names: int | None = 2, ): """Compute overlaps between groups. @@ -449,7 +453,7 @@ def update_params( old_params: Mapping[str, Any], new_params: Mapping[str, Any], check=False, -) -> Dict[str, Any]: +) -> dict[str, Any]: """\ Update old_params with new_params. @@ -613,7 +617,7 @@ def subsample( X: np.ndarray, subsample: int = 1, seed: int = 0, -) -> Tuple[np.ndarray, np.ndarray]: +) -> tuple[np.ndarray, np.ndarray]: """\ Subsample a fraction of 1/subsample samples from the rows of X. @@ -653,7 +657,7 @@ def subsample( def subsample_n( X: np.ndarray, n: int = 0, seed: int = 0 -) -> Tuple[np.ndarray, np.ndarray]: +) -> tuple[np.ndarray, np.ndarray]: """Subsample n samples from rows of array. Parameters diff --git a/scanpy/_utils/compute/is_constant.py b/scanpy/_utils/compute/is_constant.py index 6785f16d41..0488fa8428 100644 --- a/scanpy/_utils/compute/is_constant.py +++ b/scanpy/_utils/compute/is_constant.py @@ -1,17 +1,18 @@ from __future__ import annotations -from typing import Literal, TypeVar, overload -from functools import partial, wraps, singledispatch -from numbers import Integral from collections.abc import Callable +from functools import partial, singledispatch, wraps +from numbers import Integral +from typing import TYPE_CHECKING, Literal, TypeVar, overload import numpy as np -from numpy.typing import NDArray from numba import njit from scipy import sparse from ..._compat import DaskArray +if TYPE_CHECKING: + from numpy.typing import NDArray C = TypeVar("C", bound=Callable) diff --git a/scanpy/cli.py b/scanpy/cli.py index 08ce2e2199..37174ffa2a 100644 --- a/scanpy/cli.py +++ b/scanpy/cli.py @@ -1,28 +1,23 @@ +from __future__ import annotations + +import collections.abc as cabc import os import sys -import collections.abc as cabc -from argparse import ArgumentParser, Namespace, _SubParsersAction, ArgumentError +from argparse import ArgumentParser, Namespace, _SubParsersAction from functools import lru_cache, partial from pathlib import Path from shutil import which -from subprocess import run, CompletedProcess -from typing import ( - Optional, - Generator, - FrozenSet, - Sequence, - List, - Tuple, - Dict, - Any, - Mapping, -) +from subprocess import CompletedProcess, run +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Generator, Mapping, Sequence class _DelegatingSubparsersAction(_SubParsersAction): """Like a normal subcommand action, but uses a delegator for more choices""" - def __init__(self, *args, _command: str, _runargs: Dict[str, Any], **kwargs): + def __init__(self, *args, _command: str, _runargs: dict[str, Any], **kwargs): super().__init__(*args, **kwargs) self.command = _command self._name_parser_map = self.choices = _CommandDelegator( @@ -86,8 +81,8 @@ def __eq__(self, other: Mapping[str, ArgumentParser]): return self.parser_map == other @property - @lru_cache() - def commands(self) -> FrozenSet[str]: + @lru_cache + def commands(self) -> frozenset[str]: return frozenset( binary.name[len(self.command) + 1 :] for bin_dir in os.environ["PATH"].split(os.pathsep) @@ -106,9 +101,9 @@ def __init__(self, cd: _CommandDelegator, subcmd: str): def parse_known_args( self, - args: Optional[Sequence[str]] = None, - namespace: Optional[Namespace] = None, - ) -> Tuple[Namespace, List[str]]: + args: Sequence[str] | None = None, + namespace: Namespace | None = None, + ) -> tuple[Namespace, list[str]]: assert ( args is not None and namespace is None ), "Only use DelegatingParser as subparser" @@ -122,8 +117,8 @@ def _cmd_settings() -> None: def main( - argv: Optional[Sequence[str]] = None, *, check: bool = True, **runargs -) -> Optional[CompletedProcess]: + argv: Sequence[str] | None = None, *, check: bool = True, **runargs +) -> CompletedProcess | None: """\ Run a builtin scanpy command or a scanpy-* subcommand. diff --git a/scanpy/datasets/__init__.py b/scanpy/datasets/__init__.py index a38288ed2a..cff764effe 100644 --- a/scanpy/datasets/__init__.py +++ b/scanpy/datasets/__init__.py @@ -1,15 +1,31 @@ """Builtin Datasets. """ +from __future__ import annotations + from ._datasets import ( blobs, burczynski06, krumsiek11, moignard15, paul15, - toggleswitch, - pbmc68k_reduced, pbmc3k, pbmc3k_processed, + pbmc68k_reduced, + toggleswitch, visium_sge, ) from ._ebi_expression_atlas import ebi_expression_atlas + +__all__ = [ + "blobs", + "burczynski06", + "krumsiek11", + "moignard15", + "paul15", + "pbmc3k", + "pbmc3k_processed", + "pbmc68k_reduced", + "toggleswitch", + "visium_sge", + "ebi_expression_atlas", +] diff --git a/scanpy/datasets/_datasets.py b/scanpy/datasets/_datasets.py index eefe9f1de4..f190a0eb7f 100644 --- a/scanpy/datasets/_datasets.py +++ b/scanpy/datasets/_datasets.py @@ -1,16 +1,21 @@ -from pathlib import Path -from typing import Optional, Literal +from __future__ import annotations + import warnings +from pathlib import Path +from typing import TYPE_CHECKING, Literal +import anndata as ad import numpy as np import pandas as pd -import anndata as ad -from .. import logging as logg, _utils +from .. import _utils +from .. import logging as logg from .._settings import settings from ..readwrite import read, read_visium from ._utils import check_datasetdir_exists, filter_oldformatwarning -from .._utils import AnyRandom + +if TYPE_CHECKING: + from .._utils import AnyRandom HERE = Path(__file__).parent @@ -314,7 +319,7 @@ def pbmc3k_processed() -> ad.AnnData: def _download_visium_dataset( sample_id: str, spaceranger_version: str, - base_dir: Optional[Path] = None, + base_dir: Path | None = None, download_image: bool = False, ): """ diff --git a/scanpy/datasets/_ebi_expression_atlas.py b/scanpy/datasets/_ebi_expression_atlas.py index f1c7e20eff..7482835dcc 100644 --- a/scanpy/datasets/_ebi_expression_atlas.py +++ b/scanpy/datasets/_ebi_expression_atlas.py @@ -1,16 +1,18 @@ -from urllib.request import urlopen +from __future__ import annotations + +from typing import BinaryIO from urllib.error import HTTPError +from urllib.request import urlopen from zipfile import ZipFile -from typing import BinaryIO import anndata -import pandas as pd import numpy as np +import pandas as pd from scipy import sparse -from ..readwrite import _download -from .._settings import settings from .. import logging as logg +from .._settings import settings +from ..readwrite import _download from ._utils import check_datasetdir_exists diff --git a/scanpy/datasets/_utils.py b/scanpy/datasets/_utils.py index 302e23fcc4..fbd37de98c 100644 --- a/scanpy/datasets/_utils.py +++ b/scanpy/datasets/_utils.py @@ -1,9 +1,10 @@ -from functools import wraps -import warnings +from __future__ import annotations -from packaging import version +import warnings +from functools import wraps import anndata as ad +from packaging import version from .._settings import settings diff --git a/scanpy/experimental/__init__.py b/scanpy/experimental/__init__.py index 8a00c90df0..1ad2751169 100644 --- a/scanpy/experimental/__init__.py +++ b/scanpy/experimental/__init__.py @@ -1 +1,5 @@ +from __future__ import annotations + from . import pp + +__all__ = ["pp"] diff --git a/scanpy/experimental/_docs.py b/scanpy/experimental/_docs.py index a1408adf01..62610d6320 100644 --- a/scanpy/experimental/_docs.py +++ b/scanpy/experimental/_docs.py @@ -1,5 +1,6 @@ """Shared docstrings for experimental function parameters. """ +from __future__ import annotations doc_adata = """\ adata diff --git a/scanpy/experimental/pp/__init__.py b/scanpy/experimental/pp/__init__.py index a5eaf9d9c2..135840e2f2 100644 --- a/scanpy/experimental/pp/__init__.py +++ b/scanpy/experimental/pp/__init__.py @@ -1,8 +1,15 @@ +from __future__ import annotations + +from scanpy.experimental.pp._highly_variable_genes import highly_variable_genes from scanpy.experimental.pp._normalization import ( normalize_pearson_residuals, normalize_pearson_residuals_pca, ) - -from scanpy.experimental.pp._highly_variable_genes import highly_variable_genes - from scanpy.experimental.pp._recipes import recipe_pearson_residuals + +__all__ = [ + "highly_variable_genes", + "normalize_pearson_residuals", + "normalize_pearson_residuals_pca", + "recipe_pearson_residuals", +] diff --git a/scanpy/experimental/pp/_highly_variable_genes.py b/scanpy/experimental/pp/_highly_variable_genes.py index 723ea2222a..69854c965d 100644 --- a/scanpy/experimental/pp/_highly_variable_genes.py +++ b/scanpy/experimental/pp/_highly_variable_genes.py @@ -1,30 +1,33 @@ -from functools import partial +from __future__ import annotations + import warnings -from typing import Optional, Literal +from functools import partial +from math import sqrt +from typing import TYPE_CHECKING, Literal -import numpy as np import numba as nb +import numpy as np import pandas as pd import scipy.sparse as sp_sparse from anndata import AnnData -from math import sqrt -from numpy.typing import NDArray from scanpy import logging as logg -from scanpy._settings import settings, Verbosity -from scanpy._utils import check_nonnegative_integers, view_to_actual -from scanpy.get import _get_obs_rep -from scanpy._utils import _doc_params -from scanpy.preprocessing._utils import _get_mean_var -from scanpy.preprocessing._distributed import materialize_as_ndarray +from scanpy._settings import Verbosity, settings +from scanpy._utils import _doc_params, check_nonnegative_integers, view_to_actual from scanpy.experimental._docs import ( doc_adata, + doc_check_values, doc_dist_params, doc_genes_batch_chunk, - doc_check_values, - doc_layer, doc_inplace, + doc_layer, ) +from scanpy.get import _get_obs_rep +from scanpy.preprocessing._distributed import materialize_as_ndarray +from scanpy.preprocessing._utils import _get_mean_var + +if TYPE_CHECKING: + from numpy.typing import NDArray @nb.njit(parallel=True) @@ -129,15 +132,15 @@ def clac_clipped_res_dense(gene: int, cell: int) -> np.float64: def _highly_variable_pearson_residuals( adata: AnnData, theta: float = 100, - clip: Optional[float] = None, + clip: float | None = None, n_top_genes: int = 1000, - batch_key: Optional[str] = None, + batch_key: str | None = None, chunksize: int = 1000, check_values: bool = True, - layer: Optional[str] = None, + layer: str | None = None, subset: bool = False, inplace: bool = True, -) -> Optional[pd.DataFrame]: +) -> pd.DataFrame | None: view_to_actual(adata) X = _get_obs_rep(adata, layer=layer) computed_on = layer if layer else "adata.X" @@ -306,16 +309,16 @@ def highly_variable_genes( adata: AnnData, *, theta: float = 100, - clip: Optional[float] = None, - n_top_genes: Optional[int] = None, - batch_key: Optional[str] = None, + clip: float | None = None, + n_top_genes: int | None = None, + batch_key: str | None = None, chunksize: int = 1000, flavor: Literal["pearson_residuals"] = "pearson_residuals", check_values: bool = True, - layer: Optional[str] = None, + layer: str | None = None, subset: bool = False, inplace: bool = True, -) -> Optional[pd.DataFrame]: +) -> pd.DataFrame | None: """\ Select highly variable genes using analytic Pearson residuals [Lause21]_. diff --git a/scanpy/experimental/pp/_normalization.py b/scanpy/experimental/pp/_normalization.py index 1d67d71986..0287c3d8f6 100644 --- a/scanpy/experimental/pp/_normalization.py +++ b/scanpy/experimental/pp/_normalization.py @@ -1,8 +1,7 @@ from __future__ import annotations from types import MappingProxyType -from typing import Any -from collections.abc import Mapping +from typing import TYPE_CHECKING, Any from warnings import warn import numpy as np @@ -10,25 +9,28 @@ from scipy.sparse import issparse from ... import logging as logg -from ...get import _get_obs_rep, _set_obs_rep from ..._utils import ( - view_to_actual, - check_nonnegative_integers, Empty, - _empty, _doc_params, + _empty, + check_nonnegative_integers, + view_to_actual, ) -from ...preprocessing._pca import pca, _handle_mask_param -from ...preprocessing._docs import doc_mask_hvg from ...experimental._docs import ( doc_adata, - doc_dist_params, - doc_layer, doc_check_values, doc_copy, + doc_dist_params, doc_inplace, + doc_layer, doc_pca_chunk, ) +from ...get import _get_obs_rep, _set_obs_rep +from ...preprocessing._docs import doc_mask_hvg +from ...preprocessing._pca import _handle_mask_param, pca + +if TYPE_CHECKING: + from collections.abc import Mapping def _pearson_residuals(X, theta, clip, check_values, copy=False): diff --git a/scanpy/experimental/pp/_recipes.py b/scanpy/experimental/pp/_recipes.py index 7a99e5b3d9..995e839188 100644 --- a/scanpy/experimental/pp/_recipes.py +++ b/scanpy/experimental/pp/_recipes.py @@ -1,19 +1,24 @@ -from typing import Optional, Tuple -from anndata import AnnData -import pandas as pd +from __future__ import annotations + +from typing import TYPE_CHECKING + import numpy as np + from scanpy import experimental -from scanpy.preprocessing import pca +from scanpy._utils import _doc_params from scanpy.experimental._docs import ( doc_adata, + doc_check_values, doc_dist_params, doc_genes_batch_chunk, - doc_pca_chunk, - doc_layer, - doc_check_values, doc_inplace, + doc_pca_chunk, ) -from scanpy._utils import _doc_params +from scanpy.preprocessing import pca + +if TYPE_CHECKING: + import pandas as pd + from anndata import AnnData @_doc_params( @@ -28,16 +33,16 @@ def recipe_pearson_residuals( adata: AnnData, *, theta: float = 100, - clip: Optional[float] = None, + clip: float | None = None, n_top_genes: int = 1000, - batch_key: Optional[str] = None, + batch_key: str | None = None, chunksize: int = 1000, - n_comps: Optional[int] = 50, - random_state: Optional[float] = 0, + n_comps: int | None = 50, + random_state: float | None = 0, kwargs_pca: dict = {}, check_values: bool = True, inplace: bool = True, -) -> Optional[Tuple[AnnData, pd.DataFrame]]: +) -> tuple[AnnData, pd.DataFrame] | None: """\ Full pipeline for HVG selection and normalization by analytic Pearson residuals ([Lause21]_). diff --git a/scanpy/external/__init__.py b/scanpy/external/__init__.py index 9afc45dadf..6adb8cb9fc 100644 --- a/scanpy/external/__init__.py +++ b/scanpy/external/__init__.py @@ -1,10 +1,11 @@ -from . import tl -from . import pl -from . import pp -from . import exporting +from __future__ import annotations import sys + from .. import _utils +from . import exporting, pl, pp, tl _utils.annotate_doc_types(sys.modules[__name__], "scanpy") del sys, _utils + +__all__ = ["exporting", "pl", "pp", "tl"] diff --git a/scanpy/external/exporting.py b/scanpy/external/exporting.py index f32d051f66..7f7158bd4f 100644 --- a/scanpy/external/exporting.py +++ b/scanpy/external/exporting.py @@ -1,31 +1,37 @@ """\ Exporting to formats for other software. """ +from __future__ import annotations + import json import logging as logg from pathlib import Path -from typing import Union, Optional, Iterable, Mapping +from typing import TYPE_CHECKING -import numpy as np -import scipy.sparse import h5py import matplotlib.pyplot as plt -from anndata import AnnData +import numpy as np +import scipy.sparse from pandas.api.types import CategoricalDtype -from ..preprocessing._utils import _get_mean_var from .._utils import NeighborsView +from ..preprocessing._utils import _get_mean_var + +if TYPE_CHECKING: + from collections.abc import Iterable, Mapping + + from anndata import AnnData def spring_project( adata: AnnData, - project_dir: Union[Path, str], + project_dir: Path | str, embedding_method: str, - subplot_name: Optional[str] = None, - cell_groupings: Union[str, Iterable[str], None] = None, - custom_color_tracks: Union[str, Iterable[str], None] = None, + subplot_name: str | None = None, + cell_groupings: str | Iterable[str] | None = None, + custom_color_tracks: str | Iterable[str] | None = None, total_counts_key: str = "n_counts", - neighbors_key: Optional[str] = None, + neighbors_key: str | None = None, overwrite: bool = False, ): """\ @@ -307,11 +313,11 @@ def _write_graph(filename, n_nodes, edges): nodes = [{"name": int(i), "number": int(i)} for i in range(n_nodes)] edges = [{"source": int(i), "target": int(j), "distance": 0} for i, j in edges] out = {"nodes": nodes, "links": edges} - open(filename, "w").write(json.dumps(out, indent=4, separators=(",", ": "))) + Path(filename).write_text(json.dumps(out, indent=4, separators=(",", ": "))) def _write_edges(filename, edges): - with open(filename, "w") as f: + with Path(filename).open("w") as f: for e in edges: f.write("%i;%i\n" % (e[0], e[1])) @@ -322,12 +328,12 @@ def _write_color_tracks(ctracks, fname): line = name + "," + ",".join(["%.3f" % x for x in score]) out += [line] out = sorted(out, key=lambda x: x.split(",")[0]) - open(fname, "w").write("\n".join(out)) + Path(fname).write_text("\n".join(out)) def _frac_to_hex(frac): rgb = tuple(np.array(np.array(plt.cm.jet(frac)[:3]) * 255, dtype=int)) - return "#%02x%02x%02x" % rgb + return "#{:02x}{:02x}{:02x}".format(*rgb) def _get_color_stats_genes(color_stats, E, gene_list): @@ -366,8 +372,7 @@ def _get_color_stats_custom(color_stats, custom_colors): def _write_color_stats(filename, color_stats): - with open(filename, "w") as f: - f.write(json.dumps(color_stats, indent=4, sort_keys=True)) # .decode('utf-8')) + Path(filename).write_text(json.dumps(color_stats, indent=4, sort_keys=True)) def _build_categ_colors(categorical_coloring_data, cell_groupings): @@ -384,10 +389,9 @@ def _build_categ_colors(categorical_coloring_data, cell_groupings): def _write_cell_groupings(filename, categorical_coloring_data): - with open(filename, "w") as f: - f.write( - json.dumps(categorical_coloring_data, indent=4, sort_keys=True) - ) # .decode('utf-8')) + Path(filename).write_text( + json.dumps(categorical_coloring_data, indent=4, sort_keys=True) + ) def _export_PAGA_to_SPRING(adata, paga_coords, outpath): @@ -463,17 +467,17 @@ def _export_PAGA_to_SPRING(adata, paga_coords, outpath): import json - json.dump(PAGA_data, open(outpath, "w"), indent=4) + Path(outpath).write_text(json.dumps(PAGA_data, indent=4)) return None def cellbrowser( adata: AnnData, - data_dir: Union[Path, str], + data_dir: Path | str, data_name: str, - embedding_keys: Union[Iterable[str], Mapping[str, str], str, None] = None, - annot_keys: Union[Iterable[str], Mapping[str, str], None] = ( + embedding_keys: Iterable[str] | Mapping[str, str] | str | None = None, + annot_keys: Iterable[str] | Mapping[str, str] | None = ( "louvain", "percent_mito", "n_genes", @@ -482,8 +486,8 @@ def cellbrowser( cluster_field: str = "louvain", nb_marker: int = 50, skip_matrix: bool = False, - html_dir: Union[Path, str, None] = None, - port: Optional[int] = None, + html_dir: Path | str | None = None, + port: int | None = None, do_debug: bool = False, ): """\ diff --git a/scanpy/external/pl.py b/scanpy/external/pl.py index 78a3ee8788..021d9d96c9 100644 --- a/scanpy/external/pl.py +++ b/scanpy/external/pl.py @@ -1,14 +1,14 @@ -from typing import Union, List, Optional, Any, Tuple, Collection +from __future__ import annotations +from typing import TYPE_CHECKING, Any + +import matplotlib.pyplot as plt import numpy as np import pandas as pd -import matplotlib.pyplot as plt -from anndata import AnnData -from matplotlib.axes import Axes +from matplotlib.axes import Axes # noqa: TCH002 -from ..testing._doctests import doctest_needs from .._utils import _doc_params -from ..plotting import embedding +from ..plotting import _utils, embedding from ..plotting._docs import ( doc_adata_color_etc, doc_edges_arrows, @@ -16,9 +16,14 @@ doc_show_save_ax, ) from ..plotting._tools.scatterplots import _wraps_plot_scatter -from ..plotting import _utils +from ..testing._doctests import doctest_needs from .tl._wishbone import _anndata_to_wishbone +if TYPE_CHECKING: + from collections.abc import Collection + + from anndata import AnnData + @doctest_needs("phate") @_wraps_plot_scatter @@ -28,7 +33,7 @@ scatter_bulk=doc_scatter_embedding, show_save_ax=doc_show_save_ax, ) -def phate(adata, **kwargs) -> Union[List[Axes], None]: +def phate(adata, **kwargs) -> list[Axes] | None: """\ Scatter plot in PHATE basis. @@ -78,7 +83,7 @@ def phate(adata, **kwargs) -> Union[List[Axes], None]: scatter_bulk=doc_scatter_embedding, show_save_ax=doc_show_save_ax, ) -def trimap(adata, **kwargs) -> Union[Axes, List[Axes], None]: +def trimap(adata, **kwargs) -> Axes | list[Axes] | None: """\ Scatter plot in TriMap basis. @@ -105,7 +110,7 @@ def trimap(adata, **kwargs) -> Union[Axes, List[Axes], None]: ) def harmony_timeseries( adata, *, show: bool = True, return_fig: bool = False, **kwargs -) -> Union[Axes, List[Axes], None]: +) -> Axes | list[Axes] | None: """\ Scatter plot in Harmony force-directed layout basis. @@ -146,12 +151,12 @@ def harmony_timeseries( def sam( adata: AnnData, - projection: Union[str, np.ndarray] = "X_umap", - c: Optional[Union[str, np.ndarray]] = None, + projection: str | np.ndarray = "X_umap", + c: str | np.ndarray | None = None, cmap: str = "Spectral_r", linewidth: float = 0.0, edgecolor: str = "k", - axes: Optional[Axes] = None, + axes: Axes | None = None, colorbar: bool = True, s: float = 10.0, **kwargs: Any, @@ -249,11 +254,11 @@ def wishbone_marker_trajectory( smoothing_factor: int = 1, min_delta: float = 0.1, show_variance: bool = False, - figsize: Optional[Tuple[float, float]] = None, + figsize: tuple[float, float] | None = None, return_fig: bool = False, show: bool = True, - save: Optional[Union[str, bool]] = None, - ax: Optional[Axes] = None, + save: str | bool | None = None, + ax: Axes | None = None, ): """\ Plot marker trends along trajectory, and return trajectory branches for further @@ -332,10 +337,10 @@ def scrublet_score_distribution( adata, scale_hist_obs: str = "log", scale_hist_sim: str = "linear", - figsize: Optional[Tuple[float, float]] = (8, 3), + figsize: tuple[float, float] | None = (8, 3), return_fig: bool = False, show: bool = True, - save: Optional[Union[str, bool]] = None, + save: str | bool | None = None, ): """\ Plot histogram of doublet scores for observed transcriptomes and simulated doublets. diff --git a/scanpy/external/pp/__init__.py b/scanpy/external/pp/__init__.py index 50b07c002f..8a144c8595 100644 --- a/scanpy/external/pp/__init__.py +++ b/scanpy/external/pp/__init__.py @@ -1,8 +1,22 @@ -from ._mnn_correct import mnn_correct +from __future__ import annotations + from ._bbknn import bbknn from ._dca import dca from ._harmony_integrate import harmony_integrate +from ._hashsolo import hashsolo from ._magic import magic +from ._mnn_correct import mnn_correct from ._scanorama_integrate import scanorama_integrate -from ._hashsolo import hashsolo from ._scrublet import scrublet, scrublet_simulate_doublets + +__all__ = [ + "bbknn", + "dca", + "harmony_integrate", + "hashsolo", + "magic", + "mnn_correct", + "scanorama_integrate", + "scrublet", + "scrublet_simulate_doublets", +] diff --git a/scanpy/external/pp/_bbknn.py b/scanpy/external/pp/_bbknn.py index 55e3c3a34a..5d35a5ccbe 100644 --- a/scanpy/external/pp/_bbknn.py +++ b/scanpy/external/pp/_bbknn.py @@ -1,10 +1,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Union, Optional, Callable - -from anndata import AnnData +from typing import TYPE_CHECKING, Callable if TYPE_CHECKING: + from anndata import AnnData from sklearn.metrics import DistanceMetric from ...testing._doctests import doctest_needs @@ -17,12 +16,12 @@ def bbknn( use_rep: str = "X_pca", approx: bool = True, use_annoy: bool = True, - metric: Union[str, Callable, DistanceMetric] = "euclidean", + metric: str | Callable | DistanceMetric = "euclidean", copy: bool = False, *, neighbors_within_batch: int = 3, n_pcs: int = 50, - trim: Optional[int] = None, + trim: int | None = None, annoy_n_trees: int = 10, pynndescent_n_neighbors: int = 30, pynndescent_random_state: int = 0, diff --git a/scanpy/external/pp/_dca.py b/scanpy/external/pp/_dca.py index c4cd783944..30e2735a4a 100644 --- a/scanpy/external/pp/_dca.py +++ b/scanpy/external/pp/_dca.py @@ -1,10 +1,14 @@ +from __future__ import annotations + from types import MappingProxyType -from typing import Optional, Sequence, Union, Mapping, Any, Literal +from typing import TYPE_CHECKING, Any, Literal -from anndata import AnnData +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence -from ..._utils import AnyRandom + from anndata import AnnData + from ..._utils import AnyRandom _AEType = Literal["zinb-conddisp", "zinb", "nb-conddisp", "nb"] @@ -18,7 +22,7 @@ def dca( log1p: bool = True, # network args hidden_size: Sequence[int] = (64, 32, 64), - hidden_dropout: Union[float, Sequence[float]] = 0.0, + hidden_dropout: float | Sequence[float] = 0.0, batchnorm: bool = True, activation: str = "relu", init: str = "glorot_uniform", @@ -30,14 +34,14 @@ def dca( batch_size: int = 32, optimizer: str = "RMSprop", random_state: AnyRandom = 0, - threads: Optional[int] = None, - learning_rate: Optional[float] = None, + threads: int | None = None, + learning_rate: float | None = None, verbose: bool = False, training_kwds: Mapping[str, Any] = MappingProxyType({}), return_model: bool = False, return_info: bool = False, copy: bool = False, -) -> Optional[AnnData]: +) -> AnnData | None: """\ Deep count autoencoder [Eraslan18]_. diff --git a/scanpy/external/pp/_harmony_integrate.py b/scanpy/external/pp/_harmony_integrate.py index 74e1c9081d..63847f9ee1 100644 --- a/scanpy/external/pp/_harmony_integrate.py +++ b/scanpy/external/pp/_harmony_integrate.py @@ -1,12 +1,17 @@ """ Use harmony to integrate cells from different experiments. """ +from __future__ import annotations + +from typing import TYPE_CHECKING -from anndata import AnnData import numpy as np from ...testing._doctests import doctest_needs +if TYPE_CHECKING: + from anndata import AnnData + @doctest_needs("harmonypy") def harmony_integrate( @@ -37,7 +42,7 @@ def harmony_integrate( basis The name of the field in ``adata.obsm`` where the PCA table is stored. Defaults to ``'X_pca'``, which is the default for - ``sc.tl.pca()``. + ``sc.pp.pca()``. adjusted_basis The name of the field in ``adata.obsm`` where the adjusted PCA table will be stored after running this function. Defaults to @@ -60,7 +65,7 @@ def harmony_integrate( >>> import scanpy.external as sce >>> adata = sc.datasets.pbmc3k() >>> sc.pp.recipe_zheng17(adata) - >>> sc.tl.pca(adata) + >>> sc.pp.pca(adata) We now arbitrarily assign a batch metadata variable to each cell for the sake of example, but during real usage there would already diff --git a/scanpy/external/pp/_hashsolo.py b/scanpy/external/pp/_hashsolo.py index 801ff22d56..85ad9602df 100644 --- a/scanpy/external/pp/_hashsolo.py +++ b/scanpy/external/pp/_hashsolo.py @@ -14,10 +14,11 @@ barcodes should come from noise distributions. We test each of these hypotheses in a bayesian fashion, and select the most probable hypothesis. """ +from __future__ import annotations from itertools import product +from typing import TYPE_CHECKING -import anndata import numpy as np import pandas as pd from scipy.stats import norm @@ -25,6 +26,9 @@ from ..._utils import check_nonnegative_integers from ...testing._doctests import doctest_skip +if TYPE_CHECKING: + import anndata + def _calculate_log_likelihoods(data, number_of_noise_barcodes): """Calculate log likelihoods for each hypothesis, negative, singlet, doublet diff --git a/scanpy/external/pp/_magic.py b/scanpy/external/pp/_magic.py index 5f84f06906..793181d0eb 100644 --- a/scanpy/external/pp/_magic.py +++ b/scanpy/external/pp/_magic.py @@ -1,16 +1,22 @@ """\ Denoise high-dimensional data using MAGIC """ -from typing import Union, Sequence, Optional, Literal +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal -from anndata import AnnData from packaging import version from ... import logging as logg from ..._settings import settings -from ..._utils import AnyRandom from ...testing._doctests import doctest_needs +if TYPE_CHECKING: + from collections.abc import Sequence + + from anndata import AnnData + + from ..._utils import AnyRandom MIN_VERSION = "2.0" @@ -18,21 +24,21 @@ @doctest_needs("magic") def magic( adata: AnnData, - name_list: Union[Literal["all_genes", "pca_only"], Sequence[str], None] = None, + name_list: Literal["all_genes", "pca_only"] | Sequence[str] | None = None, *, knn: int = 5, - decay: Optional[float] = 1, - knn_max: Optional[int] = None, - t: Union[Literal["auto"], int] = 3, - n_pca: Optional[int] = 100, + decay: float | None = 1, + knn_max: int | None = None, + t: Literal["auto"] | int = 3, + n_pca: int | None = 100, solver: Literal["exact", "approximate"] = "exact", knn_dist: str = "euclidean", random_state: AnyRandom = None, - n_jobs: Optional[int] = None, + n_jobs: int | None = None, verbose: bool = False, - copy: Optional[bool] = None, + copy: bool | None = None, **kwargs, -) -> Optional[AnnData]: +) -> AnnData | None: """\ Markov Affinity-based Graph Imputation of Cells (MAGIC) API [vanDijk18]_. diff --git a/scanpy/external/pp/_mnn_correct.py b/scanpy/external/pp/_mnn_correct.py index d103fd1fcb..ef6015a6c4 100644 --- a/scanpy/external/pp/_mnn_correct.py +++ b/scanpy/external/pp/_mnn_correct.py @@ -1,36 +1,41 @@ -from typing import Union, Collection, Optional, Any, Sequence, Tuple, List, Literal +from __future__ import annotations -import numpy as np -import pandas as pd -from anndata import AnnData +from typing import TYPE_CHECKING, Any, Literal from ..._settings import settings +if TYPE_CHECKING: + from collections.abc import Collection, Sequence + + import numpy as np + import pandas as pd + from anndata import AnnData + def mnn_correct( - *datas: Union[AnnData, np.ndarray], - var_index: Optional[Collection[str]] = None, - var_subset: Optional[Collection[str]] = None, + *datas: AnnData | np.ndarray, + var_index: Collection[str] | None = None, + var_subset: Collection[str] | None = None, batch_key: str = "batch", index_unique: str = "-", - batch_categories: Optional[Collection[Any]] = None, + batch_categories: Collection[Any] | None = None, k: int = 20, sigma: float = 1.0, cos_norm_in: bool = True, cos_norm_out: bool = True, - svd_dim: Optional[int] = None, + svd_dim: int | None = None, var_adj: bool = True, compute_angle: bool = False, - mnn_order: Optional[Sequence[int]] = None, + mnn_order: Sequence[int] | None = None, svd_mode: Literal["svd", "rsvd", "irlb"] = "rsvd", do_concatenate: bool = True, save_raw: bool = False, - n_jobs: Optional[int] = None, + n_jobs: int | None = None, **kwargs, -) -> Tuple[ - Union[np.ndarray, AnnData], - List[pd.DataFrame], - Optional[List[Tuple[Optional[float], int]]], +) -> tuple[ + np.ndarray | AnnData, + list[pd.DataFrame], + list[tuple[float | None, int]] | None, ]: """\ Correct batch effects by matching mutual nearest neighbors [Haghverdi18]_ [Kang18]_. @@ -123,8 +128,8 @@ def mnn_correct( return datas, [], [] try: - from mnnpy import mnn_correct import mnnpy + from mnnpy import mnn_correct except ImportError: raise ImportError( "Please install the package mnnpy " diff --git a/scanpy/external/pp/_scanorama_integrate.py b/scanpy/external/pp/_scanorama_integrate.py index 8894dba378..4f57db0934 100644 --- a/scanpy/external/pp/_scanorama_integrate.py +++ b/scanpy/external/pp/_scanorama_integrate.py @@ -1,13 +1,17 @@ """ Use Scanorama to integrate cells from different experiments. """ +from __future__ import annotations -from anndata import AnnData -import numpy as np +from typing import TYPE_CHECKING +import numpy as np from ...testing._doctests import doctest_needs +if TYPE_CHECKING: + from anndata import AnnData + @doctest_needs("scanorama") def scanorama_integrate( @@ -44,7 +48,7 @@ def scanorama_integrate( basis The name of the field in ``adata.obsm`` where the PCA table is stored. Defaults to ``'X_pca'``, which is the default for - ``sc.tl.pca()``. + ``sc.pp.pca()``. adjusted_basis The name of the field in ``adata.obsm`` where the integrated embeddings will be stored after running this function. Defaults @@ -80,7 +84,7 @@ def scanorama_integrate( >>> import scanpy.external as sce >>> adata = sc.datasets.pbmc3k() >>> sc.pp.recipe_zheng17(adata) - >>> sc.tl.pca(adata) + >>> sc.pp.pca(adata) We now arbitrarily assign a batch metadata variable to each cell for the sake of example, but during real usage there would already diff --git a/scanpy/external/pp/_scrublet.py b/scanpy/external/pp/_scrublet.py index 8a41bcf909..3cf0c7e158 100644 --- a/scanpy/external/pp/_scrublet.py +++ b/scanpy/external/pp/_scrublet.py @@ -1,10 +1,10 @@ -from anndata import AnnData -from typing import Optional +from __future__ import annotations + import numpy as np import pandas as pd +from anndata import AnnData from scipy import sparse - from ... import logging as logg from ... import preprocessing as pp from ...get import _get_obs_rep @@ -12,7 +12,7 @@ def scrublet( adata: AnnData, - adata_sim: Optional[AnnData] = None, + adata_sim: AnnData | None = None, batch_key: str = None, sim_doublet_ratio: float = 2.0, expected_doublet_rate: float = 0.05, @@ -25,12 +25,12 @@ def scrublet( n_prin_comps: int = 30, use_approx_neighbors: bool = True, get_doublet_neighbor_parents: bool = False, - n_neighbors: Optional[int] = None, - threshold: Optional[float] = None, + n_neighbors: int | None = None, + threshold: float | None = None, verbose: bool = True, copy: bool = False, random_state: int = 0, -) -> Optional[AnnData]: +) -> AnnData | None: """\ Predict doublets using Scrublet [Wolock19]_. @@ -152,7 +152,7 @@ def scrublet( scores for observed transcriptomes and simulated doublets. """ try: - import scrublet as sl + import scrublet as sl # noqa: F401 except ImportError: raise ImportError( "Please install scrublet: `pip install scrublet` or `conda install scrublet`." @@ -279,7 +279,7 @@ def _run_scrublet(ad_obs, ad_sim=None): def _scrublet_call_doublets( adata_obs: AnnData, adata_sim: AnnData, - n_neighbors: Optional[int] = None, + n_neighbors: int | None = None, expected_doublet_rate: float = 0.05, stdev_doublet_rate: float = 0.02, mean_center: bool = True, @@ -288,7 +288,7 @@ def _scrublet_call_doublets( use_approx_neighbors: bool = True, knn_dist_metric: str = "euclidean", get_doublet_neighbor_parents: bool = False, - threshold: Optional[float] = None, + threshold: float | None = None, random_state: int = 0, verbose: bool = True, ) -> AnnData: diff --git a/scanpy/external/tl/__init__.py b/scanpy/external/tl/__init__.py index fc0c6c1be6..115b309a24 100644 --- a/scanpy/external/tl/__init__.py +++ b/scanpy/external/tl/__init__.py @@ -1,8 +1,23 @@ -from ._pypairs import cyclone, sandbag +from __future__ import annotations + +from ._harmony_timeseries import harmony_timeseries +from ._palantir import palantir, palantir_results from ._phate import phate from ._phenograph import phenograph -from ._palantir import palantir, palantir_results -from ._trimap import trimap -from ._harmony_timeseries import harmony_timeseries +from ._pypairs import cyclone, sandbag from ._sam import sam +from ._trimap import trimap from ._wishbone import wishbone + +__all__ = [ + "harmony_timeseries", + "palantir", + "palantir_results", + "phate", + "phenograph", + "cyclone", + "sandbag", + "sam", + "trimap", + "wishbone", +] diff --git a/scanpy/external/tl/_harmony_timeseries.py b/scanpy/external/tl/_harmony_timeseries.py index 26e2c53841..9c873868b7 100644 --- a/scanpy/external/tl/_harmony_timeseries.py +++ b/scanpy/external/tl/_harmony_timeseries.py @@ -2,26 +2,29 @@ Harmony time series for data visualization with augmented affinity matrix at discrete time points """ +from __future__ import annotations -from typing import Optional +from typing import TYPE_CHECKING import numpy as np import pandas as pd -from anndata import AnnData from ... import logging as logg from ...testing._doctests import doctest_needs +if TYPE_CHECKING: + from anndata import AnnData + @doctest_needs("harmony") def harmony_timeseries( adata: AnnData, tp: str, n_neighbors: int = 30, - n_components: Optional[int] = 1000, + n_components: int | None = 1000, n_jobs: int = -2, copy: bool = False, -) -> Optional[AnnData]: +) -> AnnData | None: """\ Harmony time series for data visualization with augmented affinity matrix at discrete time points [Nowotschin18i]_. diff --git a/scanpy/external/tl/_palantir.py b/scanpy/external/tl/_palantir.py index 968c625ec5..5510c0c287 100644 --- a/scanpy/external/tl/_palantir.py +++ b/scanpy/external/tl/_palantir.py @@ -1,14 +1,18 @@ """\ Run Diffusion maps using the adaptive anisotropic kernel """ -from typing import Optional, List +from __future__ import annotations + +from typing import TYPE_CHECKING import pandas as pd -from anndata import AnnData from ... import logging as logg from ...testing._doctests import doctest_needs +if TYPE_CHECKING: + from anndata import AnnData + @doctest_needs("palantir") def palantir( @@ -17,12 +21,12 @@ def palantir( knn: int = 30, alpha: float = 0, use_adjacency_matrix: bool = False, - distances_key: Optional[str] = None, + distances_key: str | None = None, n_eigs: int = None, impute_data: bool = True, n_steps: int = 3, copy: bool = False, -) -> Optional[AnnData]: +) -> AnnData | None: """\ Run Diffusion maps using the adaptive anisotropic kernel [Setty18]_. @@ -116,7 +120,7 @@ def palantir( *Principal component analysis* - >>> sc.tl.pca(adata, n_comps=300) + >>> sc.pp.pca(adata, n_comps=300) or, @@ -193,8 +197,8 @@ def palantir( _check_import() from palantir.utils import ( - run_diffusion_maps, determine_multiscale_space, + run_diffusion_maps, run_magic_imputation, ) @@ -243,14 +247,14 @@ def palantir_results( adata: AnnData, early_cell: str, ms_data: str = "X_palantir_multiscale", - terminal_states: List = None, + terminal_states: list = None, knn: int = 30, num_waypoints: int = 1200, n_jobs: int = -1, scale_components: bool = True, use_early_cell_as_start: bool = False, max_iterations: int = 25, -) -> Optional[AnnData]: +) -> AnnData | None: """\ **Running Palantir** @@ -309,6 +313,6 @@ def palantir_results( def _check_import(): try: - import palantir + import palantir # noqa: F401 except ImportError: raise ImportError("\nplease install palantir:\n\tpip install palantir") diff --git a/scanpy/external/tl/_phate.py b/scanpy/external/tl/_phate.py index f2e204bb1e..a4731cae69 100644 --- a/scanpy/external/tl/_phate.py +++ b/scanpy/external/tl/_phate.py @@ -1,15 +1,19 @@ """\ Embed high-dimensional data using PHATE """ -from typing import Optional, Union, Literal +from __future__ import annotations -from anndata import AnnData +from typing import TYPE_CHECKING, Literal from ... import logging as logg from ..._settings import settings -from ..._utils import AnyRandom from ...testing._doctests import doctest_needs +if TYPE_CHECKING: + from anndata import AnnData + + from ..._utils import AnyRandom + @doctest_needs("phate") def phate( @@ -18,18 +22,18 @@ def phate( k: int = 5, a: int = 15, n_landmark: int = 2000, - t: Union[int, str] = "auto", + t: int | str = "auto", gamma: float = 1.0, n_pca: int = 100, knn_dist: str = "euclidean", mds_dist: str = "euclidean", mds: Literal["classic", "metric", "nonmetric"] = "metric", - n_jobs: Optional[int] = None, + n_jobs: int | None = None, random_state: AnyRandom = None, - verbose: Union[bool, int, None] = None, + verbose: bool | int | None = None, copy: bool = False, **kwargs, -) -> Optional[AnnData]: +) -> AnnData | None: """\ PHATE [Moon17]_. diff --git a/scanpy/external/tl/_phenograph.py b/scanpy/external/tl/_phenograph.py index c635a7906e..809a7dcab3 100644 --- a/scanpy/external/tl/_phenograph.py +++ b/scanpy/external/tl/_phenograph.py @@ -1,22 +1,27 @@ """\ Perform clustering using PhenoGraph """ -from typing import Union, Tuple, Optional, Type, Any, Literal +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal -import numpy as np import pandas as pd from anndata import AnnData -from scipy.sparse import spmatrix from ... import logging as logg -from ...tools._leiden import MutableVertexPartition from ...testing._doctests import doctest_needs +if TYPE_CHECKING: + import numpy as np + from scipy.sparse import spmatrix + + from ...tools._leiden import MutableVertexPartition + @doctest_needs("phenograph") def phenograph( - adata: Union[AnnData, np.ndarray, spmatrix], - clustering_algo: Optional[Literal["louvain", "leiden"]] = "louvain", + adata: AnnData | np.ndarray | spmatrix, + clustering_algo: Literal["louvain", "leiden"] | None = "louvain", k: int = 30, directed: bool = False, prune: bool = False, @@ -32,14 +37,14 @@ def phenograph( q_tol: float = 1e-3, louvain_time_limit: int = 2000, nn_method: Literal["kdtree", "brute"] = "kdtree", - partition_type: Optional[Type[MutableVertexPartition]] = None, + partition_type: type[MutableVertexPartition] | None = None, resolution_parameter: float = 1, n_iterations: int = -1, use_weights: bool = True, - seed: Optional[int] = None, + seed: int | None = None, copy: bool = False, **kargs: Any, -) -> Tuple[Optional[np.ndarray], spmatrix, Optional[float]]: +) -> tuple[np.ndarray | None, spmatrix, float | None]: """\ PhenoGraph clustering [Levine15]_. @@ -146,7 +151,7 @@ def phenograph( Then do PCA: - >>> sc.tl.pca(adata, n_comps=100) + >>> sc.pp.pca(adata, n_comps=100) Compute phenograph clusters: @@ -181,7 +186,7 @@ def phenograph( >>> dframe = pd.DataFrame(df) >>> dframe.index, dframe.columns = (map(str, dframe.index), map(str, dframe.columns)) >>> adata = AnnData(dframe) - >>> sc.tl.pca(adata, n_comps=20) + >>> sc.pp.pca(adata, n_comps=20) >>> sce.tl.phenograph(adata, clustering_algo="leiden", k=50) >>> sc.tl.tsne(adata, random_state=1) >>> sc.pl.tsne( @@ -205,15 +210,13 @@ def phenograph( try: data = adata.obsm["X_pca"] except KeyError: - raise KeyError("Please run `sc.tl.pca` on `adata` and try again!") + raise KeyError("Please run `sc.pp.pca` on `adata` and try again!") else: data = adata copy = True comm_key = ( - "pheno_{}".format(clustering_algo) - if clustering_algo in ["louvain", "leiden"] - else "" + f"pheno_{clustering_algo}" if clustering_algo in ["louvain", "leiden"] else "" ) ig_key = "pheno_{}_ig".format("jaccard" if jaccard else "gaussian") q_key = "pheno_{}_q".format("jaccard" if jaccard else "gaussian") diff --git a/scanpy/external/tl/_pypairs.py b/scanpy/external/tl/_pypairs.py index 5a5ba37f87..9241ea86d7 100644 --- a/scanpy/external/tl/_pypairs.py +++ b/scanpy/external/tl/_pypairs.py @@ -1,15 +1,19 @@ """\ Calculate scores based on relative expression change of maker pairs """ -from typing import Mapping, Optional, Collection, Union, Tuple, List, Dict +from __future__ import annotations + +from collections.abc import Collection, Mapping +from typing import TYPE_CHECKING, Union -import pandas as pd -from anndata import AnnData from packaging import version from ..._settings import settings from ...testing._doctests import doctest_needs +if TYPE_CHECKING: + import pandas as pd + from anndata import AnnData Genes = Collection[Union[str, int, bool]] @@ -17,12 +21,12 @@ @doctest_needs("pypairs") def sandbag( adata: AnnData, - annotation: Optional[Mapping[str, Genes]] = None, + annotation: Mapping[str, Genes] | None = None, *, fraction: float = 0.65, - filter_genes: Optional[Genes] = None, - filter_samples: Optional[Genes] = None, -) -> Dict[str, List[Tuple[str, str]]]: + filter_genes: Genes | None = None, + filter_samples: Genes | None = None, +) -> dict[str, list[tuple[str, str]]]: """\ Calculate marker pairs of genes. [Scialdone15]_ [Fechtner18]_. @@ -62,8 +66,8 @@ def sandbag( >>> marker_pairs = sandbag(adata, fraction=0.5) """ _check_import() - from pypairs.pairs import sandbag from pypairs import settings as pp_settings + from pypairs.pairs import sandbag pp_settings.verbosity = settings.verbosity pp_settings.n_jobs = settings.n_jobs @@ -82,7 +86,7 @@ def sandbag( def cyclone( adata: AnnData, - marker_pairs: Optional[Mapping[str, Collection[Tuple[str, str]]]] = None, + marker_pairs: Mapping[str, Collection[tuple[str, str]]] | None = None, *, iterations: int = 1000, min_iter: int = 100, @@ -125,8 +129,8 @@ def cyclone( Where category S is assigned to samples where G1 and G2M score are < 0.5. """ _check_import() - from pypairs.pairs import cyclone from pypairs import settings as pp_settings + from pypairs.pairs import cyclone pp_settings.verbosity = settings.verbosity pp_settings.n_jobs = settings.n_jobs diff --git a/scanpy/external/tl/_sam.py b/scanpy/external/tl/_sam.py index b3db2b6333..f74da3af9a 100644 --- a/scanpy/external/tl/_sam.py +++ b/scanpy/external/tl/_sam.py @@ -2,16 +2,14 @@ Run the Self-Assembling Manifold algorithm """ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Union, Tuple, Any, Literal -from anndata import AnnData +from typing import TYPE_CHECKING, Literal if TYPE_CHECKING: + from anndata import AnnData from samalg import SAM from ... import logging as logg - - from ...testing._doctests import doctest_needs @@ -25,12 +23,12 @@ def sam( standardization: Literal["Normalizer", "StandardScaler", "None"] = "StandardScaler", weight_pcs: bool = False, sparse_pca: bool = False, - n_pcs: Optional[int] = 150, - n_genes: Optional[int] = 3000, + n_pcs: int | None = 150, + n_genes: int | None = 3000, projection: Literal["umap", "tsne", "None"] = "umap", inplace: bool = True, verbose: bool = True, -) -> Union[SAM, Tuple[SAM, AnnData]]: +) -> SAM | tuple[SAM, AnnData]: """\ Self-Assembling Manifolds single-cell RNA sequencing analysis tool [Tarashansky19]_. diff --git a/scanpy/external/tl/_trimap.py b/scanpy/external/tl/_trimap.py index 6f36ca2451..d9a01a7921 100644 --- a/scanpy/external/tl/_trimap.py +++ b/scanpy/external/tl/_trimap.py @@ -1,15 +1,19 @@ """\ Embed high-dimensional data using TriMap """ -from typing import Optional, Union, Literal +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal -from anndata import AnnData import scipy.sparse as scp from ... import logging as logg from ..._settings import settings from ...testing._doctests import doctest_needs +if TYPE_CHECKING: + from anndata import AnnData + @doctest_needs("trimap") def trimap( @@ -22,9 +26,9 @@ def trimap( weight_adj: float = 500.0, lr: float = 1000.0, n_iters: int = 400, - verbose: Union[bool, int, None] = None, + verbose: bool | int | None = None, copy: bool = False, -) -> Optional[AnnData]: +) -> AnnData | None: """\ TriMap: Large-scale Dimensionality Reduction Using Triplets [Amid19]_. diff --git a/scanpy/external/tl/_wishbone.py b/scanpy/external/tl/_wishbone.py index 6a0d720115..4f9f7d7f61 100644 --- a/scanpy/external/tl/_wishbone.py +++ b/scanpy/external/tl/_wishbone.py @@ -1,13 +1,19 @@ +from __future__ import annotations + import collections.abc as cabc -from typing import Iterable, Collection, Union +from typing import TYPE_CHECKING import numpy as np import pandas as pd -from anndata import AnnData from ... import logging from ...testing._doctests import doctest_needs +if TYPE_CHECKING: + from collections.abc import Collection, Iterable + + from anndata import AnnData + @doctest_needs("wishbone") def wishbone( @@ -16,7 +22,7 @@ def wishbone( branch: bool = True, k: int = 15, components: Iterable[int] = (1, 2, 3), - num_waypoints: Union[int, Collection] = 250, + num_waypoints: int | Collection = 250, ): """\ Wishbone identifies bifurcating developmental trajectories from single-cell data diff --git a/scanpy/get/__init__.py b/scanpy/get/__init__.py index 7ec81eb0cf..08567cfc6e 100644 --- a/scanpy/get/__init__.py +++ b/scanpy/get/__init__.py @@ -1,5 +1,21 @@ # Public -from .get import rank_genes_groups_df, obs_df, var_df - # Private -from .get import _get_obs_rep, _set_obs_rep, _check_mask +from __future__ import annotations + +from .get import ( + _check_mask, + _get_obs_rep, + _set_obs_rep, + obs_df, + rank_genes_groups_df, + var_df, +) + +__all__ = [ + "_check_mask", + "_get_obs_rep", + "_set_obs_rep", + "obs_df", + "rank_genes_groups_df", + "var_df", +] diff --git a/scanpy/get/get.py b/scanpy/get/get.py index 1b79ab7953..cb8b262dfb 100644 --- a/scanpy/get/get.py +++ b/scanpy/get/get.py @@ -1,14 +1,17 @@ """This module contains helper functions for accessing data.""" from __future__ import annotations -from typing import Optional, Iterable, Tuple, Union, List, Literal +from typing import TYPE_CHECKING, Literal import numpy as np import pandas as pd -from numpy.typing import NDArray +from anndata import AnnData from scipy.sparse import spmatrix -from anndata import AnnData +if TYPE_CHECKING: + from collections.abc import Iterable + + from numpy.typing import NDArray # -------------------------------------------------------------------------------- # Plotting data helpers @@ -18,13 +21,13 @@ # TODO: implement diffxpy method, make singledispatch def rank_genes_groups_df( adata: AnnData, - group: Union[str, Iterable[str]], + group: str | Iterable[str], *, key: str = "rank_genes_groups", - pval_cutoff: Optional[float] = None, - log2fc_min: Optional[float] = None, - log2fc_max: Optional[float] = None, - gene_symbols: Optional[str] = None, + pval_cutoff: float | None = None, + log2fc_min: float | None = None, + log2fc_max: float | None = None, + gene_symbols: str | None = None, ) -> pd.DataFrame: """\ :func:`scanpy.tl.rank_genes_groups` results in the form of a @@ -104,10 +107,10 @@ def _check_indices( dim_df: pd.DataFrame, alt_index: pd.Index, dim: Literal["obs", "var"], - keys: List[str], - alias_index: Optional[pd.Index] = None, + keys: list[str], + alias_index: pd.Index | None = None, use_raw: bool = False, -) -> Tuple[List[str], List[str], List[str]]: +) -> tuple[list[str], list[str], list[str]]: """Common logic for checking indices for obs_df and var_df.""" if use_raw: alt_repr = "adata.raw" @@ -182,7 +185,7 @@ def _check_indices( def _get_array_values( X, dim_names: pd.Index, - keys: List[str], + keys: list[str], axis: Literal[0, 1], backed: bool, ): @@ -212,7 +215,7 @@ def _get_array_values( def obs_df( adata: AnnData, keys: Iterable[str] = (), - obsm_keys: Iterable[Tuple[str, int]] = (), + obsm_keys: Iterable[tuple[str, int]] = (), *, layer: str = None, gene_symbols: str = None, @@ -330,7 +333,7 @@ def obs_df( def var_df( adata: AnnData, keys: Iterable[str] = (), - varm_keys: Iterable[Tuple[str, int]] = (), + varm_keys: Iterable[tuple[str, int]] = (), *, layer: str = None, ) -> pd.DataFrame: diff --git a/scanpy/logging.py b/scanpy/logging.py index c1eef93cdf..9f836ec977 100644 --- a/scanpy/logging.py +++ b/scanpy/logging.py @@ -4,11 +4,11 @@ import logging import sys -from functools import update_wrapper, partial -from logging import CRITICAL, ERROR, WARNING, INFO, DEBUG -from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Optional, IO import warnings +from datetime import datetime, timedelta, timezone +from functools import partial, update_wrapper +from logging import CRITICAL, DEBUG, ERROR, INFO, WARNING +from typing import IO, TYPE_CHECKING import anndata.logging @@ -31,9 +31,9 @@ def log( level: int, msg: str, *, - extra: Optional[dict] = None, + extra: dict | None = None, time: datetime = None, - deep: Optional[str] = None, + deep: str | None = None, ) -> datetime: from ._settings import settings @@ -163,7 +163,7 @@ def print_header(*, file=None): ) -def print_versions(*, file: Optional[IO[str]] = None): +def print_versions(*, file: IO[str] | None = None): """\ Print versions of imported packages, OS, and jupyter environment. @@ -218,8 +218,8 @@ def error( msg: str, *, time: datetime = None, - deep: Optional[str] = None, - extra: Optional[dict] = None, + deep: str | None = None, + extra: dict | None = None, ) -> datetime: """\ Log message with specific level and return current time. diff --git a/scanpy/metrics/__init__.py b/scanpy/metrics/__init__.py index 3c7524aa69..526ac56a80 100644 --- a/scanpy/metrics/__init__.py +++ b/scanpy/metrics/__init__.py @@ -1,3 +1,7 @@ +from __future__ import annotations + from ._gearys_c import gearys_c from ._metrics import confusion_matrix from ._morans_i import morans_i + +__all__ = ["gearys_c", "morans_i", "confusion_matrix"] diff --git a/scanpy/metrics/_common.py b/scanpy/metrics/_common.py index cae256a66c..5f2d37183b 100644 --- a/scanpy/metrics/_common.py +++ b/scanpy/metrics/_common.py @@ -1,16 +1,18 @@ from __future__ import annotations -from functools import singledispatch -from typing import TypeVar import warnings +from functools import singledispatch +from typing import TYPE_CHECKING, TypeVar import numpy as np import pandas as pd -from numpy.typing import NDArray from scipy import sparse from .._compat import DaskArray +if TYPE_CHECKING: + from numpy.typing import NDArray + @singledispatch def _resolve_vals(val: NDArray | sparse.spmatrix) -> NDArray | sparse.csr_matrix: diff --git a/scanpy/metrics/_gearys_c.py b/scanpy/metrics/_gearys_c.py index 8d40760712..7027754f3a 100644 --- a/scanpy/metrics/_gearys_c.py +++ b/scanpy/metrics/_gearys_c.py @@ -1,29 +1,31 @@ from __future__ import annotations from functools import singledispatch -from typing import Optional, Union +from typing import TYPE_CHECKING -from anndata import AnnData import numba import numpy as np from scipy import sparse -from ..get import _get_obs_rep from .._compat import fullname -from ._common import _resolve_vals, _check_vals +from ..get import _get_obs_rep +from ._common import _check_vals, _resolve_vals + +if TYPE_CHECKING: + from anndata import AnnData @singledispatch def gearys_c( adata: AnnData, *, - vals: Optional[Union[np.ndarray, sparse.spmatrix]] = None, - use_graph: Optional[str] = None, - layer: Optional[str] = None, - obsm: Optional[str] = None, - obsp: Optional[str] = None, + vals: np.ndarray | sparse.spmatrix | None = None, + use_graph: str | None = None, + layer: str | None = None, + obsm: str | None = None, + obsp: str | None = None, use_raw: bool = False, -) -> Union[np.ndarray, float]: +) -> np.ndarray | float: r""" Calculate `Geary's C `_, as used by `VISION `_. diff --git a/scanpy/metrics/_metrics.py b/scanpy/metrics/_metrics.py index e001df7194..c75e33a3c2 100644 --- a/scanpy/metrics/_metrics.py +++ b/scanpy/metrics/_metrics.py @@ -1,18 +1,23 @@ """ Metrics which don't quite deserve their own file. """ -from typing import Optional, Sequence, Union +from __future__ import annotations +from typing import TYPE_CHECKING + +import numpy as np import pandas as pd -from pandas.api.types import CategoricalDtype from natsort import natsorted -import numpy as np +from pandas.api.types import CategoricalDtype + +if TYPE_CHECKING: + from collections.abc import Sequence def confusion_matrix( - orig: Union[pd.Series, np.ndarray, Sequence], - new: Union[pd.Series, np.ndarray, Sequence], - data: Optional[pd.DataFrame] = None, + orig: pd.Series | np.ndarray | Sequence, + new: pd.Series | np.ndarray | Sequence, + data: pd.DataFrame | None = None, *, normalize: bool = True, ) -> pd.DataFrame: diff --git a/scanpy/metrics/_morans_i.py b/scanpy/metrics/_morans_i.py index eecbefdae7..a069b44bb3 100644 --- a/scanpy/metrics/_morans_i.py +++ b/scanpy/metrics/_morans_i.py @@ -2,29 +2,31 @@ from __future__ import annotations from functools import singledispatch -from typing import Union, Optional +from typing import TYPE_CHECKING -from anndata import AnnData import numpy as np -from scipy import sparse from numba import njit, prange +from scipy import sparse -from ..get import _get_obs_rep from .._compat import fullname -from ._common import _resolve_vals, _check_vals +from ..get import _get_obs_rep +from ._common import _check_vals, _resolve_vals + +if TYPE_CHECKING: + from anndata import AnnData @singledispatch def morans_i( adata: AnnData, *, - vals: Optional[Union[np.ndarray, sparse.spmatrix]] = None, - use_graph: Optional[str] = None, - layer: Optional[str] = None, - obsm: Optional[str] = None, - obsp: Optional[str] = None, + vals: np.ndarray | sparse.spmatrix | None = None, + use_graph: str | None = None, + layer: str | None = None, + obsm: str | None = None, + obsp: str | None = None, use_raw: bool = False, -) -> Union[np.ndarray, float]: +) -> np.ndarray | float: r""" Calculate Moran’s I Global Autocorrelation Statistic. diff --git a/scanpy/neighbors/__init__.py b/scanpy/neighbors/__init__.py index d6322e478c..fbda62718e 100644 --- a/scanpy/neighbors/__init__.py +++ b/scanpy/neighbors/__init__.py @@ -1,41 +1,33 @@ from __future__ import annotations -from types import MappingProxyType -from typing import ( - TYPE_CHECKING, - TypedDict, - Union, - Optional, - Any, - NamedTuple, - Literal, - get_args, -) -from collections.abc import Mapping, MutableMapping, Callable +from collections.abc import Callable, Mapping, MutableMapping from textwrap import indent +from types import MappingProxyType +from typing import TYPE_CHECKING, Any, Literal, NamedTuple, TypedDict, get_args from warnings import warn import numpy as np import scipy -from anndata import AnnData -from scipy.sparse import issparse, csr_matrix +from scipy.sparse import csr_matrix, issparse from sklearn.utils import check_random_state if TYPE_CHECKING: + from anndata import AnnData from igraph import Graph + from ._types import KnnTransformerLike +from .. import _utils +from .. import logging as logg +from .._settings import settings +from .._utils import AnyRandom, NeighborsView, _doc_params from . import _connectivity -from ._types import _Metric, _MetricFn, _Method, _KnownTransformer from ._common import ( - _has_self_column, _get_indices_distances_from_sparse_matrix, _get_sparse_matrix_from_indices_distances, ) -from .. import logging as logg -from .. import _utils, settings -from .._utils import _doc_params, AnyRandom, NeighborsView -from ..tools._utils import _choose_representation, doc_use_rep, doc_n_pcs +from ._doc import doc_n_pcs, doc_use_rep +from ._types import _KnownTransformer, _Method, _Metric, _MetricFn RPForestDict = Mapping[str, Mapping[str, np.ndarray]] @@ -53,7 +45,7 @@ class KwdsForTransformer(TypedDict): """ n_neighbors: int - metric: Union[_Metric, _MetricFn] + metric: _Metric | _MetricFn metric_params: Mapping[str, Any] random_state: AnyRandom @@ -62,18 +54,18 @@ class KwdsForTransformer(TypedDict): def neighbors( adata: AnnData, n_neighbors: int = 15, - n_pcs: Optional[int] = None, + n_pcs: int | None = None, *, - use_rep: Optional[str] = None, + use_rep: str | None = None, knn: bool = True, method: _Method = "umap", transformer: KnnTransformerLike | _KnownTransformer | None = None, - metric: Union[_Metric, _MetricFn] = "euclidean", + metric: _Metric | _MetricFn = "euclidean", metric_kwds: Mapping[str, Any] = MappingProxyType({}), random_state: AnyRandom = 0, - key_added: Optional[str] = None, + key_added: str | None = None, copy: bool = False, -) -> Optional[AnnData]: +) -> AnnData | None: """\ Computes the nearest neighbors distance matrix and a neighborhood graph of observations [McInnes18]_. @@ -286,8 +278,8 @@ def __init__( shape: tuple[int, int], DC_start: int = 0, DC_end: int = -1, - rows: Optional[MutableMapping[Any, np.ndarray]] = None, - restrict_array: Optional[np.ndarray] = None, + rows: MutableMapping[Any, np.ndarray] | None = None, + restrict_array: np.ndarray | None = None, ): self.get_row = get_row self.shape = shape @@ -353,19 +345,19 @@ class Neighbors: def __init__( self, adata: AnnData, - n_dcs: Optional[int] = None, - neighbors_key: Optional[str] = None, + n_dcs: int | None = None, + neighbors_key: str | None = None, ): self._adata = adata self._init_iroot() # use the graph in adata info_str = "" - self.knn: Optional[bool] = None - self._distances: Union[np.ndarray, csr_matrix, None] = None - self._connectivities: Union[np.ndarray, csr_matrix, None] = None - self._transitions_sym: Union[np.ndarray, csr_matrix, None] = None - self._number_connected_components: Optional[int] = None - self._rp_forest: Optional[RPForestDict] = None + self.knn: bool | None = None + self._distances: np.ndarray | csr_matrix | None = None + self._connectivities: np.ndarray | csr_matrix | None = None + self._transitions_sym: np.ndarray | csr_matrix | None = None + self._number_connected_components: int | None = None + self._rp_forest: RPForestDict | None = None if neighbors_key is None: neighbors_key = "neighbors" if neighbors_key in adata.uns: @@ -382,7 +374,7 @@ def __init__( self.n_neighbors = neighbors["params"]["n_neighbors"] else: - def count_nonzero(a: Union[np.ndarray, csr_matrix]) -> int: + def count_nonzero(a: np.ndarray | csr_matrix) -> int: return a.count_nonzero() if issparse(a) else np.count_nonzero(a) # estimating n_neighbors @@ -409,10 +401,8 @@ def count_nonzero(a: Union[np.ndarray, csr_matrix]) -> int: if n_dcs is not None: if n_dcs > len(self._eigen_values): raise ValueError( - "Cannot instantiate using `n_dcs`={}. " - "Compute diffmap/spectrum with more components first.".format( - n_dcs - ) + f"Cannot instantiate using `n_dcs`={n_dcs}. " + "Compute diffmap/spectrum with more components first." ) self._eigen_values = self._eigen_values[:n_dcs] self._eigen_basis = self._eigen_basis[:, :n_dcs] @@ -426,21 +416,21 @@ def count_nonzero(a: Union[np.ndarray, csr_matrix]) -> int: logg.debug(f" initialized {info_str}") @property - def rp_forest(self) -> Optional[RPForestDict]: + def rp_forest(self) -> RPForestDict | None: return self._rp_forest @property - def distances(self) -> Union[np.ndarray, csr_matrix, None]: + def distances(self) -> np.ndarray | csr_matrix | None: """Distances between data points (sparse matrix).""" return self._distances @property - def connectivities(self) -> Union[np.ndarray, csr_matrix, None]: + def connectivities(self) -> np.ndarray | csr_matrix | None: """Connectivities between data points (sparse matrix).""" return self._connectivities @property - def transitions(self) -> Union[np.ndarray, csr_matrix]: + def transitions(self) -> np.ndarray | csr_matrix: """Transition matrix (sparse matrix). Is conjugate to the symmetrized transition matrix via:: @@ -461,7 +451,7 @@ def transitions(self) -> Union[np.ndarray, csr_matrix]: return self.Z @ self.transitions_sym @ Zinv @property - def transitions_sym(self) -> Union[np.ndarray, csr_matrix, None]: + def transitions_sym(self) -> np.ndarray | csr_matrix | None: """Symmetrized transition matrix (sparse matrix). Is conjugate to the transition matrix via:: @@ -501,13 +491,13 @@ def to_igraph(self) -> Graph: def compute_neighbors( self, n_neighbors: int = 30, - n_pcs: Optional[int] = None, + n_pcs: int | None = None, *, - use_rep: Optional[str] = None, + use_rep: str | None = None, knn: bool = True, method: _Method = "umap", transformer: KnnTransformerLike | _KnownTransformer | None = None, - metric: Union[_Metric, _MetricFn] = "euclidean", + metric: _Metric | _MetricFn = "euclidean", metric_kwds: Mapping[str, Any] = MappingProxyType({}), random_state: AnyRandom = 0, ) -> None: @@ -527,6 +517,8 @@ def compute_neighbors( ------- Writes sparse graph attributes `.distances` and `.connectivities`. """ + from ..tools._utils import _choose_representation + start_neighbors = logg.debug("computing neighbors") if transformer is not None and not isinstance(transformer, str): n_neighbors = transformer.get_params()["n_neighbors"] @@ -731,7 +723,7 @@ def compute_transitions(self, density_normalize: bool = True): def compute_eigen( self, n_comps: int = 15, - sym: Optional[bool] = None, + sym: bool | None = None, sort: Literal["decrease", "increase"] = "decrease", random_state: AnyRandom = 0, ): diff --git a/scanpy/neighbors/_backends/_common.py b/scanpy/neighbors/_backends/_common.py index 2be2bbf2bb..f9c73cbbfe 100644 --- a/scanpy/neighbors/_backends/_common.py +++ b/scanpy/neighbors/_backends/_common.py @@ -1,3 +1,6 @@ +from __future__ import annotations + + class TransformerChecksMixin: def _transform_checks(self, X, *fitted_props, **check_params): from sklearn.utils.validation import check_is_fitted diff --git a/scanpy/neighbors/_backends/rapids.py b/scanpy/neighbors/_backends/rapids.py index 3be0ea7b98..78a6bb7359 100644 --- a/scanpy/neighbors/_backends/rapids.py +++ b/scanpy/neighbors/_backends/rapids.py @@ -1,17 +1,20 @@ from __future__ import annotations -from typing import Any, Literal -from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, Literal import numpy as np -from numpy.typing import ArrayLike -from scipy.sparse import csr_matrix from sklearn.base import BaseEstimator, TransformerMixin, check_is_fitted from sklearn.exceptions import NotFittedError from ..._settings import settings from ._common import TransformerChecksMixin +if TYPE_CHECKING: + from collections.abc import Mapping + + from numpy.typing import ArrayLike + from scipy.sparse import csr_matrix + _Algorithm = Literal["rbc", "brute", "ivfflat", "ivfpq"] _Metric = Literal[ "l1", diff --git a/scanpy/neighbors/_common.py b/scanpy/neighbors/_common.py index ce1b45d0ce..18c2eb2195 100644 --- a/scanpy/neighbors/_common.py +++ b/scanpy/neighbors/_common.py @@ -1,13 +1,16 @@ from __future__ import annotations -from math import dist + +from typing import TYPE_CHECKING from warnings import warn import numpy as np -from numpy.typing import NDArray from scipy.sparse import csr_matrix from scanpy._utils.compute.is_constant import is_constant +if TYPE_CHECKING: + from numpy.typing import NDArray + def _has_self_column( indices: NDArray[np.int32 | np.int64], diff --git a/scanpy/neighbors/_connectivity.py b/scanpy/neighbors/_connectivity.py index 9858a422f2..f6788d33d2 100644 --- a/scanpy/neighbors/_connectivity.py +++ b/scanpy/neighbors/_connectivity.py @@ -5,14 +5,13 @@ import numpy as np from numpy.typing import NDArray -from scipy.sparse import issparse, csr_matrix, coo_matrix +from scipy.sparse import coo_matrix, csr_matrix, issparse from ._common import ( _get_indices_distances_from_dense_matrix, _get_indices_distances_from_sparse_matrix, ) - D = TypeVar("D", NDArray[np.float32], csr_matrix) diff --git a/scanpy/neighbors/_doc.py b/scanpy/neighbors/_doc.py new file mode 100644 index 0000000000..28bcddca02 --- /dev/null +++ b/scanpy/neighbors/_doc.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +doc_use_rep = """\ +use_rep + Use the indicated representation. `'X'` or any key for `.obsm` is valid. + If `None`, the representation is chosen automatically: + For `.n_vars` < :attr:`~scanpy._settings.ScanpyConfig.N_PCS` (default: 50), `.X` is used, otherwise 'X_pca' is used. + If 'X_pca' is not present, it’s computed with default parameters or `n_pcs` if present.\ +""" + +doc_n_pcs = """\ +n_pcs + Use this many PCs. If `n_pcs==0` use `.X` if `use_rep is None`.\ +""" diff --git a/scanpy/neighbors/_types.py b/scanpy/neighbors/_types.py index 152eee6ce2..8d1c50dc73 100644 --- a/scanpy/neighbors/_types.py +++ b/scanpy/neighbors/_types.py @@ -1,13 +1,16 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal, Protocol, Union as _U, Callable as _C +from typing import TYPE_CHECKING, Any, Literal, Protocol +from typing import Callable as _C +from typing import Union as _U import numpy as np -from scipy.sparse import spmatrix if TYPE_CHECKING: from typing import Self + from scipy.sparse import spmatrix + _Method = Literal["umap", "gauss"] diff --git a/scanpy/plotting/__init__.py b/scanpy/plotting/__init__.py index 74a32de43e..0116771407 100644 --- a/scanpy/plotting/__init__.py +++ b/scanpy/plotting/__init__.py @@ -1,45 +1,100 @@ +from __future__ import annotations + +from . import palettes from ._anndata import ( - scatter, - violin, - ranking, clustermap, - tracksplot, - dendrogram, correlation_matrix, + dendrogram, heatmap, + ranking, + scatter, + tracksplot, + violin, ) from ._dotplot import DotPlot, dotplot from ._matrixplot import MatrixPlot, matrixplot -from ._stacked_violin import StackedViolin, stacked_violin from ._preprocessing import filter_genes_dispersion, highly_variable_genes - -from ._tools.scatterplots import ( - embedding, - pca, - diffmap, - draw_graph, - tsne, - umap, - spatial, -) -from ._tools import pca_loadings, pca_scatter, pca_overview, pca_variance_ratio -from ._tools.paga import paga, paga_adjacency, paga_compare, paga_path -from ._tools import dpt_timeseries, dpt_groups_pseudotime -from ._tools import rank_genes_groups, rank_genes_groups_violin +from ._qc import highest_expr_genes +from ._rcmod import set_rcParams_defaults, set_rcParams_scanpy +from ._stacked_violin import StackedViolin, stacked_violin from ._tools import ( + dpt_groups_pseudotime, + dpt_timeseries, + embedding_density, + pca_loadings, + pca_overview, + pca_scatter, + pca_variance_ratio, + rank_genes_groups, rank_genes_groups_dotplot, rank_genes_groups_heatmap, - rank_genes_groups_stacked_violin, rank_genes_groups_matrixplot, + rank_genes_groups_stacked_violin, rank_genes_groups_tracksplot, + rank_genes_groups_violin, + sim, ) -from ._tools import sim -from ._tools import embedding_density - -from ._rcmod import set_rcParams_scanpy, set_rcParams_defaults -from . import palettes - -from ._utils import matrix -from ._utils import timeseries, timeseries_subplot, timeseries_as_heatmap +from ._tools.paga import paga, paga_adjacency, paga_compare, paga_path +from ._tools.scatterplots import ( + diffmap, + draw_graph, + embedding, + pca, + spatial, + tsne, + umap, +) +from ._utils import matrix, timeseries, timeseries_as_heatmap, timeseries_subplot -from ._qc import highest_expr_genes +__all__ = [ + "palettes", + "clustermap", + "correlation_matrix", + "dendrogram", + "heatmap", + "ranking", + "scatter", + "tracksplot", + "violin", + "DotPlot", + "dotplot", + "MatrixPlot", + "matrixplot", + "filter_genes_dispersion", + "highly_variable_genes", + "highest_expr_genes", + "set_rcParams_defaults", + "set_rcParams_scanpy", + "StackedViolin", + "stacked_violin", + "dpt_groups_pseudotime", + "dpt_timeseries", + "embedding_density", + "pca_loadings", + "pca_overview", + "pca_scatter", + "pca_variance_ratio", + "rank_genes_groups", + "rank_genes_groups_dotplot", + "rank_genes_groups_heatmap", + "rank_genes_groups_matrixplot", + "rank_genes_groups_stacked_violin", + "rank_genes_groups_tracksplot", + "rank_genes_groups_violin", + "sim", + "paga", + "paga_adjacency", + "paga_compare", + "paga_path", + "diffmap", + "draw_graph", + "embedding", + "pca", + "spatial", + "tsne", + "umap", + "matrix", + "timeseries", + "timeseries_as_heatmap", + "timeseries_subplot", +] diff --git a/scanpy/plotting/_anndata.py b/scanpy/plotting/_anndata.py index 032dc41588..bdbf111d47 100755 --- a/scanpy/plotting/_anndata.py +++ b/scanpy/plotting/_anndata.py @@ -1,38 +1,46 @@ """Plotting functions for AnnData. """ +from __future__ import annotations + import collections.abc as cabc -from itertools import product from collections import OrderedDict -from typing import Optional, Union, Mapping, Literal # Special -from typing import Sequence, Collection, Iterable # ABCs -from typing import Tuple, List # Classes +from collections.abc import Collection, Iterable, Mapping, Sequence +from itertools import product +from typing import TYPE_CHECKING, Literal, Union import numpy as np import pandas as pd -from anndata import AnnData -from cycler import Cycler -from matplotlib.axes import Axes +from matplotlib import gridspec, patheffects, rcParams +from matplotlib import pyplot as plt +from matplotlib.colors import Colormap, ListedColormap, Normalize, is_color_like from pandas.api.types import CategoricalDtype, is_numeric_dtype from scipy.sparse import issparse -from matplotlib import pyplot as pl -from matplotlib import rcParams -from matplotlib import gridspec -from matplotlib import patheffects -from matplotlib.colors import is_color_like, Colormap, ListedColormap, Normalize from .. import get from .. import logging as logg from .._settings import settings -from .._utils import sanitize_anndata, _doc_params, _check_use_raw +from .._utils import _check_use_raw, _doc_params, sanitize_anndata from . import _utils -from ._utils import scatter_base, scatter_group, setup_axes, check_colornorm -from ._utils import ColorLike, _FontWeight, _FontSize from ._docs import ( + doc_common_plot_args, doc_scatter_basic, doc_show_save_ax, - doc_common_plot_args, doc_vboundnorm, ) +from ._utils import ( + ColorLike, + _FontSize, + _FontWeight, + check_colornorm, + scatter_base, + scatter_group, + setup_axes, +) + +if TYPE_CHECKING: + from anndata import AnnData + from cycler import Cycler + from matplotlib.axes import Axes VALID_LEGENDLOCS = { "none", @@ -60,32 +68,32 @@ @_doc_params(scatter_temp=doc_scatter_basic, show_save_ax=doc_show_save_ax) def scatter( adata: AnnData, - x: Optional[str] = None, - y: Optional[str] = None, - color: Union[str, Collection[str]] = None, - use_raw: Optional[bool] = None, - layers: Union[str, Collection[str]] = None, + x: str | None = None, + y: str | None = None, + color: str | Collection[str] = None, + use_raw: bool | None = None, + layers: str | Collection[str] = None, sort_order: bool = True, - alpha: Optional[float] = None, - basis: Optional[_Basis] = None, - groups: Union[str, Iterable[str]] = None, - components: Union[str, Collection[str]] = None, + alpha: float | None = None, + basis: _Basis | None = None, + groups: str | Iterable[str] = None, + components: str | Collection[str] = None, projection: Literal["2d", "3d"] = "2d", legend_loc: str = "right margin", - legend_fontsize: Union[int, float, _FontSize, None] = None, - legend_fontweight: Union[int, _FontWeight, None] = None, + legend_fontsize: int | float | _FontSize | None = None, + legend_fontweight: int | _FontWeight | None = None, legend_fontoutline: float = None, - color_map: Union[str, Colormap] = None, - palette: Union[Cycler, ListedColormap, ColorLike, Sequence[ColorLike]] = None, - frameon: Optional[bool] = None, - right_margin: Optional[float] = None, - left_margin: Optional[float] = None, - size: Union[int, float, None] = None, - marker: Union[str, Sequence[str]] = ".", - title: Optional[str] = None, - show: Optional[bool] = None, - save: Union[str, bool, None] = None, - ax: Optional[Axes] = None, + color_map: str | Colormap = None, + palette: Cycler | ListedColormap | ColorLike | Sequence[ColorLike] = None, + frameon: bool | None = None, + right_margin: float | None = None, + left_margin: float | None = None, + size: int | float | None = None, + marker: str | Sequence[str] = ".", + title: str | None = None, + show: bool | None = None, + save: str | bool | None = None, + ax: Axes | None = None, ): """\ Scatter plot along observations or variables axes. @@ -185,7 +193,6 @@ def _scatter_obs( ): """See docstring of scatter.""" sanitize_anndata(adata) - from scipy.sparse import issparse use_raw = _check_use_raw(adata, use_raw) @@ -509,7 +516,7 @@ def add_centroid(centroids, name, Y, mask): def ranking( adata: AnnData, attr: Literal["var", "obs", "uns", "varm", "obsm"], - keys: Union[str, Sequence[str]], + keys: str | Sequence[str], dictionary=None, indices=None, labels=None, @@ -562,7 +569,7 @@ def ranking( n_rows, n_cols = 1, n_panels else: n_rows, n_cols = 2, int(n_panels / 2 + 0.5) - _ = pl.figure( + _ = plt.figure( figsize=( n_cols * rcParams["figure.figsize"][0], n_rows * rcParams["figure.figsize"][1], @@ -579,7 +586,7 @@ def ranking( top=1 - (n_rows - 1) * bottom - 0.1 / n_rows, ) for iscore, score in enumerate(scores.T): - pl.subplot(gs[iscore]) + plt.subplot(gs[iscore]) order_scores = np.argsort(score)[::-1] if not include_lowest: indices = order_scores[: n_points + 1] @@ -594,26 +601,26 @@ def ranking( fontsize=8, ) for ig, g in enumerate(indices): - pl.text(ig, score[g], labels[g], **txt_args) + plt.text(ig, score[g], labels[g], **txt_args) if include_lowest: score_mid = (score[g] + score[neg_indices[0]]) / 2 if (len(indices) + len(neg_indices)) < len(order_scores): - pl.text(len(indices), score_mid, "⋮", **txt_args) + plt.text(len(indices), score_mid, "⋮", **txt_args) for ig, g in enumerate(neg_indices): - pl.text(ig + len(indices) + 2, score[g], labels[g], **txt_args) + plt.text(ig + len(indices) + 2, score[g], labels[g], **txt_args) else: for ig, g in enumerate(neg_indices): - pl.text(ig + len(indices), score[g], labels[g], **txt_args) - pl.xticks([]) - pl.title(keys[iscore].replace("_", " ")) + plt.text(ig + len(indices), score[g], labels[g], **txt_args) + plt.xticks([]) + plt.title(keys[iscore].replace("_", " ")) if n_panels <= 5 or iscore > n_cols: - pl.xlabel("ranking") - pl.xlim(-0.9, n_points + 0.9 + (1 if include_lowest else 0)) + plt.xlabel("ranking") + plt.xlim(-0.9, n_points + 0.9 + (1 if include_lowest else 0)) score_min, score_max = ( np.min(score[neg_indices if include_lowest else indices]), np.max(score[indices]), ) - pl.ylim( + plt.ylim( (0.95 if score_min > 0 else 1.05) * score_min, (1.05 if score_max > 0 else 0.95) * score_max, ) @@ -625,23 +632,23 @@ def ranking( @_doc_params(show_save_ax=doc_show_save_ax) def violin( adata: AnnData, - keys: Union[str, Sequence[str]], - groupby: Optional[str] = None, + keys: str | Sequence[str], + groupby: str | None = None, log: bool = False, - use_raw: Optional[bool] = None, + use_raw: bool | None = None, stripplot: bool = True, - jitter: Union[float, bool] = True, + jitter: float | bool = True, size: int = 1, - layer: Optional[str] = None, + layer: str | None = None, scale: Literal["area", "count", "width"] = "width", - order: Optional[Sequence[str]] = None, - multi_panel: Optional[bool] = None, + order: Sequence[str] | None = None, + multi_panel: bool | None = None, xlabel: str = "", - ylabel: Optional[Union[str, Sequence[str]]] = None, - rotation: Optional[float] = None, - show: Optional[bool] = None, - save: Union[bool, str, None] = None, - ax: Optional[Axes] = None, + ylabel: str | Sequence[str] | None = None, + rotation: float | None = None, + show: bool | None = None, + save: bool | str | None = None, + ax: Axes | None = None, **kwds, ): """\ @@ -887,9 +894,9 @@ def violin( def clustermap( adata: AnnData, obs_keys: str = None, - use_raw: Optional[bool] = None, - show: Optional[bool] = None, - save: Union[bool, str, None] = None, + use_raw: bool | None = None, + show: bool | None = None, + save: bool | str | None = None, **kwds, ): """\ @@ -952,7 +959,7 @@ def clustermap( show = settings.autoshow if show is None else show _utils.savefig_or_show("clustermap", show=show, save=save) if show: - pl.show() + plt.show() else: return g @@ -964,27 +971,27 @@ def clustermap( ) def heatmap( adata: AnnData, - var_names: Union[_VarNames, Mapping[str, _VarNames]], - groupby: Union[str, Sequence[str]], - use_raw: Optional[bool] = None, + var_names: _VarNames | Mapping[str, _VarNames], + groupby: str | Sequence[str], + use_raw: bool | None = None, log: bool = False, num_categories: int = 7, - dendrogram: Union[bool, str] = False, - gene_symbols: Optional[str] = None, - var_group_positions: Optional[Sequence[Tuple[int, int]]] = None, - var_group_labels: Optional[Sequence[str]] = None, - var_group_rotation: Optional[float] = None, - layer: Optional[str] = None, - standard_scale: Optional[Literal["var", "obs"]] = None, + dendrogram: bool | str = False, + gene_symbols: str | None = None, + var_group_positions: Sequence[tuple[int, int]] | None = None, + var_group_labels: Sequence[str] | None = None, + var_group_rotation: float | None = None, + layer: str | None = None, + standard_scale: Literal["var", "obs"] | None = None, swap_axes: bool = False, - show_gene_labels: Optional[bool] = None, - show: Optional[bool] = None, - save: Union[str, bool, None] = None, - figsize: Optional[Tuple[float, float]] = None, - vmin: Optional[float] = None, - vmax: Optional[float] = None, - vcenter: Optional[float] = None, - norm: Optional[Normalize] = None, + show_gene_labels: bool | None = None, + show: bool | None = None, + save: str | bool | None = None, + figsize: tuple[float, float] | None = None, + vmin: float | None = None, + vmax: float | None = None, + vcenter: float | None = None, + norm: Normalize | None = None, **kwds, ): """\ @@ -1173,7 +1180,7 @@ def heatmap( dendro_width, colorbar_width, ] - fig = pl.figure(figsize=(width, height)) + fig = plt.figure(figsize=(width, height)) axs = gridspec.GridSpec( nrows=2, @@ -1277,7 +1284,7 @@ def heatmap( else: width_ratios = [width, 0, colorbar_width] - fig = pl.figure(figsize=(width, height)) + fig = plt.figure(figsize=(width, height)) axs = gridspec.GridSpec( nrows=3, ncols=3, @@ -1377,18 +1384,18 @@ def heatmap( @_doc_params(show_save_ax=doc_show_save_ax, common_plot_args=doc_common_plot_args) def tracksplot( adata: AnnData, - var_names: Union[_VarNames, Mapping[str, _VarNames]], - groupby: Union[str, Sequence[str]], - use_raw: Optional[bool] = None, + var_names: _VarNames | Mapping[str, _VarNames], + groupby: str | Sequence[str], + use_raw: bool | None = None, log: bool = False, - dendrogram: Union[bool, str] = False, - gene_symbols: Optional[str] = None, - var_group_positions: Optional[Sequence[Tuple[int, int]]] = None, - var_group_labels: Optional[Sequence[str]] = None, - layer: Optional[str] = None, - show: Optional[bool] = None, - save: Union[str, bool, None] = None, - figsize: Optional[Tuple[float, float]] = None, + dendrogram: bool | str = False, + gene_symbols: str | None = None, + var_group_positions: Sequence[tuple[int, int]] | None = None, + var_group_labels: Sequence[str] | None = None, + layer: str | None = None, + show: bool | None = None, + save: str | bool | None = None, + figsize: tuple[float, float] | None = None, **kwds, ): """\ @@ -1519,7 +1526,7 @@ def tracksplot( obs_tidy = obs_tidy.T - fig = pl.figure(figsize=(width, height)) + fig = plt.figure(figsize=(width, height)) axs = gridspec.GridSpec( ncols=2, nrows=num_rows, @@ -1638,12 +1645,12 @@ def dendrogram( adata: AnnData, groupby: str, *, - dendrogram_key: Optional[str] = None, + dendrogram_key: str | None = None, orientation: Literal["top", "bottom", "left", "right"] = "top", remove_labels: bool = False, - show: Optional[bool] = None, - save: Union[str, bool, None] = None, - ax: Optional[Axes] = None, + show: bool | None = None, + save: str | bool | None = None, + ax: Axes | None = None, ): """\ Plots a dendrogram of the categories defined in `groupby`. @@ -1685,7 +1692,7 @@ def dendrogram( """ if ax is None: - _, ax = pl.subplots() + _, ax = plt.subplots() _plot_dendrogram( ax, adata, @@ -1703,17 +1710,17 @@ def correlation_matrix( adata: AnnData, groupby: str, show_correlation_numbers: bool = False, - dendrogram: Union[bool, str, None] = None, - figsize: Optional[Tuple[float, float]] = None, - show: Optional[bool] = None, - save: Union[str, bool, None] = None, - ax: Optional[Axes] = None, - vmin: Optional[float] = None, - vmax: Optional[float] = None, - vcenter: Optional[float] = None, - norm: Optional[Normalize] = None, + dendrogram: bool | str | None = None, + figsize: tuple[float, float] | None = None, + show: bool | None = None, + save: str | bool | None = None, + ax: Axes | None = None, + vmin: float | None = None, + vmax: float | None = None, + vcenter: float | None = None, + norm: Normalize | None = None, **kwds, -) -> Union[Axes, List[Axes]]: +) -> Axes | list[Axes]: """\ Plots the correlation matrix computed as part of `sc.tl.dendrogram`. @@ -1782,7 +1789,7 @@ def correlation_matrix( width, height = figsize corr_matrix_height = height - colorbar_height - fig = pl.figure(figsize=(width, height)) if ax is None else None + fig = plt.figure(figsize=(width, height)) if ax is None else None # layout with 2 rows and 2 columns: # row 1: dendrogram + correlation matrix # row 2: nothing + colormap bar (horizontal) @@ -1856,7 +1863,7 @@ def correlation_matrix( if ax is None: # Plot colorbar colormap_ax = fig.add_subplot(gs[3]) - cobar = pl.colorbar(img_mat, cax=colormap_ax, orientation="horizontal") + cobar = plt.colorbar(img_mat, cax=colormap_ax, orientation="horizontal") cobar.solids.set_edgecolor("face") axs.append(colormap_ax) @@ -1868,13 +1875,13 @@ def correlation_matrix( def _prepare_dataframe( adata: AnnData, - var_names: Union[_VarNames, Mapping[str, _VarNames]], - groupby: Optional[Union[str, Sequence[str]]] = None, - use_raw: Optional[bool] = None, + var_names: _VarNames | Mapping[str, _VarNames], + groupby: str | Sequence[str] | None = None, + use_raw: bool | None = None, log: bool = False, num_categories: int = 7, layer=None, - gene_symbols: Optional[str] = None, + gene_symbols: str | None = None, ): """ Given the anndata object, prepares a data frame in which the row index are the categories @@ -1990,11 +1997,11 @@ def _prepare_dataframe( def _plot_gene_groups_brackets( gene_groups_ax: Axes, - group_positions: Iterable[Tuple[int, int]], + group_positions: Iterable[tuple[int, int]], group_labels: Sequence[str], left_adjustment: float = -0.3, right_adjustment: float = 0.3, - rotation: Optional[float] = None, + rotation: float | None = None, orientation: Literal["top", "right"] = "top", ): """\ @@ -2105,7 +2112,7 @@ def _plot_gene_groups_brackets( fontsize="small", ) except Exception as e: - print("problems {}".format(e)) + print(f"problems {e}") pass path = Path(verts, codes) @@ -2267,10 +2274,10 @@ def _plot_dendrogram( dendro_ax: Axes, adata: AnnData, groupby: str, - dendrogram_key: Optional[str] = None, + dendrogram_key: str | None = None, orientation: Literal["top", "bottom", "left", "right"] = "right", remove_labels: bool = True, - ticks: Optional[Collection[float]] = None, + ticks: Collection[float] | None = None, ): """\ Plots a dendrogram on the given ax using the precomputed dendrogram @@ -2415,10 +2422,10 @@ def _plot_categories_as_colorblocks( """ groupby = obs_tidy.index.name - from matplotlib.colors import ListedColormap, BoundaryNorm + from matplotlib.colors import BoundaryNorm, ListedColormap if colors is None: - groupby_cmap = pl.get_cmap(cmap_name) + groupby_cmap = plt.get_cmap(cmap_name) else: groupby_cmap = ListedColormap(colors, groupby + "_cmap") norm = BoundaryNorm(np.arange(groupby_cmap.N + 1) - 0.5, groupby_cmap.N) @@ -2527,7 +2534,7 @@ def _plot_colorbar(mappable, fig, subplot_spec, max_cbar_height: float = 4.0): heatmap_cbar_ax = fig.add_subplot(axs2[1]) else: heatmap_cbar_ax = fig.add_subplot(subplot_spec) - pl.colorbar(mappable, cax=heatmap_cbar_ax) + plt.colorbar(mappable, cax=heatmap_cbar_ax) return heatmap_cbar_ax diff --git a/scanpy/plotting/_baseplot_class.py b/scanpy/plotting/_baseplot_class.py index 004eb0282e..65ebfa64b5 100644 --- a/scanpy/plotting/_baseplot_class.py +++ b/scanpy/plotting/_baseplot_class.py @@ -1,23 +1,25 @@ """BasePlot for dotplot, matrixplot and stacked_violin """ +from __future__ import annotations + import collections.abc as cabc from collections import namedtuple -from typing import Optional, Union, Mapping, Literal # Special -from typing import Sequence, Iterable # ABCs -from typing import Tuple # Classes +from collections.abc import Iterable, Mapping, Sequence +from typing import TYPE_CHECKING, Literal, Union +from warnings import warn import numpy as np -from anndata import AnnData -from matplotlib.axes import Axes -from matplotlib import pyplot as pl from matplotlib import gridspec -from matplotlib.colors import Normalize -from warnings import warn +from matplotlib import pyplot as plt from .. import logging as logg -from ._utils import make_grid_spec, check_colornorm -from ._utils import ColorLike, _AxesSubplot -from ._anndata import _plot_dendrogram, _get_dendrogram_key, _prepare_dataframe +from ._anndata import _get_dendrogram_key, _plot_dendrogram, _prepare_dataframe +from ._utils import ColorLike, _AxesSubplot, check_colornorm, make_grid_spec + +if TYPE_CHECKING: + from anndata import AnnData + from matplotlib.axes import Axes + from matplotlib.colors import Normalize _VarNames = Union[str, Sequence[str]] @@ -41,7 +43,7 @@ """ -class BasePlot(object): +class BasePlot: """\ Generic class for the visualization of AnnData categories and selected `var` (features or genes). @@ -73,24 +75,24 @@ class BasePlot(object): def __init__( self, adata: AnnData, - var_names: Union[_VarNames, Mapping[str, _VarNames]], - groupby: Union[str, Sequence[str]], - use_raw: Optional[bool] = None, + var_names: _VarNames | Mapping[str, _VarNames], + groupby: str | Sequence[str], + use_raw: bool | None = None, log: bool = False, num_categories: int = 7, - categories_order: Optional[Sequence[str]] = None, - title: Optional["str"] = None, - figsize: Optional[Tuple[float, float]] = None, - gene_symbols: Optional[str] = None, - var_group_positions: Optional[Sequence[Tuple[int, int]]] = None, - var_group_labels: Optional[Sequence[str]] = None, - var_group_rotation: Optional[float] = None, - layer: Optional[str] = None, - ax: Optional[_AxesSubplot] = None, - vmin: Optional[float] = None, - vmax: Optional[float] = None, - vcenter: Optional[float] = None, - norm: Optional[Normalize] = None, + categories_order: Sequence[str] | None = None, + title: str | None = None, + figsize: tuple[float, float] | None = None, + gene_symbols: str | None = None, + var_group_positions: Sequence[tuple[int, int]] | None = None, + var_group_labels: Sequence[str] | None = None, + var_group_rotation: float | None = None, + layer: str | None = None, + ax: _AxesSubplot | None = None, + vmin: float | None = None, + vmax: float | None = None, + vcenter: float | None = None, + norm: Normalize | None = None, **kwds, ): self.var_names = var_names @@ -171,7 +173,7 @@ def __init__( self.ax_dict = None self.ax = ax - def swap_axes(self, swap_axes: Optional[bool] = True): + def swap_axes(self, swap_axes: bool | None = True): """ Plots a transposed image. @@ -200,9 +202,9 @@ def swap_axes(self, swap_axes: Optional[bool] = True): def add_dendrogram( self, - show: Optional[bool] = True, - dendrogram_key: Optional[str] = None, - size: Optional[float] = 0.8, + show: bool | None = True, + dendrogram_key: str | None = None, + size: float | None = 0.8, ): r"""\ Show dendrogram based on the hierarchical clustering between the `groupby` @@ -285,10 +287,10 @@ def add_dendrogram( def add_totals( self, - show: Optional[bool] = True, + show: bool | None = True, sort: Literal["ascending", "descending"] = None, - size: Optional[float] = 0.8, - color: Optional[Union[ColorLike, Sequence[ColorLike]]] = None, + size: float | None = 0.8, + color: ColorLike | Sequence[ColorLike] | None = None, ): r"""\ Show barplot for the number of cells in in `groupby` category. @@ -363,7 +365,7 @@ def add_totals( } return self - def style(self, cmap: Optional[str] = DEFAULT_COLORMAP): + def style(self, cmap: str | None = DEFAULT_COLORMAP): """\ Set visual style parameters @@ -381,9 +383,9 @@ def style(self, cmap: Optional[str] = DEFAULT_COLORMAP): def legend( self, - show: Optional[bool] = True, - title: Optional[str] = DEFAULT_COLOR_LEGEND_TITLE, - width: Optional[float] = DEFAULT_LEGENDS_WIDTH, + show: bool | None = True, + title: str | None = DEFAULT_COLOR_LEGEND_TITLE, + width: float | None = DEFAULT_LEGENDS_WIDTH, ): r"""\ Configure legend parameters @@ -524,7 +526,7 @@ def _plot_colorbar(self, color_legend_ax: Axes, normalize): ------- None, updates color_legend_ax """ - cmap = pl.get_cmap(self.cmap) + cmap = plt.get_cmap(self.cmap) import matplotlib.colorbar from matplotlib.cm import ScalarMappable @@ -765,7 +767,7 @@ def make_figure(self): self.ax_dict = return_ax_dict - def show(self, return_axes: Optional[bool] = None): + def show(self, return_axes: bool | None = None): """ Show the figure @@ -798,9 +800,9 @@ def show(self, return_axes: Optional[bool] = None): if return_axes: return self.ax_dict else: - pl.show() + plt.show() - def savefig(self, filename: str, bbox_inches: Optional[str] = "tight", **kwargs): + def savefig(self, filename: str, bbox_inches: str | None = "tight", **kwargs): """ Save the current figure @@ -827,7 +829,7 @@ def savefig(self, filename: str, bbox_inches: Optional[str] = "tight", **kwargs) >>> sc.pl._baseplot_class.BasePlot(adata, markers, groupby='bulk_labels').savefig('plot.pdf') """ self.make_figure() - pl.savefig(filename, bbox_inches=bbox_inches, **kwargs) + plt.savefig(filename, bbox_inches=bbox_inches, **kwargs) def _reorder_categories_after_dendrogram(self, dendrogram): """\ @@ -924,11 +926,11 @@ def _format_first_three_categories(_categories): @staticmethod def _plot_var_groups_brackets( gene_groups_ax: Axes, - group_positions: Iterable[Tuple[int, int]], + group_positions: Iterable[tuple[int, int]], group_labels: Sequence[str], left_adjustment: float = -0.3, right_adjustment: float = 0.3, - rotation: Optional[float] = None, + rotation: float | None = None, orientation: Literal["top", "right"] = "top", ): """\ diff --git a/scanpy/plotting/_docs.py b/scanpy/plotting/_docs.py index e447c34870..c1078090b0 100644 --- a/scanpy/plotting/_docs.py +++ b/scanpy/plotting/_docs.py @@ -1,7 +1,7 @@ """\ Shared docstrings for plotting function parameters. """ - +from __future__ import annotations doc_adata_color_etc = """\ adata diff --git a/scanpy/plotting/_dotplot.py b/scanpy/plotting/_dotplot.py index 15a27f8f28..8d21cf384b 100644 --- a/scanpy/plotting/_dotplot.py +++ b/scanpy/plotting/_dotplot.py @@ -1,23 +1,34 @@ -from typing import Optional, Union, Mapping, Literal # Special -from typing import Sequence # ABCs -from typing import Tuple # Classes +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal import numpy as np -import pandas as pd -from anndata import AnnData -from matplotlib.axes import Axes -from matplotlib import pyplot as pl -from matplotlib.colors import Normalize +from matplotlib import pyplot as plt from .. import logging as logg from .._settings import settings from .._utils import _doc_params -from ._utils import make_grid_spec, fix_kwds, check_colornorm -from ._utils import ColorLike, _AxesSubplot -from ._utils import savefig_or_show - +from ._baseplot_class import BasePlot, _VarNames, doc_common_groupby_plot_args from ._docs import doc_common_plot_args, doc_show_save_ax, doc_vboundnorm -from ._baseplot_class import BasePlot, doc_common_groupby_plot_args, _VarNames +from ._utils import ( + ColorLike, + _AxesSubplot, + check_colornorm, + fix_kwds, + make_grid_spec, + savefig_or_show, +) + +if TYPE_CHECKING: + from collections.abc import ( + Mapping, # Special + Sequence, # ABCs + ) + + import pandas as pd + from anndata import AnnData + from matplotlib.axes import Axes + from matplotlib.colors import Normalize @_doc_params(common_plot_args=doc_common_plot_args) @@ -103,29 +114,29 @@ class DotPlot(BasePlot): def __init__( self, adata: AnnData, - var_names: Union[_VarNames, Mapping[str, _VarNames]], - groupby: Union[str, Sequence[str]], - use_raw: Optional[bool] = None, + var_names: _VarNames | Mapping[str, _VarNames], + groupby: str | Sequence[str], + use_raw: bool | None = None, log: bool = False, num_categories: int = 7, - categories_order: Optional[Sequence[str]] = None, - title: Optional[str] = None, - figsize: Optional[Tuple[float, float]] = None, - gene_symbols: Optional[str] = None, - var_group_positions: Optional[Sequence[Tuple[int, int]]] = None, - var_group_labels: Optional[Sequence[str]] = None, - var_group_rotation: Optional[float] = None, - layer: Optional[str] = None, + categories_order: Sequence[str] | None = None, + title: str | None = None, + figsize: tuple[float, float] | None = None, + gene_symbols: str | None = None, + var_group_positions: Sequence[tuple[int, int]] | None = None, + var_group_labels: Sequence[str] | None = None, + var_group_rotation: float | None = None, + layer: str | None = None, expression_cutoff: float = 0.0, mean_only_expressed: bool = False, standard_scale: Literal["var", "group"] = None, - dot_color_df: Optional[pd.DataFrame] = None, - dot_size_df: Optional[pd.DataFrame] = None, - ax: Optional[_AxesSubplot] = None, - vmin: Optional[float] = None, - vmax: Optional[float] = None, - vcenter: Optional[float] = None, - norm: Optional[Normalize] = None, + dot_color_df: pd.DataFrame | None = None, + dot_size_df: pd.DataFrame | None = None, + ax: _AxesSubplot | None = None, + vmin: float | None = None, + vmax: float | None = None, + vcenter: float | None = None, + norm: Normalize | None = None, **kwds, ): BasePlot.__init__( @@ -243,17 +254,17 @@ def __init__( def style( self, cmap: str = DEFAULT_COLORMAP, - color_on: Optional[Literal["dot", "square"]] = DEFAULT_COLOR_ON, - dot_max: Optional[float] = DEFAULT_DOT_MAX, - dot_min: Optional[float] = DEFAULT_DOT_MIN, - smallest_dot: Optional[float] = DEFAULT_SMALLEST_DOT, - largest_dot: Optional[float] = DEFAULT_LARGEST_DOT, - dot_edge_color: Optional[ColorLike] = DEFAULT_DOT_EDGECOLOR, - dot_edge_lw: Optional[float] = DEFAULT_DOT_EDGELW, - size_exponent: Optional[float] = DEFAULT_SIZE_EXPONENT, - grid: Optional[float] = False, - x_padding: Optional[float] = DEFAULT_PLOT_X_PADDING, - y_padding: Optional[float] = DEFAULT_PLOT_Y_PADDING, + color_on: Literal["dot", "square"] | None = DEFAULT_COLOR_ON, + dot_max: float | None = DEFAULT_DOT_MAX, + dot_min: float | None = DEFAULT_DOT_MIN, + smallest_dot: float | None = DEFAULT_SMALLEST_DOT, + largest_dot: float | None = DEFAULT_LARGEST_DOT, + dot_edge_color: ColorLike | None = DEFAULT_DOT_EDGECOLOR, + dot_edge_lw: float | None = DEFAULT_DOT_EDGELW, + size_exponent: float | None = DEFAULT_SIZE_EXPONENT, + grid: float | None = False, + x_padding: float | None = DEFAULT_PLOT_X_PADDING, + y_padding: float | None = DEFAULT_PLOT_Y_PADDING, ): r"""\ Modifies plot visual parameters @@ -358,12 +369,12 @@ def style( def legend( self, - show: Optional[bool] = True, - show_size_legend: Optional[bool] = True, - show_colorbar: Optional[bool] = True, - size_title: Optional[str] = DEFAULT_SIZE_LEGEND_TITLE, - colorbar_title: Optional[str] = DEFAULT_COLOR_LEGEND_TITLE, - width: Optional[float] = DEFAULT_LEGENDS_WIDTH, + show: bool | None = True, + show_size_legend: bool | None = True, + show_colorbar: bool | None = True, + size_title: str | None = DEFAULT_SIZE_LEGEND_TITLE, + colorbar_title: str | None = DEFAULT_COLOR_LEGEND_TITLE, + width: float | None = DEFAULT_LEGENDS_WIDTH, ): """\ Configures dot size and the colorbar legends @@ -447,9 +458,7 @@ def _plot_size_legend(self, size_legend_ax: Axes): zorder=100, ) size_legend_ax.set_xticks(np.arange(len(size)) + 0.5) - labels = [ - "{}".format(np.round((x * 100), decimals=0).astype(int)) for x in size_range - ] + labels = [f"{np.round((x * 100), decimals=0).astype(int)}" for x in size_range] size_legend_ax.set_xticklabels(labels, fontsize="small") # remove y ticks and labels @@ -560,23 +569,23 @@ def _dotplot( dot_color, dot_ax, cmap: str = "Reds", - color_on: Optional[str] = "dot", - y_label: Optional[str] = None, - dot_max: Optional[float] = None, - dot_min: Optional[float] = None, + color_on: str | None = "dot", + y_label: str | None = None, + dot_max: float | None = None, + dot_min: float | None = None, standard_scale: Literal["var", "group"] = None, - smallest_dot: Optional[float] = 0.0, - largest_dot: Optional[float] = 200, - size_exponent: Optional[float] = 2, - edge_color: Optional[ColorLike] = None, - edge_lw: Optional[float] = None, - grid: Optional[bool] = False, - x_padding: Optional[float] = 0.8, - y_padding: Optional[float] = 1.0, - vmin: Optional[float] = None, - vmax: Optional[float] = None, - vcenter: Optional[float] = None, - norm: Optional[Normalize] = None, + smallest_dot: float | None = 0.0, + largest_dot: float | None = 200, + size_exponent: float | None = 2, + edge_color: ColorLike | None = None, + edge_lw: float | None = None, + grid: bool | None = False, + x_padding: float | None = 0.8, + y_padding: float | None = 1.0, + vmin: float | None = None, + vmax: float | None = None, + vcenter: float | None = None, + norm: Normalize | None = None, **kwds, ): """\ @@ -674,7 +683,7 @@ def _dotplot( x = x.flatten() + 0.5 frac = dot_size.values.flatten() mean_flat = dot_color.values.flatten() - cmap = pl.get_cmap(kwds.get("cmap", cmap)) + cmap = plt.get_cmap(kwds.get("cmap", cmap)) if "cmap" in kwds: del kwds["cmap"] if dot_max is None: @@ -797,40 +806,40 @@ def _dotplot( ) def dotplot( adata: AnnData, - var_names: Union[_VarNames, Mapping[str, _VarNames]], - groupby: Union[str, Sequence[str]], - use_raw: Optional[bool] = None, + var_names: _VarNames | Mapping[str, _VarNames], + groupby: str | Sequence[str], + use_raw: bool | None = None, log: bool = False, num_categories: int = 7, expression_cutoff: float = 0.0, mean_only_expressed: bool = False, cmap: str = "Reds", - dot_max: Optional[float] = DotPlot.DEFAULT_DOT_MAX, - dot_min: Optional[float] = DotPlot.DEFAULT_DOT_MIN, - standard_scale: Optional[Literal["var", "group"]] = None, - smallest_dot: Optional[float] = DotPlot.DEFAULT_SMALLEST_DOT, - title: Optional[str] = None, - colorbar_title: Optional[str] = DotPlot.DEFAULT_COLOR_LEGEND_TITLE, - size_title: Optional[str] = DotPlot.DEFAULT_SIZE_LEGEND_TITLE, - figsize: Optional[Tuple[float, float]] = None, - dendrogram: Union[bool, str] = False, - gene_symbols: Optional[str] = None, - var_group_positions: Optional[Sequence[Tuple[int, int]]] = None, - var_group_labels: Optional[Sequence[str]] = None, - var_group_rotation: Optional[float] = None, - layer: Optional[str] = None, - swap_axes: Optional[bool] = False, - dot_color_df: Optional[pd.DataFrame] = None, - show: Optional[bool] = None, - save: Union[str, bool, None] = None, - ax: Optional[_AxesSubplot] = None, - return_fig: Optional[bool] = False, - vmin: Optional[float] = None, - vmax: Optional[float] = None, - vcenter: Optional[float] = None, - norm: Optional[Normalize] = None, + dot_max: float | None = DotPlot.DEFAULT_DOT_MAX, + dot_min: float | None = DotPlot.DEFAULT_DOT_MIN, + standard_scale: Literal["var", "group"] | None = None, + smallest_dot: float | None = DotPlot.DEFAULT_SMALLEST_DOT, + title: str | None = None, + colorbar_title: str | None = DotPlot.DEFAULT_COLOR_LEGEND_TITLE, + size_title: str | None = DotPlot.DEFAULT_SIZE_LEGEND_TITLE, + figsize: tuple[float, float] | None = None, + dendrogram: bool | str = False, + gene_symbols: str | None = None, + var_group_positions: Sequence[tuple[int, int]] | None = None, + var_group_labels: Sequence[str] | None = None, + var_group_rotation: float | None = None, + layer: str | None = None, + swap_axes: bool | None = False, + dot_color_df: pd.DataFrame | None = None, + show: bool | None = None, + save: str | bool | None = None, + ax: _AxesSubplot | None = None, + return_fig: bool | None = False, + vmin: float | None = None, + vmax: float | None = None, + vcenter: float | None = None, + norm: Normalize | None = None, **kwds, -) -> Union[DotPlot, dict, None]: +) -> DotPlot | dict | None: """\ Makes a *dot plot* of the expression values of `var_names`. diff --git a/scanpy/plotting/_matrixplot.py b/scanpy/plotting/_matrixplot.py index 509d744514..fe1b7844da 100644 --- a/scanpy/plotting/_matrixplot.py +++ b/scanpy/plotting/_matrixplot.py @@ -1,26 +1,31 @@ -from typing import Optional, Union, Mapping, Literal # Special -from typing import Sequence # ABCs -from typing import Tuple # Classes +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal import numpy as np -import pandas as pd -from anndata import AnnData -from matplotlib import pyplot as pl +from matplotlib import pyplot as plt from matplotlib import rcParams -from matplotlib.colors import Normalize from .. import logging as logg -from .._utils import _doc_params -from ._utils import fix_kwds, check_colornorm -from ._utils import ColorLike, _AxesSubplot -from ._utils import savefig_or_show from .._settings import settings +from .._utils import _doc_params +from ._baseplot_class import BasePlot, _VarNames, doc_common_groupby_plot_args from ._docs import ( doc_common_plot_args, doc_show_save_ax, doc_vboundnorm, ) -from ._baseplot_class import BasePlot, doc_common_groupby_plot_args, _VarNames +from ._utils import ColorLike, _AxesSubplot, check_colornorm, fix_kwds, savefig_or_show + +if TYPE_CHECKING: + from collections.abc import ( + Mapping, # Special + Sequence, # ABCs + ) + + import pandas as pd + from anndata import AnnData + from matplotlib.colors import Normalize @_doc_params(common_plot_args=doc_common_plot_args) @@ -92,26 +97,26 @@ class MatrixPlot(BasePlot): def __init__( self, adata: AnnData, - var_names: Union[_VarNames, Mapping[str, _VarNames]], - groupby: Union[str, Sequence[str]], - use_raw: Optional[bool] = None, + var_names: _VarNames | Mapping[str, _VarNames], + groupby: str | Sequence[str], + use_raw: bool | None = None, log: bool = False, num_categories: int = 7, - categories_order: Optional[Sequence[str]] = None, - title: Optional[str] = None, - figsize: Optional[Tuple[float, float]] = None, - gene_symbols: Optional[str] = None, - var_group_positions: Optional[Sequence[Tuple[int, int]]] = None, - var_group_labels: Optional[Sequence[str]] = None, - var_group_rotation: Optional[float] = None, - layer: Optional[str] = None, + categories_order: Sequence[str] | None = None, + title: str | None = None, + figsize: tuple[float, float] | None = None, + gene_symbols: str | None = None, + var_group_positions: Sequence[tuple[int, int]] | None = None, + var_group_labels: Sequence[str] | None = None, + var_group_rotation: float | None = None, + layer: str | None = None, standard_scale: Literal["var", "group"] = None, - ax: Optional[_AxesSubplot] = None, - values_df: Optional[pd.DataFrame] = None, - vmin: Optional[float] = None, - vmax: Optional[float] = None, - vcenter: Optional[float] = None, - norm: Optional[Normalize] = None, + ax: _AxesSubplot | None = None, + values_df: pd.DataFrame | None = None, + vmin: float | None = None, + vmax: float | None = None, + vcenter: float | None = None, + norm: Normalize | None = None, **kwds, ): BasePlot.__init__( @@ -162,8 +167,8 @@ def __init__( def style( self, cmap: str = DEFAULT_COLORMAP, - edge_color: Optional[ColorLike] = DEFAULT_EDGE_COLOR, - edge_lw: Optional[float] = DEFAULT_EDGE_LW, + edge_color: ColorLike | None = DEFAULT_EDGE_COLOR, + edge_lw: float | None = DEFAULT_EDGE_LW, ): """\ Modifies plot visual parameters. @@ -230,7 +235,7 @@ def _mainplot(self, ax): if self.are_axes_swapped: _color_df = _color_df.T - cmap = pl.get_cmap(self.kwds.get("cmap", self.cmap)) + cmap = plt.get_cmap(self.kwds.get("cmap", self.cmap)) if "cmap" in self.kwds: del self.kwds["cmap"] normalize = check_colornorm( @@ -283,34 +288,34 @@ def _mainplot(self, ax): ) def matrixplot( adata: AnnData, - var_names: Union[_VarNames, Mapping[str, _VarNames]], - groupby: Union[str, Sequence[str]], - use_raw: Optional[bool] = None, + var_names: _VarNames | Mapping[str, _VarNames], + groupby: str | Sequence[str], + use_raw: bool | None = None, log: bool = False, num_categories: int = 7, - figsize: Optional[Tuple[float, float]] = None, - dendrogram: Union[bool, str] = False, - title: Optional[str] = None, - cmap: Optional[str] = MatrixPlot.DEFAULT_COLORMAP, - colorbar_title: Optional[str] = MatrixPlot.DEFAULT_COLOR_LEGEND_TITLE, - gene_symbols: Optional[str] = None, - var_group_positions: Optional[Sequence[Tuple[int, int]]] = None, - var_group_labels: Optional[Sequence[str]] = None, - var_group_rotation: Optional[float] = None, - layer: Optional[str] = None, + figsize: tuple[float, float] | None = None, + dendrogram: bool | str = False, + title: str | None = None, + cmap: str | None = MatrixPlot.DEFAULT_COLORMAP, + colorbar_title: str | None = MatrixPlot.DEFAULT_COLOR_LEGEND_TITLE, + gene_symbols: str | None = None, + var_group_positions: Sequence[tuple[int, int]] | None = None, + var_group_labels: Sequence[str] | None = None, + var_group_rotation: float | None = None, + layer: str | None = None, standard_scale: Literal["var", "group"] = None, - values_df: Optional[pd.DataFrame] = None, + values_df: pd.DataFrame | None = None, swap_axes: bool = False, - show: Optional[bool] = None, - save: Union[str, bool, None] = None, - ax: Optional[_AxesSubplot] = None, - return_fig: Optional[bool] = False, - vmin: Optional[float] = None, - vmax: Optional[float] = None, - vcenter: Optional[float] = None, - norm: Optional[Normalize] = None, + show: bool | None = None, + save: str | bool | None = None, + ax: _AxesSubplot | None = None, + return_fig: bool | None = False, + vmin: float | None = None, + vmax: float | None = None, + vcenter: float | None = None, + norm: Normalize | None = None, **kwds, -) -> Union[MatrixPlot, dict, None]: +) -> MatrixPlot | dict | None: """\ Creates a heatmap of the mean expression values per group of each var_names. diff --git a/scanpy/plotting/_preprocessing.py b/scanpy/plotting/_preprocessing.py index 749f315814..f470d73187 100644 --- a/scanpy/plotting/_preprocessing.py +++ b/scanpy/plotting/_preprocessing.py @@ -1,10 +1,11 @@ -from typing import Optional, Union +from __future__ import annotations import numpy as np import pandas as pd -from matplotlib import pyplot as pl -from matplotlib import rcParams from anndata import AnnData +from matplotlib import pyplot as plt +from matplotlib import rcParams + from . import _utils # -------------------------------------------------------------------------------- @@ -13,10 +14,10 @@ def highly_variable_genes( - adata_or_result: Union[AnnData, pd.DataFrame, np.recarray], + adata_or_result: AnnData | pd.DataFrame | np.recarray, log: bool = False, - show: Optional[bool] = None, - save: Union[bool, str, None] = None, + show: bool | None = None, + save: bool | str | None = None, highly_variable_genes: bool = True, ): """Plot dispersions or normalized variance versus means for genes. @@ -59,10 +60,10 @@ def highly_variable_genes( var_or_disp = result.dispersions var_or_disp_norm = result.dispersions_norm size = rcParams["figure.figsize"] - pl.figure(figsize=(2 * size[0], size[1])) - pl.subplots_adjust(wspace=0.3) + plt.figure(figsize=(2 * size[0], size[1])) + plt.subplots_adjust(wspace=0.3) for idx, d in enumerate([var_or_disp_norm, var_or_disp]): - pl.subplot(1, 2, idx + 1) + plt.subplot(1, 2, idx + 1) for label, color, mask in zip( ["highly variable genes", "other genes"], ["black", "grey"], @@ -72,35 +73,35 @@ def highly_variable_genes( means_, var_or_disps_ = np.log10(means[mask]), np.log10(d[mask]) else: means_, var_or_disps_ = means[mask], d[mask] - pl.scatter(means_, var_or_disps_, label=label, c=color, s=1) + plt.scatter(means_, var_or_disps_, label=label, c=color, s=1) if log: # there's a bug in autoscale - pl.xscale("log") - pl.yscale("log") + plt.xscale("log") + plt.yscale("log") y_min = np.min(var_or_disp) y_min = 0.95 * y_min if y_min > 0 else 1e-1 - pl.xlim(0.95 * np.min(means), 1.05 * np.max(means)) - pl.ylim(y_min, 1.05 * np.max(var_or_disp)) + plt.xlim(0.95 * np.min(means), 1.05 * np.max(means)) + plt.ylim(y_min, 1.05 * np.max(var_or_disp)) if idx == 0: - pl.legend() - pl.xlabel(("$log_{10}$ " if False else "") + "mean expressions of genes") + plt.legend() + plt.xlabel(("$log_{10}$ " if False else "") + "mean expressions of genes") data_type = "dispersions" if not seurat_v3_flavor else "variances" - pl.ylabel( + plt.ylabel( ("$log_{10}$ " if False else "") - + "{} of genes".format(data_type) + + f"{data_type} of genes" + (" (normalized)" if idx == 0 else " (not normalized)") ) _utils.savefig_or_show("filter_genes_dispersion", show=show, save=save) if show is False: - return pl.gca() + return plt.gca() # backwards compat def filter_genes_dispersion( result: np.recarray, log: bool = False, - show: Optional[bool] = None, - save: Union[bool, str, None] = None, + show: bool | None = None, + save: bool | str | None = None, ): """\ Plot dispersions versus means for genes. diff --git a/scanpy/plotting/_qc.py b/scanpy/plotting/_qc.py index eee9695de8..35f3b63c53 100644 --- a/scanpy/plotting/_qc.py +++ b/scanpy/plotting/_qc.py @@ -1,25 +1,29 @@ -from typing import Optional, Union +from __future__ import annotations + +from typing import TYPE_CHECKING import numpy as np import pandas as pd -from anndata import AnnData from matplotlib import pyplot as plt -from matplotlib.axes import Axes +from .._utils import _doc_params +from ..preprocessing._normalization import normalize_total from . import _utils from ._docs import doc_show_save_ax -from ..preprocessing._normalization import normalize_total -from .._utils import _doc_params + +if TYPE_CHECKING: + from anndata import AnnData + from matplotlib.axes import Axes @_doc_params(show_save_ax=doc_show_save_ax) def highest_expr_genes( adata: AnnData, n_top: int = 30, - show: Optional[bool] = None, - save: Optional[Union[str, bool]] = None, - ax: Optional[Axes] = None, - gene_symbols: Optional[str] = None, + show: bool | None = None, + save: str | bool | None = None, + ax: Axes | None = None, + gene_symbols: str | None = None, log: bool = False, **kwds, ): diff --git a/scanpy/plotting/_rcmod.py b/scanpy/plotting/_rcmod.py index 5e8a53e4a4..cfe6da13bf 100644 --- a/scanpy/plotting/_rcmod.py +++ b/scanpy/plotting/_rcmod.py @@ -1,9 +1,10 @@ """Set the default matplotlib.rcParams. """ +from __future__ import annotations -import matplotlib -from matplotlib import rcParams +import matplotlib as mpl from cycler import cycler +from matplotlib import rcParams from . import palettes @@ -69,4 +70,4 @@ def set_rcParams_scanpy(fontsize=14, color_map=None): def set_rcParams_defaults(): """Reset `matplotlib.rcParams` to defaults.""" - rcParams.update(matplotlib.rcParamsDefault) + rcParams.update(mpl.rcParamsDefault) diff --git a/scanpy/plotting/_stacked_violin.py b/scanpy/plotting/_stacked_violin.py index 0c089f5642..e058358400 100644 --- a/scanpy/plotting/_stacked_violin.py +++ b/scanpy/plotting/_stacked_violin.py @@ -1,21 +1,26 @@ -from typing import Optional, Union, Mapping, Literal # Special -from typing import Sequence # ABCs -from typing import Tuple # Classes +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal import numpy as np import pandas as pd -from anndata import AnnData -from matplotlib import pyplot as pl -from matplotlib.colors import is_color_like, Normalize +from matplotlib import pyplot as plt +from matplotlib.colors import Normalize, is_color_like + from .. import logging as logg from .._settings import settings from .._utils import _doc_params -from ._utils import make_grid_spec, check_colornorm -from ._utils import _AxesSubplot -from ._utils import savefig_or_show - +from ._baseplot_class import BasePlot, _VarNames, doc_common_groupby_plot_args from ._docs import doc_common_plot_args, doc_show_save_ax, doc_vboundnorm -from ._baseplot_class import BasePlot, doc_common_groupby_plot_args, _VarNames +from ._utils import _AxesSubplot, check_colornorm, make_grid_spec, savefig_or_show + +if TYPE_CHECKING: + from collections.abc import ( + Mapping, # Special + Sequence, # ABCs + ) + + from anndata import AnnData @_doc_params(common_plot_args=doc_common_plot_args) @@ -128,25 +133,25 @@ class StackedViolin(BasePlot): def __init__( self, adata: AnnData, - var_names: Union[_VarNames, Mapping[str, _VarNames]], - groupby: Union[str, Sequence[str]], - use_raw: Optional[bool] = None, + var_names: _VarNames | Mapping[str, _VarNames], + groupby: str | Sequence[str], + use_raw: bool | None = None, log: bool = False, num_categories: int = 7, - categories_order: Optional[Sequence[str]] = None, - title: Optional[str] = None, - figsize: Optional[Tuple[float, float]] = None, - gene_symbols: Optional[str] = None, - var_group_positions: Optional[Sequence[Tuple[int, int]]] = None, - var_group_labels: Optional[Sequence[str]] = None, - var_group_rotation: Optional[float] = None, - layer: Optional[str] = None, + categories_order: Sequence[str] | None = None, + title: str | None = None, + figsize: tuple[float, float] | None = None, + gene_symbols: str | None = None, + var_group_positions: Sequence[tuple[int, int]] | None = None, + var_group_labels: Sequence[str] | None = None, + var_group_rotation: float | None = None, + layer: str | None = None, standard_scale: Literal["var", "group"] = None, - ax: Optional[_AxesSubplot] = None, - vmin: Optional[float] = None, - vmax: Optional[float] = None, - vcenter: Optional[float] = None, - norm: Optional[Normalize] = None, + ax: _AxesSubplot | None = None, + vmin: float | None = None, + vmax: float | None = None, + vcenter: float | None = None, + norm: Normalize | None = None, **kwds, ): BasePlot.__init__( @@ -202,17 +207,17 @@ def __init__( def style( self, - cmap: Optional[str] = DEFAULT_COLORMAP, - stripplot: Optional[bool] = DEFAULT_STRIPPLOT, - jitter: Optional[Union[float, bool]] = DEFAULT_JITTER, - jitter_size: Optional[int] = DEFAULT_JITTER_SIZE, - linewidth: Optional[float] = DEFAULT_LINE_WIDTH, - row_palette: Optional[str] = DEFAULT_ROW_PALETTE, - scale: Optional[Literal["area", "count", "width"]] = DEFAULT_SCALE, - yticklabels: Optional[bool] = DEFAULT_PLOT_YTICKLABELS, - ylim: Optional[Tuple[float, float]] = DEFAULT_YLIM, - x_padding: Optional[float] = DEFAULT_PLOT_X_PADDING, - y_padding: Optional[float] = DEFAULT_PLOT_Y_PADDING, + cmap: str | None = DEFAULT_COLORMAP, + stripplot: bool | None = DEFAULT_STRIPPLOT, + jitter: float | bool | None = DEFAULT_JITTER, + jitter_size: int | None = DEFAULT_JITTER_SIZE, + linewidth: float | None = DEFAULT_LINE_WIDTH, + row_palette: str | None = DEFAULT_ROW_PALETTE, + scale: Literal["area", "count", "width"] | None = DEFAULT_SCALE, + yticklabels: bool | None = DEFAULT_PLOT_YTICKLABELS, + ylim: tuple[float, float] | None = DEFAULT_YLIM, + x_padding: float | None = DEFAULT_PLOT_X_PADDING, + y_padding: float | None = DEFAULT_PLOT_Y_PADDING, ): r"""\ Modifies plot visual parameters @@ -327,7 +332,7 @@ def _mainplot(self, ax): if self.are_axes_swapped: _color_df = _color_df.T - cmap = pl.get_cmap(self.kwds.get("cmap", self.cmap)) + cmap = plt.get_cmap(self.kwds.get("cmap", self.cmap)) if "cmap" in self.kwds: del self.kwds["cmap"] normalize = check_colornorm( @@ -554,40 +559,40 @@ def _setup_violin_axes_ticks(self, row_ax, num_cols): ) def stacked_violin( adata: AnnData, - var_names: Union[_VarNames, Mapping[str, _VarNames]], - groupby: Union[str, Sequence[str]], + var_names: _VarNames | Mapping[str, _VarNames], + groupby: str | Sequence[str], log: bool = False, - use_raw: Optional[bool] = None, + use_raw: bool | None = None, num_categories: int = 7, - title: Optional[str] = None, - colorbar_title: Optional[str] = StackedViolin.DEFAULT_COLOR_LEGEND_TITLE, - figsize: Optional[Tuple[float, float]] = None, - dendrogram: Union[bool, str] = False, - gene_symbols: Optional[str] = None, - var_group_positions: Optional[Sequence[Tuple[int, int]]] = None, - var_group_labels: Optional[Sequence[str]] = None, - standard_scale: Optional[Literal["var", "obs"]] = None, - var_group_rotation: Optional[float] = None, - layer: Optional[str] = None, + title: str | None = None, + colorbar_title: str | None = StackedViolin.DEFAULT_COLOR_LEGEND_TITLE, + figsize: tuple[float, float] | None = None, + dendrogram: bool | str = False, + gene_symbols: str | None = None, + var_group_positions: Sequence[tuple[int, int]] | None = None, + var_group_labels: Sequence[str] | None = None, + standard_scale: Literal["var", "obs"] | None = None, + var_group_rotation: float | None = None, + layer: str | None = None, stripplot: bool = StackedViolin.DEFAULT_STRIPPLOT, - jitter: Union[float, bool] = StackedViolin.DEFAULT_JITTER, + jitter: float | bool = StackedViolin.DEFAULT_JITTER, size: int = StackedViolin.DEFAULT_JITTER_SIZE, scale: Literal["area", "count", "width"] = StackedViolin.DEFAULT_SCALE, - yticklabels: Optional[bool] = StackedViolin.DEFAULT_PLOT_YTICKLABELS, - order: Optional[Sequence[str]] = None, + yticklabels: bool | None = StackedViolin.DEFAULT_PLOT_YTICKLABELS, + order: Sequence[str] | None = None, swap_axes: bool = False, - show: Optional[bool] = None, - save: Union[bool, str, None] = None, - return_fig: Optional[bool] = False, - row_palette: Optional[str] = StackedViolin.DEFAULT_ROW_PALETTE, - cmap: Optional[str] = StackedViolin.DEFAULT_COLORMAP, - ax: Optional[_AxesSubplot] = None, - vmin: Optional[float] = None, - vmax: Optional[float] = None, - vcenter: Optional[float] = None, - norm: Optional[Normalize] = None, + show: bool | None = None, + save: bool | str | None = None, + return_fig: bool | None = False, + row_palette: str | None = StackedViolin.DEFAULT_ROW_PALETTE, + cmap: str | None = StackedViolin.DEFAULT_COLORMAP, + ax: _AxesSubplot | None = None, + vmin: float | None = None, + vmax: float | None = None, + vcenter: float | None = None, + norm: Normalize | None = None, **kwds, -) -> Union[StackedViolin, dict, None]: +) -> StackedViolin | dict | None: """\ Stacked violin plots. diff --git a/scanpy/plotting/_tools/__init__.py b/scanpy/plotting/_tools/__init__.py index 2fd627bd74..0d22821020 100644 --- a/scanpy/plotting/_tools/__init__.py +++ b/scanpy/plotting/_tools/__init__.py @@ -1,34 +1,44 @@ +from __future__ import annotations + import collections.abc as cabc +from collections.abc import Iterable, Mapping, Sequence from copy import copy +from typing import TYPE_CHECKING, Literal + import numpy as np import pandas as pd -from cycler import Cycler -from matplotlib.axes import Axes -from matplotlib.figure import Figure -from matplotlib.colors import Normalize -from matplotlib import pyplot as pl -from matplotlib import rcParams, colormaps -from anndata import AnnData -from typing import Union, Optional, List, Sequence, Iterable, Mapping, Literal - -from .._utils import savefig_or_show -from ..._utils import _doc_params, sanitize_anndata, subsample +from matplotlib import colormaps, rcParams +from matplotlib import pyplot as plt + +from scanpy.get import obs_df + from ... import logging as logg -from .._anndata import ranking -from .._utils import timeseries, timeseries_subplot, timeseries_as_heatmap from ..._settings import settings +from ..._utils import _doc_params, sanitize_anndata, subsample +from ...get import rank_genes_groups_df +from .._anndata import ranking from .._docs import ( - doc_scatter_embedding, - doc_show_save_ax, + doc_panels, doc_rank_genes_groups_plot_args, doc_rank_genes_groups_values_to_plot, + doc_scatter_embedding, + doc_show_save_ax, doc_vbound_percentile, - doc_panels, ) -from ...get import rank_genes_groups_df -from .scatterplots import pca, embedding, _panel_grid -from matplotlib.colors import Colormap -from scanpy.get import obs_df +from .._utils import ( + savefig_or_show, + timeseries, + timeseries_as_heatmap, + timeseries_subplot, +) +from .scatterplots import _panel_grid, embedding, pca + +if TYPE_CHECKING: + from anndata import AnnData + from cycler import Cycler + from matplotlib.axes import Axes + from matplotlib.colors import Colormap, Normalize + from matplotlib.figure import Figure # ------------------------------------------------------------------------------ # PCA @@ -72,7 +82,6 @@ def pca_overview(adata: AnnData, **params): See also -------- - tl.pca pp.pca """ show = params["show"] if "show" in params else None @@ -89,11 +98,11 @@ def pca_overview(adata: AnnData, **params): def pca_loadings( adata: AnnData, - components: Union[str, Sequence[int], None] = None, + components: str | Sequence[int] | None = None, include_lowest: bool = True, - n_points: Union[int, None] = None, - show: Optional[bool] = None, - save: Union[str, bool, None] = None, + n_points: int | None = None, + show: bool | None = None, + save: str | bool | None = None, ): """\ Rank genes according to contributions to PCs. @@ -164,8 +173,8 @@ def pca_variance_ratio( adata: AnnData, n_pcs: int = 30, log: bool = False, - show: Optional[bool] = None, - save: Union[bool, str, None] = None, + show: bool | None = None, + save: bool | str | None = None, ): """\ Plot the variance ratio. @@ -203,11 +212,11 @@ def pca_variance_ratio( def dpt_timeseries( adata: AnnData, - color_map: Union[str, Colormap] = None, - show: Optional[bool] = None, - save: Optional[bool] = None, + color_map: str | Colormap = None, + show: bool | None = None, + save: bool | None = None, as_heatmap: bool = True, - marker: Union[str, Sequence[str]] = ".", + marker: str | Sequence[str] = ".", ): """\ Heatmap of pseudotime series. @@ -240,20 +249,20 @@ def dpt_timeseries( xlim=[0, 1.3 * adata.X.shape[0]], marker=marker, ) - pl.xlabel("dpt order") + plt.xlabel("dpt order") savefig_or_show("dpt_timeseries", save=save, show=show) def dpt_groups_pseudotime( adata: AnnData, - color_map: Union[str, Colormap, None] = None, - palette: Union[Sequence[str], Cycler, None] = None, - show: Optional[bool] = None, - save: Union[bool, str, None] = None, - marker: Union[str, Sequence[str]] = ".", + color_map: str | Colormap | None = None, + palette: Sequence[str] | Cycler | None = None, + show: bool | None = None, + save: bool | str | None = None, + marker: str | Sequence[str] = ".", ): """Plot groups and pseudotime.""" - _, (ax_grp, ax_ord) = pl.subplots(2, 1) + _, (ax_grp, ax_ord) = plt.subplots(2, 1) timeseries_subplot( adata.obs["dpt_groups"].cat.codes, time=adata.obs["dpt_order"].values, @@ -287,16 +296,16 @@ def dpt_groups_pseudotime( @_doc_params(show_save_ax=doc_show_save_ax) def rank_genes_groups( adata: AnnData, - groups: Union[str, Sequence[str]] = None, + groups: str | Sequence[str] | None = None, n_genes: int = 20, - gene_symbols: Optional[str] = None, - key: Optional[str] = "rank_genes_groups", + gene_symbols: str | None = None, + key: str | None = "rank_genes_groups", fontsize: int = 8, ncols: int = 4, sharey: bool = True, - show: Optional[bool] = None, - save: Optional[bool] = None, - ax: Optional[Axes] = None, + show: bool | None = None, + save: bool | None = None, + ax: Axes | None = None, **kwds, ): """\ @@ -367,7 +376,7 @@ def rank_genes_groups( from matplotlib import gridspec - fig = pl.figure( + fig = plt.figure( figsize=( n_panels_x * rcParams["figure.figsize"][0], n_panels_y * rcParams["figure.figsize"][1], @@ -421,7 +430,7 @@ def rank_genes_groups( fontsize=fontsize, ) - ax.set_title("{} vs. {}".format(group_name, reference)) + ax.set_title(f"{group_name} vs. {reference}") if count >= n_panels_x * (n_panels_y - 1): ax.set_xlabel("ranking") @@ -454,17 +463,17 @@ def _fig_show_save_or_axes(plot_obj, return_fig, show, save): def _rank_genes_groups_plot( adata: AnnData, plot_type: str = "heatmap", - groups: Union[str, Sequence[str]] = None, - n_genes: Optional[int] = None, - groupby: Optional[str] = None, - values_to_plot: Optional[str] = None, - var_names: Optional[Union[Sequence[str], Mapping[str, Sequence[str]]]] = None, - min_logfoldchange: Optional[float] = None, - key: Optional[str] = None, - show: Optional[bool] = None, - save: Optional[bool] = None, - return_fig: Optional[bool] = False, - gene_symbols: Optional[str] = None, + groups: str | Sequence[str] | None = None, + n_genes: int | None = None, + groupby: str | None = None, + values_to_plot: str | None = None, + var_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None, + min_logfoldchange: float | None = None, + key: str | None = None, + show: bool | None = None, + save: bool | None = None, + return_fig: bool | None = False, + gene_symbols: str | None = None, **kwds, ): """\ @@ -620,15 +629,15 @@ def _rank_genes_groups_plot( @_doc_params(params=doc_rank_genes_groups_plot_args, show_save_ax=doc_show_save_ax) def rank_genes_groups_heatmap( adata: AnnData, - groups: Union[str, Sequence[str]] = None, - n_genes: Optional[int] = None, - groupby: Optional[str] = None, - gene_symbols: Optional[str] = None, - var_names: Optional[Union[Sequence[str], Mapping[str, Sequence[str]]]] = None, - min_logfoldchange: Optional[float] = None, - key: str = None, - show: Optional[bool] = None, - save: Optional[bool] = None, + groups: str | Sequence[str] | None = None, + n_genes: int | None = None, + groupby: str | None = None, + gene_symbols: str | None = None, + var_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None, + min_logfoldchange: float | None = None, + key: str | None = None, + show: bool | None = None, + save: bool | None = None, **kwds, ): """\ @@ -693,15 +702,15 @@ def rank_genes_groups_heatmap( @_doc_params(params=doc_rank_genes_groups_plot_args, show_save_ax=doc_show_save_ax) def rank_genes_groups_tracksplot( adata: AnnData, - groups: Union[str, Sequence[str]] = None, - n_genes: Optional[int] = None, - groupby: Optional[str] = None, - var_names: Optional[Union[Sequence[str], Mapping[str, Sequence[str]]]] = None, - gene_symbols: Optional[str] = None, - min_logfoldchange: Optional[float] = None, - key: Optional[str] = None, - show: Optional[bool] = None, - save: Optional[bool] = None, + groups: str | Sequence[str] | None = None, + n_genes: int | None = None, + groupby: str | None = None, + var_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None, + gene_symbols: str | None = None, + min_logfoldchange: float | None = None, + key: str | None = None, + show: bool | None = None, + save: bool | None = None, **kwds, ): """\ @@ -750,26 +759,25 @@ def rank_genes_groups_tracksplot( ) def rank_genes_groups_dotplot( adata: AnnData, - groups: Union[str, Sequence[str]] = None, - n_genes: Optional[int] = None, - groupby: Optional[str] = None, - values_to_plot: Optional[ - Literal[ - "scores", - "logfoldchanges", - "pvals", - "pvals_adj", - "log10_pvals", - "log10_pvals_adj", - ] - ] = None, - var_names: Optional[Union[Sequence[str], Mapping[str, Sequence[str]]]] = None, - gene_symbols: Optional[str] = None, - min_logfoldchange: Optional[float] = None, - key: Optional[str] = None, - show: Optional[bool] = None, - save: Optional[bool] = None, - return_fig: Optional[bool] = False, + groups: str | Sequence[str] | None = None, + n_genes: int | None = None, + groupby: str | None = None, + values_to_plot: Literal[ + "scores", + "logfoldchanges", + "pvals", + "pvals_adj", + "log10_pvals", + "log10_pvals_adj", + ] + | None = None, + var_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None, + gene_symbols: str | None = None, + min_logfoldchange: float | None = None, + key: str | None = None, + show: bool | None = None, + save: bool | None = None, + return_fig: bool | None = False, **kwds, ): """\ @@ -897,17 +905,17 @@ def rank_genes_groups_dotplot( @_doc_params(params=doc_rank_genes_groups_plot_args, show_save_ax=doc_show_save_ax) def rank_genes_groups_stacked_violin( adata: AnnData, - groups: Union[str, Sequence[str]] = None, - n_genes: Optional[int] = None, - groupby: Optional[str] = None, - gene_symbols: Optional[str] = None, + groups: str | Sequence[str] | None = None, + n_genes: int | None = None, + groupby: str | None = None, + gene_symbols: str | None = None, *, - var_names: Optional[Union[Sequence[str], Mapping[str, Sequence[str]]]] = None, - min_logfoldchange: Optional[float] = None, - key: Optional[str] = None, - show: Optional[bool] = None, - save: Optional[bool] = None, - return_fig: Optional[bool] = False, + var_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None, + min_logfoldchange: float | None = None, + key: str | None = None, + show: bool | None = None, + save: bool | None = None, + return_fig: bool | None = False, **kwds, ): """\ @@ -964,26 +972,25 @@ def rank_genes_groups_stacked_violin( ) def rank_genes_groups_matrixplot( adata: AnnData, - groups: Union[str, Sequence[str]] = None, - n_genes: Optional[int] = None, - groupby: Optional[str] = None, - values_to_plot: Optional[ - Literal[ - "scores", - "logfoldchanges", - "pvals", - "pvals_adj", - "log10_pvals", - "log10_pvals_adj", - ] - ] = None, - var_names: Optional[Union[Sequence[str], Mapping[str, Sequence[str]]]] = None, - gene_symbols: Optional[str] = None, - min_logfoldchange: Optional[float] = None, - key: Optional[str] = None, - show: Optional[bool] = None, - save: Optional[bool] = None, - return_fig: Optional[bool] = False, + groups: str | Sequence[str] | None = None, + n_genes: int | None = None, + groupby: str | None = None, + values_to_plot: Literal[ + "scores", + "logfoldchanges", + "pvals", + "pvals_adj", + "log10_pvals", + "log10_pvals_adj", + ] + | None = None, + var_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None, + gene_symbols: str | None = None, + min_logfoldchange: float | None = None, + key: str | None = None, + show: bool | None = None, + save: bool | None = None, + return_fig: bool | None = False, **kwds, ): """\ @@ -1094,20 +1101,20 @@ def rank_genes_groups_matrixplot( @_doc_params(show_save_ax=doc_show_save_ax) def rank_genes_groups_violin( adata: AnnData, - groups: Optional[Sequence[str]] = None, + groups: Sequence[str] | None = None, n_genes: int = 20, - gene_names: Optional[Iterable[str]] = None, - gene_symbols: Optional[str] = None, - use_raw: Optional[bool] = None, - key: Optional[str] = None, + gene_names: Iterable[str] | None = None, + gene_symbols: str | None = None, + use_raw: bool | None = None, + key: str | None = None, split: bool = True, scale: str = "width", strip: bool = True, - jitter: Union[int, float, bool] = True, + jitter: int | float | bool = True, size: int = 1, - ax: Optional[Axes] = None, - show: Optional[bool] = None, - save: Optional[bool] = None, + ax: Axes | None = None, + show: bool | None = None, + save: bool | None = None, ): """\ Plot ranking of genes for all tested comparisons. @@ -1198,7 +1205,7 @@ def rank_genes_groups_violin( ax=_ax, ) _ax.set_xlabel("genes") - _ax.set_title("{} vs. {}".format(group_name, reference)) + _ax.set_title(f"{group_name} vs. {reference}") _ax.legend_.remove() _ax.set_ylabel("expression") _ax.set_xticklabels(new_gene_names, rotation="vertical") @@ -1215,12 +1222,12 @@ def rank_genes_groups_violin( def sim( adata, - tmax_realization: Optional[int] = None, + tmax_realization: int | None = None, as_heatmap: bool = False, shuffle: bool = False, - show: Optional[bool] = None, - save: Union[bool, str, None] = None, - marker: Union[str, Sequence[str]] = ".", + show: bool | None = None, + save: bool | str | None = None, + marker: str | Sequence[str] = ".", ): """\ Plot results of simulation. @@ -1265,7 +1272,7 @@ def sim( var_names=adata.var_names, highlights_x=np.arange(tmax, n_realizations * tmax, tmax), ) - pl.xticks( + plt.xticks( np.arange(0, n_realizations * tmax, tmax), np.arange(n_realizations).astype(int) + 1, ) @@ -1292,26 +1299,26 @@ def embedding_density( adata: AnnData, # on purpose, there is no asterisk here (for backward compat) basis: str = "umap", # was positional before 1.4.5 - key: Optional[str] = None, # was positional before 1.4.5 - groupby: Optional[str] = None, - group: Optional[Union[str, List[str], None]] = "all", - color_map: Union[Colormap, str] = "YlOrRd", - bg_dotsize: Optional[int] = 80, - fg_dotsize: Optional[int] = 180, - vmax: Optional[int] = 1, - vmin: Optional[int] = 0, - vcenter: Optional[int] = None, - norm: Optional[Normalize] = None, - ncols: Optional[int] = 4, - hspace: Optional[float] = 0.25, - wspace: Optional[None] = None, - title: str = None, - show: Optional[bool] = None, - save: Union[bool, str, None] = None, - ax: Optional[Axes] = None, - return_fig: Optional[bool] = None, + key: str | None = None, # was positional before 1.4.5 + groupby: str | None = None, + group: str | Sequence[str] | None | None = "all", + color_map: Colormap | str = "YlOrRd", + bg_dotsize: int | None = 80, + fg_dotsize: int | None = 180, + vmax: int | None = 1, + vmin: int | None = 0, + vcenter: int | None = None, + norm: Normalize | None = None, + ncols: int | None = 4, + hspace: float | None = 0.25, + wspace: None = None, + title: str | None = None, + show: bool | None = None, + save: bool | str | None = None, + ax: Axes | None = None, + return_fig: bool | None = None, **kwargs, -) -> Union[Figure, Axes, None]: +) -> Figure | Axes | None: """\ Plot the density of cells in an embedding (per condition). @@ -1486,7 +1493,7 @@ def embedding_density( f"Invalid group name: {group_name}" ) - ax = pl.subplot(gs[count]) + ax = plt.subplot(gs[count]) # Define plotting data dot_sizes = np.ones(adata.n_obs) * bg_dotsize group_mask = adata.obs[groupby] == group_name @@ -1575,9 +1582,9 @@ def _get_values_to_plot( "log10_pvals_adj", ], gene_names: Sequence[str], - groups: Optional[Sequence[str]] = None, - key: Optional[str] = "rank_genes_groups", - gene_symbols: Optional[str] = None, + groups: Sequence[str] | None = None, + key: str | None = "rank_genes_groups", + gene_symbols: str | None = None, ): """ If rank_genes_groups has been called, this function diff --git a/scanpy/plotting/_tools/paga.py b/scanpy/plotting/_tools/paga.py index 5078cc89d5..5fab4eae02 100644 --- a/scanpy/plotting/_tools/paga.py +++ b/scanpy/plotting/_tools/paga.py @@ -1,25 +1,32 @@ -import warnings +from __future__ import annotations + import collections.abc as cabc +import warnings from pathlib import Path from types import MappingProxyType -from typing import Optional, Union, List, Sequence, Mapping, Any, Tuple, Literal +from typing import TYPE_CHECKING, Any, Literal import numpy as np import pandas as pd import scipy -from anndata import AnnData +from matplotlib import patheffects, rcParams, ticker +from matplotlib import pyplot as plt +from matplotlib.colors import Colormap, is_color_like from pandas.api.types import CategoricalDtype -from matplotlib import pyplot as pl, rcParams, ticker -from matplotlib import patheffects -from matplotlib.axes import Axes -from matplotlib.colors import is_color_like, Colormap from scipy.sparse import issparse from sklearn.utils import check_random_state -from .. import _utils -from .._utils import matrix, _IGraphLayout, _FontWeight, _FontSize -from ... import _utils as _sc_utils, logging as logg +from ... import _utils as _sc_utils +from ... import logging as logg from ..._settings import settings +from .. import _utils +from .._utils import _FontSize, _FontWeight, _IGraphLayout, matrix + +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + + from anndata import AnnData + from matplotlib.axes import Axes def paga_compare( @@ -32,8 +39,8 @@ def paga_compare( components=None, projection: Literal["2d", "3d"] = "2d", legend_loc="on data", - legend_fontsize: Union[int, float, _FontSize, None] = None, - legend_fontweight: Union[int, _FontWeight] = "bold", + legend_fontsize: int | float | _FontSize | None = None, + legend_fontweight: int | _FontWeight = "bold", legend_fontoutline=None, color_map=None, palette=None, @@ -96,7 +103,7 @@ def paga_compare( else: basis = "umap" - from .scatterplots import embedding, _get_basis, _components_to_dimensions + from .scatterplots import _components_to_dimensions, _get_basis, embedding embedding( adata, @@ -162,7 +169,7 @@ def paga_compare( **paga_graph_params, ) if suptitle is not None: - pl.suptitle(suptitle) + plt.suptitle(suptitle) _utils.savefig_or_show("paga_compare", show=show, save=save) if show is False: return axs @@ -178,6 +185,7 @@ def _compute_pos( layout_kwds: Mapping[str, Any] = MappingProxyType({}), ): import random + import networkx as nx random_state = check_random_state(random_state) @@ -276,47 +284,47 @@ def _compute_pos( def paga( adata: AnnData, - threshold: Optional[float] = None, - color: Optional[Union[str, Mapping[Union[str, int], Mapping[Any, float]]]] = None, - layout: Optional[_IGraphLayout] = None, + threshold: float | None = None, + color: str | Mapping[str | int, Mapping[Any, float]] | None = None, + layout: _IGraphLayout | None = None, layout_kwds: Mapping[str, Any] = MappingProxyType({}), - init_pos: Optional[np.ndarray] = None, - root: Union[int, str, Sequence[int], None] = 0, - labels: Union[str, Sequence[str], Mapping[str, str], None] = None, + init_pos: np.ndarray | None = None, + root: int | str | Sequence[int] | None = 0, + labels: str | Sequence[str] | Mapping[str, str] | None = None, single_component: bool = False, solid_edges: str = "connectivities", - dashed_edges: Optional[str] = None, - transitions: Optional[str] = None, - fontsize: Optional[int] = None, + dashed_edges: str | None = None, + transitions: str | None = None, + fontsize: int | None = None, fontweight: str = "bold", - fontoutline: Optional[int] = None, + fontoutline: int | None = None, text_kwds: Mapping[str, Any] = MappingProxyType({}), node_size_scale: float = 1.0, node_size_power: float = 0.5, edge_width_scale: float = 1.0, - min_edge_width: Optional[float] = None, - max_edge_width: Optional[float] = None, + min_edge_width: float | None = None, + max_edge_width: float | None = None, arrowsize: int = 30, - title: Optional[str] = None, + title: str | None = None, left_margin: float = 0.01, - random_state: Optional[int] = 0, - pos: Union[np.ndarray, str, Path, None] = None, + random_state: int | None = 0, + pos: np.ndarray | str | Path | None = None, normalize_to_color: bool = False, - cmap: Union[str, Colormap] = None, - cax: Optional[Axes] = None, + cmap: str | Colormap = None, + cax: Axes | None = None, colorbar=None, # TODO: this seems to be unused cb_kwds: Mapping[str, Any] = MappingProxyType({}), - frameon: Optional[bool] = None, + frameon: bool | None = None, add_pos: bool = True, export_to_gexf: bool = False, use_raw: bool = True, colors=None, # backwards compat groups=None, # backwards compat plot: bool = True, - show: Optional[bool] = None, - save: Union[bool, str, None] = None, - ax: Optional[Axes] = None, -) -> Union[Axes, List[Axes], None]: + show: bool | None = None, + save: bool | str | None = None, + ax: Axes | None = None, +) -> Axes | list[Axes] | None: """\ Plot the PAGA graph through thresholding low-connectivity edges. @@ -613,12 +621,12 @@ def is_flat(x): width = 0.006 * draw_region_width / len(colors) left = panel_pos[2][2 * icolor + 1] + 0.2 * width rectangle = [left, bottom, width, height] - fig = pl.gcf() + fig = plt.gcf() ax_cb = fig.add_axes(rectangle) else: ax_cb = cax[icolor] - _ = pl.colorbar( + _ = plt.colorbar( sct, format=ticker.FuncFormatter(_utils.ticks_formatter), cax=ax_cb, @@ -982,34 +990,34 @@ def _paga_graph( def paga_path( adata: AnnData, - nodes: Sequence[Union[str, int]], + nodes: Sequence[str | int], keys: Sequence[str], use_raw: bool = True, annotations: Sequence[str] = ("dpt_pseudotime",), - color_map: Union[str, Colormap, None] = None, - color_maps_annotations: Mapping[str, Union[str, Colormap]] = MappingProxyType( + color_map: str | Colormap | None = None, + color_maps_annotations: Mapping[str, str | Colormap] = MappingProxyType( dict(dpt_pseudotime="Greys") ), - palette_groups: Optional[Sequence[str]] = None, + palette_groups: Sequence[str] | None = None, n_avg: int = 1, - groups_key: Optional[str] = None, - xlim: Tuple[Optional[int], Optional[int]] = (None, None), - title: Optional[str] = None, + groups_key: str | None = None, + xlim: tuple[int | None, int | None] = (None, None), + title: str | None = None, left_margin=None, - ytick_fontsize: Optional[int] = None, - title_fontsize: Optional[int] = None, + ytick_fontsize: int | None = None, + title_fontsize: int | None = None, show_node_names: bool = True, show_yticks: bool = True, show_colorbar: bool = True, - legend_fontsize: Union[int, float, _FontSize, None] = None, - legend_fontweight: Union[int, _FontWeight, None] = None, + legend_fontsize: int | float | _FontSize | None = None, + legend_fontweight: int | _FontWeight | None = None, normalize_to_zero_one: bool = False, as_heatmap: bool = True, return_data: bool = False, - show: Optional[bool] = None, - save: Union[bool, str, None] = None, - ax: Optional[Axes] = None, -) -> Optional[Axes]: + show: bool | None = None, + save: bool | str | None = None, + ax: Axes | None = None, +) -> Axes | None: """\ Gene expression and annotation changes along paths in the abstracted graph. @@ -1092,7 +1100,7 @@ def paga_path( def moving_average(a): return _sc_utils.moving_average(a, n_avg) - ax = pl.gca() if ax is None else ax + ax = plt.gca() if ax is None else ax X = [] x_tick_locs = [0] @@ -1181,13 +1189,13 @@ def moving_average(a): ax.tick_params(axis="both", which="both", length=0) ax.grid(False) if show_colorbar: - pl.colorbar(img, ax=ax) + plt.colorbar(img, ax=ax) left_margin = 0.2 if left_margin is None else left_margin - pl.subplots_adjust(left=left_margin) + plt.subplots_adjust(left=left_margin) else: left_margin = 0.4 if left_margin is None else left_margin if len(keys) > 1: - pl.legend( + plt.legend( frameon=False, loc="center left", bbox_to_anchor=(-left_margin, 0.5), @@ -1196,15 +1204,15 @@ def moving_average(a): xlabel = groups_key if not as_heatmap: ax.set_xlabel(xlabel) - pl.yticks([]) + plt.yticks([]) if len(keys) == 1: - pl.ylabel(keys[0] + " (a.u.)") + plt.ylabel(keys[0] + " (a.u.)") else: import matplotlib.colors # groups bar ax_bounds = ax.get_position().bounds - groups_axis = pl.axes( + groups_axis = plt.axes( ( ax_bounds[0], ax_bounds[1] - ax_bounds[3] / len(keys), @@ -1250,7 +1258,7 @@ def moving_average(a): for ianno, anno in enumerate(annotations): if ianno > 0: y_shift = ax_bounds[3] / len(keys) / 2 - anno_axis = pl.axes( + anno_axis = plt.axes( ( ax_bounds[0], ax_bounds[1] - (ianno + 2) * y_shift, @@ -1314,14 +1322,14 @@ def paga_adjacency( matrix(connectivity, color_map=color_map, show=False) for i in range(connectivity_select.shape[0]): neighbors = connectivity_select[i].nonzero()[1] - pl.scatter([i for j in neighbors], neighbors, color="black", s=1) + plt.scatter([i for j in neighbors], neighbors, color="black", s=1) # as a stripplot else: - pl.figure() + plt.figure() for i, cs in enumerate(connectivity): x = [i for j, d in enumerate(cs) if i != j] y = [c for j, c in enumerate(cs) if i != j] - pl.scatter(x, y, color="gray", s=1) + plt.scatter(x, y, color="gray", s=1) neighbors = connectivity_select[i].nonzero()[1] - pl.scatter([i for j in neighbors], cs[neighbors], color="black", s=1) + plt.scatter([i for j in neighbors], cs[neighbors], color="black", s=1) _utils.savefig_or_show("paga_connectivity", show=show, save=save) diff --git a/scanpy/plotting/_tools/scatterplots.py b/scanpy/plotting/_tools/scatterplots.py index 655a106190..b74524d259 100644 --- a/scanpy/plotting/_tools/scatterplots.py +++ b/scanpy/plotting/_tools/scatterplots.py @@ -1,48 +1,32 @@ from __future__ import annotations -import sys -import inspect import collections.abc as cabc +import inspect +import sys +from collections.abc import Mapping, Sequence # noqa: TCH003 from copy import copy -from numbers import Integral -from itertools import combinations, product from functools import partial -from typing import ( - Collection, - Union, - Optional, - Sequence, - Any, - Mapping, - List, - Tuple, - Literal, -) +from itertools import combinations, product +from numbers import Integral +from typing import TYPE_CHECKING, Any, Literal import numpy as np import pandas as pd -from anndata import AnnData -from numpy.typing import NDArray -from pandas.api.types import CategoricalDtype -from cycler import Cycler -from matplotlib.axes import Axes -from matplotlib.figure import Figure -from matplotlib import pyplot as pl, colors, colormaps -from matplotlib import rcParams -from matplotlib import patheffects +from anndata import AnnData # noqa: TCH002 +from cycler import Cycler # noqa: TCH002 +from matplotlib import colormaps, colors, patheffects, rcParams +from matplotlib import pyplot as plt +from matplotlib.axes import Axes # noqa: TCH002 from matplotlib.colors import Colormap, Normalize +from matplotlib.figure import Figure # noqa: TCH002 +from numpy.typing import NDArray # noqa: TCH002 +from pandas.api.types import CategoricalDtype +from ... import logging as logg +from ..._settings import settings +from ..._utils import Empty, _doc_params, _empty, sanitize_anndata +from ...get import _check_mask from .. import _utils -from .._utils import ( - _IGraphLayout, - _FontWeight, - _FontSize, - ColorLike, - VBound, - circles, - check_projection, - check_colornorm, -) from .._docs import ( doc_adata_color_etc, doc_edges_arrows, @@ -50,10 +34,19 @@ doc_scatter_spatial, doc_show_save_ax, ) -from ... import logging as logg -from ..._settings import settings -from ..._utils import sanitize_anndata, _doc_params, Empty, _empty -from ...get import _check_mask +from .._utils import ( + ColorLike, + VBound, + _FontSize, + _FontWeight, + _IGraphLayout, + check_colornorm, + check_projection, + circles, +) + +if TYPE_CHECKING: + from collections.abc import Collection @_doc_params( @@ -66,53 +59,53 @@ def embedding( adata: AnnData, basis: str, *, - color: Union[str, Sequence[str], None] = None, + color: str | Sequence[str] | None = None, mask: NDArray[np.bool_] | str | None = None, - gene_symbols: Optional[str] = None, - use_raw: Optional[bool] = None, + gene_symbols: str | None = None, + use_raw: bool | None = None, sort_order: bool = True, edges: bool = False, edges_width: float = 0.1, - edges_color: Union[str, Sequence[float], Sequence[str]] = "grey", - neighbors_key: Optional[str] = None, + edges_color: str | Sequence[float] | Sequence[str] = "grey", + neighbors_key: str | None = None, arrows: bool = False, - arrows_kwds: Optional[Mapping[str, Any]] = None, + arrows_kwds: Mapping[str, Any] | None = None, groups: str | Sequence[str] | None = None, components: str | Sequence[str] | None = None, - dimensions: Optional[Union[Tuple[int, int], Sequence[Tuple[int, int]]]] = None, - layer: Optional[str] = None, + dimensions: tuple[int, int] | Sequence[tuple[int, int]] | None = None, + layer: str | None = None, projection: Literal["2d", "3d"] = "2d", - scale_factor: Optional[float] = None, - color_map: Union[Colormap, str, None] = None, - cmap: Union[Colormap, str, None] = None, - palette: Union[str, Sequence[str], Cycler, None] = None, + scale_factor: float | None = None, + color_map: Colormap | str | None = None, + cmap: Colormap | str | None = None, + palette: str | Sequence[str] | Cycler | None = None, na_color: ColorLike = "lightgray", na_in_legend: bool = True, - size: Union[float, Sequence[float], None] = None, - frameon: Optional[bool] = None, - legend_fontsize: Union[int, float, _FontSize, None] = None, - legend_fontweight: Union[int, _FontWeight] = "bold", + size: float | Sequence[float] | None = None, + frameon: bool | None = None, + legend_fontsize: int | float | _FontSize | None = None, + legend_fontweight: int | _FontWeight = "bold", legend_loc: str = "right margin", - legend_fontoutline: Optional[int] = None, - colorbar_loc: Optional[str] = "right", - vmax: Union[VBound, Sequence[VBound], None] = None, - vmin: Union[VBound, Sequence[VBound], None] = None, - vcenter: Union[VBound, Sequence[VBound], None] = None, - norm: Union[Normalize, Sequence[Normalize], None] = None, - add_outline: Optional[bool] = False, - outline_width: Tuple[float, float] = (0.3, 0.05), - outline_color: Tuple[str, str] = ("black", "white"), + legend_fontoutline: int | None = None, + colorbar_loc: str | None = "right", + vmax: VBound | Sequence[VBound] | None = None, + vmin: VBound | Sequence[VBound] | None = None, + vcenter: VBound | Sequence[VBound] | None = None, + norm: Normalize | Sequence[Normalize] | None = None, + add_outline: bool | None = False, + outline_width: tuple[float, float] = (0.3, 0.05), + outline_color: tuple[str, str] = ("black", "white"), ncols: int = 4, hspace: float = 0.25, - wspace: Optional[float] = None, - title: Union[str, Sequence[str], None] = None, - show: Optional[bool] = None, - save: Union[bool, str, None] = None, - ax: Optional[Axes] = None, - return_fig: Optional[bool] = None, - marker: Union[str, Sequence[str]] = ".", + wspace: float | None = None, + title: str | Sequence[str] | None = None, + show: bool | None = None, + save: bool | str | None = None, + ax: Axes | None = None, + return_fig: bool | None = None, + marker: str | Sequence[str] = ".", **kwargs, -) -> Union[Figure, Axes, None]: +) -> Figure | Axes | None: """\ Scatter plot for user specified embedding basis (e.g. umap, pca, etc) @@ -255,7 +248,7 @@ def embedding( else: grid = None if ax is None: - fig = pl.figure() + fig = plt.figure() ax = fig.add_subplot(111, **args_3d) ############ @@ -305,7 +298,7 @@ def embedding( # if plotting multiple panels, get the ax from the grid spec # else use the ax value (either user given or created previously) if grid: - ax = pl.subplot(grid[count], **args_3d) + ax = plt.subplot(grid[count], **args_3d) axs.append(ax) if not (settings._frameon if frameon is None else frameon): ax.axis("off") @@ -466,7 +459,7 @@ def embedding( multi_panel=bool(grid), ) elif colorbar_loc is not None: - pl.colorbar( + plt.colorbar( cax, ax=ax, pad=0.01, fraction=0.08, aspect=30, location=colorbar_loc ) @@ -484,7 +477,7 @@ def _panel_grid(hspace, wspace, ncols, num_panels): n_panels_x = min(ncols, num_panels) n_panels_y = np.ceil(num_panels / n_panels_x).astype(int) # each panel will have the size of rcParams['figure.figsize'] - fig = pl.figure( + fig = plt.figure( figsize=( n_panels_x * rcParams["figure.figsize"][0] * (1 + wspace), n_panels_y * rcParams["figure.figsize"][1], @@ -512,7 +505,7 @@ def _get_vboundnorm( norm: Sequence[Normalize], index: int, color_vector: Sequence[float], -) -> Tuple[Union[float, None], Union[float, None]]: +) -> tuple[float | None, float | None]: """ Evaluates the value of vmin, vmax and vcenter, which could be a str in which case is interpreted as a percentile and should @@ -634,7 +627,7 @@ def _wraps_plot_scatter(wrapper): scatter_bulk=doc_scatter_embedding, show_save_ax=doc_show_save_ax, ) -def umap(adata, **kwargs) -> Union[Axes, List[Axes], None]: +def umap(adata, **kwargs) -> Axes | list[Axes] | None: """\ Scatter plot in UMAP basis. @@ -696,7 +689,7 @@ def umap(adata, **kwargs) -> Union[Axes, List[Axes], None]: scatter_bulk=doc_scatter_embedding, show_save_ax=doc_show_save_ax, ) -def tsne(adata, **kwargs) -> Union[Axes, List[Axes], None]: +def tsne(adata, **kwargs) -> Axes | list[Axes] | None: """\ Scatter plot in tSNE basis. @@ -736,7 +729,7 @@ def tsne(adata, **kwargs) -> Union[Axes, List[Axes], None]: scatter_bulk=doc_scatter_embedding, show_save_ax=doc_show_save_ax, ) -def diffmap(adata, **kwargs) -> Union[Axes, List[Axes], None]: +def diffmap(adata, **kwargs) -> Axes | list[Axes] | None: """\ Scatter plot in Diffusion Map basis. @@ -777,8 +770,8 @@ def diffmap(adata, **kwargs) -> Union[Axes, List[Axes], None]: show_save_ax=doc_show_save_ax, ) def draw_graph( - adata: AnnData, *, layout: Optional[_IGraphLayout] = None, **kwargs -) -> Union[Axes, List[Axes], None]: + adata: AnnData, *, layout: _IGraphLayout | None = None, **kwargs +) -> Axes | list[Axes] | None: """\ Scatter plot in graph-drawing basis. @@ -833,11 +826,11 @@ def pca( adata, *, annotate_var_explained: bool = False, - show: Optional[bool] = None, - return_fig: Optional[bool] = None, - save: Union[bool, str, None] = None, + show: bool | None = None, + return_fig: bool | None = None, + save: bool | str | None = None, **kwargs, -) -> Union[Axes, List[Axes], None]: +) -> Axes | list[Axes] | None: """\ Scatter plot in PCA coordinates. @@ -882,7 +875,6 @@ def pca( See also -------- - tl.pca pp.pca """ if not annotate_var_explained: @@ -897,7 +889,7 @@ def pca( ) label_dict = { - "PC{}".format(i + 1): "PC{} ({}%)".format(i + 1, round(v * 100, 2)) + f"PC{i + 1}": f"PC{i + 1} ({round(v * 100, 2)}%)" for i, v in enumerate(adata.uns["pca"]["variance_ratio"]) } @@ -937,21 +929,21 @@ def spatial( adata, *, basis: str = "spatial", - img: Union[np.ndarray, None] = None, - img_key: Union[str, None, Empty] = _empty, - library_id: Union[str, None, Empty] = _empty, - crop_coord: Tuple[int, int, int, int] = None, + img: np.ndarray | None = None, + img_key: str | None | Empty = _empty, + library_id: str | None | Empty = _empty, + crop_coord: tuple[int, int, int, int] | None = None, alpha_img: float = 1.0, - bw: Optional[bool] = False, + bw: bool | None = False, size: float = 1.0, - scale_factor: Optional[float] = None, - spot_size: Optional[float] = None, - na_color: Optional[ColorLike] = None, - show: Optional[bool] = None, - return_fig: Optional[bool] = None, - save: Union[bool, str, None] = None, + scale_factor: float | None = None, + spot_size: float | None = None, + na_color: ColorLike | None = None, + show: bool | None = None, + return_fig: bool | None = None, + save: bool | str | None = None, **kwargs, -) -> Union[Axes, List[Axes], None]: +) -> Axes | list[Axes] | None: """\ Scatter plot in spatial coordinates. @@ -1052,12 +1044,12 @@ def spatial( # Helpers def _components_to_dimensions( - components: Optional[Union[str, Collection[str]]], - dimensions: Optional[Union[Collection[int], Collection[Collection[int]]]], + components: str | Collection[str] | None, + dimensions: Collection[int] | Collection[Collection[int]] | None, *, projection: Literal["2d", "3d"] = "2d", total_dims: int, -) -> List[Collection[int]]: +) -> list[Collection[int]]: """Normalize components/ dimensions args for embedding plots.""" # TODO: Deprecate components kwarg ndims = {"2d": 2, "3d": 3}[projection] @@ -1229,7 +1221,7 @@ def _color_vector( values: np.ndarray | pd.api.extensions.ExtensionArray, palette: str | Sequence[str] | Cycler | None, na_color: ColorLike = "lightgray", -) -> Tuple[np.ndarray | pd.api.extensions.ExtensionArray, bool]: +) -> tuple[np.ndarray | pd.api.extensions.ExtensionArray, bool]: """ Map array of values to array of hex (plus alpha) codes. @@ -1286,9 +1278,7 @@ def _basis2name(basis): return component_name -def _check_spot_size( - spatial_data: Optional[Mapping], spot_size: Optional[float] -) -> float: +def _check_spot_size(spatial_data: Mapping | None, spot_size: float | None) -> float: """ Resolve spot_size value. @@ -1306,9 +1296,9 @@ def _check_spot_size( def _check_scale_factor( - spatial_data: Optional[Mapping], - img_key: Optional[str], - scale_factor: Optional[float], + spatial_data: Mapping | None, + img_key: str | None, + scale_factor: float | None, ) -> float: """Resolve scale_factor, defaults to 1.""" if scale_factor is not None: @@ -1320,8 +1310,8 @@ def _check_scale_factor( def _check_spatial_data( - uns: Mapping, library_id: Union[str, None, Empty] -) -> Tuple[Optional[str], Optional[Mapping]]: + uns: Mapping, library_id: str | None | Empty +) -> tuple[str | None, Mapping | None]: """ Given a mapping, try and extract a library id/ mapping with spatial data. @@ -1346,11 +1336,11 @@ def _check_spatial_data( def _check_img( - spatial_data: Optional[Mapping], - img: Optional[np.ndarray], - img_key: Union[None, str, Empty], + spatial_data: Mapping | None, + img: np.ndarray | None, + img_key: None | str | Empty, bw: bool = False, -) -> Tuple[Optional[np.ndarray], Optional[str]]: +) -> tuple[np.ndarray | None, str | None]: """ Resolve image for spatial plots. """ @@ -1366,9 +1356,9 @@ def _check_img( def _check_crop_coord( - crop_coord: Optional[tuple], + crop_coord: tuple | None, scale_factor: float, -) -> Tuple[float, float, float, float]: +) -> tuple[float, float, float, float]: """Handle cropping with image or basis.""" if crop_coord is None: return None @@ -1379,7 +1369,7 @@ def _check_crop_coord( def _check_na_color( - na_color: Optional[ColorLike], *, img: Optional[np.ndarray] = None + na_color: ColorLike | None, *, img: np.ndarray | None = None ) -> ColorLike: if na_color is None: if img is not None: @@ -1391,7 +1381,6 @@ def _check_na_color( def _broadcast_args(*args): """Broadcasts arguments to a common length.""" - from itertools import repeat lens = [len(arg) for arg in args] longest = max(lens) diff --git a/scanpy/plotting/_utils.py b/scanpy/plotting/_utils.py index 958952e086..d87bb92efd 100644 --- a/scanpy/plotting/_utils.py +++ b/scanpy/plotting/_utils.py @@ -1,33 +1,38 @@ from __future__ import annotations -import warnings import collections.abc as cabc -from typing import Union, List, Sequence, Tuple, Collection, Optional, Callable, Literal -import anndata +import warnings +from collections.abc import Collection, Sequence +from typing import TYPE_CHECKING, Callable, Literal +from typing import Union as _U -import numpy as np import matplotlib as mpl -from matplotlib import pyplot as pl -from matplotlib import rcParams, ticker, gridspec, axes +import numpy as np +from cycler import Cycler, cycler +from matplotlib import axes, gridspec, rcParams, ticker +from matplotlib import pyplot as plt from matplotlib.axes import Axes +from matplotlib.collections import PatchCollection from matplotlib.colors import is_color_like -from matplotlib.figure import SubplotParams as sppars, Figure +from matplotlib.figure import Figure +from matplotlib.figure import SubplotParams as sppars from matplotlib.patches import Circle -from matplotlib.collections import PatchCollection -from cycler import Cycler, cycler from .. import logging as logg from .._settings import settings from .._utils import NeighborsView from . import palettes -ColorLike = Union[str, Tuple[float, ...]] +if TYPE_CHECKING: + import anndata + +ColorLike = _U[str, tuple[float, ...]] _IGraphLayout = Literal["fa", "fr", "rt", "rt_circular", "drl", "eq_tree", ...] _FontWeight = Literal["light", "normal", "medium", "semibold", "bold", "heavy", "black"] _FontSize = Literal[ "xx-small", "x-small", "small", "medium", "large", "x-large", "xx-large" ] -VBound = Union[str, float, Callable[[Sequence[float]], float]] +VBound = _U[str, float, Callable[[Sequence[float]], float]] class _AxesSubplot(Axes, axes.SubplotBase): @@ -54,7 +59,7 @@ def matrix( ): """Plot a matrix.""" if ax is None: - ax = pl.gca() + ax = plt.gca() img = ax.imshow(matrix, cmap=color_map) if xlabel is not None: ax.set_xlabel(xlabel) @@ -66,7 +71,7 @@ def matrix( ax.set_xticks(range(len(xticks)), xticks, rotation="vertical") if yticks is not None: ax.set_yticks(range(len(yticks)), yticks) - pl.colorbar( + plt.colorbar( img, shrink=colorbar_shrink, ax=ax ) # need a figure instance for colorbar savefig_or_show("matrix", show=show, save=save) @@ -74,7 +79,7 @@ def matrix( def timeseries(X, **kwargs): """Plot X. See timeseries_subplot.""" - pl.figure( + plt.figure( figsize=tuple(2 * s for s in rcParams["figure.figsize"]), subplotpars=sppars(left=0.12, right=0.98, bottom=0.13), ) @@ -92,10 +97,10 @@ def timeseries_subplot( yticks=None, xlim=None, legend=True, - palette: Union[Sequence[str], Cycler, None] = None, + palette: Sequence[str] | Cycler | None = None, color_map="viridis", - ax: Optional[Axes] = None, - marker: Union[str, Sequence[str]] = ".", + ax: Axes | None = None, + marker: str | Sequence[str] = ".", ): """\ Plot X. @@ -132,7 +137,7 @@ def timeseries_subplot( marker = [marker[0] for _ in range(len(subsets))] if ax is None: - ax = pl.subplot() + ax = plt.subplot() for i, (x, y) in enumerate(subsets): ax.scatter( x, @@ -199,19 +204,19 @@ def timeseries_as_heatmap( hold = h x_new[:, _hold:] = X[:, hold:] - _, ax = pl.subplots(figsize=(1.5 * 4, 2 * 4)) + _, ax = plt.subplots(figsize=(1.5 * 4, 2 * 4)) img = ax.imshow( np.array(X, dtype=np.float_), aspect="auto", interpolation="nearest", cmap=color_map, ) - pl.colorbar(img, shrink=0.5) - pl.yticks(range(X.shape[0]), var_names) + plt.colorbar(img, shrink=0.5) + plt.yticks(range(X.shape[0]), var_names) for h in highlights_x: - pl.plot([h, h], [0, X.shape[0]], "--", color="black") - pl.xlim([0, X.shape[1] - 1]) - pl.ylim([0, X.shape[0] - 1]) + plt.plot([h, h], [0, X.shape[0]], "--", color="black") + plt.xlim([0, X.shape[1] - 1]) + plt.ylim([0, X.shape[0] - 1]) # ------------------------------------------------------------------------------- @@ -289,15 +294,15 @@ def savefig(writekey, dpi=None, ext=None): filename = settings.figdir / f"{writekey}{settings.plot_suffix}.{ext}" # output the following msg at warning level; it's really important for the user logg.warning(f"saving figure to file {filename}") - pl.savefig(filename, dpi=dpi, bbox_inches="tight") + plt.savefig(filename, dpi=dpi, bbox_inches="tight") def savefig_or_show( writekey: str, - show: Optional[bool] = None, - dpi: Optional[int] = None, + show: bool | None = None, + dpi: int | None = None, ext: str = None, - save: Union[bool, str, None] = None, + save: bool | str | None = None, ): if isinstance(save, str): # check whether `save` contains a figure extension @@ -315,13 +320,13 @@ def savefig_or_show( if save: savefig(writekey, dpi=dpi, ext=ext) if show: - pl.show() + plt.show() if save: - pl.close() # clear figure + plt.close() # clear figure def default_palette( - palette: Union[str, Sequence[str], Cycler, None] = None + palette: str | Sequence[str] | Cycler | None = None ) -> str | Cycler: if palette is None: return rcParams["axes.prop_cycle"] @@ -365,7 +370,7 @@ def _validate_palette(adata: anndata.AnnData, key: str) -> None: def _set_colors_for_categorical_obs( - adata, value_to_plot, palette: Union[str, Sequence[str], Cycler] + adata, value_to_plot, palette: str | Sequence[str] | Cycler ): """ Sets the adata.uns[value_to_plot + '_colors'] according to the given palette @@ -394,9 +399,9 @@ def _set_colors_for_categorical_obs( else: categories = adata.obs[value_to_plot].cat.categories # check is palette is a valid matplotlib colormap - if isinstance(palette, str) and palette in pl.colormaps(): + if isinstance(palette, str) and palette in plt.colormaps(): # this creates a palette from a colormap. E.g. 'Accent, Dark2, tab20' - cmap = pl.get_cmap(palette) + cmap = plt.get_cmap(palette) colors_list = [to_hex(x) for x in cmap(np.linspace(0, 1, len(categories)))] elif isinstance(palette, cabc.Mapping): colors_list = [to_hex(palette[k], keep_alpha=True) for k in categories] @@ -573,7 +578,7 @@ def scatter_group( color = rgb2hex(adata.uns[key + "_colors"][imask]) if not is_color_like(color): - raise ValueError('"{}" is not a valid matplotlib color.'.format(color)) + raise ValueError(f'"{color}" is not a valid matplotlib color.') data = [Y[mask, 0], Y[mask, 1]] if projection == "3d": data.append(Y[mask, 2]) @@ -591,7 +596,7 @@ def scatter_group( def setup_axes( - ax: Union[Axes, Sequence[Axes]] = None, + ax: Axes | Sequence[Axes] = None, panels="blue", colorbars=(False,), right_margin=None, @@ -641,7 +646,7 @@ def setup_axes( ) if ax is None: - pl.figure( + plt.figure( figsize=(figure_width, height), subplotpars=sppars(left=0, right=1, bottom=bottom_offset), ) @@ -662,9 +667,9 @@ def setup_axes( width = draw_region_width / figure_width height = panel_pos[1][0] - bottom if projection == "2d": - ax = pl.axes([left, bottom, width, height]) + ax = plt.axes([left, bottom, width, height]) elif projection == "3d": - ax = pl.axes([left, bottom, width, height], projection="3d") + ax = plt.axes([left, bottom, width, height], projection="3d") axs.append(ax) else: axs = ax if isinstance(ax, cabc.Sequence) else [ax] @@ -691,7 +696,7 @@ def scatter_base( color_map="viridis", show_ticks=True, ax=None, -) -> Union[Axes, List[Axes]]: +) -> Axes | list[Axes]: """Plot scatter plot of data. Parameters @@ -763,9 +768,9 @@ def scatter_base( + (1.2 if projection == "3d" else 0.2) * width ) rectangle = [left, bottom, width, height] - fig = pl.gcf() + fig = plt.gcf() ax_cb = fig.add_axes(rectangle) - _ = pl.colorbar( + _ = plt.colorbar( sct, format=ticker.FuncFormatter(ticks_formatter), cax=ax_cb ) # set the title @@ -1147,14 +1152,14 @@ def circles( def make_grid_spec( - ax_or_figsize: Union[Tuple[int, int], _AxesSubplot], + ax_or_figsize: tuple[int, int] | _AxesSubplot, nrows: int, ncols: int, - wspace: Optional[float] = None, - hspace: Optional[float] = None, - width_ratios: Optional[Sequence[float]] = None, - height_ratios: Optional[Sequence[float]] = None, -) -> Tuple[Figure, gridspec.GridSpecBase]: + wspace: float | None = None, + hspace: float | None = None, + width_ratios: Sequence[float] | None = None, + height_ratios: Sequence[float] | None = None, +) -> tuple[Figure, gridspec.GridSpecBase]: kw = dict( wspace=wspace, hspace=hspace, @@ -1162,7 +1167,7 @@ def make_grid_spec( height_ratios=height_ratios, ) if isinstance(ax_or_figsize, tuple): - fig = pl.figure(figsize=ax_or_figsize) + fig = plt.figure(figsize=ax_or_figsize) return fig, gridspec.GridSpec(nrows, ncols, **kw) else: ax = ax_or_figsize diff --git a/scanpy/plotting/palettes.py b/scanpy/plotting/palettes.py index 6809c5b214..172501619d 100644 --- a/scanpy/plotting/palettes.py +++ b/scanpy/plotting/palettes.py @@ -1,8 +1,13 @@ """Color palettes in addition to matplotlib's palettes.""" +from __future__ import annotations + +from typing import TYPE_CHECKING -from typing import Mapping, Sequence from matplotlib import cm, colors +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + # Colorblindness adjusted vega_10 # See https://github.com/scverse/scanpy/issues/387 vega_10 = list(map(colors.to_hex, cm.tab10.colors)) @@ -182,9 +187,9 @@ def _plot_color_cycle(clists: Mapping[str, Sequence[str]]): - import numpy as np import matplotlib.pyplot as plt - from matplotlib.colors import ListedColormap, BoundaryNorm + import numpy as np + from matplotlib.colors import BoundaryNorm, ListedColormap fig, axes = plt.subplots(nrows=len(clists)) # type: plt.Figure, plt.Axes fig.subplots_adjust(top=0.95, bottom=0.01, left=0.3, right=0.99) diff --git a/scanpy/preprocessing/__init__.py b/scanpy/preprocessing/__init__.py index b811a89cf0..88b4460256 100644 --- a/scanpy/preprocessing/__init__.py +++ b/scanpy/preprocessing/__init__.py @@ -1,11 +1,43 @@ -from ._recipes import recipe_zheng17, recipe_weinreb17, recipe_seurat -from ._simple import filter_cells, filter_genes +from __future__ import annotations + +from ..neighbors import neighbors +from ._combat import combat from ._deprecated.highly_variable_genes import filter_genes_dispersion from ._highly_variable_genes import highly_variable_genes -from ._simple import log1p, sqrt, scale, subsample -from ._simple import normalize_per_cell, regress_out, downsample_counts +from ._normalization import normalize_total from ._pca import pca from ._qc import calculate_qc_metrics -from ._combat import combat -from ._normalization import normalize_total -from ..neighbors import neighbors +from ._recipes import recipe_seurat, recipe_weinreb17, recipe_zheng17 +from ._simple import ( + downsample_counts, + filter_cells, + filter_genes, + log1p, + normalize_per_cell, + regress_out, + scale, + sqrt, + subsample, +) + +__all__ = [ + "neighbors", + "combat", + "filter_genes_dispersion", + "highly_variable_genes", + "normalize_total", + "pca", + "calculate_qc_metrics", + "recipe_seurat", + "recipe_weinreb17", + "recipe_zheng17", + "downsample_counts", + "filter_cells", + "filter_genes", + "log1p", + "normalize_per_cell", + "regress_out", + "scale", + "sqrt", + "subsample", +] diff --git a/scanpy/preprocessing/_combat.py b/scanpy/preprocessing/_combat.py index 0e1b49d979..550e35829e 100644 --- a/scanpy/preprocessing/_combat.py +++ b/scanpy/preprocessing/_combat.py @@ -1,14 +1,20 @@ -from typing import Collection, Tuple, Optional, Union +from __future__ import annotations + +from typing import TYPE_CHECKING -import pandas as pd import numpy as np +import pandas as pd from numpy import linalg as la from scipy.sparse import issparse -from anndata import AnnData from .. import logging as logg from .._utils import sanitize_anndata +if TYPE_CHECKING: + from collections.abc import Collection + + from anndata import AnnData + def _design_matrix( model: pd.DataFrame, batch_key: str, batch_levels: Collection[str] @@ -32,7 +38,7 @@ def _design_matrix( import patsy design = patsy.dmatrix( - "~ 0 + C(Q('{}'), levels=batch_levels)".format(batch_key), + f"~ 0 + C(Q('{batch_key}'), levels=batch_levels)", model, return_type="dataframe", ) @@ -43,9 +49,9 @@ def _design_matrix( other_cols = [c for c in model.columns.values if c not in numerical_covariates] if other_cols: - col_repr = " + ".join("Q('{}')".format(x) for x in other_cols) + col_repr = " + ".join(f"Q('{x}')" for x in other_cols) factor_matrix = patsy.dmatrix( - "~ 0 + {}".format(col_repr), model[other_cols], return_type="dataframe" + f"~ 0 + {col_repr}", model[other_cols], return_type="dataframe" ) design = pd.concat((design, factor_matrix), axis=1) @@ -64,7 +70,7 @@ def _design_matrix( def _standardize_data( model: pd.DataFrame, data: pd.DataFrame, batch_key: str -) -> Tuple[pd.DataFrame, pd.DataFrame, np.ndarray, np.ndarray]: +) -> tuple[pd.DataFrame, pd.DataFrame, np.ndarray, np.ndarray]: """\ Standardizes the data per gene. @@ -131,9 +137,9 @@ def _standardize_data( def combat( adata: AnnData, key: str = "batch", - covariates: Optional[Collection[str]] = None, + covariates: Collection[str] | None = None, inplace: bool = True, -) -> Union[np.ndarray, None]: +) -> np.ndarray | None: """\ ComBat function for batch effect correction [Johnson07]_ [Leek12]_ [Pedersen12]_. @@ -171,14 +177,14 @@ def combat( # check the input if key not in adata.obs_keys(): - raise ValueError("Could not find the key {!r} in adata.obs".format(key)) + raise ValueError(f"Could not find the key {key!r} in adata.obs") if covariates is not None: cov_exist = np.isin(covariates, adata.obs_keys()) if np.any(~cov_exist): missing_cov = np.array(covariates)[~cov_exist].tolist() raise ValueError( - "Could not find the covariate(s) {!r} in adata.obs".format(missing_cov) + f"Could not find the covariate(s) {missing_cov!r} in adata.obs" ) if key in covariates: @@ -287,7 +293,7 @@ def _it_sol( a: float, b: float, conv: float = 0.0001, -) -> Tuple[np.ndarray, np.ndarray]: +) -> tuple[np.ndarray, np.ndarray]: """\ Iteratively compute the conditional posterior means for gamma and delta. diff --git a/scanpy/preprocessing/_deprecated/__init__.py b/scanpy/preprocessing/_deprecated/__init__.py index 79663a88f6..9fff5a3f3c 100644 --- a/scanpy/preprocessing/_deprecated/__init__.py +++ b/scanpy/preprocessing/_deprecated/__init__.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import numpy as np -from scipy.sparse import issparse, csr_matrix +from scipy.sparse import csr_matrix, issparse def normalize_per_cell_weinreb16_deprecated( diff --git a/scanpy/preprocessing/_deprecated/highly_variable_genes.py b/scanpy/preprocessing/_deprecated/highly_variable_genes.py index f49c22e4d8..c7c89c45e1 100644 --- a/scanpy/preprocessing/_deprecated/highly_variable_genes.py +++ b/scanpy/preprocessing/_deprecated/highly_variable_genes.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import warnings -from typing import Optional, Literal +from typing import Literal import numpy as np import pandas as pd -from scipy.sparse import issparse from anndata import AnnData +from scipy.sparse import issparse from ... import logging as logg from .._distributed import materialize_as_ndarray @@ -14,12 +16,12 @@ def filter_genes_dispersion( data: AnnData, flavor: Literal["seurat", "cell_ranger"] = "seurat", - min_disp: Optional[float] = None, - max_disp: Optional[float] = None, - min_mean: Optional[float] = None, - max_mean: Optional[float] = None, + min_disp: float | None = None, + max_disp: float | None = None, + min_mean: float | None = None, + max_mean: float | None = None, n_bins: int = 20, - n_top_genes: Optional[int] = None, + n_top_genes: int | None = None, log: bool = True, subset: bool = True, copy: bool = False, diff --git a/scanpy/preprocessing/_distributed.py b/scanpy/preprocessing/_distributed.py index b5e7161ba5..a134efe758 100644 --- a/scanpy/preprocessing/_distributed.py +++ b/scanpy/preprocessing/_distributed.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np # install dask if available diff --git a/scanpy/preprocessing/_docs.py b/scanpy/preprocessing/_docs.py index c85eef852b..10b02b4b14 100644 --- a/scanpy/preprocessing/_docs.py +++ b/scanpy/preprocessing/_docs.py @@ -1,5 +1,6 @@ """Shared docstrings for preprocessing function parameters. """ +from __future__ import annotations doc_adata_basic = """\ adata diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index 1943448eac..5d13722854 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -1,29 +1,31 @@ +from __future__ import annotations + import warnings -from typing import Optional, Literal +from typing import Literal + import numpy as np import pandas as pd import scipy.sparse as sp_sparse from anndata import AnnData - from .. import logging as logg -from .._settings import settings, Verbosity -from .._utils import sanitize_anndata, check_nonnegative_integers -from ._utils import _get_mean_var +from .._settings import Verbosity, settings +from .._utils import check_nonnegative_integers, sanitize_anndata from ._distributed import materialize_as_ndarray from ._simple import filter_genes +from ._utils import _get_mean_var def _highly_variable_genes_seurat_v3( adata: AnnData, - layer: Optional[str] = None, + layer: str | None = None, n_top_genes: int = 2000, - batch_key: Optional[str] = None, + batch_key: str | None = None, check_values: bool = True, span: float = 0.3, subset: bool = False, inplace: bool = True, -) -> Optional[pd.DataFrame]: +) -> pd.DataFrame | None: """\ See `highly_variable_genes`. @@ -175,12 +177,12 @@ def _highly_variable_genes_seurat_v3( def _highly_variable_genes_single_batch( adata: AnnData, - layer: Optional[str] = None, - min_disp: Optional[float] = 0.5, - max_disp: Optional[float] = np.inf, - min_mean: Optional[float] = 0.0125, - max_mean: Optional[float] = 3, - n_top_genes: Optional[int] = None, + layer: str | None = None, + min_disp: float | None = 0.5, + max_disp: float | None = np.inf, + min_mean: float | None = 0.0125, + max_mean: float | None = 3, + n_top_genes: int | None = None, n_bins: int = 20, flavor: Literal["seurat", "cell_ranger"] = "seurat", ) -> pd.DataFrame: @@ -298,20 +300,20 @@ def _highly_variable_genes_single_batch( def highly_variable_genes( adata: AnnData, - layer: Optional[str] = None, - n_top_genes: Optional[int] = None, - min_disp: Optional[float] = 0.5, - max_disp: Optional[float] = np.inf, - min_mean: Optional[float] = 0.0125, - max_mean: Optional[float] = 3, - span: Optional[float] = 0.3, + layer: str | None = None, + n_top_genes: int | None = None, + min_disp: float | None = 0.5, + max_disp: float | None = np.inf, + min_mean: float | None = 0.0125, + max_mean: float | None = 3, + span: float | None = 0.3, n_bins: int = 20, flavor: Literal["seurat", "cell_ranger", "seurat_v3"] = "seurat", subset: bool = False, inplace: bool = True, - batch_key: Optional[str] = None, + batch_key: str | None = None, check_values: bool = True, -) -> Optional[pd.DataFrame]: +) -> pd.DataFrame | None: """\ Annotate highly variable genes [Satija15]_ [Zheng17]_ [Stuart19]_. diff --git a/scanpy/preprocessing/_normalization.py b/scanpy/preprocessing/_normalization.py index 34ad240024..74ac636a93 100644 --- a/scanpy/preprocessing/_normalization.py +++ b/scanpy/preprocessing/_normalization.py @@ -1,16 +1,21 @@ -from typing import Optional, Union, Iterable, Dict, Literal +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal from warnings import warn import numpy as np -from anndata import AnnData from scipy.sparse import issparse from sklearn.utils import sparsefuncs - from .. import logging as logg +from .._compat import DaskArray from .._utils import view_to_actual from ..get import _get_obs_rep, _set_obs_rep -from .._compat import DaskArray + +if TYPE_CHECKING: + from collections.abc import Iterable + + from anndata import AnnData def _normalize_data(X, counts, after=None, copy=False): @@ -36,16 +41,16 @@ def _normalize_data(X, counts, after=None, copy=False): def normalize_total( adata: AnnData, - target_sum: Optional[float] = None, + target_sum: float | None = None, exclude_highly_expressed: bool = False, max_fraction: float = 0.05, - key_added: Optional[str] = None, - layer: Optional[str] = None, - layers: Union[Literal["all"], Iterable[str]] = None, - layer_norm: Optional[str] = None, + key_added: str | None = None, + layer: str | None = None, + layers: Literal["all"] | Iterable[str] = None, + layer_norm: str | None = None, inplace: bool = True, copy: bool = False, -) -> Optional[Dict[str, np.ndarray]]: +) -> dict[str, np.ndarray] | None: """\ Normalize counts per cell. diff --git a/scanpy/preprocessing/_pca.py b/scanpy/preprocessing/_pca.py index a11c2bd34b..f5d7a34249 100644 --- a/scanpy/preprocessing/_pca.py +++ b/scanpy/preprocessing/_pca.py @@ -1,37 +1,35 @@ from __future__ import annotations import warnings -from typing import Optional, Union from warnings import warn import numpy as np +from anndata import AnnData from packaging import version from scipy.sparse import issparse, spmatrix from scipy.sparse.linalg import LinearOperator, svds from sklearn.utils import check_array, check_random_state from sklearn.utils.extmath import svd_flip -from anndata import AnnData - from .. import logging as logg -from .._settings import settings from .._compat import DaskArray, pkg_version -from .._utils import AnyRandom, Empty, _empty, _doc_params +from .._settings import settings +from .._utils import AnyRandom, Empty, _doc_params, _empty from ..get import _check_mask, _get_obs_rep -from ._utils import _get_mean_var from ._docs import doc_mask_hvg +from ._utils import _get_mean_var @_doc_params( mask_hvg=doc_mask_hvg, ) def pca( - data: Union[AnnData, np.ndarray, spmatrix], - n_comps: Optional[int] = None, + data: AnnData | np.ndarray | spmatrix, + n_comps: int | None = None, *, - layer: Optional[str] = None, - zero_center: Optional[bool] = True, - svd_solver: Optional[str] = None, + layer: str | None = None, + zero_center: bool | None = True, + svd_solver: str | None = None, random_state: AnyRandom = 0, return_info: bool = False, mask: np.ndarray | str | None | Empty = _empty, @@ -39,8 +37,8 @@ def pca( dtype: str = "float32", copy: bool = False, chunked: bool = False, - chunk_size: Optional[int] = None, -) -> Union[AnnData, np.ndarray, spmatrix]: + chunk_size: int | None = None, +) -> AnnData | np.ndarray | spmatrix: """\ Principal component analysis [Pedregosa11]_. @@ -210,15 +208,15 @@ def pca( incremental_pca_kwargs = dict() if is_dask: - from dask_ml.decomposition import IncrementalPCA from dask.array import zeros + from dask_ml.decomposition import IncrementalPCA incremental_pca_kwargs["svd_solver"] = _handle_dask_ml_args( svd_solver, "IncrementalPCA" ) else: - from sklearn.decomposition import IncrementalPCA from numpy import zeros + from sklearn.decomposition import IncrementalPCA X_pca = zeros((X.shape[0], n_comps), X.dtype) diff --git a/scanpy/preprocessing/_qc.py b/scanpy/preprocessing/_qc.py index 98cc7ccdff..2c0a61e87b 100644 --- a/scanpy/preprocessing/_qc.py +++ b/scanpy/preprocessing/_qc.py @@ -1,23 +1,28 @@ -from typing import Optional, Tuple, Collection, Union +from __future__ import annotations + +from typing import TYPE_CHECKING from warnings import warn import numba import numpy as np import pandas as pd -from scipy.sparse import issparse, isspmatrix_csr, isspmatrix_coo -from scipy.sparse import spmatrix, csr_matrix +from scipy.sparse import csr_matrix, issparse, isspmatrix_coo, isspmatrix_csr, spmatrix from sklearn.utils.sparsefuncs import mean_variance_axis -from anndata import AnnData +from .._utils import _doc_params from ._docs import ( + doc_adata_basic, doc_expr_reps, doc_obs_qc_args, - doc_qc_metric_naming, doc_obs_qc_returns, + doc_qc_metric_naming, doc_var_qc_returns, - doc_adata_basic, ) -from .._utils import _doc_params + +if TYPE_CHECKING: + from collections.abc import Collection + + from anndata import AnnData def _choose_mtx_rep(adata, use_raw=False, layer=None): @@ -48,14 +53,14 @@ def describe_obs( expr_type: str = "counts", var_type: str = "genes", qc_vars: Collection[str] = (), - percent_top: Optional[Collection[int]] = (50, 100, 200, 500), - layer: Optional[str] = None, + percent_top: Collection[int] | None = (50, 100, 200, 500), + layer: str | None = None, use_raw: bool = False, - log1p: Optional[str] = True, + log1p: str | None = True, inplace: bool = False, X=None, parallel=None, -) -> Optional[pd.DataFrame]: +) -> pd.DataFrame | None: """\ Describe observations of anndata. @@ -148,12 +153,12 @@ def describe_var( *, expr_type: str = "counts", var_type: str = "genes", - layer: Optional[str] = None, + layer: str | None = None, use_raw: bool = False, inplace=False, log1p=True, X=None, -) -> Optional[pd.DataFrame]: +) -> pd.DataFrame | None: """\ Describe variables of anndata. @@ -229,13 +234,13 @@ def calculate_qc_metrics( expr_type: str = "counts", var_type: str = "genes", qc_vars: Collection[str] = (), - percent_top: Optional[Collection[int]] = (50, 100, 200, 500), - layer: Optional[str] = None, + percent_top: Collection[int] | None = (50, 100, 200, 500), + layer: str | None = None, use_raw: bool = False, inplace: bool = False, log1p: bool = True, - parallel: Optional[bool] = None, -) -> Optional[Tuple[pd.DataFrame, pd.DataFrame]]: + parallel: bool | None = None, +) -> tuple[pd.DataFrame, pd.DataFrame] | None: """\ Calculate quality control metrics. @@ -326,7 +331,7 @@ def calculate_qc_metrics( return obs_metrics, var_metrics -def top_proportions(mtx: Union[np.array, spmatrix], n: int): +def top_proportions(mtx: np.array | spmatrix, n: int): """\ Calculates cumulative proportions of top expressed genes @@ -378,7 +383,7 @@ def top_proportions_sparse_csr(data, indptr, n): def top_segment_proportions( - mtx: Union[np.array, spmatrix], ns: Collection[int] + mtx: np.array | spmatrix, ns: Collection[int] ) -> np.ndarray: """ Calculates total percentage of counts in top ns genes. @@ -404,7 +409,7 @@ def top_segment_proportions( def top_segment_proportions_dense( - mtx: Union[np.array, spmatrix], ns: Collection[int] + mtx: np.array | spmatrix, ns: Collection[int] ) -> np.ndarray: # Currently ns is considered to be 1 indexed ns = np.sort(ns) diff --git a/scanpy/preprocessing/_recipes.py b/scanpy/preprocessing/_recipes.py index c400b71627..8a70b3094e 100644 --- a/scanpy/preprocessing/_recipes.py +++ b/scanpy/preprocessing/_recipes.py @@ -1,17 +1,20 @@ """Preprocessing recipes from the literature""" -from typing import Optional - -from anndata import AnnData +from __future__ import annotations +from typing import TYPE_CHECKING +from .. import logging as logg from .. import preprocessing as pp from ._deprecated.highly_variable_genes import ( - filter_genes_dispersion, filter_genes_cv_deprecated, + filter_genes_dispersion, ) from ._normalization import normalize_total -from .. import logging as logg -from .._utils import AnyRandom + +if TYPE_CHECKING: + from anndata import AnnData + + from .._utils import AnyRandom def recipe_weinreb17( @@ -23,7 +26,7 @@ def recipe_weinreb17( svd_solver="randomized", random_state: AnyRandom = 0, copy: bool = False, -) -> Optional[AnnData]: +) -> AnnData | None: """\ Normalization and filtering as of [Weinreb17]_. @@ -39,9 +42,10 @@ def recipe_weinreb17( copy Return a copy if true. """ - from ._deprecated import normalize_per_cell_weinreb16_deprecated, zscore_deprecated from scipy.sparse import issparse + from ._deprecated import normalize_per_cell_weinreb16_deprecated, zscore_deprecated + if issparse(adata.X): raise ValueError("`recipe_weinreb16 does not support sparse matrices.") if copy: @@ -66,7 +70,7 @@ def recipe_weinreb17( def recipe_seurat( adata: AnnData, log: bool = True, plot: bool = False, copy: bool = False -) -> Optional[AnnData]: +) -> AnnData | None: """\ Normalization and filtering as of Seurat [Satija15]_. @@ -86,7 +90,7 @@ def recipe_seurat( if plot: from ..plotting import ( _preprocessing as ppp, - ) # should not import at the top of the file + ) ppp.filter_genes_dispersion(filter_result, log=not log) adata._inplace_subset_var(filter_result.gene_subset) # filter genes @@ -102,7 +106,7 @@ def recipe_zheng17( log: bool = True, plot: bool = False, copy: bool = False, -) -> Optional[AnnData]: +) -> AnnData | None: """\ Normalization and filtering as of [Zheng17]_. diff --git a/scanpy/preprocessing/_simple.py b/scanpy/preprocessing/_simple.py index 362cef94ce..6b2f28bd52 100644 --- a/scanpy/preprocessing/_simple.py +++ b/scanpy/preprocessing/_simple.py @@ -4,30 +4,28 @@ """ from __future__ import annotations -from functools import singledispatch -from numbers import Number import warnings -from typing import Union, Optional, Tuple, Collection, Sequence, Iterable, Literal +from functools import singledispatch +from typing import TYPE_CHECKING, Literal import numba import numpy as np import scipy as sp -from numpy.typing import NDArray -from scipy.sparse import issparse, isspmatrix_csr, csr_matrix, spmatrix -from sklearn.utils import sparsefuncs, check_array -from pandas.api.types import CategoricalDtype from anndata import AnnData +from pandas.api.types import CategoricalDtype +from scipy.sparse import csr_matrix, issparse, isspmatrix_csr, spmatrix +from sklearn.utils import check_array, sparsefuncs from .. import logging as logg from .._settings import settings as sett from .._utils import ( - sanitize_anndata, - deprecated_arg_names, - view_to_actual, AnyRandom, _check_array_function_arguments, + deprecated_arg_names, + sanitize_anndata, + view_to_actual, ) -from ..get import _get_obs_rep, _set_obs_rep, _check_mask +from ..get import _check_mask, _get_obs_rep, _set_obs_rep from ._distributed import materialize_as_ndarray from ._utils import _get_mean_var @@ -38,18 +36,24 @@ da = None # backwards compat -from ._deprecated.highly_variable_genes import filter_genes_dispersion +from ._deprecated.highly_variable_genes import filter_genes_dispersion # noqa: F401 + +if TYPE_CHECKING: + from collections.abc import Collection, Iterable, Sequence + from numbers import Number + + from numpy.typing import NDArray def filter_cells( data: AnnData, - min_counts: Optional[int] = None, - min_genes: Optional[int] = None, - max_counts: Optional[int] = None, - max_genes: Optional[int] = None, + min_counts: int | None = None, + min_genes: int | None = None, + max_counts: int | None = None, + max_genes: int | None = None, inplace: bool = True, copy: bool = False, -) -> Optional[Tuple[np.ndarray, np.ndarray]]: +) -> tuple[np.ndarray, np.ndarray] | None: """\ Filter cell outliers based on counts and numbers of genes expressed. @@ -177,13 +181,13 @@ def filter_cells( def filter_genes( data: AnnData, - min_counts: Optional[int] = None, - min_cells: Optional[int] = None, - max_counts: Optional[int] = None, - max_cells: Optional[int] = None, + min_counts: int | None = None, + min_cells: int | None = None, + max_counts: int | None = None, + max_cells: int | None = None, inplace: bool = True, copy: bool = False, -) -> Union[AnnData, None, Tuple[np.ndarray, np.ndarray]]: +) -> AnnData | None | tuple[np.ndarray, np.ndarray]: """\ Filter genes based on number of cells or counts. @@ -285,14 +289,14 @@ def filter_genes( @singledispatch def log1p( - X: Union[AnnData, np.ndarray, spmatrix], + X: AnnData | np.ndarray | spmatrix, *, - base: Optional[Number] = None, + base: Number | None = None, copy: bool = False, chunked: bool = None, - chunk_size: Optional[int] = None, - layer: Optional[str] = None, - obsm: Optional[str] = None, + chunk_size: int | None = None, + layer: str | None = None, + obsm: str | None = None, ): """\ Logarithmize the data matrix. @@ -331,7 +335,7 @@ def log1p( @log1p.register(spmatrix) -def log1p_sparse(X, *, base: Optional[Number] = None, copy: bool = False): +def log1p_sparse(X, *, base: Number | None = None, copy: bool = False): X = check_array( X, accept_sparse=("csr", "csc"), dtype=(np.float64, np.float32), copy=copy ) @@ -340,7 +344,7 @@ def log1p_sparse(X, *, base: Optional[Number] = None, copy: bool = False): @log1p.register(np.ndarray) -def log1p_array(X, *, base: Optional[Number] = None, copy: bool = False): +def log1p_array(X, *, base: Number | None = None, copy: bool = False): # Can force arrays to be np.ndarrays, but would be useful to not # X = check_array(X, dtype=(np.float64, np.float32), ensure_2d=False, copy=copy) if copy: @@ -360,13 +364,13 @@ def log1p_array(X, *, base: Optional[Number] = None, copy: bool = False): def log1p_anndata( adata, *, - base: Optional[Number] = None, + base: Number | None = None, copy: bool = False, chunked: bool = False, - chunk_size: Optional[int] = None, - layer: Optional[str] = None, - obsm: Optional[str] = None, -) -> Optional[AnnData]: + chunk_size: int | None = None, + layer: str | None = None, + obsm: str | None = None, +) -> AnnData | None: if "log1p" in adata.uns_keys(): logg.warning("adata.X seems to be already log-transformed.") @@ -394,8 +398,8 @@ def sqrt( data: AnnData, copy: bool = False, chunked: bool = False, - chunk_size: Optional[int] = None, -) -> Optional[AnnData]: + chunk_size: int | None = None, +) -> AnnData | None: """\ Square root the data matrix. @@ -435,15 +439,15 @@ def sqrt( def normalize_per_cell( - data: Union[AnnData, np.ndarray, spmatrix], - counts_per_cell_after: Optional[float] = None, - counts_per_cell: Optional[np.ndarray] = None, + data: AnnData | np.ndarray | spmatrix, + counts_per_cell_after: float | None = None, + counts_per_cell: np.ndarray | None = None, key_n_counts: str = "n_counts", copy: bool = False, - layers: Union[Literal["all"], Iterable[str]] = (), - use_rep: Optional[Literal["after", "X"]] = None, + layers: Literal["all"] | Iterable[str] = (), + use_rep: Literal["after", "X"] | None = None, min_counts: int = 1, -) -> Optional[AnnData]: +) -> AnnData | None: """\ Normalize total counts per cell. @@ -572,11 +576,11 @@ def normalize_per_cell( def regress_out( adata: AnnData, - keys: Union[str, Sequence[str]], - layer: Optional[str] = None, - n_jobs: Optional[int] = None, + keys: str | Sequence[str], + layer: str | None = None, + n_jobs: int | None = None, copy: bool = False, -) -> Optional[AnnData]: +) -> AnnData | None: """\ Regress out (mostly) unwanted sources of variation. @@ -903,12 +907,12 @@ def scale_anndata( def subsample( - data: Union[AnnData, np.ndarray, spmatrix], - fraction: Optional[float] = None, - n_obs: Optional[int] = None, + data: AnnData | np.ndarray | spmatrix, + fraction: float | None = None, + n_obs: int | None = None, random_state: AnyRandom = 0, copy: bool = False, -) -> Optional[AnnData]: +) -> AnnData | None: """\ Subsample to a fraction of the number of observations. @@ -966,13 +970,13 @@ def subsample( @deprecated_arg_names({"target_counts": "counts_per_cell"}) def downsample_counts( adata: AnnData, - counts_per_cell: Optional[Union[int, Collection[int]]] = None, - total_counts: Optional[int] = None, + counts_per_cell: int | Collection[int] | None = None, + total_counts: int | None = None, *, random_state: AnyRandom = 0, replace: bool = False, copy: bool = False, -) -> Optional[AnnData]: +) -> AnnData | None: """\ Downsample counts from count matrix. diff --git a/scanpy/preprocessing/_utils.py b/scanpy/preprocessing/_utils.py index d826d40af6..9df47cb981 100644 --- a/scanpy/preprocessing/_utils.py +++ b/scanpy/preprocessing/_utils.py @@ -1,14 +1,16 @@ from __future__ import annotations -from typing import Literal +from typing import TYPE_CHECKING, Literal -import numpy as np import numba -from numpy.typing import NDArray +import numpy as np from scipy import sparse from .._utils import _SupportedArray, elem_mul +if TYPE_CHECKING: + from numpy.typing import NDArray + def _get_mean_var( X: _SupportedArray, *, axis: Literal[0, 1] = 0 diff --git a/scanpy/queries/__init__.py b/scanpy/queries/__init__.py index 5caf500a4c..6080f88ef8 100644 --- a/scanpy/queries/__init__.py +++ b/scanpy/queries/__init__.py @@ -1,6 +1,16 @@ +# Biomart queries +from __future__ import annotations + from ._queries import ( biomart_annotations, + enrich, # gprofiler queries gene_coordinates, mitochondrial_genes, -) # Biomart queries -from ._queries import enrich # gprofiler queries +) + +__all__ = [ + "biomart_annotations", + "enrich", + "gene_coordinates", + "mitochondrial_genes", +] diff --git a/scanpy/queries/_queries.py b/scanpy/queries/_queries.py index e4b6610feb..b9fe05a2ba 100644 --- a/scanpy/queries/_queries.py +++ b/scanpy/queries/_queries.py @@ -1,15 +1,20 @@ +from __future__ import annotations + import collections.abc as cabc from functools import singledispatch from types import MappingProxyType -from typing import Any, Union, Optional, Iterable, Dict, Mapping +from typing import TYPE_CHECKING, Any -import pandas as pd from anndata import AnnData -from ..testing._doctests import doctest_needs -from ..get import rank_genes_groups_df from .._utils import _doc_params +from ..get import rank_genes_groups_df +from ..testing._doctests import doctest_needs + +if TYPE_CHECKING: + from collections.abc import Iterable, Mapping + import pandas as pd _doc_org = """\ org @@ -33,9 +38,9 @@ @_doc_params(doc_org=_doc_org, doc_host=_doc_host, doc_use_cache=_doc_use_cache) def simple_query( org: str, - attrs: Union[Iterable[str], str], + attrs: Iterable[str] | str, *, - filters: Optional[Dict[str, Any]] = None, + filters: dict[str, Any] | None = None, host: str = "www.ensembl.org", use_cache: bool = False, ) -> pd.DataFrame: @@ -65,9 +70,7 @@ def simple_query( "This method requires the `pybiomart` module to be installed." ) server = Server(host, use_cache=use_cache) - dataset = server.marts["ENSEMBL_MART_ENSEMBL"].datasets[ - "{}_gene_ensembl".format(org) - ] + dataset = server.marts["ENSEMBL_MART_ENSEMBL"].datasets[f"{org}_gene_ensembl"] res = dataset.query(attributes=attrs, filters=filters, use_attr_names=True) return res @@ -205,7 +208,7 @@ def mitochondrial_genes( @singledispatch @_doc_params(doc_org=_doc_org) def enrich( - container: Union[Iterable[str], Mapping[str, Iterable[str]]], + container: Iterable[str] | Mapping[str, Iterable[str]], *, org: str = "hsapiens", gprofiler_kwargs: Mapping[str, Any] = MappingProxyType({}), @@ -288,12 +291,12 @@ def _enrich_anndata( adata: AnnData, group: str, *, - org: Optional[str] = "hsapiens", + org: str | None = "hsapiens", key: str = "rank_genes_groups", pval_cutoff: float = 0.05, - log2fc_min: Optional[float] = None, - log2fc_max: Optional[float] = None, - gene_symbols: Optional[str] = None, + log2fc_min: float | None = None, + log2fc_max: float | None = None, + gene_symbols: str | None = None, gprofiler_kwargs: Mapping[str, Any] = MappingProxyType({}), ) -> pd.DataFrame: de = rank_genes_groups_df( diff --git a/scanpy/readwrite.py b/scanpy/readwrite.py index 43a4d980f5..b94b540050 100644 --- a/scanpy/readwrite.py +++ b/scanpy/readwrite.py @@ -1,28 +1,30 @@ """Reading and Writing """ +from __future__ import annotations + +import json from pathlib import Path, PurePath -from typing import Union, Dict, Optional, Tuple, BinaryIO, Literal +from typing import BinaryIO, Literal +import anndata import h5py -import json import numpy as np import pandas as pd -from matplotlib.image import imread -import anndata from anndata import ( AnnData, read_csv, - read_text, read_excel, - read_mtx, - read_loom, read_hdf, + read_loom, + read_mtx, + read_text, ) from anndata import read as read_h5ad +from matplotlib.image import imread +from . import logging as logg from ._settings import settings from ._utils import Empty, _empty -from . import logging as logg # .gz and .bz2 suffixes are also allowed for text formats text_exts = { @@ -51,15 +53,15 @@ def read( - filename: Union[Path, str], - backed: Optional[Literal["r", "r+"]] = None, - sheet: Optional[str] = None, - ext: Optional[str] = None, - delimiter: Optional[str] = None, + filename: Path | str, + backed: Literal["r", "r+"] | None = None, + sheet: str | None = None, + ext: str | None = None, + delimiter: str | None = None, first_column_names: bool = False, - backup_url: Optional[str] = None, + backup_url: str | None = None, cache: bool = False, - cache_compression: Union[Literal["gzip", "lzf"], None, Empty] = _empty, + cache_compression: Literal["gzip", "lzf"] | None | Empty = _empty, **kwargs, ) -> AnnData: """\ @@ -135,10 +137,10 @@ def read( def read_10x_h5( - filename: Union[str, Path], - genome: Optional[str] = None, + filename: str | Path, + genome: str | None = None, gex_only: bool = True, - backup_url: Optional[str] = None, + backup_url: str | None = None, ) -> AnnData: """\ Read 10x-Genomics-formatted hdf5 file. @@ -334,13 +336,13 @@ def _read_v3_10x_h5(filename, *, start=None): def read_visium( - path: Union[str, Path], - genome: Optional[str] = None, + path: str | Path, + genome: str | None = None, *, count_file: str = "filtered_feature_bc_matrix.h5", - library_id: Optional[str] = None, - load_images: Optional[bool] = True, - source_image_path: Optional[Union[str, Path]] = None, + library_id: str | None = None, + load_images: bool | None = True, + source_image_path: str | Path | None = None, ) -> AnnData: """\ Read 10x-Genomics-formatted visum dataset. @@ -493,11 +495,11 @@ def read_visium( def read_10x_mtx( - path: Union[Path, str], + path: Path | str, var_names: Literal["gene_symbols", "gene_ids"] = "gene_symbols", make_unique: bool = True, cache: bool = False, - cache_compression: Union[Literal["gzip", "lzf"], None, Empty] = _empty, + cache_compression: Literal["gzip", "lzf"] | None | Empty = _empty, gex_only: bool = True, *, prefix: str = None, @@ -626,11 +628,11 @@ def _read_v3_10x_mtx( def write( - filename: Union[str, Path], + filename: str | Path, adata: AnnData, - ext: Optional[Literal["h5", "csv", "txt", "npz"]] = None, - compression: Optional[Literal["gzip", "lzf"]] = "gzip", - compression_opts: Optional[int] = None, + ext: Literal["h5", "csv", "txt", "npz"] | None = None, + compression: Literal["gzip", "lzf"] | None = "gzip", + compression_opts: int | None = None, ): """\ Write :class:`~anndata.AnnData` objects to file. @@ -682,8 +684,8 @@ def write( def read_params( - filename: Union[Path, str], asheader: bool = False -) -> Dict[str, Union[int, float, bool, str, None]]: + filename: Path | str, asheader: bool = False +) -> dict[str, int | float | bool | str | None]: """\ Read parameter dictionary from text file. @@ -705,11 +707,11 @@ def read_params( ------- Dictionary that stores parameters. """ - filename = str(filename) # allow passing pathlib.Path objects + filename = Path(filename) # allow passing str objects from collections import OrderedDict params = OrderedDict([]) - for line in open(filename): + for line in filename.open(): if "=" in line: if not asheader or line.startswith("#"): line = line[1:] if line.startswith("#") else line @@ -720,7 +722,7 @@ def read_params( return params -def write_params(path: Union[Path, str], *args, **maps): +def write_params(path: Path | str, *args, **maps): """\ Write parameters to file, so that it's readable by read_params. @@ -832,7 +834,7 @@ def _read( return adata -def _slugify(path: Union[str, PurePath]) -> str: +def _slugify(path: str | PurePath) -> str: """Make a path into a filename.""" if not isinstance(path, PurePath): path = PurePath(path) @@ -847,7 +849,7 @@ def _slugify(path: Union[str, PurePath]) -> str: return filename -def _read_softgz(filename: Union[str, bytes, Path, BinaryIO]) -> AnnData: +def _read_softgz(filename: str | bytes | Path | BinaryIO) -> AnnData: """\ Read a SOFT format data file. @@ -933,7 +935,7 @@ def is_int(string: str) -> bool: return False -def convert_bool(string: str) -> Tuple[bool, bool]: +def convert_bool(string: str) -> tuple[bool, bool]: """Check whether string is boolean.""" if string == "True": return True, True @@ -943,7 +945,7 @@ def convert_bool(string: str) -> Tuple[bool, bool]: return False, False -def convert_string(string: str) -> Union[int, float, bool, str, None]: +def convert_string(string: str) -> int | float | bool | str | None: """Convert string to int, float or bool.""" if is_int(string): return int(string) @@ -989,13 +991,13 @@ def _get_filename_from_key(key, ext=None) -> Path: def _download(url: str, path: Path): try: - import ipywidgets + import ipywidgets # noqa: F401 from tqdm.auto import tqdm except ImportError: from tqdm import tqdm - from urllib.request import urlopen, Request from urllib.error import URLError + from urllib.request import Request, urlopen blocksize = 1024 * 8 blocknum = 0 @@ -1010,9 +1012,10 @@ def _download(url: str, path: Path): "Failed to open the url with default certificates, trying with certifi." ) - from certifi import where from ssl import create_default_context + from certifi import where + open_url = urlopen(req, context=create_default_context(cafile=where())) with open_url as resp: diff --git a/scanpy/testing/_doctests.py b/scanpy/testing/_doctests.py index 45f52da253..ac711d4263 100644 --- a/scanpy/testing/_doctests.py +++ b/scanpy/testing/_doctests.py @@ -1,9 +1,10 @@ from __future__ import annotations from types import FunctionType -from typing import TypeVar -from collections.abc import Callable +from typing import TYPE_CHECKING, TypeVar +if TYPE_CHECKING: + from collections.abc import Callable F = TypeVar("F", bound=FunctionType) diff --git a/scanpy/testing/_helpers/__init__.py b/scanpy/testing/_helpers/__init__.py index b89c9ba256..939ce60c53 100644 --- a/scanpy/testing/_helpers/__init__.py +++ b/scanpy/testing/_helpers/__init__.py @@ -1,14 +1,15 @@ """ This file contains helper functions for the scanpy test suite. """ +from __future__ import annotations +import warnings from itertools import permutations -import scanpy as sc import numpy as np -import warnings from anndata.tests.helpers import asarray, assert_equal +import scanpy as sc # TODO: Report more context on the fields being compared on error # TODO: Allow specifying paths to ignore on comparison diff --git a/scanpy/testing/_helpers/data.py b/scanpy/testing/_helpers/data.py index 66cc5649bd..a0f109aa87 100644 --- a/scanpy/testing/_helpers/data.py +++ b/scanpy/testing/_helpers/data.py @@ -14,9 +14,12 @@ def cache(func): return lru_cache(maxsize=None)(func) -from anndata import AnnData +from typing import TYPE_CHECKING + import scanpy as sc +if TYPE_CHECKING: + from anndata import AnnData # Functions returning the same objects (easy to misuse) diff --git a/scanpy/testing/_pytest/__init__.py b/scanpy/testing/_pytest/__init__.py index 6df15005ce..a87b752daa 100644 --- a/scanpy/testing/_pytest/__init__.py +++ b/scanpy/testing/_pytest/__init__.py @@ -1,14 +1,15 @@ """A private pytest plugin""" from __future__ import annotations -from collections.abc import Iterable, Generator import sys -from typing import Any +from typing import TYPE_CHECKING, Any import pytest from .fixtures import * # noqa: F403 +if TYPE_CHECKING: + from collections.abc import Generator, Iterable doctest_env_marker = pytest.mark.usefixtures("doctest_env") @@ -17,17 +18,18 @@ @pytest.fixture(autouse=True) def global_test_context() -> Generator[None, None, None]: """Switch to agg backend, reset settings, and close all figures at teardown.""" - from matplotlib import pyplot + from matplotlib import pyplot as plt + from scanpy import settings - pyplot.switch_backend("agg") + plt.switch_backend("agg") settings.logfile = sys.stderr settings.verbosity = "hint" settings.autoshow = True yield - pyplot.close("all") + plt.close("all") def pytest_addoption(parser: pytest.Parser) -> None: diff --git a/scanpy/testing/_pytest/fixtures/__init__.py b/scanpy/testing/_pytest/fixtures/__init__.py index 1fed07c646..780568d3b4 100644 --- a/scanpy/testing/_pytest/fixtures/__init__.py +++ b/scanpy/testing/_pytest/fixtures/__init__.py @@ -4,16 +4,19 @@ """ from __future__ import annotations -from pathlib import Path +from typing import TYPE_CHECKING -import pytest import numpy as np +import pytest + from .data import ( _pbmc3ks_parametrized_session, pbmc3k_parametrized, pbmc3k_parametrized_small, ) +if TYPE_CHECKING: + from pathlib import Path __all__ = [ "float_dtype", diff --git a/scanpy/testing/_pytest/fixtures/data.py b/scanpy/testing/_pytest/fixtures/data.py index f66c03917a..58523e9208 100644 --- a/scanpy/testing/_pytest/fixtures/data.py +++ b/scanpy/testing/_pytest/fixtures/data.py @@ -3,11 +3,16 @@ from __future__ import annotations from itertools import product -from collections.abc import Callable -import pytest +from typing import TYPE_CHECKING + import numpy as np +import pytest from scipy import sparse -from anndata import AnnData + +if TYPE_CHECKING: + from collections.abc import Callable + + from anndata import AnnData @pytest.fixture( diff --git a/scanpy/testing/_pytest/marks.py b/scanpy/testing/_pytest/marks.py index 476161e0ff..99a4bfce4c 100644 --- a/scanpy/testing/_pytest/marks.py +++ b/scanpy/testing/_pytest/marks.py @@ -1,8 +1,8 @@ from __future__ import annotations +import sys from enum import Enum, auto from importlib.util import find_spec -import sys import pytest diff --git a/scanpy/testing/_pytest/params.py b/scanpy/testing/_pytest/params.py index 2e5171c1ac..3ff543bfa7 100644 --- a/scanpy/testing/_pytest/params.py +++ b/scanpy/testing/_pytest/params.py @@ -2,16 +2,17 @@ from __future__ import annotations -from collections.abc import Iterable from typing import TYPE_CHECKING, Literal import pytest +from anndata.tests.helpers import as_dense_dask_array, as_sparse_dask_array, asarray from scipy import sparse -from anndata.tests.helpers import asarray, as_dense_dask_array, as_sparse_dask_array from .._pytest.marks import needs if TYPE_CHECKING: + from collections.abc import Iterable + from _pytest.mark.structures import ParameterSet diff --git a/scanpy/tests/_scripts/scanpy-testbin b/scanpy/tests/_scripts/scanpy-testbin index 5422f0d1e6..f7ed95336e 100755 --- a/scanpy/tests/_scripts/scanpy-testbin +++ b/scanpy/tests/_scripts/scanpy-testbin @@ -1,6 +1,6 @@ #!/usr/bin/env python3 +from __future__ import annotations import sys - print("test", *sys.argv[1:]) diff --git a/scanpy/tests/conftest.py b/scanpy/tests/conftest.py index 67879fabfe..5e91634b6d 100644 --- a/scanpy/tests/conftest.py +++ b/scanpy/tests/conftest.py @@ -1,6 +1,5 @@ from __future__ import annotations -import os import sys from pathlib import Path from typing import TYPE_CHECKING @@ -11,6 +10,8 @@ import scanpy as _sc # noqa: F401 if TYPE_CHECKING: # So editors understand that we’re using those fixtures + import os + from scanpy.testing._pytest.fixtures import * # noqa: F403 # define this after importing scanpy but before running tests @@ -71,7 +72,7 @@ def fmt_descr(descr): @pytest.fixture def image_comparer(check_same_image): - from matplotlib import pyplot + from matplotlib import pyplot as plt def save_and_compare(*path_parts: Path | os.PathLike, tol: int): base_pth = Path(*path_parts) @@ -80,8 +81,8 @@ def save_and_compare(*path_parts: Path | os.PathLike, tol: int): base_pth.mkdir() expected_pth = base_pth / "expected.png" actual_pth = base_pth / "actual.png" - pyplot.savefig(actual_pth, dpi=40) - pyplot.close() + plt.savefig(actual_pth, dpi=40) + plt.close() if not expected_pth.is_file(): raise OSError(f"No expected output found at {expected_pth}.") check_same_image(expected_pth, actual_pth, tol=tol) @@ -91,9 +92,9 @@ def save_and_compare(*path_parts: Path | os.PathLike, tol: int): @pytest.fixture def plt(): - from matplotlib import pyplot + from matplotlib import pyplot as plt - return pyplot + return plt @pytest.fixture diff --git a/scanpy/tests/external/test_harmony_integrate.py b/scanpy/tests/external/test_harmony_integrate.py index b379521936..42993dbaa1 100644 --- a/scanpy/tests/external/test_harmony_integrate.py +++ b/scanpy/tests/external/test_harmony_integrate.py @@ -1,9 +1,10 @@ +from __future__ import annotations + import scanpy as sc import scanpy.external as sce from scanpy.testing._helpers.data import pbmc3k from scanpy.testing._pytest.marks import needs - pytestmark = [needs.harmonypy] @@ -17,7 +18,7 @@ def test_harmony_integrate(): """ adata = pbmc3k() sc.pp.recipe_zheng17(adata) - sc.tl.pca(adata) + sc.pp.pca(adata) adata.obs["batch"] = 1350 * ["a"] + 1350 * ["b"] sce.pp.harmony_integrate(adata, "batch") assert adata.obsm["X_pca_harmony"].shape == adata.obsm["X_pca"].shape diff --git a/scanpy/tests/external/test_harmony_timeseries.py b/scanpy/tests/external/test_harmony_timeseries.py index be15c71dd7..9eb0ecd7be 100644 --- a/scanpy/tests/external/test_harmony_timeseries.py +++ b/scanpy/tests/external/test_harmony_timeseries.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from itertools import product from anndata import AnnData @@ -7,7 +9,6 @@ from scanpy.testing._helpers.data import pbmc3k from scanpy.testing._pytest.marks import needs - pytestmark = [needs.harmony] diff --git a/scanpy/tests/external/test_hashsolo.py b/scanpy/tests/external/test_hashsolo.py index 7051d49977..499a204b21 100644 --- a/scanpy/tests/external/test_hashsolo.py +++ b/scanpy/tests/external/test_hashsolo.py @@ -1,12 +1,16 @@ -from anndata import AnnData +from __future__ import annotations + import numpy as np +from anndata import AnnData + import scanpy.external as sce def test_cell_demultiplexing(): - from scipy import stats import random + from scipy import stats + random.seed(52) signal = stats.poisson.rvs(1000, 1, 990) doublet_signal = stats.poisson.rvs(1000, 1, 10) diff --git a/scanpy/tests/external/test_magic.py b/scanpy/tests/external/test_magic.py index 07eb243422..a4d7e11124 100644 --- a/scanpy/tests/external/test_magic.py +++ b/scanpy/tests/external/test_magic.py @@ -1,10 +1,11 @@ +from __future__ import annotations + import numpy as np from anndata import AnnData import scanpy as sc from scanpy.testing._pytest.marks import needs - pytestmark = [needs.magic] A_list = [ diff --git a/scanpy/tests/external/test_palantir.py b/scanpy/tests/external/test_palantir.py index e1bb308293..e6d1f83ab3 100644 --- a/scanpy/tests/external/test_palantir.py +++ b/scanpy/tests/external/test_palantir.py @@ -1,8 +1,9 @@ +from __future__ import annotations + import scanpy.external as sce from scanpy.testing._helpers.data import pbmc3k_processed from scanpy.testing._pytest.marks import needs - pytestmark = [needs.palantir] diff --git a/scanpy/tests/external/test_phenograph.py b/scanpy/tests/external/test_phenograph.py index e077e2e5f6..bb137f0e13 100644 --- a/scanpy/tests/external/test_phenograph.py +++ b/scanpy/tests/external/test_phenograph.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np import pandas as pd from anndata import AnnData @@ -6,7 +8,6 @@ import scanpy.external as sce from scanpy.testing._pytest.marks import needs - pytestmark = [needs.phenograph] @@ -15,6 +16,6 @@ def test_phenograph(): dframe = pd.DataFrame(df) dframe.index, dframe.columns = (map(str, dframe.index), map(str, dframe.columns)) adata = AnnData(dframe) - sc.tl.pca(adata, n_comps=20) + sc.pp.pca(adata, n_comps=20) sce.tl.phenograph(adata, clustering_algo="leiden", k=50) assert adata.obs["pheno_leiden"].shape[0], "phenograph_Community Detection Error!" diff --git a/scanpy/tests/external/test_sam.py b/scanpy/tests/external/test_sam.py index 309b1741ad..edb4e860c5 100644 --- a/scanpy/tests/external/test_sam.py +++ b/scanpy/tests/external/test_sam.py @@ -1,11 +1,12 @@ -import scanpy as sc -import scanpy.external as sce +from __future__ import annotations + import numpy as np +import scanpy as sc +import scanpy.external as sce from scanpy.testing._helpers.data import pbmc3k from scanpy.testing._pytest.marks import needs - pytestmark = [needs.samalg] diff --git a/scanpy/tests/external/test_scanorama_integrate.py b/scanpy/tests/external/test_scanorama_integrate.py index 7c5a6b3ef4..2635289427 100644 --- a/scanpy/tests/external/test_scanorama_integrate.py +++ b/scanpy/tests/external/test_scanorama_integrate.py @@ -1,9 +1,10 @@ +from __future__ import annotations + import scanpy as sc import scanpy.external as sce from scanpy.testing._helpers.data import pbmc68k_reduced from scanpy.testing._pytest.marks import needs - pytestmark = [needs.scanorama] @@ -16,7 +17,7 @@ def test_scanorama_integrate(): and makes sure it has the same dimensions as the original PCA table. """ adata = pbmc68k_reduced() - sc.tl.pca(adata) + sc.pp.pca(adata) adata.obs["batch"] = 350 * ["a"] + 350 * ["b"] sce.pp.scanorama_integrate(adata, "batch", approx=False) assert adata.obsm["X_scanorama"].shape == adata.obsm["X_pca"].shape diff --git a/scanpy/tests/external/test_scrublet.py b/scanpy/tests/external/test_scrublet.py index 12c3949b67..701a352395 100644 --- a/scanpy/tests/external/test_scrublet.py +++ b/scanpy/tests/external/test_scrublet.py @@ -1,18 +1,18 @@ -import pytest +from __future__ import annotations -import scanpy as sc -import scanpy.external as sce -from anndata.tests.helpers import assert_equal -import pandas as pd import anndata as ad import numpy as np -import scanpy.preprocessing as pp +import pandas as pd +import pytest import scipy.sparse as sparse +from anndata.tests.helpers import assert_equal +import scanpy as sc +import scanpy.external as sce +import scanpy.preprocessing as pp from scanpy.testing._helpers.data import paul15, pbmc3k from scanpy.testing._pytest.marks import needs - pytestmark = [needs.scrublet] diff --git a/scanpy/tests/external/test_wishbone.py b/scanpy/tests/external/test_wishbone.py index e2a492fc37..89df828f64 100644 --- a/scanpy/tests/external/test_wishbone.py +++ b/scanpy/tests/external/test_wishbone.py @@ -1,9 +1,10 @@ +from __future__ import annotations + import scanpy as sc import scanpy.external as sce from scanpy.testing._helpers.data import pbmc3k from scanpy.testing._pytest.marks import needs - pytestmark = [needs.wishbone] diff --git a/scanpy/tests/notebooks/test_paga_paul15_subsampled.py b/scanpy/tests/notebooks/test_paga_paul15_subsampled.py index 7e8185279a..1aa683d731 100644 --- a/scanpy/tests/notebooks/test_paga_paul15_subsampled.py +++ b/scanpy/tests/notebooks/test_paga_paul15_subsampled.py @@ -2,19 +2,19 @@ # Hematopoiesis: trace myeloid and erythroid differentiation for data of [Paul *et al.* (2015)](https://doi.org/10.1016/j.cell.2015.11.013). # # This is the subsampled notebook for testing. +from __future__ import annotations + from functools import partial from pathlib import Path import numpy as np from matplotlib.testing import setup -import pytest setup() import scanpy as sc -from scanpy.testing._pytest.marks import needs from scanpy.testing._helpers.data import paul15 - +from scanpy.testing._pytest.marks import needs HERE: Path = Path(__file__).parent ROOT = HERE / "_images_paga_paul15_subsampled" @@ -32,7 +32,7 @@ def test_paga_paul15_subsampled(image_comparer, plt): # Preprocessing and Visualization sc.pp.recipe_zheng17(adata) - sc.tl.pca(adata, svd_solver="arpack") + sc.pp.pca(adata, svd_solver="arpack") sc.pp.neighbors(adata, n_neighbors=4, n_pcs=20) sc.tl.draw_graph(adata) sc.pl.draw_graph(adata, color="paul15_clusters", legend_loc="on data") @@ -140,7 +140,7 @@ def test_paga_paul15_subsampled(image_comparer, plt): show_colorbar=False, color_map="Greys", color_maps_annotations={"distance": "viridis"}, - title="{} path".format(descr), + title=f"{descr} path", return_data=True, show=False, ) diff --git a/scanpy/tests/notebooks/test_pbmc3k.py b/scanpy/tests/notebooks/test_pbmc3k.py index 0e24b98748..5e2c71e617 100644 --- a/scanpy/tests/notebooks/test_pbmc3k.py +++ b/scanpy/tests/notebooks/test_pbmc3k.py @@ -9,11 +9,12 @@ # The data consists in *3k PBMCs from a Healthy Donor* and is freely available from 10x Genomics # ([here](https://cf.10xgenomics.com/samples/cell-exp/1.1.0/pbmc3k/pbmc3k_filtered_gene_bc_matrices.tar.gz) # from this [webpage](https://support.10xgenomics.com/single-cell-gene-expression/datasets/1.1.0/pbmc3k)). +from __future__ import annotations + from functools import partial from pathlib import Path import numpy as np - from matplotlib.testing import setup setup() @@ -21,7 +22,6 @@ import scanpy as sc from scanpy.testing._pytest.marks import needs - HERE: Path = Path(__file__).parent ROOT = HERE / "_images_pbmc3k" @@ -88,7 +88,7 @@ def test_pbmc3k(image_comparer): # PCA - sc.tl.pca(adata, svd_solver="arpack") + sc.pp.pca(adata, svd_solver="arpack") sc.pl.pca(adata, color="CST3", show=False) save_and_compare_images("pca") diff --git a/scanpy/tests/test_binary.py b/scanpy/tests/test_binary.py index a30da8fd2a..6b70c824f3 100644 --- a/scanpy/tests/test_binary.py +++ b/scanpy/tests/test_binary.py @@ -1,16 +1,19 @@ +from __future__ import annotations + import os import re from pathlib import Path from subprocess import PIPE -from typing import List +from typing import TYPE_CHECKING import pytest -from _pytest.capture import CaptureFixture -from _pytest.monkeypatch import MonkeyPatch import scanpy from scanpy.cli import main +if TYPE_CHECKING: + from _pytest.capture import CaptureFixture + from _pytest.monkeypatch import MonkeyPatch HERE = Path(__file__).parent @@ -27,7 +30,7 @@ def test_builtin_settings(capsys: CaptureFixture): @pytest.mark.parametrize("args", [[], ["-h"]]) -def test_help_displayed(args: List[str], capsys: CaptureFixture): +def test_help_displayed(args: list[str], capsys: CaptureFixture): try: # -h raises it, no args doesn’t. Maybe not ideal but meh. main(args) except SystemExit as se: diff --git a/scanpy/tests/test_clustering.py b/scanpy/tests/test_clustering.py index 9f7a538d89..f42a4a4b2a 100644 --- a/scanpy/tests/test_clustering.py +++ b/scanpy/tests/test_clustering.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import pytest + import scanpy as sc from scanpy.testing._helpers.data import pbmc68k_reduced from scanpy.testing._pytest.marks import needs diff --git a/scanpy/tests/test_combat.py b/scanpy/tests/test_combat.py index 196b46e198..586eb38c52 100644 --- a/scanpy/tests/test_combat.py +++ b/scanpy/tests/test_combat.py @@ -1,11 +1,12 @@ +from __future__ import annotations + import numpy as np import pandas as pd -from sklearn.metrics import silhouette_score - from anndata.tests.helpers import assert_equal +from sklearn.metrics import silhouette_score import scanpy as sc -from scanpy.preprocessing._combat import _standardize_data, _design_matrix +from scanpy.preprocessing._combat import _design_matrix, _standardize_data def test_norm(): @@ -82,7 +83,7 @@ def test_silhouette(): sc.pp.combat(adata, "blobs") # compute pca - sc.tl.pca(adata) + sc.pp.pca(adata) X_pca = adata.obsm["X_pca"] # compute silhouette coefficient in pca diff --git a/scanpy/tests/test_datasets.py b/scanpy/tests/test_datasets.py index a06dd58ebe..868355dca7 100644 --- a/scanpy/tests/test_datasets.py +++ b/scanpy/tests/test_datasets.py @@ -1,13 +1,16 @@ """ Tests to make sure the example datasets load. """ +from __future__ import annotations + +import subprocess +from pathlib import Path -import scanpy as sc import numpy as np import pytest -from pathlib import Path from anndata.tests.helpers import assert_adata_equal -import subprocess + +import scanpy as sc @pytest.fixture(scope="module") diff --git a/scanpy/tests/test_dendrogram_key_added.py b/scanpy/tests/test_dendrogram_key_added.py index a9cd922a39..586612962a 100644 --- a/scanpy/tests/test_dendrogram_key_added.py +++ b/scanpy/tests/test_dendrogram_key_added.py @@ -1,9 +1,10 @@ -import scanpy as sc +from __future__ import annotations + import pytest +import scanpy as sc from scanpy.testing._helpers.data import pbmc68k_reduced - n_neighbors = 5 key = "test" diff --git a/scanpy/tests/test_deprecations.py b/scanpy/tests/test_deprecations.py index d7edeff557..a2010ba90a 100644 --- a/scanpy/tests/test_deprecations.py +++ b/scanpy/tests/test_deprecations.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest import scanpy as sc diff --git a/scanpy/tests/test_embedding.py b/scanpy/tests/test_embedding.py index 65ad31e8bb..530f013304 100644 --- a/scanpy/tests/test_embedding.py +++ b/scanpy/tests/test_embedding.py @@ -1,4 +1,5 @@ -from unittest.mock import patch +from __future__ import annotations + import numpy as np import pytest from numpy.testing import assert_array_almost_equal, assert_array_equal, assert_raises diff --git a/scanpy/tests/test_embedding_density.py b/scanpy/tests/test_embedding_density.py index 59fefcf410..bf40bf30da 100644 --- a/scanpy/tests/test_embedding_density.py +++ b/scanpy/tests/test_embedding_density.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import numpy as np from anndata import AnnData + import scanpy as sc from scanpy.testing._helpers.data import pbmc68k_reduced diff --git a/scanpy/tests/test_embedding_plots.py b/scanpy/tests/test_embedding_plots.py index f0381bf133..b2fc458c76 100644 --- a/scanpy/tests/test_embedding_plots.py +++ b/scanpy/tests/test_embedding_plots.py @@ -1,19 +1,20 @@ +from __future__ import annotations + from functools import partial from pathlib import Path import matplotlib as mpl import matplotlib.pyplot as plt -from matplotlib.colors import Normalize -from matplotlib.testing.compare import compare_images import numpy as np import pandas as pd import pytest import seaborn as sns +from matplotlib.colors import Normalize +from matplotlib.testing.compare import compare_images import scanpy as sc from scanpy.testing._helpers.data import pbmc3k_processed - HERE: Path = Path(__file__).parent ROOT = HERE / "_images" @@ -29,8 +30,8 @@ def check_images(pth1, pth2, *, tol): def adata(): """A bit cute.""" from matplotlib.image import imread - from sklearn.datasets import make_blobs from sklearn.cluster import DBSCAN + from sklearn.datasets import make_blobs empty_pixel = np.array([1.0, 1.0, 1.0, 0]).reshape(1, 1, -1) image = imread(HERE.parent.parent / "docs/_static/img/Scanpy_Logo_RGB.png") diff --git a/scanpy/tests/test_filter_rank_genes_groups.py b/scanpy/tests/test_filter_rank_genes_groups.py index 755ed07321..6153e29703 100644 --- a/scanpy/tests/test_filter_rank_genes_groups.py +++ b/scanpy/tests/test_filter_rank_genes_groups.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import numpy as np -from scanpy.testing._helpers.data import pbmc68k_reduced -from scanpy.tools import rank_genes_groups, filter_rank_genes_groups +from scanpy.testing._helpers.data import pbmc68k_reduced +from scanpy.tools import filter_rank_genes_groups, rank_genes_groups names_no_reference = np.array( [ diff --git a/scanpy/tests/test_get.py b/scanpy/tests/test_get.py index 6893c59167..31fb1bae09 100644 --- a/scanpy/tests/test_get.py +++ b/scanpy/tests/test_get.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from functools import partial -from itertools import repeat, chain +from itertools import chain, repeat import numpy as np import pandas as pd @@ -11,7 +13,6 @@ from scanpy.datasets._utils import filter_oldformatwarning from scanpy.testing._helpers.data import pbmc68k_reduced - TRANSPOSE_PARAMS = pytest.mark.parametrize( "dim,transform,func", [ diff --git a/scanpy/tests/test_highly_variable_genes.py b/scanpy/tests/test_highly_variable_genes.py index e6404a389e..d0e817df54 100644 --- a/scanpy/tests/test_highly_variable_genes.py +++ b/scanpy/tests/test_highly_variable_genes.py @@ -1,10 +1,12 @@ +from __future__ import annotations + from pathlib import Path -import pytest -import pandas as pd import numpy as np -import scanpy as sc +import pandas as pd +import pytest +import scanpy as sc from scanpy.testing._helpers import _check_check_values_warnings from scanpy.testing._helpers.data import pbmc3k, pbmc68k_reduced from scanpy.testing._pytest.marks import needs diff --git a/scanpy/tests/test_ingest.py b/scanpy/tests/test_ingest.py index 67bc2515c5..09bec96e7e 100644 --- a/scanpy/tests/test_ingest.py +++ b/scanpy/tests/test_ingest.py @@ -1,6 +1,7 @@ -import pytest -import numpy as np +from __future__ import annotations +import numpy as np +import pytest from sklearn.neighbors import KDTree from umap import UMAP @@ -9,7 +10,6 @@ from scanpy._compat import pkg_version from scanpy.testing._helpers.data import pbmc68k_reduced - X = np.array( [ [1.0, 2.5, 3.0, 5.0, 8.7], diff --git a/scanpy/tests/test_logging.py b/scanpy/tests/test_logging.py index d59716b0c4..45f0b62b51 100644 --- a/scanpy/tests/test_logging.py +++ b/scanpy/tests/test_logging.py @@ -1,13 +1,20 @@ +from __future__ import annotations + +import sys from contextlib import redirect_stdout from datetime import datetime from io import StringIO -from pathlib import Path -import sys +from typing import TYPE_CHECKING import pytest -from scanpy import Verbosity, settings as s, logging as log import scanpy as sc +from scanpy import Verbosity +from scanpy import logging as log +from scanpy import settings as s + +if TYPE_CHECKING: + from pathlib import Path def test_defaults(): diff --git a/scanpy/tests/test_marker_gene_overlap.py b/scanpy/tests/test_marker_gene_overlap.py index ad1860c0f0..cb1ef16a88 100644 --- a/scanpy/tests/test_marker_gene_overlap.py +++ b/scanpy/tests/test_marker_gene_overlap.py @@ -1,7 +1,9 @@ +from __future__ import annotations + +import numpy as np from anndata import AnnData import scanpy as sc -import numpy as np def generate_test_data(): diff --git a/scanpy/tests/test_metrics.py b/scanpy/tests/test_metrics.py index f7b0f2f390..3e2e0e7f4b 100644 --- a/scanpy/tests/test_metrics.py +++ b/scanpy/tests/test_metrics.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import warnings from functools import partial from operator import eq @@ -5,16 +7,14 @@ import numpy as np import pandas as pd -import scanpy as sc -from scipy import sparse - import pytest +from scipy import sparse +import scanpy as sc from scanpy._compat import DaskArray from scanpy.testing._helpers.data import pbmc68k_reduced from scanpy.testing._pytest.params import ARRAY_TYPES - mark_flaky = pytest.mark.xfail( strict=False, reason="This used to work reliably, but doesn’t anymore", diff --git a/scanpy/tests/test_neighbors.py b/scanpy/tests/test_neighbors.py index 8fa85c4a47..dbefe5a006 100644 --- a/scanpy/tests/test_neighbors.py +++ b/scanpy/tests/test_neighbors.py @@ -1,16 +1,20 @@ -from typing import Literal +from __future__ import annotations + import warnings +from typing import TYPE_CHECKING, Literal import numpy as np import pytest from anndata import AnnData -from pytest_mock import MockerFixture from scipy.sparse import csr_matrix, issparse from sklearn.neighbors import KNeighborsTransformer import scanpy as sc from scanpy import Neighbors +if TYPE_CHECKING: + from pytest_mock import MockerFixture + # the input data X = [[1, 0], [3, 0], [5, 6], [0, 4]] n_neighbors = 3 # includes data points themselves diff --git a/scanpy/tests/test_neighbors_common.py b/scanpy/tests/test_neighbors_common.py index 834143fcdb..8874ccae03 100644 --- a/scanpy/tests/test_neighbors_common.py +++ b/scanpy/tests/test_neighbors_common.py @@ -1,20 +1,23 @@ from __future__ import annotations -from collections.abc import Callable -from typing import Literal +from typing import TYPE_CHECKING, Literal -import pytest import numpy as np -from scipy import sparse +import pytest from sklearn.neighbors import KNeighborsTransformer from scanpy._utils.compute.is_constant import is_constant from scanpy.neighbors._common import ( + _get_sparse_matrix_from_indices_distances, _has_self_column, _ind_dist_shortcut, - _get_sparse_matrix_from_indices_distances, ) +if TYPE_CHECKING: + from collections.abc import Callable + + from scipy import sparse + def mk_knn_matrix( n_obs: int, diff --git a/scanpy/tests/test_neighbors_key_added.py b/scanpy/tests/test_neighbors_key_added.py index d4e9fd795c..f0ddb939eb 100644 --- a/scanpy/tests/test_neighbors_key_added.py +++ b/scanpy/tests/test_neighbors_key_added.py @@ -1,9 +1,11 @@ -import scanpy as sc +from __future__ import annotations + import numpy as np import pytest -from scanpy.testing._pytest.marks import needs +import scanpy as sc from scanpy.testing._helpers.data import pbmc68k_reduced +from scanpy.testing._pytest.marks import needs n_neighbors = 5 key = "test" diff --git a/scanpy/tests/test_normalization.py b/scanpy/tests/test_normalization.py index facec68e41..c395b41c80 100644 --- a/scanpy/tests/test_normalization.py +++ b/scanpy/tests/test_normalization.py @@ -1,25 +1,26 @@ from __future__ import annotations -from typing import Any -from collections.abc import Callable +from typing import TYPE_CHECKING, Any -import pytest import numpy as np +import pytest from anndata import AnnData -from scipy.sparse import csr_matrix -from scipy import sparse from anndata.tests.helpers import assert_equal +from scipy import sparse +from scipy.sparse import csr_matrix import scanpy as sc from scanpy.testing._helpers import ( + _check_check_values_warnings, check_rep_mutation, check_rep_results, - _check_check_values_warnings, ) # TODO: Add support for sparse-in-dask from scanpy.testing._pytest.params import ARRAY_TYPES_SUPPORTED +if TYPE_CHECKING: + from collections.abc import Callable X_total = np.array([[1, 0], [3, 0], [5, 6]]) X_frac = np.array([[1, 0, 1], [3, 0, 1], [5, 6, 1]]) diff --git a/scanpy/tests/test_package_structure.py b/scanpy/tests/test_package_structure.py index 8b11d86d94..58afdcb06a 100644 --- a/scanpy/tests/test_package_structure.py +++ b/scanpy/tests/test_package_structure.py @@ -1,16 +1,15 @@ -import email +from __future__ import annotations + import inspect import os -from importlib.util import find_spec -from types import FunctionType from pathlib import Path +from types import FunctionType import pytest -from scanpy._utils import descend_classes_and_funcs # CLI is locally not imported by default but on travis it is? import scanpy.cli - +from scanpy._utils import descend_classes_and_funcs mod_dir = Path(scanpy.__file__).parent proj_dir = mod_dir.parent diff --git a/scanpy/tests/test_paga.py b/scanpy/tests/test_paga.py index 04657fa12e..1ae59cfc0f 100644 --- a/scanpy/tests/test_paga.py +++ b/scanpy/tests/test_paga.py @@ -1,15 +1,16 @@ +from __future__ import annotations + from functools import partial from pathlib import Path -import pytest import numpy as np +import pytest from matplotlib import cm import scanpy as sc from scanpy.testing._helpers.data import pbmc3k_processed, pbmc68k_reduced from scanpy.testing._pytest.marks import needs - HERE: Path = Path(__file__).parent ROOT = HERE / "_images" diff --git a/scanpy/tests/test_pca.py b/scanpy/tests/test_pca.py index 9a63767c8f..6e529b4320 100644 --- a/scanpy/tests/test_pca.py +++ b/scanpy/tests/test_pca.py @@ -1,13 +1,16 @@ +from __future__ import annotations + +import warnings from typing import Literal + import numpy as np import pytest -import warnings from anndata import AnnData from anndata.tests.helpers import ( as_dense_dask_array, as_sparse_dask_array, - assert_equal, asarray, + assert_equal, ) from scipy import sparse from sklearn.utils import issparse diff --git a/scanpy/tests/test_performance.py b/scanpy/tests/test_performance.py index edf653aeb3..26efe44cf7 100644 --- a/scanpy/tests/test_performance.py +++ b/scanpy/tests/test_performance.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import sys from subprocess import run diff --git a/scanpy/tests/test_plotting.py b/scanpy/tests/test_plotting.py index 6cedc1431a..5e911dfd4f 100644 --- a/scanpy/tests/test_plotting.py +++ b/scanpy/tests/test_plotting.py @@ -1,9 +1,8 @@ from __future__ import annotations from functools import partial +from itertools import chain, combinations, repeat from pathlib import Path -from itertools import repeat, chain, combinations -from collections.abc import Callable import pytest from matplotlib.testing import setup @@ -11,25 +10,28 @@ setup() +from typing import TYPE_CHECKING + import matplotlib as mpl import matplotlib.pyplot as plt -from matplotlib.axes import Axes -import seaborn as sns import numpy as np import pandas as pd -from matplotlib.testing.compare import compare_images +import seaborn as sns from anndata import AnnData +from matplotlib.testing.compare import compare_images import scanpy as sc from scanpy._compat import pkg_version -from scanpy.testing._pytest.marks import needs from scanpy.testing._helpers.data import ( + krumsiek11, pbmc3k, pbmc3k_processed, - krumsiek11, pbmc68k_reduced, ) +from scanpy.testing._pytest.marks import needs +if TYPE_CHECKING: + from collections.abc import Callable HERE: Path = Path(__file__).parent ROOT = HERE / "_images" @@ -947,8 +949,8 @@ def test_genes_symbols(image_comparer, id, fn): adata = krumsiek11() # add a 'symbols' column - adata.var["symbols"] = adata.var.index.map(lambda x: "symbol_{}".format(x)) - symbols = ["symbol_{}".format(x) for x in adata.var_names] + adata.var["symbols"] = adata.var.index.map(lambda x: f"symbol_{x}") + symbols = [f"symbol_{x}" for x in adata.var_names] fn(adata, symbols, "cell_type", dendrogram=True, gene_symbols="symbols", show=False) save_and_compare_images(f"{id}_gene_symbols") diff --git a/scanpy/tests/test_plotting_utils.py b/scanpy/tests/test_plotting_utils.py index c1deff53eb..6b53cd5b50 100644 --- a/scanpy/tests/test_plotting_utils.py +++ b/scanpy/tests/test_plotting_utils.py @@ -1,14 +1,15 @@ +from __future__ import annotations + from typing import cast + import numpy as np import pytest - from anndata import AnnData from matplotlib import colormaps from matplotlib.colors import ListedColormap from scanpy.plotting._utils import _validate_palette - viridis = cast(ListedColormap, colormaps["viridis"]) diff --git a/scanpy/tests/test_preprocessing.py b/scanpy/tests/test_preprocessing.py index 7c6fb1bafc..7f4547b242 100644 --- a/scanpy/tests/test_preprocessing.py +++ b/scanpy/tests/test_preprocessing.py @@ -1,14 +1,16 @@ +from __future__ import annotations + from itertools import product import numpy as np import pandas as pd -from scipy import sparse as sp -import scanpy as sc -from numpy.testing import assert_allclose import pytest from anndata import AnnData -from anndata.tests.helpers import assert_equal, asarray +from anndata.tests.helpers import asarray, assert_equal +from numpy.testing import assert_allclose +from scipy import sparse as sp +import scanpy as sc from scanpy.testing._helpers import check_rep_mutation, check_rep_results from scanpy.testing._helpers.data import pbmc3k, pbmc68k_reduced from scanpy.testing._pytest.params import ARRAY_TYPES_SUPPORTED @@ -251,8 +253,8 @@ def test_regress_out_view(): def test_regress_out_categorical(): - from scipy.sparse import random import pandas as pd + from scipy.sparse import random adata = AnnData(random(1000, 100, density=0.6, format="csr")) # create a categorical column diff --git a/scanpy/tests/test_preprocessing_distributed.py b/scanpy/tests/test_preprocessing_distributed.py index ecff2755ad..7f5dce7db0 100644 --- a/scanpy/tests/test_preprocessing_distributed.py +++ b/scanpy/tests/test_preprocessing_distributed.py @@ -1,15 +1,21 @@ +from __future__ import annotations + from pathlib import Path import anndata as ad import numpy.testing as npt import pytest -from scanpy.preprocessing import normalize_total, filter_genes -from scanpy.preprocessing import log1p, normalize_per_cell, filter_cells +from scanpy.preprocessing import ( + filter_cells, + filter_genes, + log1p, + normalize_per_cell, + normalize_total, +) from scanpy.preprocessing._distributed import materialize_as_ndarray from scanpy.testing._pytest.marks import needs - HERE = Path(__file__).parent / Path("_data/") input_file = str(Path(HERE, "10x-10k-subset.zarr")) diff --git a/scanpy/tests/test_qc_metrics.py b/scanpy/tests/test_qc_metrics.py index 71f6e728e0..06a4d0ceae 100644 --- a/scanpy/tests/test_qc_metrics.py +++ b/scanpy/tests/test_qc_metrics.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np import pandas as pd import pytest @@ -6,10 +8,10 @@ import scanpy as sc from scanpy.preprocessing._qc import ( + describe_obs, + describe_var, top_proportions, top_segment_proportions, - describe_var, - describe_obs, ) diff --git a/scanpy/tests/test_queries.py b/scanpy/tests/test_queries.py index 8c1d778f90..cee1425d7f 100644 --- a/scanpy/tests/test_queries.py +++ b/scanpy/tests/test_queries.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import pandas as pd import pytest + import scanpy as sc from scanpy.testing._helpers.data import pbmc68k_reduced diff --git a/scanpy/tests/test_rank_genes_groups.py b/scanpy/tests/test_rank_genes_groups.py index c08e0789c6..77ce1dbce9 100644 --- a/scanpy/tests/test_rank_genes_groups.py +++ b/scanpy/tests/test_rank_genes_groups.py @@ -1,28 +1,31 @@ from __future__ import annotations import pickle -from pathlib import Path -from collections.abc import Callable -from typing import Any from functools import partial +from pathlib import Path +from typing import TYPE_CHECKING, Any -import pytest import numpy as np import pandas as pd +import pytest import scipy -from numpy.typing import NDArray from anndata import AnnData +from numpy.random import binomial, negative_binomial, seed from packaging import version from scipy.stats import mannwhitneyu -from numpy.random import negative_binomial, binomial, seed import scanpy as sc +from scanpy._utils import elem_mul, select_groups +from scanpy.get import rank_genes_groups_df from scanpy.testing._helpers.data import pbmc68k_reduced from scanpy.testing._pytest.params import ARRAY_TYPES, ARRAY_TYPES_MEM from scanpy.tools import rank_genes_groups from scanpy.tools._rank_genes_groups import _RankGenes -from scanpy.get import rank_genes_groups_df -from scanpy._utils import select_groups, elem_mul + +if TYPE_CHECKING: + from collections.abc import Callable + + from numpy.typing import NDArray HERE = Path(__file__).parent DATA_PATH = HERE / "_data" diff --git a/scanpy/tests/test_rank_genes_groups_logreg.py b/scanpy/tests/test_rank_genes_groups_logreg.py index 4352f2e21f..3cc294487e 100644 --- a/scanpy/tests/test_rank_genes_groups_logreg.py +++ b/scanpy/tests/test_rank_genes_groups_logreg.py @@ -1,8 +1,10 @@ -import pytest +from __future__ import annotations import numpy as np -import scanpy as sc import pandas as pd +import pytest + +import scanpy as sc @pytest.mark.parametrize("method", ["t-test", "logreg"]) diff --git a/scanpy/tests/test_read_10x.py b/scanpy/tests/test_read_10x.py index 730424a10b..1fe7d20315 100644 --- a/scanpy/tests/test_read_10x.py +++ b/scanpy/tests/test_read_10x.py @@ -1,12 +1,14 @@ -from unittest.mock import patch +from __future__ import annotations + +import shutil from pathlib import Path +from unittest.mock import patch import h5py import numpy as np import pytest -import scanpy as sc -import shutil +import scanpy as sc ROOT = Path(__file__).parent ROOT = ROOT / "_data" / "10x_data" diff --git a/scanpy/tests/test_readwrite.py b/scanpy/tests/test_readwrite.py index f572d5f663..e9c2b11f36 100644 --- a/scanpy/tests/test_readwrite.py +++ b/scanpy/tests/test_readwrite.py @@ -1,4 +1,6 @@ -from pathlib import PureWindowsPath, PurePosixPath +from __future__ import annotations + +from pathlib import PurePosixPath, PureWindowsPath import pytest diff --git a/scanpy/tests/test_scaling.py b/scanpy/tests/test_scaling.py index a0308c3d7f..b0687af0b3 100644 --- a/scanpy/tests/test_scaling.py +++ b/scanpy/tests/test_scaling.py @@ -1,5 +1,7 @@ -import pytest +from __future__ import annotations + import numpy as np +import pytest from anndata import AnnData from scipy.sparse import csr_matrix diff --git a/scanpy/tests/test_score_genes.py b/scanpy/tests/test_score_genes.py index 9601fe8150..cb44e46dc9 100644 --- a/scanpy/tests/test_score_genes.py +++ b/scanpy/tests/test_score_genes.py @@ -1,14 +1,16 @@ +from __future__ import annotations + +import pickle +from pathlib import Path + import numpy as np -import scanpy as sc +import pytest from anndata import AnnData from scipy.sparse import csr_matrix -import pytest -import pickle -from pathlib import Path +import scanpy as sc from scanpy.testing._helpers.data import paul15 - HERE = Path(__file__).parent / Path("_data/") diff --git a/scanpy/tests/test_sim.py b/scanpy/tests/test_sim.py index 317694766c..49c299d37d 100644 --- a/scanpy/tests/test_sim.py +++ b/scanpy/tests/test_sim.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np import scanpy as sc diff --git a/scanpy/tests/test_utils.py b/scanpy/tests/test_utils.py index d94be7772e..e7bdb24f7d 100644 --- a/scanpy/tests/test_utils.py +++ b/scanpy/tests/test_utils.py @@ -1,18 +1,20 @@ +from __future__ import annotations + from types import ModuleType -import pytest import numpy as np -from scipy.sparse import csr_matrix +import pytest from anndata.tests.helpers import asarray +from scipy.sparse import csr_matrix +from scanpy._compat import DaskArray from scanpy._utils import ( - descend_classes_and_funcs, check_nonnegative_integers, + descend_classes_and_funcs, elem_mul, is_constant, ) from scanpy.testing._pytest.marks import needs -from scanpy._compat import DaskArray from scanpy.testing._pytest.params import ARRAY_TYPES, ARRAY_TYPES_SUPPORTED diff --git a/scanpy/tools/__init__.py b/scanpy/tools/__init__.py index 12d03427e6..b791a48a35 100644 --- a/scanpy/tools/__init__.py +++ b/scanpy/tools/__init__.py @@ -1,22 +1,57 @@ -from ..preprocessing import pca -from ._tsne import tsne -from ._umap import umap +from __future__ import annotations + +from typing import Any + +from ._dendrogram import dendrogram from ._diffmap import diffmap +from ._dpt import dpt from ._draw_graph import draw_graph - +from ._embedding_density import embedding_density +from ._ingest import Ingest, ingest +from ._leiden import leiden +from ._louvain import louvain +from ._marker_gene_overlap import marker_gene_overlap from ._paga import ( paga, + paga_compare_paths, paga_degrees, paga_expression_entropies, - paga_compare_paths, ) -from ._rank_genes_groups import rank_genes_groups, filter_rank_genes_groups -from ._dpt import dpt -from ._leiden import leiden -from ._louvain import louvain -from ._sim import sim +from ._rank_genes_groups import filter_rank_genes_groups, rank_genes_groups from ._score_genes import score_genes, score_genes_cell_cycle -from ._dendrogram import dendrogram -from ._embedding_density import embedding_density -from ._marker_gene_overlap import marker_gene_overlap -from ._ingest import ingest, Ingest +from ._sim import sim +from ._tsne import tsne +from ._umap import umap + + +def __getattr__(name: str) -> Any: + if name == "pca": + from ..preprocessing import pca + + return pca + raise AttributeError(name) + + +__all__ = [ + "dendrogram", + "diffmap", + "dpt", + "draw_graph", + "embedding_density", + "Ingest", + "ingest", + "leiden", + "louvain", + "marker_gene_overlap", + "paga", + "paga_compare_paths", + "paga_degrees", + "paga_expression_entropies", + "filter_rank_genes_groups", + "rank_genes_groups", + "score_genes", + "score_genes_cell_cycle", + "sim", + "tsne", + "umap", +] diff --git a/scanpy/tools/_dendrogram.py b/scanpy/tools/_dendrogram.py index c85d563a79..ccae494748 100644 --- a/scanpy/tools/_dendrogram.py +++ b/scanpy/tools/_dendrogram.py @@ -4,31 +4,36 @@ from __future__ import annotations -from typing import Optional, Sequence, Dict, Any +from typing import TYPE_CHECKING, Any import pandas as pd -from anndata import AnnData from pandas.api.types import CategoricalDtype from .. import logging as logg from .._utils import _doc_params -from ..tools._utils import _choose_representation, doc_use_rep, doc_n_pcs +from ..neighbors._doc import doc_n_pcs, doc_use_rep +from ._utils import _choose_representation + +if TYPE_CHECKING: + from collections.abc import Sequence + + from anndata import AnnData @_doc_params(n_pcs=doc_n_pcs, use_rep=doc_use_rep) def dendrogram( adata: AnnData, groupby: str | Sequence[str], - n_pcs: Optional[int] = None, - use_rep: Optional[str] = None, - var_names: Optional[Sequence[str]] = None, - use_raw: Optional[bool] = None, + n_pcs: int | None = None, + use_rep: str | None = None, + var_names: Sequence[str] | None = None, + use_raw: bool | None = None, cor_method: str = "pearson", linkage_method: str = "complete", optimal_ordering: bool = False, - key_added: Optional[str] = None, + key_added: str | None = None, inplace: bool = True, -) -> Optional[Dict[str, Any]]: +) -> dict[str, Any] | None: """\ Computes a hierarchical clustering for the given `groupby` categories. diff --git a/scanpy/tools/_diffmap.py b/scanpy/tools/_diffmap.py index e6b5123d88..8b1a09eb16 100644 --- a/scanpy/tools/_diffmap.py +++ b/scanpy/tools/_diffmap.py @@ -1,14 +1,19 @@ -from anndata import AnnData -from typing import Optional +from __future__ import annotations + +from typing import TYPE_CHECKING from ._dpt import _diffmap -from .._utils import AnyRandom + +if TYPE_CHECKING: + from anndata import AnnData + + from .._utils import AnyRandom def diffmap( adata: AnnData, n_comps: int = 15, - neighbors_key: Optional[str] = None, + neighbors_key: str | None = None, random_state: AnyRandom = 0, copy: bool = False, ): diff --git a/scanpy/tools/_dpt.py b/scanpy/tools/_dpt.py index 6502db592a..864211d9bc 100644 --- a/scanpy/tools/_dpt.py +++ b/scanpy/tools/_dpt.py @@ -1,14 +1,20 @@ -from typing import Tuple, Optional, Sequence, List +from __future__ import annotations + +from typing import TYPE_CHECKING import numpy as np import pandas as pd import scipy as sp -from anndata import AnnData from natsort import natsorted from .. import logging as logg from ..neighbors import Neighbors, OnFlySymMatrix +if TYPE_CHECKING: + from collections.abc import Sequence + + from anndata import AnnData + def _diffmap(adata, n_comps=15, neighbors_key=None, random_state=0): start = logg.info(f"computing Diffusion Maps using n_comps={n_comps}(=n_dcs)") @@ -34,9 +40,9 @@ def dpt( n_branchings: int = 0, min_group_size: float = 0.01, allow_kendall_tau_shift: bool = True, - neighbors_key: Optional[str] = None, + neighbors_key: str | None = None, copy: bool = False, -) -> Optional[AnnData]: +) -> AnnData | None: """\ Infer progression of cells through geodesic distance along the graph [Haghverdi16]_ [Wolf19]_. @@ -368,7 +374,7 @@ def check_adjacency(self): # print(self.segs_adjacency) # self.segs_adjacency.eliminate_zeros() - def select_segment(self, segs, segs_tips, segs_undecided) -> Tuple[int, int]: + def select_segment(self, segs, segs_tips, segs_undecided) -> tuple[int, int]: """\ Out of a list of line segments, choose segment that has the most distant second data point. @@ -743,11 +749,11 @@ def _detect_branching( Dseg: np.ndarray, tips: np.ndarray, seg_reference=None, - ) -> Tuple[ - List[np.ndarray], - List[np.ndarray], - List[List[int]], - List[List[int]], + ) -> tuple[ + list[np.ndarray], + list[np.ndarray], + list[list[int]], + list[list[int]], int, ]: """\ @@ -1099,7 +1105,7 @@ def _kendall_tau_subtract(self, len_old: int, diff_neg: int, tau_old: float): """ return 2.0 / (len_old - 2) * (-float(diff_neg) / (len_old - 1) + tau_old) - def _kendall_tau_diff(self, a: np.ndarray, b: np.ndarray, i) -> Tuple[int, int]: + def _kendall_tau_diff(self, a: np.ndarray, b: np.ndarray, i) -> tuple[int, int]: """Compute difference in concordance of pairs in split sequences. Consider splitting a and b at index i. diff --git a/scanpy/tools/_draw_graph.py b/scanpy/tools/_draw_graph.py index 723417748f..5a88597245 100644 --- a/scanpy/tools/_draw_graph.py +++ b/scanpy/tools/_draw_graph.py @@ -1,15 +1,18 @@ -from typing import Union, Optional, Literal +from __future__ import annotations -import numpy as np import random -from anndata import AnnData -from scipy.sparse import spmatrix +from typing import TYPE_CHECKING, Literal + +import numpy as np from .. import _utils from .. import logging as logg -from ._utils import get_init_pos_from_paga from .._utils import AnyRandom, _choose_graph +from ._utils import get_init_pos_from_paga +if TYPE_CHECKING: + from anndata import AnnData + from scipy.sparse import spmatrix _LAYOUTS = ("fr", "drl", "kk", "grid_fr", "lgl", "rt", "rt_circular", "fa") _Layout = Literal[_LAYOUTS] @@ -18,14 +21,14 @@ def draw_graph( adata: AnnData, layout: _Layout = "fa", - init_pos: Union[str, bool, None] = None, - root: Optional[int] = None, + init_pos: str | bool | None = None, + root: int | None = None, random_state: AnyRandom = 0, - n_jobs: Optional[int] = None, - adjacency: Optional[spmatrix] = None, - key_added_ext: Optional[str] = None, - neighbors_key: Optional[str] = None, - obsp: Optional[str] = None, + n_jobs: int | None = None, + adjacency: spmatrix | None = None, + key_added_ext: str | None = None, + neighbors_key: str | None = None, + obsp: str | None = None, copy: bool = False, **kwds, ): diff --git a/scanpy/tools/_embedding_density.py b/scanpy/tools/_embedding_density.py index 6d28a83318..44296556f0 100644 --- a/scanpy/tools/_embedding_density.py +++ b/scanpy/tools/_embedding_density.py @@ -1,14 +1,20 @@ """\ Calculate density of cells in embeddings """ +from __future__ import annotations + +from typing import TYPE_CHECKING import numpy as np -from anndata import AnnData -from typing import Union, Optional, Sequence from .. import logging as logg from .._utils import sanitize_anndata +if TYPE_CHECKING: + from collections.abc import Sequence + + from anndata import AnnData + def _calc_density(x: np.ndarray, y: np.ndarray): """\ @@ -33,9 +39,9 @@ def embedding_density( adata: AnnData, # there is no asterisk here for backward compat (previously, there was) basis: str = "umap", # was positional before 1.4.5 - groupby: Optional[str] = None, - key_added: Optional[str] = None, - components: Union[str, Sequence[str]] = None, + groupby: str | None = None, + key_added: str | None = None, + components: str | Sequence[str] | None = None, ) -> None: """\ Calculate the density of cells in an embedding (per condition). diff --git a/scanpy/tools/_ingest.py b/scanpy/tools/_ingest.py index 1994bf4433..e4d7b61ec0 100644 --- a/scanpy/tools/_ingest.py +++ b/scanpy/tools/_ingest.py @@ -1,23 +1,23 @@ from __future__ import annotations -import doctest -from collections.abc import Iterable, MutableMapping, Generator -from typing import Union, Optional +from collections.abc import Generator, Iterable, MutableMapping +from typing import TYPE_CHECKING -import pandas as pd import numpy as np +import pandas as pd from packaging import version -from sklearn.utils import check_random_state from scipy.sparse import issparse -from anndata import AnnData +from sklearn.utils import check_random_state from .. import logging as logg +from .._compat import pkg_version from .._settings import settings -from ..neighbors import FlatTree, RPForestDict from .._utils import NeighborsView -from .._compat import pkg_version +from ..neighbors import FlatTree, RPForestDict from ..testing._doctests import doctest_skip +if TYPE_CHECKING: + from anndata import AnnData ANNDATA_MIN_VERSION = version.parse("0.7rc1") @@ -26,10 +26,10 @@ def ingest( adata: AnnData, adata_ref: AnnData, - obs: Optional[Union[str, Iterable[str]]] = None, - embedding_method: Union[str, Iterable[str]] = ("umap", "pca"), + obs: str | Iterable[str] | None = None, + embedding_method: str | Iterable[str] = ("umap", "pca"), labeling_method: str = "knn", - neighbors_key: Optional[str] = None, + neighbors_key: str | None = None, inplace: bool = True, **kwargs, ): @@ -259,8 +259,9 @@ def _init_umap(self, adata): def _init_dist_search(self, dist_args): from functools import partial - from umap.nndescent import initialise_search + from umap.distances import named_distances + from umap.nndescent import initialise_search self._random_init = None self._tree_init = None diff --git a/scanpy/tools/_leiden.py b/scanpy/tools/_leiden.py index 96869fe8f9..d0d903dc4d 100644 --- a/scanpy/tools/_leiden.py +++ b/scanpy/tools/_leiden.py @@ -1,16 +1,21 @@ -from typing import Optional, Tuple, Sequence, Type +from __future__ import annotations + +from typing import TYPE_CHECKING import numpy as np import pandas as pd from natsort import natsorted -from anndata import AnnData -from scipy import sparse from .. import _utils from .. import logging as logg - from ._utils_clustering import rename_groups, restrict_adjacency +if TYPE_CHECKING: + from collections.abc import Sequence + + from anndata import AnnData + from scipy import sparse + try: from leidenalg.VertexPartition import MutableVertexPartition except ImportError: @@ -25,19 +30,19 @@ def leiden( adata: AnnData, resolution: float = 1, *, - restrict_to: Optional[Tuple[str, Sequence[str]]] = None, + restrict_to: tuple[str, Sequence[str]] | None = None, random_state: _utils.AnyRandom = 0, key_added: str = "leiden", - adjacency: Optional[sparse.spmatrix] = None, + adjacency: sparse.spmatrix | None = None, directed: bool = True, use_weights: bool = True, n_iterations: int = -1, - partition_type: Optional[Type[MutableVertexPartition]] = None, - neighbors_key: Optional[str] = None, - obsp: Optional[str] = None, + partition_type: type[MutableVertexPartition] | None = None, + neighbors_key: str | None = None, + obsp: str | None = None, copy: bool = False, **partition_kwargs, -) -> Optional[AnnData]: +) -> AnnData | None: """\ Cluster cells into subgroups [Traag18]_. diff --git a/scanpy/tools/_louvain.py b/scanpy/tools/_louvain.py index 9424611144..5fc9a9ac9e 100644 --- a/scanpy/tools/_louvain.py +++ b/scanpy/tools/_louvain.py @@ -1,17 +1,24 @@ -from types import MappingProxyType -from typing import Optional, Tuple, Sequence, Type, Mapping, Any, Literal +from __future__ import annotations + import warnings +from types import MappingProxyType +from typing import TYPE_CHECKING, Any, Literal import numpy as np import pandas as pd -from anndata import AnnData from natsort import natsorted -from scipy.sparse import spmatrix from packaging import version -from ._utils_clustering import rename_groups, restrict_adjacency -from .. import _utils, logging as logg +from .. import _utils +from .. import logging as logg from .._utils import _choose_graph +from ._utils_clustering import rename_groups, restrict_adjacency + +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + + from anndata import AnnData + from scipy.sparse import spmatrix try: from louvain.VertexPartition import MutableVertexPartition @@ -25,20 +32,20 @@ class MutableVertexPartition: def louvain( adata: AnnData, - resolution: Optional[float] = None, + resolution: float | None = None, random_state: _utils.AnyRandom = 0, - restrict_to: Optional[Tuple[str, Sequence[str]]] = None, + restrict_to: tuple[str, Sequence[str]] | None = None, key_added: str = "louvain", - adjacency: Optional[spmatrix] = None, + adjacency: spmatrix | None = None, flavor: Literal["vtraag", "igraph", "rapids"] = "vtraag", directed: bool = True, use_weights: bool = False, - partition_type: Optional[Type[MutableVertexPartition]] = None, + partition_type: type[MutableVertexPartition] | None = None, partition_kwargs: Mapping[str, Any] = MappingProxyType({}), - neighbors_key: Optional[str] = None, - obsp: Optional[str] = None, + neighbors_key: str | None = None, + obsp: str | None = None, copy: bool = False, -) -> Optional[AnnData]: +) -> AnnData | None: """\ Cluster cells into subgroups [Blondel08]_ [Levine15]_ [Traag17]_. @@ -207,8 +214,8 @@ def louvain( ) elif flavor == "taynaud": # this is deprecated - import networkx as nx import community + import networkx as nx g = nx.Graph(adjacency) partition = community.best_partition(g) diff --git a/scanpy/tools/_marker_gene_overlap.py b/scanpy/tools/_marker_gene_overlap.py index eafdf4df8c..6d0f6ef54e 100644 --- a/scanpy/tools/_marker_gene_overlap.py +++ b/scanpy/tools/_marker_gene_overlap.py @@ -1,16 +1,20 @@ """\ Calculate overlaps of rank_genes_groups marker genes with marker gene dictionaries """ +from __future__ import annotations + import collections.abc as cabc -from typing import Union, Optional, Dict, Literal +from typing import TYPE_CHECKING, Literal import numpy as np import pandas as pd -from anndata import AnnData from .. import logging as logg from ..testing._doctests import doctest_needs +if TYPE_CHECKING: + from anndata import AnnData + def _calc_overlap_count(markers1: dict, markers2: dict): """\ @@ -74,13 +78,13 @@ def _calc_jaccard(markers1: dict, markers2: dict): @doctest_needs("leidenalg") def marker_gene_overlap( adata: AnnData, - reference_markers: Union[Dict[str, set], Dict[str, list]], + reference_markers: dict[str, set] | dict[str, list], *, key: str = "rank_genes_groups", method: _Method = "overlap_count", - normalize: Optional[Literal["reference", "data"]] = None, - top_n_markers: Optional[int] = None, - adj_pval_threshold: Optional[float] = None, + normalize: Literal["reference", "data"] | None = None, + top_n_markers: int | None = None, + adj_pval_threshold: float | None = None, key_added: str = "marker_gene_overlap", inplace: bool = False, ): diff --git a/scanpy/tools/_paga.py b/scanpy/tools/_paga.py index 0d87a1208a..c4f9ca334b 100644 --- a/scanpy/tools/_paga.py +++ b/scanpy/tools/_paga.py @@ -1,23 +1,27 @@ -from typing import List, Optional, NamedTuple, Literal +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal, NamedTuple import numpy as np import scipy as sp -from anndata import AnnData from scipy.sparse.csgraph import minimum_spanning_tree from .. import _utils from .. import logging as logg from ..neighbors import Neighbors +if TYPE_CHECKING: + from anndata import AnnData + _AVAIL_MODELS = {"v1.0", "v1.2"} def paga( adata: AnnData, - groups: Optional[str] = None, + groups: str | None = None, use_rna_velocity: bool = False, model: Literal["v1.2", "v1.0"] = "v1.2", - neighbors_key: Optional[str] = None, + neighbors_key: str | None = None, copy: bool = False, ): """\ @@ -382,7 +386,7 @@ def compute_transitions_old(self): self.transitions_confidence = transitions_confidence.T -def paga_degrees(adata: AnnData) -> List[int]: +def paga_degrees(adata: AnnData) -> list[int]: """Compute the degree of each node in the abstracted graph. Parameters @@ -401,7 +405,7 @@ def paga_degrees(adata: AnnData) -> List[int]: return degrees -def paga_expression_entropies(adata) -> List[float]: +def paga_expression_entropies(adata) -> list[float]: """Compute the median expression entropy for each node-group. Parameters @@ -440,7 +444,7 @@ def paga_compare_paths( adata1: AnnData, adata2: AnnData, adjacency_key: str = "connectivities", - adjacency_key2: Optional[str] = None, + adjacency_key2: str | None = None, ) -> PAGAComparePathsResult: """Compare paths in abstracted graphs in two datasets. @@ -506,8 +510,8 @@ def paga_compare_paths( # loop over all pairs of leaf nodes in the reference adata1 for r, s in itertools.combinations(leaf_nodes1, r=2): r2, s2 = asso_groups1[r][0], asso_groups1[s][0] - on1_g1, on2_g1 = [orig_names1[int(i)] for i in [r, s]] - on1_g2, on2_g2 = [orig_names2[int(i)] for i in [r2, s2]] + on1_g1, on2_g1 = (orig_names1[int(i)] for i in [r, s]) + on1_g2, on2_g2 = (orig_names2[int(i)] for i in [r2, s2]) logg.debug( f"compare shortest paths between leafs ({on1_g1}, {on2_g1}) " f"in graph1 and ({on1_g2}, {on2_g2}) in graph2:" diff --git a/scanpy/tools/_pca.py b/scanpy/tools/_pca.py deleted file mode 100644 index e110bb51c6..0000000000 --- a/scanpy/tools/_pca.py +++ /dev/null @@ -1 +0,0 @@ -from ..preprocessing import pca diff --git a/scanpy/tools/_rank_genes_groups.py b/scanpy/tools/_rank_genes_groups.py index 6c2d6cc4d6..bf8922f003 100644 --- a/scanpy/tools/_rank_genes_groups.py +++ b/scanpy/tools/_rank_genes_groups.py @@ -3,21 +3,23 @@ from __future__ import annotations from math import floor -from typing import Literal, get_args -from collections.abc import Generator, Iterable +from typing import TYPE_CHECKING, Literal, get_args import numpy as np import pandas as pd -from numpy.typing import NDArray -from anndata import AnnData from scipy.sparse import issparse, vstack from .. import _utils from .. import logging as logg -from ..preprocessing._simple import _get_mean_var -from ..get import _check_mask from .._utils import check_nonnegative_integers +from ..get import _check_mask +from ..preprocessing._simple import _get_mean_var + +if TYPE_CHECKING: + from collections.abc import Generator, Iterable + from anndata import AnnData + from numpy.typing import NDArray _Method = Literal["logreg", "t-test", "wilcoxon", "t-test_overestim_var"] _CorrMethod = Literal["benjamini-hochberg", "bonferroni"] diff --git a/scanpy/tools/_score_genes.py b/scanpy/tools/_score_genes.py index d7305c00df..303609e80a 100644 --- a/scanpy/tools/_score_genes.py +++ b/scanpy/tools/_score_genes.py @@ -1,16 +1,24 @@ """Calculate scores based on the expression of gene lists. """ -from typing import Sequence, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING import numpy as np import pandas as pd -from anndata import AnnData from scipy.sparse import issparse -from .. import logging as logg -from .._utils import AnyRandom from scanpy._utils import _check_use_raw +from .. import logging as logg + +if TYPE_CHECKING: + from collections.abc import Sequence + + from anndata import AnnData + + from .._utils import AnyRandom + def _sparse_nanmean(X, axis): """ @@ -41,13 +49,13 @@ def score_genes( adata: AnnData, gene_list: Sequence[str], ctrl_size: int = 50, - gene_pool: Optional[Sequence[str]] = None, + gene_pool: Sequence[str] | None = None, n_bins: int = 25, score_name: str = "score", random_state: AnyRandom = 0, copy: bool = False, - use_raw: Optional[bool] = None, -) -> Optional[AnnData]: + use_raw: bool | None = None, +) -> AnnData | None: """\ Score a set of genes [Satija15]_. @@ -196,7 +204,7 @@ def score_genes_cell_cycle( g2m_genes: Sequence[str], copy: bool = False, **kwargs, -) -> Optional[AnnData]: +) -> AnnData | None: """\ Score cell cycle genes [Satija15]_. diff --git a/scanpy/tools/_sim.py b/scanpy/tools/_sim.py index 2d8d2b5815..8d24172a61 100644 --- a/scanpy/tools/_sim.py +++ b/scanpy/tools/_sim.py @@ -8,33 +8,39 @@ ---- Beta Version. The code will be reorganized soon. """ +from __future__ import annotations import itertools import shutil import sys from pathlib import Path from types import MappingProxyType -from typing import Optional, Union, List, Tuple, Mapping, Literal +from typing import TYPE_CHECKING, Literal import numpy as np import scipy as sp -from anndata import AnnData -from .. import _utils, readwrite, logging as logg +from .. import _utils, readwrite +from .. import logging as logg from .._settings import settings +if TYPE_CHECKING: + from collections.abc import Mapping + + from anndata import AnnData + def sim( model: Literal["krumsiek11", "toggleswitch"], params_file: bool = True, - tmax: Optional[int] = None, - branching: Optional[bool] = None, - nrRealizations: Optional[int] = None, - noiseObs: Optional[float] = None, - noiseDyn: Optional[float] = None, - step: Optional[int] = None, - seed: Optional[int] = None, - writedir: Optional[Union[str, Path]] = None, + tmax: int | None = None, + branching: bool | None = None, + nrRealizations: int | None = None, + noiseObs: float | None = None, + noiseDyn: float | None = None, + step: int | None = None, + seed: int | None = None, + writedir: str | Path | None = None, ) -> AnnData: """\ Simulate dynamic gene expression data [Wittmann09]_ [Wolf18]_. @@ -294,7 +300,7 @@ def write_data( else: id = 0 with filename.open("w") as f: - id = "{:0>6}".format(id) + id = f"{id:0>6}" f.write(str(id)) # dimension dim = X.shape[1] @@ -896,7 +902,7 @@ def write_data( def _check_branching( X: np.ndarray, Xsamples: np.ndarray, restart: int, threshold: float = 0.25 -) -> Tuple[bool, List[np.ndarray]]: +) -> tuple[bool, list[np.ndarray]]: """\ Check whether time series branches. @@ -975,7 +981,7 @@ def check_nocycles(Adj: np.ndarray, verbosity: int = 2) -> bool: def sample_coupling_matrix( dim: int = 3, connectivity: float = 0.5 -) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int]: +) -> tuple[np.ndarray, np.ndarray, np.ndarray, int]: """\ Sample coupling matrix. diff --git a/scanpy/tools/_top_genes.py b/scanpy/tools/_top_genes.py index 64cb1768dd..96bee3efcf 100644 --- a/scanpy/tools/_top_genes.py +++ b/scanpy/tools/_top_genes.py @@ -3,26 +3,32 @@ """\ This modules provides all non-visualization tools for advanced gene ranking and exploration of genes """ -from typing import Optional, Collection, Literal +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal import pandas as pd -from anndata import AnnData -from sklearn import metrics from scipy.sparse import issparse +from sklearn import metrics from .. import logging as logg from .._utils import select_groups +if TYPE_CHECKING: + from collections.abc import Collection + + from anndata import AnnData + def correlation_matrix( adata: AnnData, - name_list: Optional[Collection[str]] = None, - groupby: Optional[str] = None, - group: Optional[int] = None, + name_list: Collection[str] | None = None, + groupby: str | None = None, + group: int | None = None, n_genes: int = 20, data: Literal["Complete", "Group", "Rest"] = "Complete", method: Literal["pearson", "kendall", "spearman"] = "pearson", - annotation_key: Optional[str] = None, + annotation_key: str | None = None, ) -> None: """\ Calculate correlation matrix. @@ -120,7 +126,7 @@ def correlation_matrix( def ROC_AUC_analysis( adata: AnnData, groupby: str, - group: Optional[str] = None, + group: str | None = None, n_genes: int = 100, ): """\ diff --git a/scanpy/tools/_tsne.py b/scanpy/tools/_tsne.py index c08b20a244..9db096af65 100644 --- a/scanpy/tools/_tsne.py +++ b/scanpy/tools/_tsne.py @@ -1,30 +1,35 @@ -from packaging import version -from typing import Optional, Union +from __future__ import annotations + import warnings +from typing import TYPE_CHECKING -from anndata import AnnData +from packaging import version -from .._utils import _doc_params, AnyRandom -from ..tools._utils import _choose_representation, doc_use_rep, doc_n_pcs -from .._settings import settings from .. import logging as logg +from .._settings import settings +from .._utils import AnyRandom, _doc_params +from ..neighbors._doc import doc_n_pcs, doc_use_rep +from ._utils import _choose_representation + +if TYPE_CHECKING: + from anndata import AnnData @_doc_params(doc_n_pcs=doc_n_pcs, use_rep=doc_use_rep) def tsne( adata: AnnData, - n_pcs: Optional[int] = None, - use_rep: Optional[str] = None, - perplexity: Union[float, int] = 30, - early_exaggeration: Union[float, int] = 12, - learning_rate: Union[float, int] = 1000, + n_pcs: int | None = None, + use_rep: str | None = None, + perplexity: float | int = 30, + early_exaggeration: float | int = 12, + learning_rate: float | int = 1000, random_state: AnyRandom = 0, use_fast_tsne: bool = False, - n_jobs: Optional[int] = None, + n_jobs: int | None = None, copy: bool = False, *, metric: str = "euclidean", -) -> Optional[AnnData]: +) -> AnnData | None: """\ t-SNE [Maaten08]_ [Amir13]_ [Pedregosa11]_. diff --git a/scanpy/tools/_umap.py b/scanpy/tools/_umap.py index 0ec0a1d506..5a97c4fc3c 100644 --- a/scanpy/tools/_umap.py +++ b/scanpy/tools/_umap.py @@ -1,16 +1,19 @@ -from typing import Optional, Union, Literal +from __future__ import annotations + import warnings +from typing import TYPE_CHECKING, Literal import numpy as np from packaging import version -from anndata import AnnData -from sklearn.utils import check_random_state, check_array +from sklearn.utils import check_array, check_random_state -from ._utils import get_init_pos_from_paga, _choose_representation from .. import logging as logg from .._settings import settings from .._utils import AnyRandom, NeighborsView +from ._utils import _choose_representation, get_init_pos_from_paga +if TYPE_CHECKING: + from anndata import AnnData _InitPos = Literal["paga", "spectral", "random"] @@ -20,18 +23,18 @@ def umap( min_dist: float = 0.5, spread: float = 1.0, n_components: int = 2, - maxiter: Optional[int] = None, + maxiter: int | None = None, alpha: float = 1.0, gamma: float = 1.0, negative_sample_rate: int = 5, - init_pos: Union[_InitPos, np.ndarray, None] = "spectral", + init_pos: _InitPos | np.ndarray | None = "spectral", random_state: AnyRandom = 0, - a: Optional[float] = None, - b: Optional[float] = None, + a: float | None = None, + b: float | None = None, copy: bool = False, method: Literal["umap", "rapids"] = "umap", - neighbors_key: Optional[str] = None, -) -> Optional[AnnData]: + neighbors_key: str | None = None, +) -> AnnData | None: """\ Embed the neighborhood graph using UMAP [McInnes18]_. diff --git a/scanpy/tools/_utils.py b/scanpy/tools/_utils.py index 679986d841..317948caf6 100644 --- a/scanpy/tools/_utils.py +++ b/scanpy/tools/_utils.py @@ -1,28 +1,16 @@ from __future__ import annotations -from typing import Optional +from typing import TYPE_CHECKING import numpy as np -from scipy.sparse import csr_matrix -from anndata import AnnData from .. import logging as logg -from ._pca import pca from .._settings import settings from .._utils import _choose_graph -doc_use_rep = """\ -use_rep - Use the indicated representation. `'X'` or any key for `.obsm` is valid. - If `None`, the representation is chosen automatically: - For `.n_vars` < :attr:`~scanpy._settings.ScanpyConfig.N_PCS` (default: 50), `.X` is used, otherwise 'X_pca' is used. - If 'X_pca' is not present, it’s computed with default parameters or `n_pcs` if present.\ -""" - -doc_n_pcs = """\ -n_pcs - Use this many PCs. If `n_pcs==0` use `.X` if `use_rep is None`.\ -""" +if TYPE_CHECKING: + from anndata import AnnData + from scipy.sparse import csr_matrix def _choose_representation( @@ -32,6 +20,8 @@ def _choose_representation( n_pcs: int | None = None, silent: bool = False, ) -> np.ndarray | csr_matrix: # TODO: what else? + from ..preprocessing import pca + verbosity = settings.verbosity if silent and settings.verbosity > 1: settings.verbosity = 1 @@ -80,7 +70,7 @@ def _choose_representation( return X -def preprocess_with_pca(adata, n_pcs: Optional[int] = None, random_state=0): +def preprocess_with_pca(adata, n_pcs: int | None = None, random_state=0): """ Parameters ---------- @@ -89,6 +79,8 @@ def preprocess_with_pca(adata, n_pcs: Optional[int] = None, random_state=0): If `None` and there is a PCA version of the data, use this. If an integer, compute the PCA. """ + from ..preprocessing import pca + if n_pcs == 0: logg.info(" using data matrix X directly (no PCA)") return adata.X diff --git a/scanpy/tools/_utils_clustering.py b/scanpy/tools/_utils_clustering.py index 0c21d17927..d741d481d0 100644 --- a/scanpy/tools/_utils_clustering.py +++ b/scanpy/tools/_utils_clustering.py @@ -1,3 +1,6 @@ +from __future__ import annotations + + def rename_groups( adata, key_added, restrict_key, restrict_categories, restrict_indices, groups ): @@ -16,9 +19,7 @@ def restrict_adjacency(adata, restrict_key, restrict_categories, adjacency): ) for c in restrict_categories: if c not in adata.obs[restrict_key].cat.categories: - raise ValueError( - "'{}' is not a valid category for '{}'".format(c, restrict_key) - ) + raise ValueError(f"'{c}' is not a valid category for '{restrict_key}'") restrict_indices = adata.obs[restrict_key].isin(restrict_categories).values adjacency = adjacency[restrict_indices, :] adjacency = adjacency[:, restrict_indices]