diff --git a/schist/inference/_flat_model.py b/schist/inference/_flat_model.py index 87dc540..bacff73 100644 --- a/schist/inference/_flat_model.py +++ b/schist/inference/_flat_model.py @@ -4,7 +4,7 @@ import pandas as pd from anndata import AnnData from scipy import sparse -from joblib import delayed, Parallel +from joblib import delayed, Parallel, parallel_config from natsort import natsorted from scanpy import logging as logg from scanpy.tools._utils_clustering import rename_groups, restrict_adjacency @@ -120,9 +120,11 @@ def flat_model( seeds = np.random.choice(range(n_init**2), size=n_init, replace=False) -# if dispatch_backend == 'threads': -# logg.warning('We noticed a large performance degradation with this backend\n' -# '``dispatch_backend=processes`` should be preferred') + # the following lines are for compatibility + if dispatch_backend == 'threads': + dispatch_backend = 'threading' + elif dispatch_backend == 'processes': + dispatch_backend = 'loky' if collect_marginals and not refine_model: if n_init < 100: @@ -193,9 +195,12 @@ def fast_min(state, beta, n_sweep, fast_tol, max_iter=max_iter, seed=None): # perform a mcmc sweep on each # no list comprehension as I need to collect stats - states = Parallel(n_jobs=n_jobs, prefer=dispatch_backend)( - delayed(fast_min)(states[x], beta, n_sweep, tolerance, seeds[x]) for x in range(n_init) - ) + with parallel_config(backend=dispatch_backend, + max_nbytes=None, + n_jobs=n_jobs): + states = Parallel()( + delayed(fast_min)(states[x], beta, n_sweep, tolerance, seeds[x]) for x in range(n_init) + ) pmode = gt.PartitionModeState([x.get_blocks().a for x in states], converge=True) bs = pmode.get_max(g) diff --git a/schist/inference/_multi_nested.py b/schist/inference/_multi_nested.py index a61c54d..f648fd0 100644 --- a/schist/inference/_multi_nested.py +++ b/schist/inference/_multi_nested.py @@ -4,7 +4,7 @@ import pandas as pd from anndata import AnnData from scipy import sparse -from joblib import delayed, Parallel +from joblib import delayed, Parallel, parallel_config from scanpy import logging as logg from scanpy.tools._utils_clustering import rename_groups, restrict_adjacency @@ -136,9 +136,11 @@ def nested_model_multi( seeds = np.random.choice(range(n_init**2), size=n_init, replace=False) -# if dispatch_backend == 'threads': -# logg.warning('We noticed a large performance degradation with this backend\n' -# '``dispatch_backend=processes`` should be preferred') + # the following lines are for compatibility + if dispatch_backend == 'threads': + dispatch_backend = 'threading' + elif dispatch_backend == 'processes': + dispatch_backend = 'loky' if collect_marginals and not refine_model: if n_init < 100: @@ -261,9 +263,13 @@ def fast_min(state, beta, n_sweep, fast_tol, max_iter=max_iter, seed=None): n += 1 return state - states = Parallel(n_jobs=n_jobs, prefer=dispatch_backend)( - delayed(fast_min)(states[x], beta, n_sweep, tolerance, seeds[x]) for x in range(n_init) - ) + with parallel_config(backend=dispatch_backend, + max_nbytes=None, + n_jobs=n_jobs): + states = Parallel()( + delayed(fast_min)(states[x], beta, n_sweep, tolerance, seeds[x]) for x in range(n_init) + ) + logg.info(' minimization step done', time=start) pmode = gt.PartitionModeState([x.get_bs() for x in states], converge=True, nested=True) bs = pmode.get_max_nested() diff --git a/schist/inference/_planted_model.py b/schist/inference/_planted_model.py index 028c0de..5b3f4e1 100644 --- a/schist/inference/_planted_model.py +++ b/schist/inference/_planted_model.py @@ -4,7 +4,7 @@ import pandas as pd from anndata import AnnData from scipy import sparse -from joblib import delayed, Parallel +from joblib import delayed, Parallel, parallel_config from natsort import natsorted from scanpy import logging as logg from scanpy.tools._utils_clustering import rename_groups, restrict_adjacency @@ -121,9 +121,11 @@ def planted_model( seeds = np.random.choice(range(n_init**2), size=n_init, replace=False) -# if dispatch_backend == 'threads': -# logg.warning('We noticed a large performance degradation with this backend\n' -# '``dispatch_backend=processes`` should be preferred') + # the following lines are for compatibility + if dispatch_backend == 'threads': + dispatch_backend = 'threading' + elif dispatch_backend == 'processes': + dispatch_backend = 'loky' if collect_marginals and not refine_model: if n_init < 100: @@ -189,9 +191,13 @@ def fast_min(state, beta, n_sweep, fast_tol, max_iter=max_iter, seed=None): # perform a mcmc sweep on each # no list comprehension as I need to collect stats - states = Parallel(n_jobs=n_jobs, prefer=dispatch_backend)( - delayed(fast_min)(states[x], beta, n_sweep, tolerance, seeds[x]) for x in range(n_init) - ) + with parallel_config(backend=dispatch_backend, + max_nbytes=None, + n_jobs=n_jobs): + states = Parallel()( + delayed(fast_min)(states[x], beta, n_sweep, tolerance, seeds[x]) for x in range(n_init) + ) + logg.info(' minimization step done', time=start) pmode = gt.PartitionModeState([x.get_blocks().a for x in states], converge=True) diff --git a/schist/inference/_pmleiden.py b/schist/inference/_pmleiden.py index 16a5234..ccdcad4 100644 --- a/schist/inference/_pmleiden.py +++ b/schist/inference/_pmleiden.py @@ -13,7 +13,7 @@ from scanpy.tools._utils_clustering import rename_groups, restrict_adjacency from scanpy._utils import get_igraph_from_adjacency, _choose_graph -from joblib import Parallel, delayed +from joblib import delayed, Parallel, parallel_config try: from leidenalg.VertexPartition import MutableVertexPartition @@ -143,9 +143,11 @@ def leiden( ) partition_kwargs = dict(partition_kwargs) -# if dispatch_backend == 'threads': -# logg.warning('We noticed a large performance degradation with this backend\n' -# '``dispatch_backend=processes`` should be preferred') + # the following lines are for compatibility + if dispatch_backend == 'threads': + dispatch_backend = 'threading' + elif dispatch_backend == 'processes': + dispatch_backend = 'loky' start = logg.info('running Leiden clustering') @@ -185,11 +187,14 @@ def leiden( def membership(g, partition_type, seed, **partition_kwargs): return leidenalg.find_partition(g, partition_type, seed=seed, **partition_kwargs).membership - - parts = Parallel(n_jobs=n_jobs, prefer=dispatch_backend)( - delayed(membership)(g, partition_type, - seeds[i], **partition_kwargs) - for i in range(n_init)) + + with parallel_config(backend=dispatch_backend, + max_nbytes=None, + n_jobs=n_jobs): + parts = Parallel()( + delayed(membership)(g, partition_type, seeds[i], **partition_kwargs) for x in range(n_init) + ) + pmode = gt.PartitionModeState(parts, converge=True)