Skip to content

Commit

Permalink
chore: refactor test dependency injection
Browse files Browse the repository at this point in the history
  • Loading branch information
ClemDoum committed Jan 31, 2024
1 parent 7636a26 commit 6bbaf38
Show file tree
Hide file tree
Showing 17 changed files with 395 additions and 372 deletions.
12 changes: 5 additions & 7 deletions neo4j-app/neo4j_app/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,13 @@
"log_level",
]

_DEFAULT_ASYNC_DEPS = "neo4j_app.tasks.ASYNC_APP_LIFESPAN_DEPS"
_DEFAULT_DEPS = "neo4j_app.app.dependencies.HTTP_SERVICE_LIFESPAN_DEPS"


class ServiceConfig(AppConfig):
neo4j_app_async_dependencies: Optional[str] = "neo4j_app.tasks.WORKER_LIFESPAN_DEPS"
neo4j_app_async_app: str = "neo4j_app.tasks.app"
neo4j_app_async_app: Optional[str] = "neo4j_app.tasks.app"
neo4j_app_dependencies: Optional[str] = _DEFAULT_DEPS
neo4j_app_gunicorn_workers: int = 1
neo4j_app_host: str = "127.0.0.1"
neo4j_app_n_async_workers: int = 1
Expand All @@ -50,11 +53,6 @@ def to_worker_config(self, **kwargs) -> WorkerConfig:
kwargs = copy(kwargs)
for suffix in _SHARED_WITH_NEO4J_WORKER_CONFIG_PREFIXED:
kwargs[suffix] = getattr(self, f"neo4j_app_{suffix}")

if self.test:
from neo4j_app.tests.icij_worker.conftest import MockWorkerConfig

return MockWorkerConfig(**kwargs)
from neo4j_app.icij_worker.worker.neo4j import Neo4jWorkerConfig

for k in _SHARED_WITH_NEO4J_WORKER_CONFIG:
Expand Down
134 changes: 26 additions & 108 deletions neo4j-app/neo4j_app/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
import os
import tempfile
from contextlib import asynccontextmanager
from multiprocessing.managers import SyncManager
from pathlib import Path
from typing import Optional, cast
from typing import Dict, Optional, cast

