diff --git a/neo4j-app/neo4j_app/app/utils.py b/neo4j-app/neo4j_app/app/utils.py index 9e2e791b..a01e2764 100644 --- a/neo4j-app/neo4j_app/app/utils.py +++ b/neo4j-app/neo4j_app/app/utils.py @@ -18,8 +18,8 @@ from neo4j_app.app.main import main_router from neo4j_app.app.named_entities import named_entities_router from neo4j_app.core import AppConfig -from neo4j_app.core.neo4j import MIGRATIONS, migrate_db_schema -from neo4j_app.core.neo4j.migrations import delete_all_migrations_tx +from neo4j_app.core.neo4j import MIGRATIONS, migrate_db_schemas +from neo4j_app.core.neo4j.migrations import delete_all_migrations from neo4j_app.core.utils.logging import DifferedLoggingMessage _REQUEST_VALIDATION_ERROR = "Request Validation Error" @@ -100,7 +100,7 @@ def create_app(config: AppConfig) -> FastAPI: app.add_exception_handler(StarletteHTTPException, http_exception_handler) app.add_exception_handler(Exception, internal_exception_handler) app.add_event_handler("startup", app.state.config.setup_loggers) - app.add_event_handler("startup", functools.partial(migrate_app_db, app)) + app.add_event_handler("startup", functools.partial(migrate_app_dbs, app)) app.include_router(main_router()) app.include_router(documents_router()) app.include_router(named_entities_router()) @@ -109,19 +109,19 @@ def create_app(config: AppConfig) -> FastAPI: return app -async def migrate_app_db(app: FastAPI): +async def migrate_app_dbs(app: FastAPI): config: AppConfig = app.state.config async with config.to_neo4j_driver() as driver: - async with driver.session() as sess: - logger.info("Running schema migrations at application startup...") - if config.force_migrations: - await sess.execute_write(delete_all_migrations_tx) - await migrate_db_schema( - sess, - registry=MIGRATIONS, - timeout_s=config.neo4j_app_migration_timeout_s, - throttle_s=config.neo4j_app_migration_throttle_s, - ) + logger.info("Running schema migrations at application startup...") + if config.force_migrations: + # TODO: improve this as is could lead to race conditions... + await delete_all_migrations(driver) + await migrate_db_schemas( + driver, + registry=MIGRATIONS, + timeout_s=config.neo4j_app_migration_timeout_s, + throttle_s=config.neo4j_app_migration_throttle_s, + ) def _display_errors(errors: List[Dict]) -> str: diff --git a/neo4j-app/neo4j_app/core/neo4j/__init__.py b/neo4j-app/neo4j_app/core/neo4j/__init__.py index 079f1d0a..3e194017 100644 --- a/neo4j-app/neo4j_app/core/neo4j/__init__.py +++ b/neo4j-app/neo4j_app/core/neo4j/__init__.py @@ -7,7 +7,7 @@ from .imports import Neo4Import, Neo4jImportWorker from .migrations import Migration -from .migrations.migrate import MigrationError, migrate_db_schema +from .migrations.migrate import MigrationError, migrate_db_schemas from .migrations.migrations import ( create_document_and_ne_id_unique_constraint_tx, create_migration_unique_constraint_tx, diff --git a/neo4j-app/neo4j_app/core/neo4j/migrations/__init__.py b/neo4j-app/neo4j_app/core/neo4j/migrations/__init__.py index 21f690ac..0853c284 100644 --- a/neo4j-app/neo4j_app/core/neo4j/migrations/__init__.py +++ b/neo4j-app/neo4j_app/core/neo4j/migrations/__init__.py @@ -1 +1 @@ -from .migrate import Migration, delete_all_migrations_tx +from .migrate import Migration, delete_all_migrations diff --git a/neo4j-app/neo4j_app/core/neo4j/migrations/migrate.py b/neo4j-app/neo4j_app/core/neo4j/migrations/migrate.py index 4c77c1b2..745f1dab 100644 --- a/neo4j-app/neo4j_app/core/neo4j/migrations/migrate.py +++ b/neo4j-app/neo4j_app/core/neo4j/migrations/migrate.py @@ -26,6 +26,8 @@ MigrationFn = Callable[[neo4j.AsyncTransaction], Coroutine] +_NEO4J_SYSTEM_DB = "system" + _MIGRATION_TIMEOUT_MSG = """Migration timeout expired ! Please check that a migration is indeed in progress. If the application is in a \ deadlock restart it forcing the migration index cleanup.""" @@ -156,57 +158,94 @@ async def migrations_tx(tx: neo4j.AsyncTransaction) -> List[Neo4jMigration]: return migrations -async def delete_all_migrations_tx(tx: neo4j.AsyncTransaction): +async def delete_all_migrations(driver: neo4j.AsyncDriver): query = f"""MATCH (m:{MIGRATION_NODE}) DETACH DELETE m """ - await tx.run(query) + await driver.execute_query(query) + + +async def retrieve_project_dbs(neo4j_driver: neo4j.AsyncDriver) -> List[str]: + query = f"""SHOW DATABASES YIELD name, currentStatus +WHERE currentStatus = "online" and name <> "{_NEO4J_SYSTEM_DB}" +RETURN name +""" + records, _, _ = await neo4j_driver.execute_query(query) + project_dbs = [rec["name"] for rec in records] + return project_dbs -async def migrate_db_schema( - neo4j_session: neo4j.AsyncSession, +async def migrate_db_schemas( + neo4j_driver: neo4j.AsyncDriver, registry: MigrationRegistry, *, timeout_s: float, throttle_s: float, ): + project_dbs = await retrieve_project_dbs(neo4j_driver) + tasks = [ + migrate_project_db_schema( + neo4j_driver, registry, db=db, timeout_s=timeout_s, throttle_s=throttle_s + ) + for db in project_dbs + ] + await asyncio.gather(*tasks) + + +# Retrieve all project DBs + + +async def migrate_project_db_schema( + neo4j_driver: neo4j.AsyncDriver, + registry: MigrationRegistry, + db: str, + *, + timeout_s: float, + throttle_s: float, +): + logger.info("Migrating project DB %s", db) start = time.monotonic() if not registry: return todo = sorted(registry, key=lambda m: m.version) - while "Waiting for DB to be migrated or for a timeout": - elapsed = time.monotonic() - start - if elapsed > timeout_s: - logger.error(_MIGRATION_TIMEOUT_MSG) - raise MigrationError(_MIGRATION_TIMEOUT_MSG) - migrations = await neo4j_session.execute_read(migrations_tx) - in_progress = [m for m in migrations if m.status is MigrationStatus.IN_PROGRESS] - if len(in_progress) > 1: - raise MigrationError(f"Found several migration in progress: {in_progress}") - if in_progress: - logger.info( - "Found that %s is in progress, waiting for %s seconds...", - in_progress[0].label, - throttle_s, - ) - await asyncio.sleep(throttle_s) - continue - done = [m for m in migrations if m.status is MigrationStatus.DONE] - if done: - current_version = max((m.version for m in done)) - todo = [m for m in todo if m.version > current_version] - if not todo: - break - try: - await _migration_wrapper(neo4j_session, todo[0]) - todo = todo[1:] - continue - except ConstraintError: - logger.info( - "Migration %s has just started somewhere else, " - " waiting for %s seconds...", - todo[0].label, - throttle_s, - ) - await asyncio.sleep(throttle_s) - continue + async with neo4j_driver.session(database=db) as sess: + while "Waiting for DB to be migrated or for a timeout": + elapsed = time.monotonic() - start + if elapsed > timeout_s: + logger.error(_MIGRATION_TIMEOUT_MSG) + raise MigrationError(_MIGRATION_TIMEOUT_MSG) + migrations = await sess.execute_read(migrations_tx) + in_progress = [ + m for m in migrations if m.status is MigrationStatus.IN_PROGRESS + ] + if len(in_progress) > 1: + raise MigrationError( + f"Found several migration in progress: {in_progress}" + ) + if in_progress: + logger.info( + "Found that %s is in progress, waiting for %s seconds...", + in_progress[0].label, + throttle_s, + ) + await asyncio.sleep(throttle_s) + continue + done = [m for m in migrations if m.status is MigrationStatus.DONE] + if done: + current_version = max((m.version for m in done)) + todo = [m for m in todo if m.version > current_version] + if not todo: + break + try: + await _migration_wrapper(sess, todo[0]) + todo = todo[1:] + continue + except ConstraintError: + logger.info( + "Migration %s has just started somewhere else, " + " waiting for %s seconds...", + todo[0].label, + throttle_s, + ) + await asyncio.sleep(throttle_s) + continue diff --git a/neo4j-app/neo4j_app/tests/core/neo4j/migrations/test_migrate.py b/neo4j-app/neo4j_app/tests/core/neo4j/migrations/test_migrate.py index 195ab0d1..9319db2d 100644 --- a/neo4j-app/neo4j_app/tests/core/neo4j/migrations/test_migrate.py +++ b/neo4j-app/neo4j_app/tests/core/neo4j/migrations/test_migrate.py @@ -7,7 +7,7 @@ import pytest_asyncio import neo4j_app -from neo4j_app.core.neo4j import FIRST_MIGRATION, Migration, migrate_db_schema +from neo4j_app.core.neo4j import FIRST_MIGRATION, Migration, migrate_db_schemas from neo4j_app.core.neo4j.migrations import migrate from neo4j_app.core.neo4j.migrations.migrate import ( MigrationError, @@ -21,12 +21,12 @@ @pytest_asyncio.fixture(scope="function") async def _migration_index_and_constraint( - neo4j_test_session: neo4j.AsyncSession, -) -> neo4j.AsyncSession: - await migrate_db_schema( - neo4j_test_session, _BASE_REGISTRY, timeout_s=30, throttle_s=0.1 + neo4j_test_driver: neo4j.AsyncDriver, +) -> neo4j.AsyncDriver: + await migrate_db_schemas( + neo4j_test_driver, _BASE_REGISTRY, timeout_s=30, throttle_s=0.1 ) - return neo4j_test_session + return neo4j_test_driver async def _create_indexes_tx(tx: neo4j.AsyncTransaction): @@ -70,33 +70,33 @@ async def _drop_constraint_tx(tx: neo4j.AsyncTransaction): ], ) async def test_migrate_db_schema( - _migration_index_and_constraint: neo4j.AsyncSession, # pylint: disable=invalid-name + _migration_index_and_constraint: neo4j.AsyncDriver, # pylint: disable=invalid-name registry: List[Migration], expected_indexes: Set[str], not_expected_indexes: Set[str], ): # Given - neo4j_session = _migration_index_and_constraint + neo4j_driver = _migration_index_and_constraint # When - await migrate_db_schema(neo4j_session, registry, timeout_s=10, throttle_s=0.1) + await migrate_db_schemas(neo4j_driver, registry, timeout_s=10, throttle_s=0.1) # Then - index_res = await neo4j_session.run("SHOW INDEXES") + index_res, _, _ = await neo4j_driver.execute_query("SHOW INDEXES") existing_indexes = set() - async for rec in index_res: + for rec in index_res: existing_indexes.add(rec["name"]) missing_indexes = expected_indexes - existing_indexes assert not missing_indexes assert not not_expected_indexes.intersection(existing_indexes) if registry: - db_migrations_res = await neo4j_session.run( + db_migrations_recs, _, _ = await neo4j_driver.execute_query( "MATCH (m:Migration) RETURN m as migration" ) db_migrations = [ Neo4jMigration.from_neo4j(rec, key="migration") - async for rec in db_migrations_res + for rec in db_migrations_recs ] assert len(db_migrations) == len(registry) + 1 assert all(m.status is MigrationStatus.DONE for m in db_migrations) @@ -107,26 +107,27 @@ async def test_migrate_db_schema( @pytest.mark.asyncio async def test_migrate_db_schema_should_raise_after_timeout( - neo4j_test_session_session: neo4j.AsyncSession, + neo4j_test_driver_session: neo4j.AsyncDriver, ): # Given - neo4j_session = neo4j_test_session_session + neo4j_driver = neo4j_test_driver_session registry = [_MIGRATION_0] # When expected_msg = "Migration timeout expired" with pytest.raises(MigrationError, match=expected_msg): - await migrate_db_schema(neo4j_session, registry, timeout_s=0, throttle_s=0.1) + await migrate_db_schemas(neo4j_driver, registry, timeout_s=0, throttle_s=0.1) @pytest.mark.asyncio async def test_migrate_db_schema_should_wait_when_other_migration_in_progress( caplog, monkeypatch, - _migration_index_and_constraint: neo4j.AsyncSession, # pylint: disable=invalid-name + _migration_index_and_constraint: neo4j.AsyncDriver, + # pylint: disable=invalid-name ): # Given - neo4j_session_0 = _migration_index_and_constraint + neo4j_driver_0 = _migration_index_and_constraint caplog.set_level(logging.INFO, logger=neo4j_app.__name__) async def mocked_get_migrations( @@ -148,8 +149,8 @@ async def mocked_get_migrations( with pytest.raises(MigrationError, match=expected_msg): timeout_s = 0.5 wait_s = 0.1 - await migrate_db_schema( - neo4j_session_0, + await migrate_db_schemas( + neo4j_driver_0, [_MIGRATION_0, _MIGRATION_1], timeout_s=timeout_s, throttle_s=wait_s, @@ -164,10 +165,12 @@ async def mocked_get_migrations( @pytest.mark.asyncio async def test_migrate_db_schema_should_wait_when_other_migration_just_started( - monkeypatch, caplog, _migration_index_and_constraint # pylint: disable=invalid-name + monkeypatch, + caplog, + _migration_index_and_constraint: neo4j.AsyncDriver, # pylint: disable=invalid-name ): # Given - neo4j_session = _migration_index_and_constraint + neo4j_driver = _migration_index_and_constraint caplog.set_level(logging.INFO, logger=neo4j_app.__name__) async def mocked_get_migrations( @@ -186,7 +189,7 @@ async def mocked_get_migrations( started: $started }) """ - await neo4j_session.run( + await neo4j_driver.execute_query( query, version=str(_MIGRATION_0.version), started=datetime.now() ) try: @@ -195,8 +198,8 @@ async def mocked_get_migrations( with pytest.raises(MigrationError, match=expected_msg): timeout_s = 0.5 wait_s = 0.1 - await migrate_db_schema( - neo4j_session, + await migrate_db_schemas( + neo4j_driver, [_MIGRATION_0], timeout_s=timeout_s, throttle_s=wait_s, @@ -209,4 +212,5 @@ async def mocked_get_migrations( ) finally: # Don't forget to cleanup other the DB will be locked - await wipe_db(neo4j_session) + async with neo4j_driver.session(database="neo4j") as sess: + await wipe_db(sess)