diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6db1a96711..96388dc021 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -23,7 +23,7 @@ jobs: image: postgres env: POSTGRES_PASSWORD: postgres - POSTGRES_DB: fractal + POSTGRES_DB: fractal_test options: >- --health-cmd pg_isready --health-interval 10s @@ -62,7 +62,7 @@ jobs: run: poetry install --with dev --no-interaction --all-extras - name: Test with pytest - run: poetry run coverage run -m pytest + run: poetry run coverage run -m pytest tests - name: Upload coverage data uses: actions/upload-artifact@v3 diff --git a/CHANGELOG.md b/CHANGELOG.md index 34068b3c86..c27cdec095 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ # 1.2.5 +* Fix bug in task collection when using sqlite (\#664). * Improve error handling in workflow-apply endpoint (\#665). * Fix a bug upon project removal in the presence of project-related jobs (\#666). Note: this removes the `ApplyWorkflow.Project` attribute. diff --git a/fractal_server/app/api/v1/task.py b/fractal_server/app/api/v1/task.py index 364f76a275..f9eb030a00 100644 --- a/fractal_server/app/api/v1/task.py +++ b/fractal_server/app/api/v1/task.py @@ -60,81 +60,80 @@ async def _background_collect_pip( directory. """ - # Note: anext(get_db()) is only available for python>=3.10 - db = await get_db().__anext__() + async for db in get_db(): - state: State = await db.get(State, state_id) + state: State = await db.get(State, state_id) - logger_name = task_pkg.package.replace("/", "_") - logger = set_logger( - logger_name=logger_name, - log_file_path=get_log_path(venv_path), - ) - - logger.debug("Start background task collection") - data = TaskCollectStatus(**state.data) - data.info = None - - try: - # install - logger.debug("Task-collection status: installing") - data.status = "installing" - - state.data = data.sanitised_dict() - await db.merge(state) - await db.commit() - task_list = await create_package_environment_pip( - venv_path=venv_path, - task_pkg=task_pkg, + logger_name = task_pkg.package.replace("/", "_") + logger = set_logger( logger_name=logger_name, + log_file_path=get_log_path(venv_path), ) - # collect - logger.debug("Task-collection status: collecting") - data.status = "collecting" - state.data = data.sanitised_dict() - await db.merge(state) - await db.commit() - tasks = await _insert_tasks(task_list=task_list, db=db) - - # finalise - logger.debug("Task-collection status: finalising") - collection_path = get_collection_path(venv_path) - data.task_list = tasks - with collection_path.open("w") as f: - json.dump(data.sanitised_dict(), f) - - # Update DB - data.status = "OK" - data.log = get_collection_log(venv_path) - state.data = data.sanitised_dict() - db.add(state) - await db.merge(state) - await db.commit() - - # Write last logs to file - logger.debug("Task-collection status: OK") - logger.info("Background task collection completed successfully") - close_logger(logger) - await db.close() + logger.debug("Start background task collection") + data = TaskCollectStatus(**state.data) + data.info = None - except Exception as e: - # Write last logs to file - logger.debug("Task-collection status: fail") - logger.info(f"Background collection failed. Original error: {e}") - close_logger(logger) + try: + # install + logger.debug("Task-collection status: installing") + data.status = "installing" + + state.data = data.sanitised_dict() + await db.merge(state) + await db.commit() + task_list = await create_package_environment_pip( + venv_path=venv_path, + task_pkg=task_pkg, + logger_name=logger_name, + ) - # Update db - data.status = "fail" - data.info = f"Original error: {e}" - data.log = get_collection_log(venv_path) - state.data = data.sanitised_dict() - await db.merge(state) - await db.commit() - await db.close() + # collect + logger.debug("Task-collection status: collecting") + data.status = "collecting" + state.data = data.sanitised_dict() + await db.merge(state) + await db.commit() + tasks = await _insert_tasks(task_list=task_list, db=db) + + # finalise + logger.debug("Task-collection status: finalising") + collection_path = get_collection_path(venv_path) + data.task_list = tasks + with collection_path.open("w") as f: + json.dump(data.sanitised_dict(), f) + + # Update DB + data.status = "OK" + data.log = get_collection_log(venv_path) + state.data = data.sanitised_dict() + db.add(state) + await db.merge(state) + await db.commit() + + # Write last logs to file + logger.debug("Task-collection status: OK") + logger.info("Background task collection completed successfully") + close_logger(logger) + await db.close() + + except Exception as e: + # Write last logs to file + logger.debug("Task-collection status: fail") + logger.info(f"Background collection failed. Original error: {e}") + close_logger(logger) + + # Update db + data.status = "fail" + data.info = f"Original error: {e}" + data.log = get_collection_log(venv_path) + state.data = data.sanitised_dict() + await db.merge(state) + await db.commit() + await db.close() - # Delete corrupted package dir - shell_rmtree(venv_path) + # Delete corrupted package dir + shell_rmtree(venv_path) async def _insert_tasks( diff --git a/fractal_server/app/db/__init__.py b/fractal_server/app/db/__init__.py index 9a3cf37111..69dedc2f14 100644 --- a/fractal_server/app/db/__init__.py +++ b/fractal_server/app/db/__init__.py @@ -10,6 +10,7 @@ from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.orm import Session as DBSyncSession from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import StaticPool from ...config import get_settings from ...logger import set_logger @@ -52,18 +53,28 @@ def set_db(cls): "the database cannot be guaranteed." ) + # Set some sqlite-specific options + if settings.DB_ENGINE == "sqlite": + engine_kwargs_async = dict(poolclass=StaticPool) + engine_kwargs_sync = dict( + poolclass=StaticPool, + connect_args={"check_same_thread": False}, + ) + else: + engine_kwargs_async = {} + engine_kwargs_sync = {} + cls._engine_async = create_async_engine( - settings.DATABASE_URL, echo=settings.DB_ECHO, future=True + settings.DATABASE_URL, + echo=settings.DB_ECHO, + future=True, + **engine_kwargs_async, ) cls._engine_sync = create_engine( settings.DATABASE_SYNC_URL, echo=settings.DB_ECHO, future=True, - connect_args=( - {"check_same_thread": False} - if settings.DB_ENGINE == "sqlite" - else {} - ), + **engine_kwargs_sync, ) cls._async_session_maker = sessionmaker( diff --git a/tests/fixtures_server.py b/tests/fixtures_server.py index 87d63b3386..f1b5d704ee 100644 --- a/tests/fixtures_server.py +++ b/tests/fixtures_server.py @@ -86,14 +86,12 @@ def get_patched_settings(temp_path: Path): settings.DB_ENGINE = DB_ENGINE if DB_ENGINE == "sqlite": - settings.SQLITE_PATH = ( - f"{temp_path.as_posix()}/_test.db?mode=memory&cache=shared" - ) + settings.SQLITE_PATH = f"{temp_path.as_posix()}/_test.db" elif DB_ENGINE == "postgres": settings.DB_ENGINE = "postgres" settings.POSTGRES_USER = "postgres" settings.POSTGRES_PASSWORD = "postgres" - settings.POSTGRES_DB = "fractal" + settings.POSTGRES_DB = "fractal_test" else: raise ValueError @@ -202,6 +200,7 @@ async def db_create_tables(override_settings): engine = DB.engine_sync() metadata = SQLModel.metadata metadata.create_all(engine) + yield metadata.drop_all(engine) diff --git a/tests/test_background_tasks.py b/tests/test_background_tasks.py new file mode 100644 index 0000000000..ad00c34601 --- /dev/null +++ b/tests/test_background_tasks.py @@ -0,0 +1,120 @@ +""" +See https://github.com/fractal-analytics-platform/fractal-server/issues/661 +""" +from typing import Any +from typing import AsyncGenerator + +import pytest +from asgi_lifespan import LifespanManager +from devtools import debug +from fastapi import BackgroundTasks +from fastapi import Depends +from fastapi import FastAPI +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession + +from fractal_server.app.api import router_default +from fractal_server.app.db import DBSyncSession +from fractal_server.app.db import get_db +from fractal_server.app.db import get_sync_db +from fractal_server.app.models import State +from fractal_server.logger import set_logger + +logger = set_logger(__name__) + + +# BackgroundTasks functions + + +async def bgtask_sync_db(state_id: int): + """ + This is a function to be executed as a background task, and it uses a + sync db session. + """ + new_db: DBSyncSession = next(get_sync_db()) + state = new_db.get(State, state_id) + state.data = {"c": "d"} + new_db.merge(state) + new_db.commit() + new_db.close() + + +async def bgtask_async_db(state_id: int): + """ + This is a function to be executed as a background task, and it uses an + async db session. + """ + logger.critical("bgtask_async_db START") + async for new_db in get_db(): + state = await new_db.get(State, state_id) + state.data = {"a": "b"} + await new_db.merge(state) + await new_db.commit() + await new_db.close() + logger.critical("bgtask_async_db END") + + +# New endpoints and client + + +@router_default.get("/test_async") +async def run_background_task_async( + background_tasks: BackgroundTasks, + db: AsyncSession = Depends(get_db), +): + """Endpoint that calls bgtask_async_db in background.""" + logger.critical("START run_background_task_async") + state = State() + db.add(state) + await db.commit() + debug(state) + state_id = state.id + await db.close() + logger.critical("END run_background_task_async") + + background_tasks.add_task(bgtask_async_db, state_id) + + +@router_default.get("/test_sync") +async def run_background_task_sync( + background_tasks: BackgroundTasks, + db: AsyncSession = Depends(get_db), +): + """Endpoint that calls bgtask_sync_db in background.""" + + logger.critical("START run_background_task_sync") + state = State() + db.add(state) + await db.commit() + debug(state) + state_id = state.id + await db.close() + logger.critical("END run_background_task_sync") + + background_tasks.add_task(bgtask_sync_db, state_id) + + +@pytest.fixture +async def client_for_bgtasks( + app: FastAPI, + db: AsyncSession, +) -> AsyncGenerator[AsyncClient, Any]: + """Client wich includes the two new endpoints.""" + + app.include_router(router_default, prefix="/test_bgtasks") + async with AsyncClient( + app=app, base_url="http://test" + ) as client, LifespanManager(app): + yield client + + +async def test_async_db(db, client_for_bgtasks): + """Call the run_background_task_async endpoint""" + res = await client_for_bgtasks.get("http://test_bgtasks/test_async") + debug(res) + + +async def test_sync_db(db, db_sync, client_for_bgtasks): + """Call the run_background_task_sync endpoint""" + res = await client_for_bgtasks.get("http://test_bgtasks/test_sync") + debug(res) diff --git a/tests/test_startup_commands.py b/tests/test_startup_commands.py index 93fa32cfeb..70fd976971 100644 --- a/tests/test_startup_commands.py +++ b/tests/test_startup_commands.py @@ -34,7 +34,7 @@ def _prepare_config_and_db(_tmp_path: Path): [ "POSTGRES_USER=postgres", "POSTGRES_PASSWORD=postgres", - "POSTGRES_DB=fractal", + "POSTGRES_DB=fractal_test", ] ) elif DB_ENGINE == "sqlite":