import neo4j
from fastapi import FastAPI
Expand All @@ -15,12 +14,14 @@
from neo4j_app.icij_worker import (
EventPublisher,
Neo4jEventPublisher,
WorkerConfig,
)
from neo4j_app.icij_worker.backend.backend import WorkerBackend
from neo4j_app.icij_worker.task_manager import TaskManager
from neo4j_app.icij_worker.task_manager.neo4j import Neo4JTaskManager
from neo4j_app.icij_worker.utils import run_deps
from neo4j_app.icij_worker.utils.dependencies import DependencyInjectionError
from neo4j_app.icij_worker.utils.imports import import_variable
from neo4j_app.tasks.dependencies import (
config_enter,
create_project_registry_db_enter,
Expand All @@ -40,10 +41,7 @@
_EVENT_PUBLISHER: Optional[EventPublisher] = None
_MP_CONTEXT = None
_NEO4J_DRIVER: Optional[neo4j.AsyncDriver] = None
_PROCESS_MANAGER: Optional[SyncManager] = None
_TASK_MANAGER: Optional[TaskManager] = None
_TEST_DB_FILE: Optional[Path] = None
_TEST_LOCK: Optional[multiprocessing.Lock] = None
_WORKER_POOL_IS_RUNNING = False


Expand Down Expand Up @@ -89,86 +87,16 @@ def lifespan_mp_context():
return _MP_CONTEXT


def test_db_path_enter(**_):
config = cast(
ServiceConfig,
lifespan_config(),
)
if config.test:
# pylint: disable=consider-using-with
from neo4j_app.tests.icij_worker.conftest import DBMixin

global _TEST_DB_FILE
_TEST_DB_FILE = tempfile.NamedTemporaryFile(prefix="db", suffix=".json")

DBMixin.fresh_db(Path(_TEST_DB_FILE.name))
_TEST_DB_FILE.__enter__() # pylint: disable=unnecessary-dunder-call


def test_db_path_exit(exc_type, exc_value, trace):
if _TEST_DB_FILE is not None:
_TEST_DB_FILE.__exit__(exc_type, exc_value, trace)


def _lifespan_test_db_path() -> Path:
if _TEST_DB_FILE is None:
raise DependencyInjectionError("test db path")
return Path(_TEST_DB_FILE.name)


def test_process_manager_enter(**_):
global _PROCESS_MANAGER
_PROCESS_MANAGER = lifespan_mp_context().Manager()


def test_process_manager_exit(exc_type, exc_value, trace):
_PROCESS_MANAGER.__exit__(exc_type, exc_value, trace)


def lifespan_test_process_manager() -> SyncManager:
if _PROCESS_MANAGER is None:
raise DependencyInjectionError("process manager")
return _PROCESS_MANAGER


def _test_lock_enter(**_):
config = cast(
ServiceConfig,
lifespan_config(),
)
if config.test:
global _TEST_LOCK
_TEST_LOCK = lifespan_test_process_manager().Lock()


def _lifespan_test_lock() -> multiprocessing.Lock:
if _TEST_LOCK is None:
raise DependencyInjectionError("test lock")
return cast(multiprocessing.Lock, _TEST_LOCK)


def lifespan_worker_pool_is_running() -> bool:
return _WORKER_POOL_IS_RUNNING


def task_manager_enter(**_):
global _TASK_MANAGER
config = cast(
ServiceConfig,
lifespan_config(),
config = cast(ServiceConfig, lifespan_config())
_TASK_MANAGER = Neo4JTaskManager(
lifespan_neo4j_driver(), max_queue_size=config.neo4j_app_task_queue_size
)
if config.test:
from neo4j_app.tests.icij_worker.conftest import MockManager

_TASK_MANAGER = MockManager(
_lifespan_test_db_path(),
_lifespan_test_lock(),
max_queue_size=config.neo4j_app_task_queue_size,
)
else:
_TASK_MANAGER = Neo4JTaskManager(
lifespan_neo4j_driver(), max_queue_size=config.neo4j_app_task_queue_size
)


def lifespan_task_manager() -> TaskManager:
Expand All @@ -179,18 +107,7 @@ def lifespan_task_manager() -> TaskManager:

def event_publisher_enter(**_):
global _EVENT_PUBLISHER
config = cast(
ServiceConfig,
lifespan_config(),
)
if config.test:
from neo4j_app.tests.icij_worker.conftest import MockEventPublisher

_EVENT_PUBLISHER = MockEventPublisher(
_lifespan_test_db_path(), _lifespan_test_lock()
)
else:
_EVENT_PUBLISHER = Neo4jEventPublisher(lifespan_neo4j_driver())
_EVENT_PUBLISHER = Neo4jEventPublisher(lifespan_neo4j_driver())


def lifespan_event_publisher() -> EventPublisher:
Expand All @@ -200,35 +117,39 @@ def lifespan_event_publisher() -> EventPublisher:


@asynccontextmanager
async def run_app_deps(app: FastAPI):
async def run_http_service_deps(
app: FastAPI,
async_app: str,
worker_config: WorkerConfig,
worker_extras: Optional[Dict] = None,
):
config = app.state.config
n_workers = config.neo4j_app_n_async_workers
async with run_deps(
dependencies=FASTAPI_LIFESPAN_DEPS, ctx="FastAPI HTTP server", config=config
):
deps = import_variable(config.neo4j_app_dependencies)
async with run_deps(dependencies=deps, ctx="FastAPI HTTP server", config=config):
# Compute the support only once we know the neo4j driver deps has successfully
# completed
app.state.config = await config.with_neo4j_support()
worker_extras = {"teardown_dependencies": config.test}
config_extra = dict()
# Forward the past of the app config to load to the async app
async_app_extras = {"config_path": _lifespan_async_app_config_path()}
if config.test:
config_extra["db_path"] = _lifespan_test_db_path()
worker_extras["lock"] = _lifespan_test_lock()
worker_config = config.to_worker_config(**config_extra)
# config_extra = dict()
# # Forward the part of the app config to load to the async app
# async_app_extras = {"config_path": _lifespan_async_app_config_path()}
# if is_test:
# config_extra["db_path"] = _lifespan_test_db_path()
# TODO 1: set the async app config path inside the deps itself
# TODO 3: set the DB path in deps
with WorkerBackend.MULTIPROCESSING.run_cm(
config.neo4j_app_async_app,
async_app,
n_workers=n_workers,
config=worker_config,
worker_extras=worker_extras,
app_deps_extras=async_app_extras,
):
global _WORKER_POOL_IS_RUNNING
_WORKER_POOL_IS_RUNNING = True
yield
_WORKER_POOL_IS_RUNNING = False


FASTAPI_LIFESPAN_DEPS = [
HTTP_SERVICE_LIFESPAN_DEPS = [
("configuration reading", config_enter, None),
("loggers setup", loggers_enter, None),
(
Expand All @@ -240,9 +161,6 @@ async def run_app_deps(app: FastAPI):
("neo4j project registry creation", create_project_registry_db_enter, None),
("ES client creation", es_client_enter, es_client_exit),
(None, mp_context_enter, None),
(None, test_process_manager_enter, test_process_manager_exit),
(None, test_db_path_enter, test_db_path_exit),
(None, _test_lock_enter, None),
("task manager creation", task_manager_enter, None),
("event publisher creation", event_publisher_enter, None),
("neo4j DB migration", migrate_app_db_enter, None),
Expand Down
27 changes: 21 additions & 6 deletions neo4j-app/neo4j_app/app/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import logging
import traceback
from typing import Dict, Iterable, List, Optional
Expand All @@ -13,15 +14,15 @@

from neo4j_app.app import ServiceConfig
from neo4j_app.app.admin import admin_router
from neo4j_app.app.dependencies import run_app_deps
from neo4j_app.app.dependencies import run_http_service_deps
from neo4j_app.app.doc import DOCUMENT_TAG, NE_TAG, OTHER_TAG
from neo4j_app.app.documents import documents_router
from neo4j_app.app.graphs import graphs_router
from neo4j_app.app.main import main_router
from neo4j_app.app.named_entities import named_entities_router
from neo4j_app.app.projects import projects_router
from neo4j_app.app.tasks import tasks_router
from neo4j_app.icij_worker import AsyncApp
from neo4j_app.icij_worker import WorkerConfig

INTERNAL_SERVER_ERROR = "Internal Server Error"
_REQUEST_VALIDATION_ERROR = "Request Validation Error"
Expand Down Expand Up @@ -83,15 +84,29 @@ def _debug():
logger.info("im here")


def create_app(config: ServiceConfig, async_app: Optional[AsyncApp] = None) -> FastAPI:
def create_app(
config: ServiceConfig,
async_app: Optional[str] = None,
worker_config: WorkerConfig = None,
worker_extras: Optional[Dict] = None,
) -> FastAPI:
if bool(async_app) == bool(config.neo4j_app_async_app):
raise ValueError("Please provide exactly one config")
async_app = async_app or config.neo4j_app_async_app
if worker_config is None:
worker_config = config.to_worker_config()
lifespan = functools.partial(
run_http_service_deps,
async_app=async_app,
worker_config=worker_config,
worker_extras=worker_extras,
)
app = FastAPI(
title=config.doc_app_name,
openapi_tags=_make_open_api_tags([DOCUMENT_TAG, NE_TAG, OTHER_TAG]),
lifespan=run_app_deps,
lifespan=lifespan,
)
app.state.config = config
if async_app is not None:
app.state.async_app = async_app
app.add_exception_handler(RequestValidationError, request_validation_error_handler)
app.add_exception_handler(StarletteHTTPException, http_exception_handler)
app.add_exception_handler(Exception, internal_exception_handler)
Expand Down
5 changes: 3 additions & 2 deletions neo4j-app/neo4j_app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from neo4j_app.core.utils.pydantic import (
IgnoreExtraModel,
LowerCamelCaseModel,
NoEnumModel,
safe_copy,
)

Expand All @@ -32,7 +33,7 @@ def _es_version() -> str:
return ".".join(str(num) for num in elasticsearch.__version__)


class AppConfig(LowerCamelCaseModel, IgnoreExtraModel):
class AppConfig(LowerCamelCaseModel, IgnoreExtraModel, NoEnumModel):
elasticsearch_address: str = "http://127.0.0.1:9200"
elasticsearch_version: str = Field(default_factory=_es_version, const=True)
es_doc_type_field: str = Field(alias="docTypeField", default="type")
Expand All @@ -44,7 +45,7 @@ class AppConfig(LowerCamelCaseModel, IgnoreExtraModel):
es_keep_alive: str = "1m"
force_migrations: bool = False
neo4j_app_log_level: str = "INFO"
neo4j_app_cancelled_task_refresh_interval_s: int = 2
neo4j_app_cancelled_tasks_refresh_interval_s: int = 2
neo4j_app_log_in_json: bool = False
neo4j_app_max_dumped_documents: Optional[int] = None
neo4j_app_max_records_in_memory: int = int(1e6)
Expand Down
6 changes: 0 additions & 6 deletions neo4j-app/neo4j_app/icij_worker/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def run(
n_workers: int,
config: WorkerConfig,
worker_extras: Optional[Dict] = None,
app_deps_extras: Optional[Dict] = None,
):
# This function is meant to be run as the main function of a Python command,
# in this case we want th main process to handle signals
Expand All @@ -30,7 +29,6 @@ def run(
config,
handle_signals=True,
worker_extras=worker_extras,
app_deps_extras=app_deps_extras,
):
pass

Expand All @@ -44,7 +42,6 @@ def run_cm(
n_workers: int,
config: WorkerConfig,
worker_extras: Optional[Dict] = None,
app_deps_extras: Optional[Dict] = None,
):
# This usage is meant for when a backend is run from another process which
# handles signals by itself
Expand All @@ -54,7 +51,6 @@ def run_cm(
config,
handle_signals=False,
worker_extras=worker_extras,
app_deps_extras=app_deps_extras,
):
yield

Expand All @@ -67,7 +63,6 @@ def _run_cm(
*,
handle_signals: bool = False,
worker_extras: Optional[Dict] = None,
app_deps_extras: Optional[Dict] = None,
):
if self is WorkerBackend.MULTIPROCESSING:
with run_workers_with_multiprocessing(
Expand All @@ -76,7 +71,6 @@ def _run_cm(
config,
handle_signals=handle_signals,
worker_extras=worker_extras,
app_deps_extras=app_deps_extras,
):
yield
else:
Expand Down
Loading

0 comments on commit 6bbaf38

Please sign in to comment.