Skip to content

Commit

Permalink
extend parallel_config() to all methods
Browse files Browse the repository at this point in the history
  • Loading branch information
dawe committed Aug 8, 2023
1 parent 6323eac commit d0bbc96
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 30 deletions.
19 changes: 12 additions & 7 deletions schist/inference/_flat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd
from anndata import AnnData
from scipy import sparse
from joblib import delayed, Parallel
from joblib import delayed, Parallel, parallel_config
from natsort import natsorted
from scanpy import logging as logg
from scanpy.tools._utils_clustering import rename_groups, restrict_adjacency
Expand Down Expand Up @@ -120,9 +120,11 @@ def flat_model(

seeds = np.random.choice(range(n_init**2), size=n_init, replace=False)

# if dispatch_backend == 'threads':
# logg.warning('We noticed a large performance degradation with this backend\n'
# '``dispatch_backend=processes`` should be preferred')
# the following lines are for compatibility
if dispatch_backend == 'threads':
dispatch_backend = 'threading'
elif dispatch_backend == 'processes':
dispatch_backend = 'loky'

if collect_marginals and not refine_model:
if n_init < 100:
Expand Down Expand Up @@ -193,9 +195,12 @@ def fast_min(state, beta, n_sweep, fast_tol, max_iter=max_iter, seed=None):
# perform a mcmc sweep on each
# no list comprehension as I need to collect stats

states = Parallel(n_jobs=n_jobs, prefer=dispatch_backend)(
delayed(fast_min)(states[x], beta, n_sweep, tolerance, seeds[x]) for x in range(n_init)
)
with parallel_config(backend=dispatch_backend,
max_nbytes=None,
n_jobs=n_jobs):
states = Parallel()(
delayed(fast_min)(states[x], beta, n_sweep, tolerance, seeds[x]) for x in range(n_init)
)

pmode = gt.PartitionModeState([x.get_blocks().a for x in states], converge=True)
bs = pmode.get_max(g)
Expand Down
20 changes: 13 additions & 7 deletions schist/inference/_multi_nested.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd
from anndata import AnnData
from scipy import sparse
from joblib import delayed, Parallel
from joblib import delayed, Parallel, parallel_config

from scanpy import logging as logg
from scanpy.tools._utils_clustering import rename_groups, restrict_adjacency
Expand Down Expand Up @@ -136,9 +136,11 @@ def nested_model_multi(

seeds = np.random.choice(range(n_init**2), size=n_init, replace=False)

# if dispatch_backend == 'threads':
# logg.warning('We noticed a large performance degradation with this backend\n'
# '``dispatch_backend=processes`` should be preferred')
# the following lines are for compatibility
if dispatch_backend == 'threads':
dispatch_backend = 'threading'
elif dispatch_backend == 'processes':
dispatch_backend = 'loky'

if collect_marginals and not refine_model:
if n_init < 100:
Expand Down Expand Up @@ -261,9 +263,13 @@ def fast_min(state, beta, n_sweep, fast_tol, max_iter=max_iter, seed=None):
n += 1
return state

states = Parallel(n_jobs=n_jobs, prefer=dispatch_backend)(
delayed(fast_min)(states[x], beta, n_sweep, tolerance, seeds[x]) for x in range(n_init)
)
with parallel_config(backend=dispatch_backend,
max_nbytes=None,
n_jobs=n_jobs):
states = Parallel()(
delayed(fast_min)(states[x], beta, n_sweep, tolerance, seeds[x]) for x in range(n_init)
)

logg.info(' minimization step done', time=start)
pmode = gt.PartitionModeState([x.get_bs() for x in states], converge=True, nested=True)
bs = pmode.get_max_nested()
Expand Down
20 changes: 13 additions & 7 deletions schist/inference/_planted_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd
from anndata import AnnData
from scipy import sparse
from joblib import delayed, Parallel
from joblib import delayed, Parallel, parallel_config
from natsort import natsorted
from scanpy import logging as logg
from scanpy.tools._utils_clustering import rename_groups, restrict_adjacency
Expand Down Expand Up @@ -121,9 +121,11 @@ def planted_model(

seeds = np.random.choice(range(n_init**2), size=n_init, replace=False)

# if dispatch_backend == 'threads':
# logg.warning('We noticed a large performance degradation with this backend\n'
# '``dispatch_backend=processes`` should be preferred')
# the following lines are for compatibility
if dispatch_backend == 'threads':
dispatch_backend = 'threading'
elif dispatch_backend == 'processes':
dispatch_backend = 'loky'

if collect_marginals and not refine_model:
if n_init < 100:
Expand Down Expand Up @@ -189,9 +191,13 @@ def fast_min(state, beta, n_sweep, fast_tol, max_iter=max_iter, seed=None):
# perform a mcmc sweep on each
# no list comprehension as I need to collect stats

states = Parallel(n_jobs=n_jobs, prefer=dispatch_backend)(
delayed(fast_min)(states[x], beta, n_sweep, tolerance, seeds[x]) for x in range(n_init)
)
with parallel_config(backend=dispatch_backend,
max_nbytes=None,
n_jobs=n_jobs):
states = Parallel()(
delayed(fast_min)(states[x], beta, n_sweep, tolerance, seeds[x]) for x in range(n_init)
)

logg.info(' minimization step done', time=start)
pmode = gt.PartitionModeState([x.get_blocks().a for x in states], converge=True)

Expand Down
23 changes: 14 additions & 9 deletions schist/inference/_pmleiden.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from scanpy.tools._utils_clustering import rename_groups, restrict_adjacency

from scanpy._utils import get_igraph_from_adjacency, _choose_graph
from joblib import Parallel, delayed
from joblib import delayed, Parallel, parallel_config

try:
from leidenalg.VertexPartition import MutableVertexPartition
Expand Down Expand Up @@ -143,9 +143,11 @@ def leiden(
)
partition_kwargs = dict(partition_kwargs)

# if dispatch_backend == 'threads':
# logg.warning('We noticed a large performance degradation with this backend\n'
# '``dispatch_backend=processes`` should be preferred')
# the following lines are for compatibility
if dispatch_backend == 'threads':
dispatch_backend = 'threading'
elif dispatch_backend == 'processes':
dispatch_backend = 'loky'


start = logg.info('running Leiden clustering')
Expand Down Expand Up @@ -185,11 +187,14 @@ def leiden(
def membership(g, partition_type, seed, **partition_kwargs):
return leidenalg.find_partition(g, partition_type,
seed=seed, **partition_kwargs).membership

parts = Parallel(n_jobs=n_jobs, prefer=dispatch_backend)(
delayed(membership)(g, partition_type,
seeds[i], **partition_kwargs)
for i in range(n_init))

with parallel_config(backend=dispatch_backend,
max_nbytes=None,
n_jobs=n_jobs):
parts = Parallel()(
delayed(membership)(g, partition_type, seeds[i], **partition_kwargs) for x in range(n_init)
)


pmode = gt.PartitionModeState(parts, converge=True)

Expand Down

0 comments on commit d0bbc96

Please sign in to comment.