diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c4e7fc9..8d5cae2 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -44,5 +44,18 @@ jobs: with: python-version: '3.11' cache: 'pip' - - run: pip install -r requirements.txt - - run: python -m pytest tests \ No newline at end of file + + - name: Install deps + run: pip install -r requirements.txt + + - name: Tests with numba + run: coverage run --source=tsbrowse -m pytest -x tests + + - name: Tests with coverage (no numba) + run: TSBROWSE_DISABLE_NUMBA=1 coverage run --source=tsbrowse -m pytest -x tests + + - name: Upload coverage + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + coveralls --service=github \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 905dce6..ebc1592 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,8 @@ appdirs click +coverage +coveralls daiquiri -# It's not clear why we need Dask here, but tests are failing -# otherwise. Pinning to arbitrary working version -dask[dataframe]==2024.4.1 datashader diskcache hvplot diff --git a/tsbrowse/jit.py b/tsbrowse/jit.py new file mode 100644 index 0000000..7ec9c6e --- /dev/null +++ b/tsbrowse/jit.py @@ -0,0 +1,56 @@ +import functools +import logging +import os + +import numba + +logger = logging.getLogger(__name__) + +_DISABLE_NUMBA = os.environ.get("TSBROWSE_DISABLE_NUMBA", "0") + +try: + ENABLE_NUMBA = {"0": True, "1": False}[_DISABLE_NUMBA] +except KeyError as e: # pragma: no cover + raise KeyError( + "Environment variable 'TSBROWSE_DISABLE_NUMBA' must be '0' or '1'" + ) from e + +# We will mostly be using disable numba for debugging and running tests for +# coverage, so raise a loud warning in case this is being used accidentally. + +if not ENABLE_NUMBA: + logger.warning( + "numba globally disabled for tsbrowse; performance will be drastically" + " reduced." + ) + + +DEFAULT_NUMBA_ARGS = { + "nopython": True, + "cache": True, +} + + +def numba_njit(**numba_kwargs): + def _numba_njit(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) # pragma: no cover + + if ENABLE_NUMBA: # pragma: no cover + combined_kwargs = {**DEFAULT_NUMBA_ARGS, **numba_kwargs} + return numba.jit(**combined_kwargs)(func) + else: + return func + + return _numba_njit + + +def numba_jitclass(spec): + def _numba_jitclass(cls): + if ENABLE_NUMBA: # pragma: no cover + return numba.experimental.jitclass(spec)(cls) + else: + return cls + + return _numba_jitclass diff --git a/tsbrowse/model.py b/tsbrowse/model.py index 283c9e3..eb273a4 100644 --- a/tsbrowse/model.py +++ b/tsbrowse/model.py @@ -8,9 +8,9 @@ import pandas as pd import tskit +from . import jit from .cache import disk_cache - logger = daiquiri.getLogger("tsbrowse") spec = [ @@ -28,7 +28,7 @@ ] -@numba.experimental.jitclass(spec) +@jit.numba_jitclass(spec) class TreePosition: def __init__( self, @@ -91,7 +91,7 @@ def alloc_tree_position(ts): ) -@numba.njit +@jit.numba_njit() def _compute_per_tree_stats( tree_pos, num_trees, num_nodes, nodes_time, edges_parent, edges_child ): @@ -163,7 +163,7 @@ def compute_per_tree_stats(ts): ) -@numba.njit +@jit.numba_njit() def _compute_mutation_parent_counts(mutations_parent): N = mutations_parent.shape[0] num_parents = np.zeros(N, dtype=np.int32) @@ -176,7 +176,7 @@ def _compute_mutation_parent_counts(mutations_parent): return num_parents -@numba.njit +@jit.numba_njit() def _compute_mutation_inheritance_counts( tree_pos, num_nodes, @@ -259,7 +259,7 @@ def compute_mutation_counts(ts): return MutationCounts(num_parents, num_inheritors, num_descendants) -@numba.njit +@jit.numba_njit() def _compute_population_mutation_counts( tree_pos, num_nodes, @@ -398,7 +398,7 @@ def _repr_html_(self): return self.summary_df._repr_html_() @staticmethod - @numba.njit + @jit.numba_njit() def child_bounds(num_nodes, edges_left, edges_right, edges_child): num_edges = edges_left.shape[0] child_left = np.zeros(num_nodes, dtype=np.float64) + np.inf