diff --git a/neo4j-app/neo4j_app/app/dependencies.py b/neo4j-app/neo4j_app/app/dependencies.py index 670eed11..7f887539 100644 --- a/neo4j-app/neo4j_app/app/dependencies.py +++ b/neo4j-app/neo4j_app/app/dependencies.py @@ -3,6 +3,7 @@ import logging import multiprocessing import os +import platform import sys import tempfile import traceback @@ -37,6 +38,7 @@ _TEST_DB_FILE: Optional[Path] = None _TEST_LOCK: Optional[multiprocessing.Lock] = None _PROCESS_EXECUTOR: Optional[concurrent.futures.ProcessPoolExecutor] = None +_MP_CONTEXT = None class DependencyInjectionError(RuntimeError): @@ -63,6 +65,24 @@ def lifespan_config() -> AppConfig: return cast(AppConfig, _CONFIG) +def mp_context_enter(**__): + global _MP_CONTEXT + platform_system = platform.system() + if platform_system == "Darwin": + ctx = "spawn" + elif platform_system == "Linux": + ctx = "fork" + else: + raise ValueError(f"Unsupported OS: {platform_system}") + _MP_CONTEXT = multiprocessing.get_context(ctx) + + +def lifespan_mp_context(): + if _MP_CONTEXT is None: + raise DependencyInjectionError("multiprocessing context") + return _MP_CONTEXT + + async def neo4j_driver_enter(**__): global _NEO4J_DRIVER _NEO4J_DRIVER = lifespan_config().to_neo4j_driver() @@ -126,7 +146,7 @@ def _lifespan_test_db_path() -> Path: def test_process_manager_enter(**_): global _PROCESS_MANAGER - _PROCESS_MANAGER = multiprocessing.Manager() + _PROCESS_MANAGER = lifespan_mp_context().Manager() def test_process_manager_exit(exc_type, exc_value, trace): @@ -164,7 +184,7 @@ def process_executor_enter(**_): worker_ids = [f"worker-{process_id}-{i}" for i in range(n_workers)] _PROCESS_EXECUTOR = concurrent.futures.ProcessPoolExecutor( # pylint: disable=unnecessary-dunder-call max_workers=n_workers, - mp_context=multiprocessing.get_context("spawn"), + mp_context=lifespan_mp_context(), ).__enter__() kwargs = dict() worker_cls = config.to_worker_cls() @@ -261,6 +281,7 @@ def lifespan_event_publisher() -> EventPublisher: ("neo4j driver creation", neo4j_driver_enter, neo4j_driver_exit), ("neo4j project registry creation", create_project_registry_db_enter, None), ("ES client creation", es_client_enter, es_client_exit), + (None, mp_context_enter, None), (None, test_process_manager_enter, test_process_manager_exit), (None, test_db_path_enter, test_db_path_exit), (None, _test_lock_enter, None),