Skip to content

Commit

Permalink
fix: migrate all projects
Browse files Browse the repository at this point in the history
  • Loading branch information
ClemDoum committed Aug 4, 2023
1 parent ce61544 commit 25dc5ed
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 82 deletions.
28 changes: 14 additions & 14 deletions neo4j-app/neo4j_app/app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from neo4j_app.app.main import main_router
from neo4j_app.app.named_entities import named_entities_router
from neo4j_app.core import AppConfig
from neo4j_app.core.neo4j import MIGRATIONS, migrate_db_schema
from neo4j_app.core.neo4j.migrations import delete_all_migrations_tx
from neo4j_app.core.neo4j import MIGRATIONS, migrate_db_schemas
from neo4j_app.core.neo4j.migrations import delete_all_migrations
from neo4j_app.core.utils.logging import DifferedLoggingMessage

_REQUEST_VALIDATION_ERROR = "Request Validation Error"
Expand Down Expand Up @@ -100,7 +100,7 @@ def create_app(config: AppConfig) -> FastAPI:
app.add_exception_handler(StarletteHTTPException, http_exception_handler)
app.add_exception_handler(Exception, internal_exception_handler)
app.add_event_handler("startup", app.state.config.setup_loggers)
app.add_event_handler("startup", functools.partial(migrate_app_db, app))
app.add_event_handler("startup", functools.partial(migrate_app_dbs, app))
app.include_router(main_router())
app.include_router(documents_router())
app.include_router(named_entities_router())
Expand All @@ -109,19 +109,19 @@ def create_app(config: AppConfig) -> FastAPI:
return app


async def migrate_app_db(app: FastAPI):
async def migrate_app_dbs(app: FastAPI):
config: AppConfig = app.state.config
async with config.to_neo4j_driver() as driver:
async with driver.session() as sess:
logger.info("Running schema migrations at application startup...")
if config.force_migrations:
await sess.execute_write(delete_all_migrations_tx)
await migrate_db_schema(
sess,
registry=MIGRATIONS,
timeout_s=config.neo4j_app_migration_timeout_s,
throttle_s=config.neo4j_app_migration_throttle_s,
)
logger.info("Running schema migrations at application startup...")
if config.force_migrations:
# TODO: improve this as is could lead to race conditions...
await delete_all_migrations(driver)
await migrate_db_schemas(
driver,
registry=MIGRATIONS,
timeout_s=config.neo4j_app_migration_timeout_s,
throttle_s=config.neo4j_app_migration_throttle_s,
)


