From 5cbdf7e12791a80ad5a249e3da86ab9f28b67aed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cle=CC=81ment=20Doumouro?= Date: Mon, 29 Jan 2024 18:32:05 +0100 Subject: [PATCH] chore: isolate the async app from the HTTP app as well as create the `WorkerBackend.MULTIPROCESSING` async backend --- neo4j-app/neo4j_app/__init__.py | 2 + neo4j-app/neo4j_app/app/__init__.py | 1 + neo4j-app/neo4j_app/app/admin.py | 6 +- neo4j-app/neo4j_app/app/config.py | 81 +++++ neo4j-app/neo4j_app/app/dependencies.py | 293 ++++++------------ neo4j-app/neo4j_app/app/documents.py | 7 +- neo4j-app/neo4j_app/app/graphs.py | 6 +- neo4j-app/neo4j_app/app/main.py | 15 +- neo4j-app/neo4j_app/app/named_entities.py | 7 +- neo4j-app/neo4j_app/app/utils.py | 15 +- neo4j-app/neo4j_app/{core => }/config.py | 105 +------ neo4j-app/neo4j_app/core/__init__.py | 1 - neo4j-app/neo4j_app/core/imports.py | 3 +- neo4j-app/neo4j_app/core/utils/logging.py | 13 +- neo4j-app/neo4j_app/icij_worker/__init__.py | 6 +- neo4j-app/neo4j_app/icij_worker/app.py | 46 ++- .../neo4j_app/icij_worker/backend/__init__.py | 1 + .../neo4j_app/icij_worker/backend/backend.py | 98 ++++++ neo4j-app/neo4j_app/icij_worker/backend/mp.py | 112 +++++++ neo4j-app/neo4j_app/icij_worker/exceptions.py | 8 +- neo4j-app/neo4j_app/icij_worker/typing_.py | 33 ++ .../neo4j_app/icij_worker/utils/__init__.py | 3 + .../icij_worker/utils/dependencies.py | 78 +++++ .../neo4j_app/icij_worker/worker/__init__.py | 13 +- .../neo4j_app/icij_worker/worker/config.py | 7 +- .../neo4j_app/icij_worker/worker/neo4j.py | 83 +++-- .../neo4j_app/icij_worker/worker/process.py | 38 ++- .../neo4j_app/icij_worker/worker/worker.py | 98 +++--- neo4j-app/neo4j_app/run/run.py | 12 +- neo4j-app/neo4j_app/tasks/__init__.py | 2 +- neo4j-app/neo4j_app/tasks/app.py | 24 +- neo4j-app/neo4j_app/tasks/dependencies.py | 130 ++++++++ neo4j-app/neo4j_app/tasks/imports.py | 6 +- .../tests/{core => app}/test_config.py | 50 ++- neo4j-app/neo4j_app/tests/app/test_main.py | 4 +- neo4j-app/neo4j_app/tests/app/test_tasks.py | 23 +- neo4j-app/neo4j_app/tests/conftest.py | 23 +- .../neo4j_app/tests/icij_worker/conftest.py | 47 +-- .../tests/icij_worker/worker/conftest.py | 14 +- .../tests/icij_worker/worker/test_neo4j.py | 72 +++-- .../tests/icij_worker/worker/test_process.py | 6 +- .../tests/icij_worker/worker/test_worker.py | 4 +- .../tests/icij_worker/worker/worker_main.py | 51 --- 43 files changed, 989 insertions(+), 658 deletions(-) create mode 100644 neo4j-app/neo4j_app/app/config.py rename neo4j-app/neo4j_app/{core => }/config.py (70%) create mode 100644 neo4j-app/neo4j_app/icij_worker/backend/__init__.py create mode 100644 neo4j-app/neo4j_app/icij_worker/backend/backend.py create mode 100644 neo4j-app/neo4j_app/icij_worker/backend/mp.py create mode 100644 neo4j-app/neo4j_app/icij_worker/typing_.py create mode 100644 neo4j-app/neo4j_app/icij_worker/utils/dependencies.py create mode 100644 neo4j-app/neo4j_app/tasks/dependencies.py rename neo4j-app/neo4j_app/tests/{core => app}/test_config.py (63%) delete mode 100644 neo4j-app/neo4j_app/tests/icij_worker/worker/worker_main.py diff --git a/neo4j-app/neo4j_app/__init__.py b/neo4j-app/neo4j_app/__init__.py index caceb2e4..56916fd6 100644 --- a/neo4j-app/neo4j_app/__init__.py +++ b/neo4j-app/neo4j_app/__init__.py @@ -1,3 +1,5 @@ from pathlib import Path +from neo4j_app.config import AppConfig + ROOT_DIR = Path(__file__).parent diff --git a/neo4j-app/neo4j_app/app/__init__.py b/neo4j-app/neo4j_app/app/__init__.py index e69de29b..48e90f7d 100644 --- a/neo4j-app/neo4j_app/app/__init__.py +++ b/neo4j-app/neo4j_app/app/__init__.py @@ -0,0 +1 @@ +from .config import ServiceConfig diff --git a/neo4j-app/neo4j_app/app/admin.py b/neo4j-app/neo4j_app/app/admin.py index 8eb8323d..b7cac6e4 100644 --- a/neo4j-app/neo4j_app/app/admin.py +++ b/neo4j-app/neo4j_app/app/admin.py @@ -6,7 +6,6 @@ from fastapi import APIRouter, Request from neo4j_app.app.dependencies import ( - lifespan_es_client, lifespan_neo4j_driver, ) from neo4j_app.app.doc import ( @@ -14,13 +13,14 @@ DOC_NEO4J_CSV, DOC_NEO4J_CSV_DESC, ) -from neo4j_app.core import AppConfig +from neo4j_app.app import ServiceConfig from neo4j_app.core.imports import to_neo4j_csvs from neo4j_app.core.objects import ( Neo4jCSVRequest, Neo4jCSVResponse, ) from neo4j_app.core.utils.logging import log_elapsed_time_cm +from neo4j_app.tasks.dependencies import lifespan_es_client logger = logging.getLogger(__name__) @@ -37,7 +37,7 @@ def admin_router() -> APIRouter: async def _neo4j_csv( project: str, payload: Neo4jCSVRequest, request: Request ) -> Neo4jCSVResponse: - config: AppConfig = request.app.state.config + config: ServiceConfig = request.app.state.config with log_elapsed_time_cm( logger, logging.INFO, "Exported ES to CSV in {elapsed_time} !" diff --git a/neo4j-app/neo4j_app/app/config.py b/neo4j-app/neo4j_app/app/config.py new file mode 100644 index 00000000..bb2a2ff5 --- /dev/null +++ b/neo4j-app/neo4j_app/app/config.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import configparser +import functools +import io +from copy import copy +from typing import Optional, TextIO + +from pydantic import Field + +from neo4j_app import AppConfig +from neo4j_app.core.utils.pydantic import ( + BaseICIJModel, +) +from neo4j_app.icij_worker import WorkerConfig, WorkerType + +_SHARED_WITH_NEO4J_WORKER_CONFIG = [ + "neo4j_connection_timeout", + "neo4j_host", + "neo4j_password", + "neo4j_port", + "neo4j_uri_scheme", + "neo4j_user", +] + +_SHARED_WITH_NEO4J_WORKER_CONFIG_PREFIXED = [ + "cancelled_tasks_refresh_interval_s", + "task_queue_poll_interval_s", + "log_level", +] + + +class ServiceConfig(AppConfig): + neo4j_app_async_dependencies: Optional[str] = "neo4j_app.tasks.WORKER_LIFESPAN_DEPS" + neo4j_app_async_app: str = "neo4j_app.tasks.app" + neo4j_app_gunicorn_workers: int = 1 + neo4j_app_host: str = "127.0.0.1" + neo4j_app_n_async_workers: int = 1 + neo4j_app_name: str = "🕸 neo4j app" + neo4j_app_port: int = 8080 + neo4j_app_task_queue_size: int = 2 + neo4j_app_worker_type: WorkerType = WorkerType.neo4j + test: bool = False + + @functools.cached_property + def doc_app_name(self) -> str: + return self.neo4j_app_name + + def to_worker_config(self, **kwargs) -> WorkerConfig: + kwargs = copy(kwargs) + for suffix in _SHARED_WITH_NEO4J_WORKER_CONFIG_PREFIXED: + kwargs[suffix] = getattr(self, f"neo4j_app_{suffix}") + + if self.test: + from neo4j_app.tests.icij_worker.conftest import MockWorkerConfig + + return MockWorkerConfig(**kwargs) + from neo4j_app.icij_worker.worker.neo4j import Neo4jWorkerConfig + + for k in _SHARED_WITH_NEO4J_WORKER_CONFIG: + if k in kwargs: + continue + kwargs[k] = getattr(self, k) + return Neo4jWorkerConfig(**kwargs) + + def write_java_properties(self, file: TextIO): + parser = self._get_config_parser() + parser[configparser.DEFAULTSECT] = dict( + sorted(self.dict(exclude_unset=True, by_alias=True).items()) + ) + config_str_io = io.StringIO() + parser.write(config_str_io, space_around_delimiters=False) + config_str = config_str_io.getvalue() + # Remove the mandatory default section + config_str = config_str.replace(f"[{configparser.DEFAULTSECT}]\n", "") + file.write(config_str) + + +class UviCornModel(BaseICIJModel): + host: str = Field(default="127.0.0.1", const=True) + port: int diff --git a/neo4j-app/neo4j_app/app/dependencies.py b/neo4j-app/neo4j_app/app/dependencies.py index 8c062ab1..03db2b7b 100644 --- a/neo4j-app/neo4j_app/app/dependencies.py +++ b/neo4j-app/neo4j_app/app/dependencies.py @@ -1,54 +1,75 @@ -import inspect import logging import multiprocessing import os -import sys import tempfile -import traceback -from contextlib import asynccontextmanager, contextmanager +from contextlib import asynccontextmanager from multiprocessing.managers import SyncManager from pathlib import Path -from typing import AsyncGenerator, List, Optional, cast +from typing import Optional, cast import neo4j from fastapi import FastAPI -from neo4j_app.core import AppConfig +from neo4j_app.app.config import ServiceConfig from neo4j_app.core.elasticsearch import ESClientABC -from neo4j_app.core.neo4j import MIGRATIONS, migrate_db_schemas -from neo4j_app.core.neo4j.migrations import delete_all_migrations -from neo4j_app.core.neo4j.projects import create_project_registry_db from neo4j_app.icij_worker import ( EventPublisher, Neo4jEventPublisher, ) +from neo4j_app.icij_worker.backend.backend import WorkerBackend from neo4j_app.icij_worker.task_manager import TaskManager from neo4j_app.icij_worker.task_manager.neo4j import Neo4JTaskManager +from neo4j_app.icij_worker.utils import run_deps +from neo4j_app.icij_worker.utils.dependencies import DependencyInjectionError +from neo4j_app.tasks.dependencies import ( + config_enter, + create_project_registry_db_enter, + es_client_enter, + es_client_exit, + lifespan_config, + lifespan_neo4j_driver, + migrate_app_db_enter, + neo4j_driver_enter, + neo4j_driver_exit, +) logger = logging.getLogger(__name__) -_CONFIG: Optional[AppConfig] = None +_ASYNC_APP_CONFIG_PATH: Optional[Path] = None _ES_CLIENT: Optional[ESClientABC] = None _EVENT_PUBLISHER: Optional[EventPublisher] = None -_PROCESS_MANAGER: Optional[SyncManager] = None +_MP_CONTEXT = None _NEO4J_DRIVER: Optional[neo4j.AsyncDriver] = None +_PROCESS_MANAGER: Optional[SyncManager] = None _TASK_MANAGER: Optional[TaskManager] = None _TEST_DB_FILE: Optional[Path] = None _TEST_LOCK: Optional[multiprocessing.Lock] = None -_WORKER_POOL: Optional[multiprocessing.Pool] = None -_MP_CONTEXT = None +_WORKER_POOL_IS_RUNNING = False -class DependencyInjectionError(RuntimeError): - def __init__(self, name: str): - msg = f"{name} was not injected" - super().__init__(msg) +def write_async_app_config_enter(**_): + config = lifespan_config() + config = cast(ServiceConfig, config) + global _ASYNC_APP_CONFIG_PATH + _, _ASYNC_APP_CONFIG_PATH = tempfile.mkstemp( + prefix="neo4j-worker-config", suffix=".properties" + ) + _ASYNC_APP_CONFIG_PATH = Path(_ASYNC_APP_CONFIG_PATH) + with _ASYNC_APP_CONFIG_PATH.open("w") as f: + config.write_java_properties(file=f) + logger.info("Loaded config %s", config.json(indent=2)) -def config_enter(config: AppConfig, **_): - global _CONFIG - _CONFIG = config - logger.info("Loaded config %s", config.json(indent=2)) +def write_async_app_config_exit(*_, **__): + path = _lifespan_async_app_config_path() + if path.exists(): + os.remove(path) + + +def _lifespan_async_app_config_path() -> Path: + if _ASYNC_APP_CONFIG_PATH is None: + raise DependencyInjectionError("async app config path") + return _ASYNC_APP_CONFIG_PATH def loggers_enter(**_): @@ -57,12 +78,6 @@ def loggers_enter(**_): logger.info("Loggers ready to log 💬") -def lifespan_config() -> AppConfig: - if _CONFIG is None: - raise DependencyInjectionError("config") - return cast(AppConfig, _CONFIG) - - def mp_context_enter(**__): global _MP_CONTEXT _MP_CONTEXT = multiprocessing.get_context("spawn") @@ -74,51 +89,11 @@ def lifespan_mp_context(): return _MP_CONTEXT -async def neo4j_driver_enter(**__): - global _NEO4J_DRIVER - _NEO4J_DRIVER = lifespan_config().to_neo4j_driver() - await _NEO4J_DRIVER.__aenter__() # pylint: disable=unnecessary-dunder-call - - logger.debug("pinging neo4j...") - async with _NEO4J_DRIVER.session(database=neo4j.SYSTEM_DATABASE) as sess: - await sess.run("CALL db.ping()") - logger.debug("neo4j driver is ready") - - -async def neo4j_driver_exit(exc_type, exc_value, trace): - already_closed = False - try: - await _NEO4J_DRIVER.verify_connectivity() - except: # pylint: disable=bare-except - already_closed = True - if not already_closed: - await _NEO4J_DRIVER.__aexit__(exc_type, exc_value, trace) - - -def lifespan_neo4j_driver() -> neo4j.AsyncDriver: - if _NEO4J_DRIVER is None: - raise DependencyInjectionError("neo4j driver") - return cast(neo4j.AsyncDriver, _NEO4J_DRIVER) - - -async def es_client_enter(**_): - global _ES_CLIENT - _ES_CLIENT = lifespan_config().to_es_client() - await _ES_CLIENT.__aenter__() # pylint: disable=unnecessary-dunder-call - - -async def es_client_exit(exc_type, exc_value, trace): - await _ES_CLIENT.__aexit__(exc_type, exc_value, trace) - - -def lifespan_es_client() -> ESClientABC: - if _ES_CLIENT is None: - raise DependencyInjectionError("es client") - return cast(ESClientABC, _ES_CLIENT) - - def test_db_path_enter(**_): - config = lifespan_config() + config = cast( + ServiceConfig, + lifespan_config(), + ) if config.test: # pylint: disable=consider-using-with from neo4j_app.tests.icij_worker.conftest import DBMixin @@ -157,7 +132,10 @@ def lifespan_test_process_manager() -> SyncManager: def _test_lock_enter(**_): - config = lifespan_config() + config = cast( + ServiceConfig, + lifespan_config(), + ) if config.test: global _TEST_LOCK _TEST_LOCK = lifespan_test_process_manager().Lock() @@ -169,50 +147,16 @@ def _lifespan_test_lock() -> multiprocessing.Lock: return cast(multiprocessing.Lock, _TEST_LOCK) -def worker_pool_enter(**_): - # pylint: disable=consider-using-with - global _WORKER_POOL - process_id = os.getpid() - config = lifespan_config() - n_workers = min(config.neo4j_app_n_async_workers, config.neo4j_app_task_queue_size) - n_workers = max(1, n_workers) - # TODO: let the process choose they ID and set it with the worker process ID, - # this will help debugging - worker_ids = [f"worker-{process_id}-{i}" for i in range(n_workers)] - _WORKER_POOL = lifespan_mp_context().Pool( - processes=config.neo4j_app_n_async_workers, maxtasksperchild=1 - ) - _WORKER_POOL.__enter__() # pylint: disable=unnecessary-dunder-call - kwargs = dict() - worker_cls = config.to_worker_cls() - if worker_cls.__name__ == "MockWorker": - kwargs = {"db_path": _lifespan_test_db_path(), "lock": _lifespan_test_lock()} - - kwargs["config"] = config - for w_id in worker_ids: - kwargs.update({"worker_id": w_id}) - logger.info("starting worker %s", w_id) - _WORKER_POOL.apply_async(worker_cls.work_forever_from_config, kwds=kwargs) - - logger.info("worker pool ready !") - - -def worker_pool_exit(exc_type, exc_value, trace): - # pylint: disable=unused-argument - pool = lifespan_worker_pool() - pool.__exit__(exc_type, exc_value, trace) - logger.debug("async worker pool has shut down !") - - -def lifespan_worker_pool() -> multiprocessing.Pool: - if _WORKER_POOL is None: - raise DependencyInjectionError("worker pool") - return cast(multiprocessing.Pool, _WORKER_POOL) +def lifespan_worker_pool_is_running() -> bool: + return _WORKER_POOL_IS_RUNNING def task_manager_enter(**_): global _TASK_MANAGER - config = lifespan_config() + config = cast( + ServiceConfig, + lifespan_config(), + ) if config.test: from neo4j_app.tests.icij_worker.conftest import MockManager @@ -235,7 +179,10 @@ def lifespan_task_manager() -> TaskManager: def event_publisher_enter(**_): global _EVENT_PUBLISHER - config = lifespan_config() + config = cast( + ServiceConfig, + lifespan_config(), + ) if config.test: from neo4j_app.tests.icij_worker.conftest import MockEventPublisher @@ -246,35 +193,49 @@ def event_publisher_enter(**_): _EVENT_PUBLISHER = Neo4jEventPublisher(lifespan_neo4j_driver()) -async def create_project_registry_db_enter(**_): - driver = lifespan_neo4j_driver() - await create_project_registry_db(driver) - - -async def migrate_app_db_enter(config: AppConfig): - logger.info("Running schema migrations...") - driver = lifespan_neo4j_driver() - if config.force_migrations: - # TODO: improve this as is could lead to race conditions... - logger.info("Deleting all previous migrations...") - await delete_all_migrations(driver) - await migrate_db_schemas( - driver, - registry=MIGRATIONS, - timeout_s=config.neo4j_app_migration_timeout_s, - throttle_s=config.neo4j_app_migration_throttle_s, - ) - - def lifespan_event_publisher() -> EventPublisher: if _EVENT_PUBLISHER is None: raise DependencyInjectionError("event publisher") return cast(EventPublisher, _EVENT_PUBLISHER) +@asynccontextmanager +async def run_app_deps(app: FastAPI): + config = app.state.config + n_workers = config.neo4j_app_n_async_workers + async with run_deps( + dependencies=FASTAPI_LIFESPAN_DEPS, ctx="FastAPI HTTP server", config=config + ): + app.state.config = await config.with_neo4j_support() + worker_extras = {"teardown_dependencies": config.test} + config_extra = dict() + # Forward the past of the app config to load to the async app + async_app_extras = {"config_path": _lifespan_async_app_config_path()} + if config.test: + config_extra["db_path"] = _lifespan_test_db_path() + worker_extras["lock"] = _lifespan_test_lock() + worker_config = config.to_worker_config(**config_extra) + with WorkerBackend.MULTIPROCESSING.run_cm( + config.neo4j_app_async_app, + n_workers=n_workers, + config=worker_config, + worker_extras=worker_extras, + app_deps_extras=async_app_extras, + ): + global _WORKER_POOL_IS_RUNNING + _WORKER_POOL_IS_RUNNING = True + yield + _WORKER_POOL_IS_RUNNING = False + + FASTAPI_LIFESPAN_DEPS = [ ("configuration reading", config_enter, None), ("loggers setup", loggers_enter, None), + ( + "write async config for workers", + write_async_app_config_enter, + write_async_app_config_exit, + ), ("neo4j driver creation", neo4j_driver_enter, neo4j_driver_exit), ("neo4j project registry creation", create_project_registry_db_enter, None), ("ES client creation", es_client_enter, es_client_exit), @@ -284,71 +245,5 @@ def lifespan_event_publisher() -> EventPublisher: (None, _test_lock_enter, None), ("task manager creation", task_manager_enter, None), ("event publisher creation", event_publisher_enter, None), - ("async worker pool creation", worker_pool_enter, worker_pool_exit), ("neo4j DB migration", migrate_app_db_enter, None), ] - - -@contextmanager -def _log_exceptions(): - try: - yield - except Exception as exc: - from neo4j_app.app.utils import INTERNAL_SERVER_ERROR - - title = INTERNAL_SERVER_ERROR - detail = f"{type(exc).__name__}: {exc}" - trace = "".join(traceback.format_exc()) - logger.error("%s\nDetail: %s\nTrace: %s", title, detail, trace) - - raise exc - - -@asynccontextmanager -async def run_app_deps(app: FastAPI, dependencies: List) -> AsyncGenerator[None, None]: - async with run_deps(dependencies, config=app.state.config): - app.state.config = await app.state.config.with_neo4j_support() - yield - - -@asynccontextmanager -async def run_deps(dependencies: List, **kwargs) -> AsyncGenerator[None, None]: - to_close = [] - original_ex = None - try: - with _log_exceptions(): - logger.info("applying dependencies...") - for name, enter_fn, exit_fn in dependencies: - if enter_fn is not None: - if name is not None: - logger.debug("applying: %s", name) - if inspect.iscoroutinefunction(enter_fn): - await enter_fn(**kwargs) - else: - enter_fn(**kwargs) - to_close.append((name, exit_fn)) - yield - except Exception as e: # pylint: disable=broad-exception-caught - original_ex = e - finally: - to_raise = [] - if original_ex is not None: - to_raise.append(original_ex) - logger.info("rolling back dependencies...") - for name, exit_fn in to_close[::-1]: - if exit_fn is None: - continue - try: - if name is not None: - logger.debug("rolling back %s", name) - exc_info = sys.exc_info() - with _log_exceptions(): - if inspect.iscoroutinefunction(exit_fn): - await exit_fn(*exc_info) - else: - exit_fn(*exc_info) - except Exception as e: # pylint: disable=broad-exception-caught - to_raise.append(e) - logger.debug("rolled back all dependencies !") - if to_raise: - raise RuntimeError(to_raise) diff --git a/neo4j-app/neo4j_app/app/documents.py b/neo4j-app/neo4j_app/app/documents.py index 1b840c62..daacd47b 100644 --- a/neo4j-app/neo4j_app/app/documents.py +++ b/neo4j-app/neo4j_app/app/documents.py @@ -2,12 +2,13 @@ from fastapi import APIRouter, Request -from neo4j_app.app.dependencies import lifespan_es_client, lifespan_neo4j_driver +from neo4j_app.app import ServiceConfig +from neo4j_app.app.dependencies import lifespan_neo4j_driver from neo4j_app.app.doc import DOCUMENT_TAG, DOC_IMPORT_DESC, DOC_IMPORT_SUM -from neo4j_app.core import AppConfig from neo4j_app.core.imports import import_documents from neo4j_app.core.objects import IncrementalImportRequest, IncrementalImportResponse from neo4j_app.core.utils.logging import log_elapsed_time_cm +from neo4j_app.tasks.dependencies import lifespan_es_client logger = logging.getLogger(__name__) @@ -26,7 +27,7 @@ async def _import_documents( payload: IncrementalImportRequest, request: Request, ) -> IncrementalImportResponse: - config: AppConfig = request.app.state.config + config: ServiceConfig = request.app.state.config with log_elapsed_time_cm( logger, logging.INFO, "Imported documents in {elapsed_time} !" ): diff --git a/neo4j-app/neo4j_app/app/graphs.py b/neo4j-app/neo4j_app/app/graphs.py index 1aab6d2f..1ac98bbb 100644 --- a/neo4j-app/neo4j_app/app/graphs.py +++ b/neo4j-app/neo4j_app/app/graphs.py @@ -3,9 +3,9 @@ from fastapi import APIRouter, Request from starlette.responses import StreamingResponse +from neo4j_app.app import ServiceConfig from neo4j_app.app.dependencies import lifespan_neo4j_driver from neo4j_app.app.doc import DOC_GRAPH_DUMP, DOC_GRAPH_DUMP_DESC, GRAPH_TAG -from neo4j_app.core import AppConfig from neo4j_app.core.neo4j.graphs import count_documents_and_named_entities, dump_graph from neo4j_app.core.objects import DumpRequest, GraphCounts from neo4j_app.core.utils.logging import log_elapsed_time_cm @@ -27,7 +27,7 @@ async def _graph_dump( payload: DumpRequest, request: Request, ) -> StreamingResponse: - config: AppConfig = request.app.state.config + config: ServiceConfig = request.app.state.config if config.supports_neo4j_parallel_runtime is None: msg = ( "parallel support has not been set, config has not been properly" @@ -53,7 +53,7 @@ async def _graph_dump( async def _count_documents_and_named_entities( project: str, request: Request ) -> GraphCounts: - config: AppConfig = request.app.state.config + config: ServiceConfig = request.app.state.config if config.supports_neo4j_parallel_runtime is None: msg = ( "parallel support has not been set, config has not been properly" diff --git a/neo4j-app/neo4j_app/app/main.py b/neo4j-app/neo4j_app/app/main.py index 77b6c641..457b4daa 100644 --- a/neo4j-app/neo4j_app/app/main.py +++ b/neo4j-app/neo4j_app/app/main.py @@ -4,15 +4,14 @@ from neo4j.exceptions import DriverError from starlette.requests import Request +from neo4j_app.app import ServiceConfig from neo4j_app.app.dependencies import ( DependencyInjectionError, - lifespan_es_client, - lifespan_neo4j_driver, lifespan_task_manager, - lifespan_worker_pool, + lifespan_worker_pool_is_running, ) from neo4j_app.app.doc import OTHER_TAG -from neo4j_app.core import AppConfig +from neo4j_app.tasks.dependencies import lifespan_es_client, lifespan_neo4j_driver def main_router() -> APIRouter: @@ -25,13 +24,15 @@ async def ping() -> str: await driver.verify_connectivity() lifespan_es_client() lifespan_task_manager() - lifespan_worker_pool() + lifespan_worker_pool_is_running() except (DriverError, DependencyInjectionError) as e: raise HTTPException(503, detail="Service Unavailable") from e return "pong" - @router.get("/config", response_model=AppConfig, response_model_exclude_unset=True) - async def config(request: Request) -> AppConfig: + @router.get( + "/config", response_model=ServiceConfig, response_model_exclude_unset=True + ) + async def config(request: Request) -> ServiceConfig: if ( request.app.state.config.supports_neo4j_enterprise is None or request.app.state.config.supports_neo4j_parallel_runtime is None diff --git a/neo4j-app/neo4j_app/app/named_entities.py b/neo4j-app/neo4j_app/app/named_entities.py index 62312bd4..5129d6f8 100644 --- a/neo4j-app/neo4j_app/app/named_entities.py +++ b/neo4j-app/neo4j_app/app/named_entities.py @@ -2,12 +2,13 @@ from fastapi import APIRouter, Request -from neo4j_app.app.dependencies import lifespan_es_client, lifespan_neo4j_driver +from neo4j_app.app import ServiceConfig +from neo4j_app.app.dependencies import lifespan_neo4j_driver from neo4j_app.app.doc import NE_IMPORT_DESC, NE_IMPORT_SUM, NE_TAG -from neo4j_app.core import AppConfig from neo4j_app.core.imports import import_named_entities from neo4j_app.core.objects import IncrementalImportRequest, IncrementalImportResponse from neo4j_app.core.utils.logging import log_elapsed_time_cm +from neo4j_app.tasks.dependencies import lifespan_es_client logger = logging.getLogger(__name__) @@ -26,7 +27,7 @@ async def _import_named_entities( payload: IncrementalImportRequest, request: Request, ) -> IncrementalImportResponse: - config: AppConfig = request.app.state.config + config: ServiceConfig = request.app.state.config with log_elapsed_time_cm( logger, logging.INFO, "Imported named entities in {elapsed_time} !" ): diff --git a/neo4j-app/neo4j_app/app/utils.py b/neo4j-app/neo4j_app/app/utils.py index 106ae942..aca9271c 100644 --- a/neo4j-app/neo4j_app/app/utils.py +++ b/neo4j-app/neo4j_app/app/utils.py @@ -1,4 +1,3 @@ -import functools import logging import traceback from typing import Dict, Iterable, List, Optional @@ -12,8 +11,9 @@ from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.responses import JSONResponse, Response +from neo4j_app.app import ServiceConfig from neo4j_app.app.admin import admin_router -from neo4j_app.app.dependencies import FASTAPI_LIFESPAN_DEPS, run_app_deps +from neo4j_app.app.dependencies import run_app_deps from neo4j_app.app.doc import DOCUMENT_TAG, NE_TAG, OTHER_TAG from neo4j_app.app.documents import documents_router from neo4j_app.app.graphs import graphs_router @@ -21,8 +21,7 @@ from neo4j_app.app.named_entities import named_entities_router from neo4j_app.app.projects import projects_router from neo4j_app.app.tasks import tasks_router -from neo4j_app.core import AppConfig -from neo4j_app.icij_worker import ICIJApp +from neo4j_app.icij_worker import AsyncApp INTERNAL_SERVER_ERROR = "Internal Server Error" _REQUEST_VALIDATION_ERROR = "Request Validation Error" @@ -84,18 +83,14 @@ def _debug(): logger.info("im here") -def create_app(config: AppConfig, async_app: Optional[ICIJApp] = None) -> FastAPI: +def create_app(config: ServiceConfig, async_app: Optional[AsyncApp] = None) -> FastAPI: app = FastAPI( title=config.doc_app_name, openapi_tags=_make_open_api_tags([DOCUMENT_TAG, NE_TAG, OTHER_TAG]), - lifespan=functools.partial(run_app_deps, dependencies=FASTAPI_LIFESPAN_DEPS), + lifespan=run_app_deps, ) app.state.config = config if async_app is not None: - if async_app.config is not None and async_app.config is not config: - msg = f"HTTP app async app must share the same {AppConfig.__name__}" - raise ValueError(msg) - async_app.config = config app.state.async_app = async_app app.add_exception_handler(RequestValidationError, request_validation_error_handler) app.add_exception_handler(StarletteHTTPException, http_exception_handler) diff --git a/neo4j-app/neo4j_app/core/config.py b/neo4j-app/neo4j_app/config.py similarity index 70% rename from neo4j-app/neo4j_app/core/config.py rename to neo4j-app/neo4j_app/config.py index d3646a66..c554f17e 100644 --- a/neo4j-app/neo4j_app/core/config.py +++ b/neo4j-app/neo4j_app/config.py @@ -1,13 +1,10 @@ from __future__ import annotations import configparser -import importlib import logging import sys from configparser import ConfigParser -from enum import Enum, unique -from logging.handlers import SysLogHandler -from typing import Callable, Dict, List, Optional, TextIO, Tuple, Type, Union +from typing import Dict, List, Optional, TextIO, Union import neo4j from pydantic import Field, validator @@ -23,33 +20,11 @@ WorkerIdFilter, ) from neo4j_app.core.utils.pydantic import ( - BaseICIJModel, IgnoreExtraModel, LowerCamelCaseModel, safe_copy, ) -_SYSLOG_MODEL_SPLIT_CHAR = "@" -_SYSLOG_FMT = f"%(name)s{_SYSLOG_MODEL_SPLIT_CHAR}%(message)s" - - -@unique -class WorkerType(str, Enum): - MOCK = "MOCK" - NEO4J = "NEO4J" - - @property - def as_worker_cls(self) -> Type["Worker"]: - if self is WorkerType.NEO4J: - from neo4j_app.icij_worker import Neo4jAsyncWorker - - return Neo4jAsyncWorker - if self is WorkerType.MOCK: - from neo4j_app.tests.icij_worker.conftest import MockWorker - - return MockWorker - raise NotImplementedError(f"as_worker_cls not implemented for {self}") - def _es_version() -> str: import elasticsearch @@ -58,7 +33,6 @@ def _es_version() -> str: class AppConfig(LowerCamelCaseModel, IgnoreExtraModel): - doc_app_name: str = "🕸 neo4j app" elasticsearch_address: str = "http://127.0.0.1:9200" elasticsearch_version: str = Field(default_factory=_es_version, const=True) es_doc_type_field: str = Field(alias="docTypeField", default="type") @@ -69,11 +43,8 @@ class AppConfig(LowerCamelCaseModel, IgnoreExtraModel): es_timeout_s: Union[int, float] = 60 * 5 es_keep_alive: str = "1m" force_migrations: bool = False - neo4j_app_async_app: str = "neo4j_app.tasks.app" - neo4j_app_async_dependencies: Optional[str] = "neo4j_app.tasks.WORKER_LIFESPAN_DEPS" - neo4j_app_host: str = "127.0.0.1" - neo4j_app_gunicorn_workers: int = 1 neo4j_app_log_level: str = "INFO" + neo4j_app_cancelled_task_refresh_interval_s: int = 2 neo4j_app_log_in_json: bool = False neo4j_app_max_dumped_documents: Optional[int] = None neo4j_app_max_records_in_memory: int = int(1e6) @@ -82,12 +53,8 @@ class AppConfig(LowerCamelCaseModel, IgnoreExtraModel): neo4j_app_n_async_workers: int = 1 neo4j_app_name: str = "neo4j app" neo4j_app_port: int = 8080 - neo4j_app_syslog_facility: Optional[str] = None - neo4j_app_task_queue_size: int = 2 neo4j_app_task_queue_poll_interval_s: int = 1.0 - neo4j_app_cancelled_task_refresh_interval_s: int = 2 neo4j_app_uses_opensearch: bool = False - neo4j_app_worker_type: WorkerType = WorkerType.NEO4J neo4j_concurrency: int = 2 neo4j_connection_timeout: float = 5.0 neo4j_host: str = "127.0.0.1" @@ -101,11 +68,6 @@ class AppConfig(LowerCamelCaseModel, IgnoreExtraModel): neo4j_uri_scheme: str = "bolt" supports_neo4j_enterprise: Optional[bool] = None supports_neo4j_parallel_runtime: Optional[bool] = None - test: bool = False - - # Ugly but hard to do differently if we want to avoid to retrieve the config on a - # per request basis using FastApi dependencies... - _global: Optional[AppConfig] = None @validator("neo4j_import_batch_size") def neo4j_import_batch_size_must_be_less_than_max_records_in_memory( @@ -129,7 +91,7 @@ def neo4j_user_and_password_xor(cls, v, values): # pylint: disable=no-self-argu return v @classmethod - def from_java_properties(cls, file: TextIO, **kwargs) -> AppConfig: + def _get_config_parser(cls) -> ConfigParser: parser = ConfigParser( allow_no_value=True, strict=True, @@ -139,24 +101,23 @@ def from_java_properties(cls, file: TextIO, **kwargs) -> AppConfig: ) # Let's avoid lower-casing the keys parser.optionxform = str + return parser + + @classmethod + def from_java_properties(cls, file: TextIO, **kwargs) -> AppConfig: + parser = cls._get_config_parser() # Config need a section, let's fake one section_name = configparser.DEFAULTSECT section_str = f"""[{section_name}] - {file.read()} - """ + {file.read()} + """ parser.read_string(section_str) config_dict = dict(parser[section_name].items()) config_dict.update(kwargs) config_dict = _sanitize_values(config_dict) - config = AppConfig.parse_obj(config_dict.items()) + config = cls.parse_obj(config_dict.items()) return config - @classmethod - def set_config_globally(cls, value: AppConfig): - if cls._global is not None: - raise ValueError("Can't set config globally twice") - cls._global = value - @property def neo4j_uri(self) -> str: return f"{self.neo4j_uri_scheme}://{self.neo4j_host}:{self.neo4j_port}" @@ -190,9 +151,6 @@ def to_es_client(self) -> ESClientABC: ) return client - def to_worker_cls(self) -> Type["Worker"]: - return WorkerType[self.neo4j_app_worker_type].as_worker_cls - async def with_neo4j_support(self) -> AppConfig: async with self.to_neo4j_driver() as neo4j_driver: # pylint: disable=not-async-context-manager enterprise_support = await is_enterprise(neo4j_driver) @@ -248,53 +206,12 @@ def _handlers( fmt = logging.Formatter(fmt, DATE_FMT) stream_handler.setFormatter(fmt) handlers = [stream_handler] - if self.neo4j_app_syslog_facility is not None: - syslog_handler = SysLogHandler( - facility=self._neo4j_app_syslog_facility_int, - ) - syslog_handler.setFormatter(logging.Formatter(_SYSLOG_FMT)) - handlers.append(syslog_handler) for handler in handlers: if worker_id_filter is not None: handler.addFilter(worker_id_filter) handler.setLevel(self.neo4j_app_log_level) return handlers - @property - def _neo4j_app_syslog_facility_int(self) -> int: - try: - return getattr( - SysLogHandler, f"LOG_{self.neo4j_app_syslog_facility.upper()}" - ) - except AttributeError as e: - msg = f"Invalid syslog facility {self.neo4j_app_syslog_facility}" - raise ValueError(msg) from e - - def to_async_app(self): - app_path = self.neo4j_app_async_app.split(".") - module, app_name = app_path[:-1], app_path[-1] - module = ".".join(module) - module = importlib.import_module(module) - app = getattr(module, app_name) - app.config = self - return app - - def to_async_deps(self) -> List[Tuple[Callable, Callable]]: - deps_path = self.neo4j_app_async_dependencies - if deps_path is None: - return [] - deps_path = deps_path.split(".") - module, app_name = deps_path[:-1], deps_path[-1] - module = ".".join(module) - module = importlib.import_module(module) - deps = getattr(module, app_name) - return deps - - -class UviCornModel(BaseICIJModel): - host: str = Field(default="127.0.0.1", const=True) - port: int - def _sanitize_values(java_config: Dict[str, str]) -> Dict[str, str]: return { diff --git a/neo4j-app/neo4j_app/core/__init__.py b/neo4j-app/neo4j_app/core/__init__.py index e2d5ab00..e69de29b 100644 --- a/neo4j-app/neo4j_app/core/__init__.py +++ b/neo4j-app/neo4j_app/core/__init__.py @@ -1 +0,0 @@ -from .config import AppConfig, UviCornModel diff --git a/neo4j-app/neo4j_app/core/imports.py b/neo4j-app/neo4j_app/core/imports.py index f1931f64..97e3a9c5 100644 --- a/neo4j-app/neo4j_app/core/imports.py +++ b/neo4j-app/neo4j_app/core/imports.py @@ -28,7 +28,6 @@ import neo4j from datrie import BaseTrie -from neo4j_app import ROOT_DIR from neo4j_app.constants import ( DOC_COLUMNS, DOC_CREATED_AT, @@ -878,6 +877,8 @@ def _ne_trie_key(ne: Dict) -> str: def _compress_csvs_destructively( export_dir: Path, metadata: Neo4jCSVs, *, targz_path: Path ): + from neo4j_app import ROOT_DIR + with tarfile.open(targz_path, "w:gz") as tar: # Index json_index = json.dumps(metadata.dict(by_alias=True)).encode() diff --git a/neo4j-app/neo4j_app/core/utils/logging.py b/neo4j-app/neo4j_app/core/utils/logging.py index 5110d882..a9ff90fe 100644 --- a/neo4j-app/neo4j_app/core/utils/logging.py +++ b/neo4j-app/neo4j_app/core/utils/logging.py @@ -2,7 +2,7 @@ import contextlib import logging -from abc import ABC, abstractmethod +from abc import ABC from datetime import datetime from functools import wraps from typing import Optional, final @@ -26,15 +26,8 @@ def __str__(self): class LogWithNameMixin(ABC): - @property - @abstractmethod - def _logger(self) -> logging.Logger: - pass - - @property - @abstractmethod - def logged_named(self) -> str: - pass + def __init__(self, logger: logging.Logger): + self._logger = logger @final def info(self, msg, *args, **kwargs): diff --git a/neo4j-app/neo4j_app/icij_worker/__init__.py b/neo4j-app/neo4j_app/icij_worker/__init__.py index da2abf48..548ae00a 100644 --- a/neo4j-app/neo4j_app/icij_worker/__init__.py +++ b/neo4j-app/neo4j_app/icij_worker/__init__.py @@ -1,4 +1,6 @@ -from .app import ICIJApp +from .app import AsyncApp from .task import Task, TaskError, TaskEvent, TaskResult, TaskStatus -from .worker import Worker, Neo4jAsyncWorker, WorkerConfig +from .worker import Worker, WorkerConfig, WorkerType +from .worker.neo4j import Neo4jWorker +from .backend import WorkerBackend from .event_publisher import EventPublisher, Neo4jEventPublisher diff --git a/neo4j-app/neo4j_app/icij_worker/app.py b/neo4j-app/neo4j_app/icij_worker/app.py index d8db074c..138d4c5e 100644 --- a/neo4j-app/neo4j_app/icij_worker/app.py +++ b/neo4j-app/neo4j_app/icij_worker/app.py @@ -1,11 +1,16 @@ +from __future__ import annotations import functools +import importlib +from contextlib import asynccontextmanager from functools import cached_property -from typing import Callable, Dict, Optional, Tuple, Type +from typing import Callable, Dict, List, Optional, Tuple, Type, final from pydantic import Field -from neo4j_app.core.config import AppConfig from neo4j_app.core.utils.pydantic import BaseICIJModel +from neo4j_app.icij_worker.exceptions import UnknownApp +from neo4j_app.icij_worker.typing_ import Dependency +from neo4j_app.icij_worker.utils.dependencies import run_deps class RegisteredTask(BaseICIJModel): @@ -15,24 +20,18 @@ class RegisteredTask(BaseICIJModel): max_retries: Optional[int] = Field(const=True, default=None) -class ICIJApp: - def __init__(self, name: str, config: Optional[AppConfig] = None): +class AsyncApp: + def __init__(self, name: str, dependencies: Optional[List[Dependency]] = None): self._name = name - self._config = config self._registry = dict() + if dependencies is None: + dependencies = [] + self._dependencies = dependencies @cached_property def registry(self) -> Dict[str, RegisteredTask]: return self._registry - @property - def config(self) -> Optional[AppConfig]: - return self._config - - @config.setter - def config(self, value: AppConfig): - self._config = value - @property def name(self) -> str: return self._name @@ -53,6 +52,13 @@ def task( max_retries=max_retries, ) + @final + @asynccontextmanager + async def lifetime_dependencies(self, **kwargs): + ctx = f"{self.name} async app" + async with run_deps(self._dependencies, ctx=ctx, **kwargs): + yield + def _register_task( self, f: Callable, @@ -75,3 +81,17 @@ def wrapped(*args, **kwargs): return f(*args, **kwargs) return wrapped + + @classmethod + def load(cls, app_path: str) -> AsyncApp: + app_path = app_path.split(".") + module, app_name = app_path[:-1], app_path[-1] + module = ".".join(module) + try: + module = importlib.import_module(module) + except ModuleNotFoundError as e: + msg = f'Expected app_path to be the fully qualified path to a \ + {AsyncApp.__name__} instance "my_module.my_app_instance"' + raise UnknownApp(msg) from e + app = getattr(module, app_name) + return app diff --git a/neo4j-app/neo4j_app/icij_worker/backend/__init__.py b/neo4j-app/neo4j_app/icij_worker/backend/__init__.py new file mode 100644 index 00000000..570ffc8a --- /dev/null +++ b/neo4j-app/neo4j_app/icij_worker/backend/__init__.py @@ -0,0 +1 @@ +from .backend import start_workers, WorkerBackend diff --git a/neo4j-app/neo4j_app/icij_worker/backend/backend.py b/neo4j-app/neo4j_app/icij_worker/backend/backend.py new file mode 100644 index 00000000..6edddf45 --- /dev/null +++ b/neo4j-app/neo4j_app/icij_worker/backend/backend.py @@ -0,0 +1,98 @@ +from contextlib import contextmanager +from enum import Enum +from pathlib import Path +from typing import Dict, Optional + +from neo4j_app.icij_worker import WorkerConfig +from neo4j_app.icij_worker.backend.mp import run_workers_with_multiprocessing + + +class WorkerBackend(str, Enum): + # pylint: disable=invalid-name + + # We could support more backend type, and for instance support asyncio/thread backed + # workers for IO based tasks + MULTIPROCESSING = "multiprocessing" + + def run( + self, + app: str, + n_workers: int, + config: WorkerConfig, + worker_extras: Optional[Dict] = None, + app_deps_extras: Optional[Dict] = None, + ): + # This function is meant to be run as the main function of a Python command, + # in this case we want th main process to handle signals + with self._run_cm( + app, + n_workers, + config, + handle_signals=True, + worker_extras=worker_extras, + app_deps_extras=app_deps_extras, + ): + pass + + # TODO: remove this when the HTTP server doesn't + # TODO: also refactor underlying functions to be simple function rather than + # context managers + @contextmanager + def run_cm( + self, + app: str, + n_workers: int, + config: WorkerConfig, + worker_extras: Optional[Dict] = None, + app_deps_extras: Optional[Dict] = None, + ): + # This usage is meant for when a backend is run from another process which + # handles signals by itself + with self._run_cm( + app, + n_workers, + config, + handle_signals=False, + worker_extras=worker_extras, + app_deps_extras=app_deps_extras, + ): + yield + + @contextmanager + def _run_cm( + self, + app: str, + n_workers: int, + config: WorkerConfig, + *, + handle_signals: bool = False, + worker_extras: Optional[Dict] = None, + app_deps_extras: Optional[Dict] = None, + ): + if self is WorkerBackend.MULTIPROCESSING: + with run_workers_with_multiprocessing( + app, + n_workers, + config, + handle_signals=handle_signals, + worker_extras=worker_extras, + app_deps_extras=app_deps_extras, + ): + yield + else: + raise NotImplementedError(f"Can't start workers with backend: {self}") + + +def start_workers( + app: str, + n_workers: int, + config_path: Optional[Path], + backend: WorkerBackend, +): + if n_workers < 1: + raise ValueError("n_workers must be >= 1") + if config_path is not None: + config = WorkerConfig.parse_file(config_path) + else: + config = WorkerConfig() + backend.run(app, n_workers=n_workers, config=config) diff --git a/neo4j-app/neo4j_app/icij_worker/backend/mp.py b/neo4j-app/neo4j_app/icij_worker/backend/mp.py new file mode 100644 index 00000000..4a0298c9 --- /dev/null +++ b/neo4j-app/neo4j_app/icij_worker/backend/mp.py @@ -0,0 +1,112 @@ +import functools +import logging +import multiprocessing +import os +import signal +import sys +from contextlib import contextmanager +from typing import Dict, Optional + +from neo4j_app.icij_worker import AsyncApp, Worker, WorkerConfig + +logger = logging.getLogger(__name__) + +_HANDLED_SIGNALS = [signal.SIGTERM, signal.SIGINT] +if sys.platform == "win32": + _HANDLED_SIGNALS += [signal.CTRL_C_EVENT, signal.CTRL_BREAK_EVENT] + + +def _mp_work_forever( + app: str, + config: WorkerConfig, + worker_id: str, + *, + worker_extras: Optional[Dict] = None, + app_deps_extras: Optional[Dict] = None, +): + if app_deps_extras is None: + app_deps_extras = dict() + if worker_extras is None: + worker_extras = dict() + # For multiprocessing, lifespan dependencies need to be run once per process + app = AsyncApp.load(app) + deps_cm = app.lifetime_dependencies(worker_id=worker_id, **app_deps_extras) + worker = Worker.from_config(config, app=app, worker_id=worker_id, **worker_extras) + # This is ugly, but we have to work around the fact that we can't use asyncio code + # here + worker.loop.run_until_complete( + deps_cm.__aenter__() # pylint: disable=unnecessary-dunder-call + ) + try: + worker.work_forever() + finally: + worker.info("worker stopped working, tearing down %s dependencies", app.name) + worker.loop.run_until_complete(deps_cm.__aexit__(*sys.exc_info())) + + +def signal_handler(sig_num, *_, pool: multiprocessing.Pool): + logger.error( + "received %s, triggering process pool worker shutdown !", + signal.Signals(sig_num).name, + ) + logger.info("Sending termination signal to workers (SIGTERM)...") + pool.terminate() + pool.join() + + +def setup_main_process_signal_handlers(pool: multiprocessing.Pool): + handler = functools.partial(signal_handler, pool=pool) + for s in _HANDLED_SIGNALS: + signal.signal(s, handler) + + +@contextmanager +def run_workers_with_multiprocessing( + app: str, + n_workers: int, + config: WorkerConfig, + *, + handle_signals: bool = True, + worker_extras: Optional[Dict] = None, + app_deps_extras: Optional[Dict] = None, +): + logger.info("Creating multiprocessing worker pool with %s workers", n_workers) + # Here we set maxtasksperchild to 1. Each worker has a single never ending task + # which consists in working forever. Additionally, in some cases using + # maxtasksperchild=1 seems to help to terminate the worker pull + # (cpython bug: https://github.com/python/cpython/pull/8009) + mp_ctx = multiprocessing.get_context("spawn") + main_process_id = os.getpid() + # TODO: make this a bit more informative be for instance adding the child process ID + worker_ids = [f"worker-{main_process_id}-{i}" for i in range(n_workers)] + kwds = {"app": app, "config": config} + kwds["worker_extras"] = worker_extras + kwds["app_deps_extras"] = app_deps_extras + pool = mp_ctx.Pool(n_workers, maxtasksperchild=1) + logger.debug("Setting up signal handlers...") + tasks = [] + if handle_signals: + setup_main_process_signal_handlers(pool) + try: + for w_id in worker_ids: + kwds.update({"worker_id": w_id}) + logger.info("starting worker %s", w_id) + tasks.append(pool.apply_async(_mp_work_forever, kwds=kwds)) + yield + except KeyboardInterrupt as e: + if not handle_signals: + logger.info( + "received %s, triggering process pool worker shutdown !", + KeyboardInterrupt.__name__, + ) + else: + msg = ( + f"Received {KeyboardInterrupt.__name__} while SIGINT was expected to" + f" be handled" + ) + raise RuntimeError(msg) from e + finally: + logger.info("Sending termination signal to workers (SIGTERM)...") + pool.terminate() + pool.join() # Wait forever + logger.info("Terminated worker pool !") diff --git a/neo4j-app/neo4j_app/icij_worker/exceptions.py b/neo4j-app/neo4j_app/icij_worker/exceptions.py index e01c0f86..7d073e00 100644 --- a/neo4j-app/neo4j_app/icij_worker/exceptions.py +++ b/neo4j-app/neo4j_app/icij_worker/exceptions.py @@ -1,8 +1,12 @@ -import abc +from abc import ABC from typing import Optional, Sequence -class ICIJWorkerError(metaclass=abc.ABCMeta): +class ICIJWorkerError(ABC): + ... + + +class UnknownApp(ICIJWorkerError, ValueError, ABC): ... diff --git a/neo4j-app/neo4j_app/icij_worker/typing_.py b/neo4j-app/neo4j_app/icij_worker/typing_.py new file mode 100644 index 00000000..8b3df785 --- /dev/null +++ b/neo4j-app/neo4j_app/icij_worker/typing_.py @@ -0,0 +1,33 @@ +from types import TracebackType +from typing import Callable, Coroutine, Optional, Protocol, Tuple, Type, Union + +DependencyLabel = Optional[str] +DependencySetup = Callable[..., None] +DependencyAsyncSetup = Callable[..., Coroutine[None, None, None]] + + +class DependencyTeardown(Protocol): + def __call__( + self, + exc_type: Optional[Type[Exception]], + exc_value: Optional[Exception], + traceback: Optional[TracebackType], + ) -> None: + ... + + +class DependencyAsyncTeardown(Protocol): + async def __call__( + self, + exc_type: Optional[Type[Exception]], + exc_value: Optional[Exception], + traceback: Optional[TracebackType], + ) -> None: + ... + + +Dependency = Tuple[ + DependencyLabel, + Union[DependencySetup, DependencyAsyncSetup], + Optional[Union[DependencyTeardown, DependencyAsyncTeardown]], +] diff --git a/neo4j-app/neo4j_app/icij_worker/utils/__init__.py b/neo4j-app/neo4j_app/icij_worker/utils/__init__.py index e69de29b..c335d998 100644 --- a/neo4j-app/neo4j_app/icij_worker/utils/__init__.py +++ b/neo4j-app/neo4j_app/icij_worker/utils/__init__.py @@ -0,0 +1,3 @@ +from .dependencies import run_deps +from .from_config import FromConfig +from .registrable import Registrable, RegistrableConfig diff --git a/neo4j-app/neo4j_app/icij_worker/utils/dependencies.py b/neo4j-app/neo4j_app/icij_worker/utils/dependencies.py new file mode 100644 index 00000000..72017a68 --- /dev/null +++ b/neo4j-app/neo4j_app/icij_worker/utils/dependencies.py @@ -0,0 +1,78 @@ +import inspect +import logging +import sys +import traceback +from contextlib import asynccontextmanager, contextmanager +from typing import AsyncGenerator, List + +from neo4j_app.icij_worker.typing_ import Dependency + +logger = logging.getLogger(__name__) + + +class DependencyInjectionError(RuntimeError): + def __init__(self, name: str): + msg = f"{name} was not injected" + super().__init__(msg) + + +@contextmanager +def _log_exception_and_continue(): + try: + yield + except Exception as exc: + from neo4j_app.app.utils import INTERNAL_SERVER_ERROR + + title = INTERNAL_SERVER_ERROR + detail = f"{type(exc).__name__}: {exc}" + trace = "".join(traceback.format_exc()) + logger.error("%s\nDetail: %s\nTrace: %s", title, detail, trace) + + raise exc + + +@asynccontextmanager +async def run_deps( + dependencies: List[Dependency], ctx: str, **kwargs +) -> AsyncGenerator[None, None]: + to_close = [] + original_ex = None + try: + with _log_exception_and_continue(): + logger.info("Setting up dependencies for %s...", ctx) + for name, enter_fn, exit_fn in dependencies: + if enter_fn is not None: + if name is not None: + logger.debug("applying: %s", name) + if inspect.iscoroutinefunction(enter_fn): + await enter_fn(**kwargs) + else: + enter_fn(**kwargs) + to_close.append((name, exit_fn)) + yield + except Exception as e: # pylint: disable=broad-exception-caught + original_ex = e + finally: + to_raise = [] + if original_ex is not None: + to_raise.append(original_ex) + logger.info("Rolling back dependencies for %s...", ctx) + for name, exit_fn in to_close[::-1]: + if exit_fn is None: + continue + try: + if name is not None: + logger.debug("rolling back %s", name) + exc_info = sys.exc_info() + with _log_exception_and_continue(): + if inspect.iscoroutinefunction(exit_fn): + await exit_fn(*exc_info) + else: + exit_fn(*exc_info) + except Exception as e: # pylint: disable=broad-exception-caught + to_raise.append(e) + logger.debug("Rolled back all dependencies for %s!", ctx) + if to_raise: + for e in to_raise: + logger.error("Error while handling dependencies %s!", e) + raise RuntimeError(to_raise) diff --git a/neo4j-app/neo4j_app/icij_worker/worker/__init__.py b/neo4j-app/neo4j_app/icij_worker/worker/__init__.py index 6e81fa19..5d6838e5 100644 --- a/neo4j-app/neo4j_app/icij_worker/worker/__init__.py +++ b/neo4j-app/neo4j_app/icij_worker/worker/__init__.py @@ -1,4 +1,11 @@ -from .worker import Worker -from .neo4j import Neo4jAsyncWorker +from enum import Enum, unique + from .config import WorkerConfig -from .process import ProcessWorkerMixin +from .worker import Worker + + +@unique +class WorkerType(str, Enum): + # pylint: disable=invalid-name + mock = "mock" + neo4j = "neo4j" diff --git a/neo4j-app/neo4j_app/icij_worker/worker/config.py b/neo4j-app/neo4j_app/icij_worker/worker/config.py index 82bc3bda..0004391b 100644 --- a/neo4j-app/neo4j_app/icij_worker/worker/config.py +++ b/neo4j-app/neo4j_app/icij_worker/worker/config.py @@ -1,15 +1,14 @@ -from abc import ABC - from pydantic import Field from neo4j_app.icij_worker.utils.registrable import RegistrableConfig -class WorkerConfig(RegistrableConfig, ABC): +class WorkerConfig(RegistrableConfig): registry_key: str = Field(const=True, default="type") + cancelled_tasks_refresh_interval_s: int = 2 + task_queue_poll_interval_s: int = 1 log_level: str = "INFO" type: str class Config: env_prefix = "ICIJ_WORKER_" - case_sensitive = False diff --git a/neo4j-app/neo4j_app/icij_worker/worker/neo4j.py b/neo4j-app/neo4j_app/icij_worker/worker/neo4j.py index c5174c7f..38f3797b 100644 --- a/neo4j-app/neo4j_app/icij_worker/worker/neo4j.py +++ b/neo4j-app/neo4j_app/icij_worker/worker/neo4j.py @@ -1,14 +1,15 @@ +from __future__ import annotations + import asyncio import json -import logging from contextlib import asynccontextmanager from datetime import datetime -from functools import cached_property from typing import AsyncGenerator, Dict, List, Optional, Tuple import neo4j from fastapi.encoders import jsonable_encoder from neo4j.exceptions import ConstraintError, ResultNotSingleError +from pydantic import Field from neo4j_app.constants import ( TASK_ERROR_NODE, @@ -27,38 +28,67 @@ from neo4j_app.core.neo4j.migrations.migrate import retrieve_projects from neo4j_app.core.neo4j.projects import project_db_session from neo4j_app.icij_worker import ( - ICIJApp, + AsyncApp, Task, TaskError, TaskResult, TaskStatus, + Worker, + WorkerConfig, + WorkerType, ) from neo4j_app.icij_worker.event_publisher.neo4j import Neo4jEventPublisher from neo4j_app.icij_worker.exceptions import TaskAlreadyReserved, UnknownTask -from neo4j_app.icij_worker.worker.process import ProcessWorkerMixin _TASK_MANDATORY_FIELDS_BY_ALIAS = { f for f in Task.schema(by_alias=True)["required"] if f != "id" } -class Neo4jAsyncWorker(ProcessWorkerMixin, Neo4jEventPublisher): +class Neo4jWorkerConfig(WorkerConfig): + type: str = Field(const=True, default=WorkerType.neo4j) + + neo4j_connection_timeout: float = 5.0 + neo4j_host: str = "127.0.0.1" + neo4j_password: Optional[str] = None + neo4j_port: int = 7687 + neo4j_uri_scheme: str = "bolt" + neo4j_user: Optional[str] = None + + @property + def neo4j_uri(self) -> str: + return f"{self.neo4j_uri_scheme}://{self.neo4j_host}:{self.neo4j_port}" + + def to_neo4j_driver(self) -> neo4j.AsyncDriver: + auth = None + if self.neo4j_password: + # TODO: add support for expiring and auto renew auth: + # https://neo4j.com/docs/api/python-driver/current/api.html + # #neo4j.auth_management.AuthManagers.expiration_based + auth = neo4j.basic_auth(self.neo4j_user, self.neo4j_password) + driver = neo4j.AsyncGraphDatabase.driver( + self.neo4j_uri, + connection_timeout=self.neo4j_connection_timeout, + connection_acquisition_timeout=self.neo4j_connection_timeout, + max_transaction_retry_time=self.neo4j_connection_timeout, + auth=auth, + ) + return driver + + +@Worker.register(WorkerType.neo4j) +class Neo4jWorker(Worker, Neo4jEventPublisher): def __init__( - self, - app: ICIJApp, - worker_id: str, - driver: Optional[neo4j.AsyncDriver] = None, - logger: Optional[logging.Logger] = None, + self, app: AsyncApp, worker_id: str, driver: neo4j.AsyncDriver, **kwargs ): - super().__init__(app, worker_id) - self._inherited_driver = False - if driver is None: - self._inherited_driver = True - driver = app.config.to_neo4j_driver() - Neo4jEventPublisher.__init__(self, driver) - if logger is None: - logger = logging.getLogger(__name__) - self._logger_ = logger + super().__init__(app, worker_id, **kwargs) + self._driver = driver + + @classmethod + def _from_config(cls, config: Neo4jWorkerConfig, **extras) -> Neo4jWorker: + worker = cls(driver=config.to_neo4j_driver(), **extras) + worker.set_config(config) + return worker async def _consume(self) -> Tuple[Task, str]: projects = [] @@ -75,7 +105,7 @@ async def _consume(self) -> Tuple[Task, str]: ) if received is not None: return received, p.name - await asyncio.sleep(self._app.config.neo4j_app_task_queue_poll_interval_s) + await asyncio.sleep(self.config.cancelled_tasks_refresh_interval_s) refresh_projects_i += 1 async def _negatively_acknowledge( @@ -132,19 +162,8 @@ async def _project_session( async with project_db_session(self._driver, project) as sess: yield sess - @cached_property - def logged_named(self) -> str: - from neo4j_app.icij_worker import Worker - - return Worker.logged_name(self) - - @property - def _logger(self) -> logging.Logger: - return self._logger_ - async def _aexit__(self, exc_type, exc_val, exc_tb): - if not self._inherited_driver: - await self._driver.__aexit__(exc_type, exc_val, exc_tb) + await self._driver.__aexit__(exc_type, exc_val, exc_tb) async def _consume_task_tx( diff --git a/neo4j-app/neo4j_app/icij_worker/worker/process.py b/neo4j-app/neo4j_app/icij_worker/worker/process.py index d689116b..b84b1cb0 100644 --- a/neo4j-app/neo4j_app/icij_worker/worker/process.py +++ b/neo4j-app/neo4j_app/icij_worker/worker/process.py @@ -1,22 +1,29 @@ +import asyncio import functools +import logging import signal -import sys from abc import ABC +from asyncio import AbstractEventLoop +from typing import Optional -from neo4j_app.icij_worker.worker.worker import Worker +from neo4j_app.core.utils.logging import LogWithNameMixin -_HANDLE_SIGNALS = [ - signal.SIGINT, - signal.SIGTERM, -] -if sys.platform == "win32": - _HANDLE_SIGNALS += [signal.CTRL_C_EVENT, signal.CTRL_BREAK_EVENT] +_HANDLE_SIGNALS = [signal.SIGTERM] -class ProcessWorkerMixin(Worker, ABC): +# TODO: rename this file to signals +class HandleSignalsMixin(LogWithNameMixin, ABC): + _work_forever_task: Optional[asyncio.Task] + _loop: AbstractEventLoop + + def __init__(self, logger: logging.Logger, handle_signals: bool = True): + super().__init__(logger) + self._handle_signals = handle_signals + async def _aenter__(self): - await super()._aenter__() - self._setup_signal_handlers() + # TODO: define this one on the worker side + if self._handle_signals: + self._setup_child_process_signal_handlers() def _signal_handler(self, signal_name: signal.Signals, *, graceful: bool): self.error("received %s", signal_name) @@ -25,9 +32,12 @@ def _signal_handler(self, signal_name: signal.Signals, *, graceful: bool): self.info("cancelling worker loop...") self._work_forever_task.cancel() - def _setup_signal_handlers(self): - # Let's always shutdown gracefully for now since when the server shutdown - # it will try to SIGTERM, we want to avoid loosing track of running tasks + def _setup_child_process_signal_handlers(self): + # We ignore SIGINT (graceful shutdown), this signal is handled by the + # process handling the pool, which will terminate the pool and send a SIGTERM, + # which is handled here + + self._loop.add_signal_handler(signal.SIGINT, signal.getsignal(signal.SIG_IGN)) for s in _HANDLE_SIGNALS: handler = functools.partial(self._signal_handler, s, graceful=True) self._loop.add_signal_handler(s, handler) diff --git a/neo4j-app/neo4j_app/icij_worker/worker/worker.py b/neo4j-app/neo4j_app/icij_worker/worker/worker.py index acf0522b..5fa5e107 100644 --- a/neo4j-app/neo4j_app/icij_worker/worker/worker.py +++ b/neo4j-app/neo4j_app/icij_worker/worker/worker.py @@ -5,7 +5,7 @@ import inspect import logging import traceback -from abc import ABC, abstractmethod +from abc import abstractmethod from collections import defaultdict from contextlib import AbstractAsyncContextManager, asynccontextmanager from copy import deepcopy @@ -19,13 +19,12 @@ Optional, Tuple, Type, + TypeVar, final, ) -from neo4j_app.core import AppConfig -from neo4j_app.core.utils.logging import LogWithNameMixin from neo4j_app.core.utils.progress import CheckCancelledProgress -from neo4j_app.icij_worker.app import ICIJApp, RegisteredTask +from neo4j_app.icij_worker.app import AsyncApp, RegisteredTask from neo4j_app.icij_worker.event_publisher import EventPublisher from neo4j_app.icij_worker.exceptions import ( MaxRetriesExceeded, @@ -41,27 +40,54 @@ TaskResult, TaskStatus, ) +from neo4j_app.icij_worker.utils.registrable import Registrable +from neo4j_app.icij_worker.worker.process import HandleSignalsMixin from neo4j_app.typing_ import PercentProgress logger = logging.getLogger(__name__) PROGRESS_HANDLER_ARG = "progress" +C = TypeVar("C", bound="WorkerConfig") -class Worker(EventPublisher, LogWithNameMixin, AbstractAsyncContextManager, ABC): - def __init__(self, app: ICIJApp, worker_id: str): - if app.config is None: - raise ValueError("worker requires a configured app, app config is missing") + +class Worker( + EventPublisher, + Registrable, + HandleSignalsMixin, + AbstractAsyncContextManager, +): + def __init__( + self, + app: AsyncApp, + worker_id: str, + handle_signals: bool = True, + teardown_dependencies: bool = False, + ): + # If worker are run using a thread backend then signal handling might not be + # required, in this case the signal handling mixing will just do nothing + HandleSignalsMixin.__init__(self, logger, handle_signals=handle_signals) self._app = app self._id = worker_id + self._teardown_dependencies = teardown_dependencies self._graceful_shutdown = True self._loop = asyncio.get_event_loop() self._work_forever_task: Optional[asyncio.Task] = None self._already_exiting = False - self._config = app.config - self._cancelled_ = defaultdict(set) - self.__deps_cm = None self._current = None + self._cancelled_ = defaultdict(set) + self._config: Optional[C] = None + + def set_config(self, config: C): + self._config = config + + def _to_config(self) -> C: + if self._config is None: + raise ValueError( + "worker was initialized using a from_config, " + "but the config was not attached using .set_config" + ) + return self._config @property def loop(self) -> asyncio.AbstractEventLoop: @@ -75,21 +101,6 @@ def _cancelled(self) -> List[str]: def id(self) -> str: return self._id - @classmethod - def from_config(cls, config: AppConfig, worker_id: str, **kwargs) -> Worker: - worker_cls = config.to_worker_cls() - return worker_cls(app=config.to_async_app(), worker_id=worker_id, **kwargs) - - @classmethod - @final - def work_forever_from_config(cls, config: AppConfig, worker_id: str, **kwargs): - """ - Convenience function to ease multiprocessing serialization and avoid pickle - errors - """ - worker = cls.from_config(config, worker_id, **kwargs) - worker.work_forever() - @final def work_forever(self): with self: # The graceful shutdown happens here @@ -113,11 +124,6 @@ async def _work_forever(self): while True: await self._work_once() - @final - @functools.cached_property - def config(self) -> AppConfig: - return self._config - @final def logged_name(self) -> str: return self.id @@ -277,7 +283,7 @@ def parse_task( @final @functools.cached_property def _cancelled_task_refresh_interval_s(self) -> int: - return self._app.config.neo4j_app_cancelled_task_refresh_interval_s + return self.config.cancelled_tasks_refresh_interval_s @final async def check_cancelled( @@ -315,27 +321,12 @@ def _make_progress(self, task: Task, project: str) -> PercentProgress: ) return progress - @final - @asynccontextmanager - async def _deps_cm(self): - if self._config is not None: - from neo4j_app.app.dependencies import run_deps - - async with run_deps( - self.config.to_async_deps(), config=self.config, worker_id=self.id - ): - yield - else: - yield - @final def __enter__(self): self._loop.run_until_complete(self.__aenter__()) @final async def __aenter__(self): - self.__deps_cm = self._deps_cm() - await self.__deps_cm.__aenter__() await self._aenter__() async def _aenter__(self): @@ -352,12 +343,11 @@ async def __aexit__(self, exc_type, exc_value, tb): self._already_exiting = True # Let's try to shut down gracefully await self.shutdown() - # Then call any extra context manager exit which is not a global - # dependency - await self._aexit__(exc_type, exc_value, tb) - # Then close global dependencies - self.info("closing dependencies...") - await self.__deps_cm.__aexit__(None, None, None) + # Clean worker dependencies only if needed, dependencies might be share in + # which case we don't want to tear them down + if self._teardown_dependencies: + self.info("cleaning worker dependencies...") + await self._aexit__(exc_type, exc_value, tb) async def _aexit__(self, exc_type, exc_val, exc_tb): pass @@ -384,7 +374,7 @@ async def shutdown(self): def _retrieve_registered_task( task: Task, - app: ICIJApp, + app: AsyncApp, ) -> RegisteredTask: registered = app.registry.get(task.type) if registered is None: diff --git a/neo4j-app/neo4j_app/run/run.py b/neo4j-app/neo4j_app/run/run.py index 197fc88a..d9845958 100644 --- a/neo4j-app/neo4j_app/run/run.py +++ b/neo4j-app/neo4j_app/run/run.py @@ -11,13 +11,13 @@ from gunicorn.app.base import BaseApplication import neo4j_app +from neo4j_app.app import ServiceConfig from neo4j_app.app.utils import create_app -from neo4j_app.core.config import AppConfig from neo4j_app.core.utils.logging import DATE_FMT, STREAM_HANDLER_FMT def debug_app(): - config = AppConfig() + config = ServiceConfig() app = create_app(config) return app @@ -32,7 +32,7 @@ def _start_app_(ns): class GunicornApp(BaseApplication): # pylint: disable=abstract-method - def __init__(self, app: FastAPI, config: AppConfig, **kwargs): + def __init__(self, app: FastAPI, config: ServiceConfig, **kwargs): self.application = app self._app_config = config super().__init__(**kwargs) @@ -47,7 +47,7 @@ def load(self): return self.application @classmethod - def from_config(cls, config: AppConfig) -> GunicornApp: + def from_config(cls, config: ServiceConfig) -> GunicornApp: fast_api = create_app(config) return cls(fast_api, config) @@ -58,11 +58,11 @@ def _start_app(config_path: Optional[str] = None, force_migrations: bool = False if not config_path.exists(): raise ValueError(f"Provided config path does not exists: {config_path}") with config_path.open() as f: - config = AppConfig.from_java_properties( + config = ServiceConfig.from_java_properties( f, force_migrations=force_migrations ) else: - config = AppConfig() + config = ServiceConfig() app = GunicornApp.from_config(config) app.run() diff --git a/neo4j-app/neo4j_app/tasks/__init__.py b/neo4j-app/neo4j_app/tasks/__init__.py index ffbd73cc..8f7111df 100644 --- a/neo4j-app/neo4j_app/tasks/__init__.py +++ b/neo4j-app/neo4j_app/tasks/__init__.py @@ -1,2 +1,2 @@ -from .app import app, WORKER_LIFESPAN_DEPS +from .app import WORKER_LIFESPAN_DEPS, app from .imports import * diff --git a/neo4j-app/neo4j_app/tasks/app.py b/neo4j-app/neo4j_app/tasks/app.py index e3320b50..cc678ec6 100644 --- a/neo4j-app/neo4j_app/tasks/app.py +++ b/neo4j-app/neo4j_app/tasks/app.py @@ -1,27 +1,9 @@ import logging -from neo4j_app.app.dependencies import ( - config_enter, - es_client_enter, - es_client_exit, - neo4j_driver_enter, - neo4j_driver_exit, -) -from neo4j_app.core import AppConfig -from neo4j_app.icij_worker import ICIJApp +from neo4j_app.icij_worker import AsyncApp +from neo4j_app.tasks.dependencies import WORKER_LIFESPAN_DEPS logger = logging.getLogger(__name__) -app = ICIJApp(name="neo4j-app") -def loggers_enter(config: AppConfig, worker_id: str): - config.setup_loggers(worker_id=worker_id) - logger.info("worker loggers ready to log 💬") - - -WORKER_LIFESPAN_DEPS = [ - ("configuration loading", config_enter, None), - ("loggers setup", loggers_enter, None), - ("neo4j driver creation", neo4j_driver_enter, neo4j_driver_exit), - ("ES client creation", es_client_enter, es_client_exit), -] +app = AsyncApp(name="neo4j-app", dependencies=WORKER_LIFESPAN_DEPS) diff --git a/neo4j-app/neo4j_app/tasks/dependencies.py b/neo4j-app/neo4j_app/tasks/dependencies.py new file mode 100644 index 00000000..af3d867a --- /dev/null +++ b/neo4j-app/neo4j_app/tasks/dependencies.py @@ -0,0 +1,130 @@ +import logging +from pathlib import Path +from typing import Optional, cast + +import neo4j + +from neo4j_app.core.elasticsearch import ESClientABC +from neo4j_app.core.neo4j import MIGRATIONS, migrate_db_schemas +from neo4j_app.core.neo4j.migrations import delete_all_migrations +from neo4j_app.core.neo4j.projects import create_project_registry_db +from neo4j_app.icij_worker.utils.dependencies import DependencyInjectionError +from neo4j_app.config import AppConfig + +logger = logging.getLogger(__name__) + +_CONFIG: Optional[AppConfig] = None +_ASYNC_APP_CONFIG: Optional[AppConfig] = None +_ES_CLIENT: Optional[ESClientABC] = None +_ASYNC_APP_CONFIG_PATH: Optional[Path] = None +_NEO4J_DRIVER: Optional[neo4j.AsyncDriver] = None + + +def config_enter(config: AppConfig, **_): + global _CONFIG + _CONFIG = config + logger.info("Loaded config %s", config.json(indent=2)) + + +async def config_from_path_enter(config_path: Path, **_): + global _CONFIG + with config_path.open() as f: + config = AppConfig.from_java_properties(f) + config = await config.with_neo4j_support() + _CONFIG = config + logger.info("Loaded config %s", config.json(indent=2)) + + +async def config_neo4j_support_enter(**_): + global _CONFIG + config = lifespan_config() + _CONFIG = await config.with_neo4j_support() + + +def lifespan_config() -> AppConfig: + if _CONFIG is None: + raise DependencyInjectionError("config") + return _CONFIG + + +def loggers_enter(worker_id: str, **_): + config = lifespan_config() + config.setup_loggers(worker_id=worker_id) + logger.info("worker loggers ready to log 💬") + + +async def neo4j_driver_enter(**__): + global _NEO4J_DRIVER + _NEO4J_DRIVER = lifespan_config().to_neo4j_driver() + await _NEO4J_DRIVER.__aenter__() # pylint: disable=unnecessary-dunder-call + + logger.debug("pinging neo4j...") + async with _NEO4J_DRIVER.session(database=neo4j.SYSTEM_DATABASE) as sess: + await sess.run("CALL db.ping()") + logger.debug("neo4j driver is ready") + + +async def neo4j_driver_exit(exc_type, exc_value, trace): + already_closed = False + try: + await _NEO4J_DRIVER.verify_connectivity() + except: # pylint: disable=bare-except + already_closed = True + if not already_closed: + await _NEO4J_DRIVER.__aexit__(exc_type, exc_value, trace) + + +def lifespan_neo4j_driver() -> neo4j.AsyncDriver: + if _NEO4J_DRIVER is None: + raise DependencyInjectionError("neo4j driver") + return cast(neo4j.AsyncDriver, _NEO4J_DRIVER) + + +async def es_client_enter(**_): + global _ES_CLIENT + _ES_CLIENT = lifespan_config().to_es_client() + await _ES_CLIENT.__aenter__() # pylint: disable=unnecessary-dunder-call + + +async def es_client_exit(exc_type, exc_value, trace): + await _ES_CLIENT.__aexit__(exc_type, exc_value, trace) + + +def lifespan_es_client() -> ESClientABC: + if _ES_CLIENT is None: + raise DependencyInjectionError("es client") + return cast(ESClientABC, _ES_CLIENT) + + +async def create_project_registry_db_enter(**_): + driver = lifespan_neo4j_driver() + await create_project_registry_db(driver) + + +async def migrate_app_db_enter(**_): + logger.info("Running schema migrations...") + config = lifespan_config() + driver = lifespan_neo4j_driver() + if config.force_migrations: + # TODO: improve this as is could lead to race conditions... + logger.info("Deleting all previous migrations...") + await delete_all_migrations(driver) + await migrate_db_schemas( + driver, + registry=MIGRATIONS, + timeout_s=config.neo4j_app_migration_timeout_s, + throttle_s=config.neo4j_app_migration_throttle_s, + ) + + +WORKER_LIFESPAN_DEPS = [ + ("configuration loading", config_from_path_enter, None), + ("loggers setup", loggers_enter, None), + ("neo4j driver creation", neo4j_driver_enter, neo4j_driver_exit), + # This has to be done after the neo4j driver creation, once we know we can reach + # the neo4j server + ("add configuration neo4j support", config_neo4j_support_enter, None), + ("neo4j project registry creation", create_project_registry_db_enter, None), + ("neo4j DB migration", migrate_app_db_enter, None), + ("ES client creation", es_client_enter, es_client_exit), +] diff --git a/neo4j-app/neo4j_app/tasks/imports.py b/neo4j-app/neo4j_app/tasks/imports.py index af52b079..a8dd0dce 100644 --- a/neo4j-app/neo4j_app/tasks/imports.py +++ b/neo4j-app/neo4j_app/tasks/imports.py @@ -1,10 +1,5 @@ import logging -from neo4j_app.app.dependencies import ( - lifespan_config, - lifespan_es_client, - lifespan_neo4j_driver, -) from neo4j_app.core.imports import import_documents, import_named_entities from neo4j_app.core.objects import IncrementalImportResponse from neo4j_app.core.utils.logging import log_elapsed_time_cm @@ -12,6 +7,7 @@ from neo4j_app.core.utils.pydantic import LowerCamelCaseModel from neo4j_app.typing_ import PercentProgress from .app import app +from .dependencies import lifespan_config, lifespan_es_client, lifespan_neo4j_driver logger = logging.getLogger(__name__) diff --git a/neo4j-app/neo4j_app/tests/core/test_config.py b/neo4j-app/neo4j_app/tests/app/test_config.py similarity index 63% rename from neo4j-app/neo4j_app/tests/core/test_config.py rename to neo4j-app/neo4j_app/tests/app/test_config.py index da23d360..7b63f747 100644 --- a/neo4j-app/neo4j_app/tests/core/test_config.py +++ b/neo4j-app/neo4j_app/tests/app/test_config.py @@ -4,63 +4,61 @@ import pytest from pydantic import ValidationError -from neo4j_app.core import AppConfig +from neo4j_app.app import ServiceConfig from neo4j_app.tests.conftest import fail_if_exception def test_should_support_alias(): # When neo4j_app_name = "test_name" - config = AppConfig(neo4j_app_name=neo4j_app_name) + config = ServiceConfig(neo4j_app_name=neo4j_app_name) # Then assert config.neo4j_app_name == neo4j_app_name @pytest.mark.parametrize( - "config,expected_config", + "config_as_str,expected_config,expected_written_config", [ + ("someExtraInfo=useless", ServiceConfig(), ""), ( - """neo4jProject=test-project -neo4jImportDir=import-dir -""", - AppConfig( - neo4j_app_host="127.0.0.1", - neo4j_app_port=8080, - elasticsearch_address="http://127.0.0.1:9200", - ), - ), - ( - """neo4jProject=test-project -neo4jImportDir=import-dir + """elasticsearchAddress=http://elasticsearch:9222 neo4jAppHost=this-the-neo4j-app -neo4jAppPort=3333 -elasticsearchAddress=http://elasticsearch:9222 -someExtraInfo=useless -""", - AppConfig( +neo4jAppPort=3333""", + ServiceConfig( neo4j_app_host="this-the-neo4j-app", neo4j_app_port=3333, elasticsearch_address="http://elasticsearch:9222", ), + """elasticsearchAddress=http://elasticsearch:9222 +neo4jAppHost=this-the-neo4j-app +neo4jAppPort=3333 + +""", ), ], ) -def test_should_load_from_java(config: str, expected_config: AppConfig): +def test_should_load_from_java_and_write_to_java( + config_as_str: str, expected_config: ServiceConfig, expected_written_config: str +): # Given - config_stream = io.StringIO(config) + config_stream = io.StringIO(config_as_str) # When - loaded_config = AppConfig.from_java_properties(config_stream) + loaded_config = ServiceConfig.from_java_properties(config_stream) + config_io = io.StringIO() + loaded_config.write_java_properties(config_io) + written = config_io.getvalue() # Then assert loaded_config == expected_config + assert written == expected_written_config @pytest.mark.pull(id="62") def test_should_support_address_without_port(): # Given - config = AppConfig(elasticsearch_address="http://elasticsearch") + config = ServiceConfig(elasticsearch_address="http://elasticsearch") # Then with fail_if_exception("Failed to initialize ES client"): config.to_es_client() @@ -70,7 +68,7 @@ def test_should_support_address_without_port(): def test_should_forward_page_size_to_client(): # Given es_default_page_size = 666 - config = AppConfig( + config = ServiceConfig( elasticsearch_address="http://elasticsearch", es_default_page_size=es_default_page_size, ) @@ -87,7 +85,7 @@ def test_should_raise_for_missing_auth_part( # When/Then expected_msg = "neo4j authentication is missing user or password" with pytest.raises(ValidationError, match=expected_msg): - AppConfig( + ServiceConfig( elasticsearch_address="http://elasticsearch:9222", neo4j_user=user, neo4j_password=password, diff --git a/neo4j-app/neo4j_app/tests/app/test_main.py b/neo4j-app/neo4j_app/tests/app/test_main.py index 45177842..4f9a74e7 100644 --- a/neo4j-app/neo4j_app/tests/app/test_main.py +++ b/neo4j-app/neo4j_app/tests/app/test_main.py @@ -1,7 +1,7 @@ from starlette.testclient import TestClient from neo4j_app import ROOT_DIR -from neo4j_app.core import AppConfig +from neo4j_app.app import ServiceConfig try: import tomllib @@ -29,7 +29,7 @@ def test_config(test_client: TestClient): # Then assert res.status_code == 200, res.json() - config = AppConfig.parse_obj(res.json()) + config = ServiceConfig.parse_obj(res.json()) assert isinstance(config.supports_neo4j_enterprise, bool) diff --git a/neo4j-app/neo4j_app/tests/app/test_tasks.py b/neo4j-app/neo4j_app/tests/app/test_tasks.py index 752a5a38..c4bb22f8 100644 --- a/neo4j-app/neo4j_app/tests/app/test_tasks.py +++ b/neo4j-app/neo4j_app/tests/app/test_tasks.py @@ -8,27 +8,29 @@ from starlette.testclient import TestClient from neo4j_app.app.utils import create_app -from neo4j_app.core import AppConfig -from neo4j_app.core.config import WorkerType +from neo4j_app.app.config import ServiceConfig, WorkerType from neo4j_app.core.objects import TaskJob from neo4j_app.core.utils.logging import DifferedLoggingMessage from neo4j_app.core.utils.pydantic import safe_copy -from neo4j_app.icij_worker import ICIJApp, Task, TaskStatus +from neo4j_app.icij_worker import AsyncApp, Task, TaskStatus from neo4j_app.tests.conftest import TEST_PROJECT, test_error_router, true_after @pytest.fixture(scope="function") def test_client_prod( - test_config: AppConfig, - test_async_app: ICIJApp, + test_config: ServiceConfig, + test_async_app: AsyncApp, # Wipe neo4j and init project neo4j_app_driver: neo4j.AsyncSession, ) -> TestClient: # pylint: disable=unused-argument config = safe_copy( - test_config, update={"neo4j_app_worker_type": WorkerType.NEO4J, "test": False} + test_config, update={"neo4j_app_worker_type": WorkerType.neo4j, "test": False} + ) + new_async_app = AsyncApp( + name=test_async_app.name, + dependencies=test_async_app._dependencies, # pylint: disable=protected-access ) - new_async_app = ICIJApp(name=test_async_app.name, config=config) new_async_app._registry = ( # pylint: disable=protected-access test_async_app.registry ) @@ -147,10 +149,13 @@ def test_cancel_task(test_client: TestClient): @pytest.fixture(scope="function") def test_client_limited_queue( - test_config: AppConfig, test_async_app: ICIJApp + test_config: ServiceConfig, test_async_app: AsyncApp ) -> TestClient: config = safe_copy(test_config, update={"neo4j_app_task_queue_size": 0}) - new_async_app = ICIJApp(name=test_async_app.name, config=config) + new_async_app = AsyncApp( + name=test_async_app.name, + dependencies=test_async_app._dependencies, # pylint: disable=protected-access + ) new_async_app._registry = ( # pylint: disable=protected-access test_async_app.registry ) diff --git a/neo4j-app/neo4j_app/tests/conftest.py b/neo4j-app/neo4j_app/tests/conftest.py index 7b1ac310..60245ac1 100644 --- a/neo4j-app/neo4j_app/tests/conftest.py +++ b/neo4j-app/neo4j_app/tests/conftest.py @@ -30,20 +30,19 @@ from starlette.testclient import TestClient import neo4j_app +from neo4j_app.app import ServiceConfig from neo4j_app.app.dependencies import ( config_enter, loggers_enter, ) from neo4j_app.app.utils import create_app -from neo4j_app.core import AppConfig -from neo4j_app.core.config import WorkerType from neo4j_app.core.elasticsearch import ESClient, ESClientABC from neo4j_app.core.elasticsearch.client import PointInTime from neo4j_app.core.neo4j import MIGRATIONS from neo4j_app.core.neo4j.migrations.migrate import init_project from neo4j_app.core.neo4j.projects import NEO4J_COMMUNITY_DB from neo4j_app.core.utils.pydantic import BaseICIJModel -from neo4j_app.icij_worker import ICIJApp +from neo4j_app.icij_worker import AsyncApp, WorkerType from neo4j_app.typing_ import PercentProgress # TODO: at a high level it's a waste to have to repeat code for each fixture level, @@ -51,7 +50,7 @@ # https://docs.pytest.org/en/6.2.x/fixture.html#dynamic-scope -APP = ICIJApp(name="test-app") +APP = AsyncApp(name="test-app") DATA_DIR = Path(__file__).parents[3].joinpath(".data") TEST_PROJECT = "test_project" @@ -142,15 +141,15 @@ def event_loop(): @pytest.fixture(scope="session") -def test_config() -> AppConfig: - config = AppConfig( +def test_config() -> ServiceConfig: + config = ServiceConfig( elasticsearch_address=f"http://127.0.0.1:{ELASTICSEARCH_TEST_PORT}", es_default_page_size=5, neo4j_app_host="127.0.0.1", neo4j_port=NEO4J_TEST_PORT, neo4j_user=NEO4J_TEST_USER, neo4j_password=NEO4J_TEST_PASSWORD, - neo4j_app_worker_type=WorkerType.MOCK, + neo4j_app_worker_type=WorkerType.mock, test=True, neo4j_app_async_app=f"{__name__}.APP", neo4j_app_async_dependencies=f"{__name__}.TEST_WORKER_DEPS", @@ -165,7 +164,7 @@ def test_config() -> AppConfig: @pytest.fixture(scope="session") -def test_app_session(test_config: AppConfig) -> FastAPI: +def test_app_session(test_config: ServiceConfig) -> FastAPI: return create_app(test_config) @@ -208,8 +207,8 @@ def test_client_with_async( es_test_client: ESClient, # Same for neo4j neo4j_test_session: neo4j.AsyncSession, - test_async_app: ICIJApp, - test_config: AppConfig, + test_async_app: AsyncApp, + test_config: ServiceConfig, ) -> Generator[TestClient, None, None]: # pylint: disable=unused-argument # pylint: disable=unused-argument @@ -652,8 +651,8 @@ async def sleep_for( @pytest.fixture(scope="session") -def test_async_app(test_config: AppConfig) -> ICIJApp: - return test_config.to_async_app() +def test_async_app(test_config: ServiceConfig) -> AsyncApp: + return AsyncApp.load(test_config.neo4j_app_async_app) @pytest.fixture() diff --git a/neo4j-app/neo4j_app/tests/icij_worker/conftest.py b/neo4j-app/neo4j_app/tests/icij_worker/conftest.py index 298b76b9..945b9702 100644 --- a/neo4j-app/neo4j_app/tests/icij_worker/conftest.py +++ b/neo4j-app/neo4j_app/tests/icij_worker/conftest.py @@ -7,7 +7,6 @@ import threading from abc import ABC from datetime import datetime -from functools import cached_property from pathlib import Path from typing import Dict, List, Optional, Tuple, Union @@ -15,17 +14,22 @@ import pytest import pytest_asyncio from fastapi.encoders import jsonable_encoder +from pydantic import Field -from neo4j_app.core import AppConfig +from neo4j_app import AppConfig +from neo4j_app.app.dependencies import FASTAPI_LIFESPAN_DEPS from neo4j_app.core.utils.pydantic import safe_copy from neo4j_app.icij_worker import ( + AsyncApp, EventPublisher, - ICIJApp, Task, TaskError, TaskEvent, TaskResult, TaskStatus, + Worker, + WorkerConfig, + WorkerType, ) from neo4j_app.icij_worker.exceptions import ( TaskAlreadyExists, @@ -33,7 +37,6 @@ UnknownTask, ) from neo4j_app.icij_worker.task_manager import TaskManager -from neo4j_app.icij_worker.worker import ProcessWorkerMixin from neo4j_app.typing_ import PercentProgress @@ -236,23 +239,33 @@ def _get_db_task(self, db: Dict, task_id: str, project: str) -> Dict: raise UnknownTask(task_id) from e -class MockWorker(ProcessWorkerMixin, MockEventPublisher): +class MockWorkerConfig(WorkerConfig): + type: str = Field(const=True, default=WorkerType.mock) + db_path: Path + + +@Worker.register(WorkerType.mock) +class MockWorker(Worker, MockEventPublisher): def __init__( self, - app: ICIJApp, + app: AsyncApp, worker_id: str, db_path: Path, lock: Union[threading.Lock, multiprocessing.Lock], + **kwargs, ): - super().__init__(app, worker_id) + super().__init__(app, worker_id, **kwargs) MockEventPublisher.__init__(self, db_path, lock) self._worker_id = worker_id self._logger_ = logging.getLogger(__name__) - # TODO: not sure why this one is not inherited - @cached_property - def logged_named(self) -> str: - return super().logged_named + @classmethod + def _from_config(cls, config: MockWorkerConfig, **extras) -> MockWorker: + worker = cls(db_path=config.db_path, **extras) + return worker + + def _to_config(self) -> MockWorkerConfig: + return MockWorkerConfig(db_path=self._db_path) async def _save_result(self, result: TaskResult, project: str): task_key = self._task_key(task_id=result.task_id, project=project) @@ -272,10 +285,6 @@ async def _save_error(self, error: TaskError, task: Task, project: str): db[self._error_collection][task_key] = errors self._write(db) - @property - def _logger(self) -> logging.Logger: - return self._logger_ - def _get_db_errors(self, task_id: str, project: str) -> List[TaskError]: key = self._task_key(task_id=task_id, project=project) with self.db_lock: @@ -358,7 +367,7 @@ async def _consume(self) -> Tuple[Task, str]: k, t = min(queued, key=lambda x: x[1].created_at) project = eval(k)[1] # pylint: disable=eval-used return t, project - await asyncio.sleep(self.config.neo4j_app_task_queue_poll_interval_s) + await asyncio.sleep(self.config.task_queue_poll_interval_s) class Recoverable(ValueError): @@ -366,8 +375,10 @@ class Recoverable(ValueError): @pytest.fixture(scope="function") -def test_failing_async_app(test_config: AppConfig) -> ICIJApp: - app = ICIJApp(name="test-app", config=test_config) +def test_failing_async_app( + test_config: AppConfig, # pylint: disable=unused-argument +) -> AsyncApp: + app = AsyncApp(name="test-app", dependencies=FASTAPI_LIFESPAN_DEPS) already_failed = False @app.task("recovering_task", recover_from=(Recoverable,)) diff --git a/neo4j-app/neo4j_app/tests/icij_worker/worker/conftest.py b/neo4j-app/neo4j_app/tests/icij_worker/worker/conftest.py index 654abb51..37b7a326 100644 --- a/neo4j-app/neo4j_app/tests/icij_worker/worker/conftest.py +++ b/neo4j-app/neo4j_app/tests/icij_worker/worker/conftest.py @@ -5,14 +5,14 @@ import pytest -from neo4j_app.core import AppConfig -from neo4j_app.icij_worker import ICIJApp +from neo4j_app.app.dependencies import FASTAPI_LIFESPAN_DEPS +from neo4j_app.icij_worker import AsyncApp from neo4j_app.tests.icij_worker.conftest import MockWorker @pytest.fixture(scope="module") -def test_app(test_config: AppConfig) -> ICIJApp: - app = ICIJApp(name="test-app", config=test_config) +def test_app() -> AsyncApp: + app = AsyncApp(name="test-app", dependencies=FASTAPI_LIFESPAN_DEPS) @app.task async def hello_word(greeted: str): @@ -22,9 +22,11 @@ async def hello_word(greeted: str): @pytest.fixture(scope="function") -def mock_worker(test_async_app: ICIJApp, tmpdir: Path) -> MockWorker: +def mock_worker(test_async_app: AsyncApp, tmpdir: Path) -> MockWorker: db_path = Path(tmpdir) / "db.json" MockWorker.fresh_db(db_path) lock = threading.Lock() - worker = MockWorker(test_async_app, "test-worker", db_path, lock) + worker = MockWorker( + test_async_app, "test-worker", db_path, lock, teardown_dependencies=False + ) return worker diff --git a/neo4j-app/neo4j_app/tests/icij_worker/worker/test_neo4j.py b/neo4j-app/neo4j_app/tests/icij_worker/worker/test_neo4j.py index b23c1676..99dbbeac 100644 --- a/neo4j-app/neo4j_app/tests/icij_worker/worker/test_neo4j.py +++ b/neo4j-app/neo4j_app/tests/icij_worker/worker/test_neo4j.py @@ -9,7 +9,7 @@ from neo4j_app.core.neo4j.projects import project_db_session from neo4j_app.core.utils.pydantic import safe_copy from neo4j_app.icij_worker import ( - ICIJApp, + AsyncApp, Task, TaskError, TaskEvent, @@ -17,7 +17,7 @@ TaskStatus, ) from neo4j_app.icij_worker.task_manager.neo4j import Neo4JTaskManager -from neo4j_app.icij_worker.worker.neo4j import Neo4jAsyncWorker +from neo4j_app.icij_worker.worker.neo4j import Neo4jWorker from neo4j_app.tests.conftest import ( TEST_PROJECT, fail_if_exception, @@ -25,8 +25,8 @@ @pytest.fixture(scope="function") -def worker(test_app: ICIJApp, neo4j_app_driver: neo4j.AsyncDriver) -> Neo4jAsyncWorker: - worker = Neo4jAsyncWorker(test_app, "test-worker", neo4j_app_driver) +def worker(test_app: AsyncApp, neo4j_app_driver: neo4j.AsyncDriver) -> Neo4jWorker: + worker = Neo4jWorker(test_app, "test-worker", neo4j_app_driver) return worker @@ -39,9 +39,7 @@ async def _count_locks(driver: neo4j.AsyncDriver, project: str) -> int: return counts["nLocks"] -async def test_worker_consume_task( - populate_tasks: List[Task], worker: Neo4jAsyncWorker -): +async def test_worker_consume_task(populate_tasks: List[Task], worker: Neo4jWorker): # pylint: disable=unused-argument # Given project = TEST_PROJECT @@ -61,7 +59,7 @@ async def test_worker_consume_task( async def test_worker_negatively_acknowledge( - populate_tasks: List[Task], worker: Neo4jAsyncWorker + populate_tasks: List[Task], worker: Neo4jWorker ): # pylint: disable=unused-argument # When @@ -79,7 +77,7 @@ async def test_worker_negatively_acknowledge( async def test_worker_negatively_acknowledge_and_requeue( - populate_tasks: List[Task], worker: Neo4jAsyncWorker + populate_tasks: List[Task], worker: Neo4jWorker ): # pylint: disable=unused-argument # Given @@ -115,7 +113,7 @@ async def test_worker_negatively_acknowledge_and_requeue( assert n_locks == 0 -async def test_worker_save_result(populate_tasks: List[Task], worker: Neo4jAsyncWorker): +async def test_worker_save_result(populate_tasks: List[Task], worker: Neo4jWorker): # Given task_manager = Neo4JTaskManager(worker.driver, max_queue_size=10) project = TEST_PROJECT @@ -135,7 +133,7 @@ async def test_worker_save_result(populate_tasks: List[Task], worker: Neo4jAsync async def test_worker_should_raise_when_saving_existing_result( - populate_tasks: List[Task], worker: Neo4jAsyncWorker + populate_tasks: List[Task], worker: Neo4jWorker ): # Given project = TEST_PROJECT @@ -152,33 +150,8 @@ async def test_worker_should_raise_when_saving_existing_result( await worker.save_result(result=task_result, project=project) -async def test_worker_save_error(populate_tasks: List[Task], worker: Neo4jAsyncWorker): - # pylint: disable=unused-argument - # Given - task_manager = Neo4JTaskManager(worker.driver, max_queue_size=10) - project = TEST_PROJECT - error = TaskError( - id="error-id", - title="someErrorTitle", - detail="with_details", - occurred_at=datetime.now(), - ) - - # When - task, _ = await worker.consume() - await worker.save_error(error=error, task=task, project=project) - saved_task = await task_manager.get_task(task_id=task.id, project=project) - saved_errors = await task_manager.get_task_errors(task_id=task.id, project=project) - - # Then - # We don't expect the task status to be updated by saving the error, the negative - # acknowledgment will do it - assert saved_task == task - assert saved_errors == [error] - - async def test_worker_acknowledgment_cm( - populate_tasks: List[Task], worker: Neo4jAsyncWorker + populate_tasks: List[Task], worker: Neo4jWorker ): # Given created = populate_tasks[0] @@ -204,3 +177,28 @@ async def test_worker_acknowledgment_cm( count_locks_query = "MATCH (lock:_TaskLock) RETURN count(*) as nLocks" recs, _, _ = await worker.driver.execute_query(count_locks_query) assert recs[0]["nLocks"] == 0 + + +async def test_worker_save_error(populate_tasks: List[Task], worker: Neo4jWorker): + # pylint: disable=unused-argument + # Given + task_manager = Neo4JTaskManager(worker.driver, max_queue_size=10) + project = TEST_PROJECT + error = TaskError( + id="error-id", + title="someErrorTitle", + detail="with_details", + occurred_at=datetime.now(), + ) + + # When + task, _ = await worker.consume() + await worker.save_error(error=error, task=task, project=project) + saved_task = await task_manager.get_task(task_id=task.id, project=project) + saved_errors = await task_manager.get_task_errors(task_id=task.id, project=project) + + # Then + # We don't expect the task status to be updated by saving the error, the negative + # acknowledgment will do it + assert saved_task == task + assert saved_errors == [error] diff --git a/neo4j-app/neo4j_app/tests/icij_worker/worker/test_process.py b/neo4j-app/neo4j_app/tests/icij_worker/worker/test_process.py index afd9529b..09bd6970 100644 --- a/neo4j-app/neo4j_app/tests/icij_worker/worker/test_process.py +++ b/neo4j-app/neo4j_app/tests/icij_worker/worker/test_process.py @@ -4,13 +4,11 @@ import pytest -from neo4j_app.icij_worker.worker import ProcessWorkerMixin +from neo4j_app.icij_worker import Worker @pytest.mark.parametrize("signal", [Signals.SIGINT, Signals.SIGTERM]) -def test_worker_signal_handler( - mock_worker: ProcessWorkerMixin, signal: Signals, caplog -): +def test_worker_signal_handler(mock_worker: Worker, signal: Signals, caplog): # pylint: disable=protected-access # Given caplog.set_level(logging.INFO) diff --git a/neo4j-app/neo4j_app/tests/icij_worker/worker/test_worker.py b/neo4j-app/neo4j_app/tests/icij_worker/worker/test_worker.py index abbec02c..ba64bc9b 100644 --- a/neo4j-app/neo4j_app/tests/icij_worker/worker/test_worker.py +++ b/neo4j-app/neo4j_app/tests/icij_worker/worker/test_worker.py @@ -12,7 +12,7 @@ import pytest from neo4j_app.icij_worker import ( - ICIJApp, + AsyncApp, Task, TaskError, TaskEvent, @@ -26,7 +26,7 @@ @pytest.fixture(scope="function") -def mock_failing_worker(test_failing_async_app: ICIJApp, tmpdir: Path) -> MockWorker: +def mock_failing_worker(test_failing_async_app: AsyncApp, tmpdir: Path) -> MockWorker: db_path = Path(tmpdir) / "db.json" MockWorker.fresh_db(db_path) lock = threading.Lock() diff --git a/neo4j-app/neo4j_app/tests/icij_worker/worker/worker_main.py b/neo4j-app/neo4j_app/tests/icij_worker/worker/worker_main.py deleted file mode 100644 index e167fba4..00000000 --- a/neo4j-app/neo4j_app/tests/icij_worker/worker/worker_main.py +++ /dev/null @@ -1,51 +0,0 @@ -import asyncio -import multiprocessing -import sys -import tempfile -from contextlib import contextmanager -from json import JSONDecodeError -from pathlib import Path - -from neo4j_app.core import AppConfig -from neo4j_app.icij_worker import Neo4jAsyncWorker -from neo4j_app.tests.icij_worker.conftest import MockWorker - -_FMT = "[%(levelname)s][%(asctime)s.%(msecs)03d][%(name)s]: %(message)s" -_DATE_FMT = "%H:%M:%S" - - -@contextmanager -def db_path_cm(test: bool): - if not test: - yield None - else: - with tempfile.NamedTemporaryFile(prefix="db") as db_f: - yield Path(db_f.name) - - -async def main(): - # Setup logger main logger - config_path = Path(sys.argv[1]) - worker_id = sys.argv[2] - try: - config = AppConfig.parse_file(config_path) - except JSONDecodeError: - with config_path.open() as f: - config = AppConfig.from_java_properties(f) - with multiprocessing.Manager() as m: - with db_path_cm(config.test) as db_path: - if db_path is not None: - # TODO: this will erase the DB each time, it should be done outside in - # case of multiple workers - MockWorker.fresh_db(db_path) - lock = m.Lock() - worker = MockWorker.from_config( - config, worker_id, db_path=db_path, lock=lock - ) - else: - worker = Neo4jAsyncWorker.from_config(config, worker_id) - await worker.work_forever() - - -if __name__ == "__main__": - asyncio.run(main())