diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6be9680..5acfd73 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ repos: additional_dependencies: ['gibberish-detector'] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.6 + rev: v0.1.8 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] @@ -41,6 +41,7 @@ repos: rev: v0.23.1 hooks: - id: toml-sort-fix + args: ['--trailing-comma-inline-array'] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 diff --git a/pyproject.toml b/pyproject.toml index 3ec9a50..40f1c92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,15 +57,14 @@ ignore = [ "TRY003", # there is EM102 "D203", # there is D211 "D213", # there is D212 - "FIX002" # there is TD002,TD003 + "FIX002", # there is TD002,TD003 ] [tool.ruff.extend-per-file-ignores] -"__init__.py" = ["E401", "E402"] "**/tests/**/*.py" = [ "S101", # assert is fine in tests "D100", # tests is not a package - "D104" # tests modules don't need docstrings + "D104", # tests modules don't need docstrings ] [tool.ruff.isort] diff --git a/requirements/prod.txt b/requirements/prod.txt index b8b70ca..670eece 100644 --- a/requirements/prod.txt +++ b/requirements/prod.txt @@ -1,3 +1,4 @@ -pathos>=0.2.0 +pathos>=0.3.1 # https://github.com/uqfoundation/pathos/pull/252 +multiprocess psutil tqdm>=4.27 # from tqdm.auto import tqdm diff --git a/src/mapply/parallel.py b/src/mapply/parallel.py index 4d620e7..3cc58fa 100644 --- a/src/mapply/parallel.py +++ b/src/mapply/parallel.py @@ -25,6 +25,7 @@ def some_heavy_computation(x, power): from functools import partial from typing import Any, Callable, Iterable, Iterator +import multiprocess import psutil from pathos.multiprocessing import ProcessPool from tqdm.auto import tqdm as _tqdm @@ -41,6 +42,7 @@ def sensible_cpu_count() -> int: N_CORES = sensible_cpu_count() MAX_TASKS_PER_CHILD = int(os.environ.get("MAPPLY_MAX_TASKS_PER_CHILD", 4)) +CONTEXT = multiprocess.get_context(os.environ.get("MAPPLY_START_METHOD")) def _choose_n_workers(n_chunks: int | None, n_workers: int) -> int: @@ -98,7 +100,11 @@ def multiprocessing_imap( stage = map(func, iterable) else: logger.debug("Starting ProcessPool with %d workers", n_workers) - pool = ProcessPool(n_workers, maxtasksperchild=MAX_TASKS_PER_CHILD) + pool = ProcessPool( + n_workers, + maxtasksperchild=MAX_TASKS_PER_CHILD, + context=CONTEXT, + ) stage = pool.imap(func, iterable)