def _display_errors(errors: List[Dict]) -> str:
Expand Down
2 changes: 1 addition & 1 deletion neo4j-app/neo4j_app/core/neo4j/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .imports import Neo4Import, Neo4jImportWorker
from .migrations import Migration
from .migrations.migrate import MigrationError, migrate_db_schema
from .migrations.migrate import MigrationError, migrate_db_schemas
from .migrations.migrations import (
create_document_and_ne_id_unique_constraint_tx,
create_migration_unique_constraint_tx,
Expand Down
2 changes: 1 addition & 1 deletion neo4j-app/neo4j_app/core/neo4j/migrations/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .migrate import Migration, delete_all_migrations_tx
from .migrate import Migration, delete_all_migrations
119 changes: 79 additions & 40 deletions neo4j-app/neo4j_app/core/neo4j/migrations/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

MigrationFn = Callable[[neo4j.AsyncTransaction], Coroutine]

_NEO4J_SYSTEM_DB = "system"

_MIGRATION_TIMEOUT_MSG = """Migration timeout expired !
Please check that a migration is indeed in progress. If the application is in a \
deadlock restart it forcing the migration index cleanup."""
Expand Down Expand Up @@ -156,57 +158,94 @@ async def migrations_tx(tx: neo4j.AsyncTransaction) -> List[Neo4jMigration]:
return migrations


async def delete_all_migrations_tx(tx: neo4j.AsyncTransaction):
async def delete_all_migrations(driver: neo4j.AsyncDriver):
query = f"""MATCH (m:{MIGRATION_NODE})
DETACH DELETE m
"""
await tx.run(query)
await driver.execute_query(query)


async def retrieve_project_dbs(neo4j_driver: neo4j.AsyncDriver) -> List[str]:
query = f"""SHOW DATABASES YIELD name, currentStatus
WHERE currentStatus = "online" and name <> "{_NEO4J_SYSTEM_DB}"
RETURN name
"""
records, _, _ = await neo4j_driver.execute_query(query)
project_dbs = [rec["name"] for rec in records]
return project_dbs


async def migrate_db_schema(
neo4j_session: neo4j.AsyncSession,
async def migrate_db_schemas(
neo4j_driver: neo4j.AsyncDriver,
registry: MigrationRegistry,
*,
timeout_s: float,
throttle_s: float,
):
project_dbs = await retrieve_project_dbs(neo4j_driver)
tasks = [
migrate_project_db_schema(
neo4j_driver, registry, db=db, timeout_s=timeout_s, throttle_s=throttle_s
)
for db in project_dbs
]
await asyncio.gather(*tasks)


# Retrieve all project DBs


async def migrate_project_db_schema(
neo4j_driver: neo4j.AsyncDriver,
registry: MigrationRegistry,
db: str,
*,
timeout_s: float,
throttle_s: float,
):
logger.info("Migrating project DB %s", db)
start = time.monotonic()
if not registry:
return
todo = sorted(registry, key=lambda m: m.version)
while "Waiting for DB to be migrated or for a timeout":
elapsed = time.monotonic() - start
if elapsed > timeout_s:
logger.error(_MIGRATION_TIMEOUT_MSG)
raise MigrationError(_MIGRATION_TIMEOUT_MSG)
migrations = await neo4j_session.execute_read(migrations_tx)
in_progress = [m for m in migrations if m.status is MigrationStatus.IN_PROGRESS]
if len(in_progress) > 1:
raise MigrationError(f"Found several migration in progress: {in_progress}")
if in_progress:
logger.info(
"Found that %s is in progress, waiting for %s seconds...",
in_progress[0].label,
throttle_s,
)
await asyncio.sleep(throttle_s)
continue
done = [m for m in migrations if m.status is MigrationStatus.DONE]
if done:
current_version = max((m.version for m in done))
todo = [m for m in todo if m.version > current_version]
if not todo:
break
try:
await _migration_wrapper(neo4j_session, todo[0])
todo = todo[1:]
continue
except ConstraintError:
logger.info(
"Migration %s has just started somewhere else, "
" waiting for %s seconds...",
todo[0].label,
throttle_s,
)
await asyncio.sleep(throttle_s)
continue
async with neo4j_driver.session(database=db) as sess:
while "Waiting for DB to be migrated or for a timeout":
elapsed = time.monotonic() - start
if elapsed > timeout_s:
logger.error(_MIGRATION_TIMEOUT_MSG)
raise MigrationError(_MIGRATION_TIMEOUT_MSG)
migrations = await sess.execute_read(migrations_tx)
in_progress = [
m for m in migrations if m.status is MigrationStatus.IN_PROGRESS
]
if len(in_progress) > 1:
raise MigrationError(
f"Found several migration in progress: {in_progress}"
)
if in_progress:
logger.info(
"Found that %s is in progress, waiting for %s seconds...",
in_progress[0].label,
throttle_s,
)
await asyncio.sleep(throttle_s)
continue
done = [m for m in migrations if m.status is MigrationStatus.DONE]
if done:
current_version = max((m.version for m in done))
todo = [m for m in todo if m.version > current_version]
if not todo:
break
try:
await _migration_wrapper(sess, todo[0])
todo = todo[1:]
continue
except ConstraintError:
logger.info(
"Migration %s has just started somewhere else, "
" waiting for %s seconds...",
todo[0].label,
throttle_s,
)
await asyncio.sleep(throttle_s)
continue
56 changes: 30 additions & 26 deletions neo4j-app/neo4j_app/tests/core/neo4j/migrations/test_migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest_asyncio

import neo4j_app
from neo4j_app.core.neo4j import FIRST_MIGRATION, Migration, migrate_db_schema
from neo4j_app.core.neo4j import FIRST_MIGRATION, Migration, migrate_db_schemas
from neo4j_app.core.neo4j.migrations import migrate
from neo4j_app.core.neo4j.migrations.migrate import (
MigrationError,
Expand All @@ -21,12 +21,12 @@

@pytest_asyncio.fixture(scope="function")
async def _migration_index_and_constraint(
neo4j_test_session: neo4j.AsyncSession,
) -> neo4j.AsyncSession:
await migrate_db_schema(
neo4j_test_session, _BASE_REGISTRY, timeout_s=30, throttle_s=0.1
neo4j_test_driver: neo4j.AsyncDriver,
) -> neo4j.AsyncDriver:
await migrate_db_schemas(
neo4j_test_driver, _BASE_REGISTRY, timeout_s=30, throttle_s=0.1
)
return neo4j_test_session
return neo4j_test_driver


async def _create_indexes_tx(tx: neo4j.AsyncTransaction):
Expand Down Expand Up @@ -70,33 +70,33 @@ async def _drop_constraint_tx(tx: neo4j.AsyncTransaction):
],
)
async def test_migrate_db_schema(
_migration_index_and_constraint: neo4j.AsyncSession, # pylint: disable=invalid-name
_migration_index_and_constraint: neo4j.AsyncDriver, # pylint: disable=invalid-name
registry: List[Migration],
expected_indexes: Set[str],
not_expected_indexes: Set[str],
):
# Given
neo4j_session = _migration_index_and_constraint
neo4j_driver = _migration_index_and_constraint

# When
await migrate_db_schema(neo4j_session, registry, timeout_s=10, throttle_s=0.1)
await migrate_db_schemas(neo4j_driver, registry, timeout_s=10, throttle_s=0.1)

# Then
index_res = await neo4j_session.run("SHOW INDEXES")
index_res, _, _ = await neo4j_driver.execute_query("SHOW INDEXES")
existing_indexes = set()
async for rec in index_res:
for rec in index_res:
existing_indexes.add(rec["name"])
missing_indexes = expected_indexes - existing_indexes
assert not missing_indexes
assert not not_expected_indexes.intersection(existing_indexes)

if registry:
db_migrations_res = await neo4j_session.run(
db_migrations_recs, _, _ = await neo4j_driver.execute_query(
"MATCH (m:Migration) RETURN m as migration"
)
db_migrations = [
Neo4jMigration.from_neo4j(rec, key="migration")
async for rec in db_migrations_res
for rec in db_migrations_recs
]
assert len(db_migrations) == len(registry) + 1
assert all(m.status is MigrationStatus.DONE for m in db_migrations)
Expand All @@ -107,26 +107,27 @@ async def test_migrate_db_schema(

@pytest.mark.asyncio
async def test_migrate_db_schema_should_raise_after_timeout(
neo4j_test_session_session: neo4j.AsyncSession,
neo4j_test_driver_session: neo4j.AsyncDriver,
):
# Given
neo4j_session = neo4j_test_session_session
neo4j_driver = neo4j_test_driver_session
registry = [_MIGRATION_0]

# When
expected_msg = "Migration timeout expired"
with pytest.raises(MigrationError, match=expected_msg):
await migrate_db_schema(neo4j_session, registry, timeout_s=0, throttle_s=0.1)
await migrate_db_schemas(neo4j_driver, registry, timeout_s=0, throttle_s=0.1)


@pytest.mark.asyncio
async def test_migrate_db_schema_should_wait_when_other_migration_in_progress(
caplog,
monkeypatch,
_migration_index_and_constraint: neo4j.AsyncSession, # pylint: disable=invalid-name
_migration_index_and_constraint: neo4j.AsyncDriver,
# pylint: disable=invalid-name
):
# Given
neo4j_session_0 = _migration_index_and_constraint
neo4j_driver_0 = _migration_index_and_constraint
caplog.set_level(logging.INFO, logger=neo4j_app.__name__)

async def mocked_get_migrations(
Expand All @@ -148,8 +149,8 @@ async def mocked_get_migrations(
with pytest.raises(MigrationError, match=expected_msg):
timeout_s = 0.5
wait_s = 0.1
await migrate_db_schema(
neo4j_session_0,
await migrate_db_schemas(
neo4j_driver_0,
[_MIGRATION_0, _MIGRATION_1],
timeout_s=timeout_s,
throttle_s=wait_s,
Expand All @@ -164,10 +165,12 @@ async def mocked_get_migrations(

@pytest.mark.asyncio
async def test_migrate_db_schema_should_wait_when_other_migration_just_started(
monkeypatch, caplog, _migration_index_and_constraint # pylint: disable=invalid-name
monkeypatch,
caplog,
_migration_index_and_constraint: neo4j.AsyncDriver, # pylint: disable=invalid-name
):
# Given
neo4j_session = _migration_index_and_constraint
neo4j_driver = _migration_index_and_constraint
caplog.set_level(logging.INFO, logger=neo4j_app.__name__)

async def mocked_get_migrations(
Expand All @@ -186,7 +189,7 @@ async def mocked_get_migrations(
started: $started
})
"""
await neo4j_session.run(
await neo4j_driver.execute_query(
query, version=str(_MIGRATION_0.version), started=datetime.now()
)
try:
Expand All @@ -195,8 +198,8 @@ async def mocked_get_migrations(
with pytest.raises(MigrationError, match=expected_msg):
timeout_s = 0.5
wait_s = 0.1
await migrate_db_schema(
neo4j_session,
await migrate_db_schemas(
neo4j_driver,
[_MIGRATION_0],
timeout_s=timeout_s,
throttle_s=wait_s,
Expand All @@ -209,4 +212,5 @@ async def mocked_get_migrations(
)
finally:
# Don't forget to cleanup other the DB will be locked
await wipe_db(neo4j_session)
async with neo4j_driver.session(database="neo4j") as sess:
await wipe_db(sess)

0 comments on commit 25dc5ed

Please sign in to comment.