Skip to content

Commit

Permalink
fix: set multiprocessing context according to OS
Browse files Browse the repository at this point in the history
  • Loading branch information
ClemDoum committed Dec 21, 2023
1 parent bec29f2 commit 257f685
Showing 1 changed file with 23 additions and 2 deletions.
25 changes: 23 additions & 2 deletions neo4j-app/neo4j_app/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import multiprocessing
import os
import platform
import sys
import tempfile
import traceback
Expand Down Expand Up @@ -37,6 +38,7 @@
_TEST_DB_FILE: Optional[Path] = None
_TEST_LOCK: Optional[multiprocessing.Lock] = None
_PROCESS_EXECUTOR: Optional[concurrent.futures.ProcessPoolExecutor] = None
_MP_CONTEXT = None


class DependencyInjectionError(RuntimeError):
Expand All @@ -63,6 +65,24 @@ def lifespan_config() -> AppConfig:
return cast(AppConfig, _CONFIG)


def mp_context_enter(**__):
global _MP_CONTEXT
platform_system = platform.system()
if platform_system == "Darwin":
ctx = "spawn"
elif platform_system == "Linux":
ctx = "fork"
else:
raise ValueError(f"Unsupported OS: {platform_system}")
_MP_CONTEXT = multiprocessing.get_context(ctx)


def lifespan_mp_context():
if _MP_CONTEXT is None:
raise DependencyInjectionError("multiprocessing context")
return _MP_CONTEXT


async def neo4j_driver_enter(**__):
global _NEO4J_DRIVER
_NEO4J_DRIVER = lifespan_config().to_neo4j_driver()
Expand Down Expand Up @@ -126,7 +146,7 @@ def _lifespan_test_db_path() -> Path:

def test_process_manager_enter(**_):
global _PROCESS_MANAGER
_PROCESS_MANAGER = multiprocessing.Manager()
_PROCESS_MANAGER = lifespan_mp_context().Manager()


def test_process_manager_exit(exc_type, exc_value, trace):
Expand Down Expand Up @@ -164,7 +184,7 @@ def process_executor_enter(**_):
worker_ids = [f"worker-{process_id}-{i}" for i in range(n_workers)]
_PROCESS_EXECUTOR = concurrent.futures.ProcessPoolExecutor( # pylint: disable=unnecessary-dunder-call
max_workers=n_workers,
mp_context=multiprocessing.get_context("spawn"),
mp_context=lifespan_mp_context(),
).__enter__()
kwargs = dict()
worker_cls = config.to_worker_cls()
Expand Down Expand Up @@ -261,6 +281,7 @@ def lifespan_event_publisher() -> EventPublisher:
("neo4j driver creation", neo4j_driver_enter, neo4j_driver_exit),
("neo4j project registry creation", create_project_registry_db_enter, None),
("ES client creation", es_client_enter, es_client_exit),
(None, mp_context_enter, None),
(None, test_process_manager_enter, test_process_manager_exit),
(None, test_db_path_enter, test_db_path_exit),
(None, _test_lock_enter, None),
Expand Down

0 comments on commit 257f685

Please sign in to comment.