Skip to content

Commit

Permalink
🏗️ Turn multiprocessing_imap into a generator (#19)
Browse files Browse the repository at this point in the history
* 🏗️ Turn multiprocessing_imap into a generator

* 🚨 Make lint
  • Loading branch information
ddelange authored Apr 6, 2022
1 parent e57761d commit c879199
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 36 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ repos:
exclude: .*/tests/.*

- repo: https://github.com/psf/black
rev: 20.8b1
rev: 22.3.0
hooks:
- id: black

Expand Down
2 changes: 1 addition & 1 deletion requirements/ci.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
detect-secrets~=0.14.2
mypy~=0.782
pre-commit~=2.6.0
pre-commit~=2.17
pytest-cov~=2.10.1
pytest-env~=0.6.2
pytest-sugar~=0.9.4
Expand Down
12 changes: 7 additions & 5 deletions src/mapply/mapply.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,13 @@ def run_apply(func, df_or_series, args=(), **kwargs):
if not isseries:
kwargs["axis"] = axis

results = multiprocessing_imap(
partial(run_apply, func, args=args, **kwargs),
dfs,
n_workers=n_workers,
progressbar=progressbar,
results = list(
multiprocessing_imap(
partial(run_apply, func, args=args, **kwargs),
dfs,
n_workers=n_workers,
progressbar=progressbar,
)
)

if (
Expand Down
22 changes: 12 additions & 10 deletions src/mapply/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,19 @@
def some_heavy_computation(x, power):
return pow(x, power)
multicore_list = multiprocessing_imap(
some_heavy_computation,
range(100),
power=2.5,
progressbar=False,
n_workers=-1
multicore_list = list(
multiprocessing_imap(
some_heavy_computation,
range(100),
power=2.5,
progressbar=False,
n_workers=-1,
)
)
"""
import logging
from functools import partial
from typing import Any, Callable, Iterable, List, Optional
from typing import Any, Callable, Iterable, Optional

import psutil
from pathos.multiprocessing import ProcessPool
Expand Down Expand Up @@ -63,7 +65,7 @@ def multiprocessing_imap(
progressbar: bool = True,
args=(),
**kwargs
) -> List[Any]:
) -> Iterable[Any]:
"""Execute func on each element in iterable on n_workers, ensuring order.
Args:
Expand All @@ -74,7 +76,7 @@ def multiprocessing_imap(
args: Additional positional arguments to pass to func.
kwargs: Additional keyword arguments to pass to func.
Returns:
Yields:
Results in same order as input iterable.
Raises:
Expand All @@ -100,7 +102,7 @@ def multiprocessing_imap(
stage = tqdm(stage, total=n_chunks)

try:
return list(stage)
yield from stage
except (Exception, KeyboardInterrupt):
if pool:
logger.debug("Terminating ProcessPool")
Expand Down
18 changes: 9 additions & 9 deletions tests/test_mapply.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ def test_df_mapply():

# same output along both axes
pd.testing.assert_frame_equal(
df.apply(lambda x: x ** 2),
df.mapply(lambda x: x ** 2),
df.apply(lambda x: x**2),
df.mapply(lambda x: x**2),
)
pd.testing.assert_frame_equal(
df.mapply(lambda x: x ** 2, axis=0),
df.mapply(lambda x: x ** 2, axis=1),
df.mapply(lambda x: x**2, axis=0),
df.mapply(lambda x: x**2, axis=1),
)

# vectorized
Expand All @@ -46,15 +46,15 @@ def test_df_mapply():
# max_chunks_per_worker=0
mapply.init(progressbar=False, chunk_size=1, max_chunks_per_worker=0)
pd.testing.assert_frame_equal(
df.apply(lambda x: x ** 2),
df.mapply(lambda x: x ** 2),
df.apply(lambda x: x**2),
df.mapply(lambda x: x**2),
)

# n_workers=1
mapply.init(progressbar=False, chunk_size=1, n_workers=1)
pd.testing.assert_frame_equal(
df.apply(lambda x: x ** 2),
df.mapply(lambda x: x ** 2),
df.apply(lambda x: x**2),
df.mapply(lambda x: x**2),
)

# concat for only one result
Expand All @@ -70,7 +70,7 @@ def test_series_mapply():
# chunk_size>1
mapply.init(progressbar=False, chunk_size=5)

fn = lambda x: x ** 2 # noqa:E731
fn = lambda x: x**2 # noqa:E731
series = pd.Series(range(100))

with pytest.raises(ValueError, match="Passing axis=1 is not allowed for Series"):
Expand Down
30 changes: 20 additions & 10 deletions tests/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,36 @@ def foo(x, power):


def test_multiprocessing_imap(size=100, power=1.1):
multicore_list1 = multiprocessing_imap(
foo, range(size), power=power, progressbar=False, n_workers=size
multicore_list1 = list(
multiprocessing_imap(
foo, range(size), power=power, progressbar=False, n_workers=size
)
)
multicore_list2 = multiprocessing_imap(
foo, range(size), power=power, progressbar=True, n_workers=1
multicore_list2 = list(
multiprocessing_imap(
foo, range(size), power=power, progressbar=True, n_workers=1
)
)
multicore_list3 = multiprocessing_imap( # generator with unknown length
foo, (i for i in range(size)), power=power, progressbar=False, n_workers=2
multicore_list3 = list(
multiprocessing_imap( # generator with unknown length
foo, (i for i in range(size)), power=power, progressbar=False, n_workers=2
)
)

assert multicore_list1 == multicore_list2
assert multicore_list1 == multicore_list3
assert multicore_list1 == [foo(x, power=power) for x in range(size)]
with pytest.raises(ValueError, match="reraise"):
# hit with ProcessPool
multiprocessing_imap(
foo, range(size), power=None, progressbar=False, n_workers=2
list(
multiprocessing_imap(
foo, range(size), power=None, progressbar=False, n_workers=2
)
)
with pytest.raises(ValueError, match="reraise"):
# hit without ProcessPool
multiprocessing_imap(
foo, range(size), power=None, progressbar=False, n_workers=1
list(
multiprocessing_imap(
foo, range(size), power=None, progressbar=False, n_workers=1
)
)

0 comments on commit c879199

Please sign in to comment.