Skip to content

Commit

Permalink
Merge pull request #664 from fractal-analytics-platform/debug_backgro…
Browse files Browse the repository at this point in the history
…und_task

Use StaticPool, for sqlite engine (ref #661)
  • Loading branch information
tcompa authored May 8, 2023
2 parents 7e30d7e + c944fd1 commit 011af83
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 80 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
image: postgres
env:
POSTGRES_PASSWORD: postgres
POSTGRES_DB: fractal
POSTGRES_DB: fractal_test
options: >-
--health-cmd pg_isready
--health-interval 10s
Expand Down Expand Up @@ -62,7 +62,7 @@ jobs:
run: poetry install --with dev --no-interaction --all-extras

- name: Test with pytest
run: poetry run coverage run -m pytest
run: poetry run coverage run -m pytest tests

- name: Upload coverage data
uses: actions/upload-artifact@v3
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# 1.2.5

* Fix bug in task collection when using sqlite (\#664).
* Improve error handling in workflow-apply endpoint (\#665).
* Fix a bug upon project removal in the presence of project-related jobs (\#666). Note: this removes the `ApplyWorkflow.Project` attribute.

Expand Down
133 changes: 66 additions & 67 deletions fractal_server/app/api/v1/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,81 +60,80 @@ async def _background_collect_pip(
directory.
"""

# Note: anext(get_db()) is only available for python>=3.10
db = await get_db().__anext__()
async for db in get_db():

state: State = await db.get(State, state_id)
state: State = await db.get(State, state_id)

logger_name = task_pkg.package.replace("/", "_")
logger = set_logger(
logger_name=logger_name,
log_file_path=get_log_path(venv_path),
)

logger.debug("Start background task collection")
data = TaskCollectStatus(**state.data)
data.info = None

try:
# install
logger.debug("Task-collection status: installing")
data.status = "installing"

state.data = data.sanitised_dict()
await db.merge(state)
await db.commit()
task_list = await create_package_environment_pip(
venv_path=venv_path,
task_pkg=task_pkg,
logger_name = task_pkg.package.replace("/", "_")
logger = set_logger(
logger_name=logger_name,
log_file_path=get_log_path(venv_path),
)

# collect
logger.debug("Task-collection status: collecting")
data.status = "collecting"
state.data = data.sanitised_dict()
await db.merge(state)
await db.commit()
tasks = await _insert_tasks(task_list=task_list, db=db)

# finalise
logger.debug("Task-collection status: finalising")
collection_path = get_collection_path(venv_path)
data.task_list = tasks
with collection_path.open("w") as f:
json.dump(data.sanitised_dict(), f)

# Update DB
data.status = "OK"
data.log = get_collection_log(venv_path)
state.data = data.sanitised_dict()
db.add(state)
await db.merge(state)
await db.commit()

# Write last logs to file
logger.debug("Task-collection status: OK")
logger.info("Background task collection completed successfully")
close_logger(logger)
await db.close()
logger.debug("Start background task collection")
data = TaskCollectStatus(**state.data)
data.info = None

except Exception as e:
# Write last logs to file
logger.debug("Task-collection status: fail")
logger.info(f"Background collection failed. Original error: {e}")
close_logger(logger)
try:
# install
logger.debug("Task-collection status: installing")
data.status = "installing"

state.data = data.sanitised_dict()
await db.merge(state)
await db.commit()
task_list = await create_package_environment_pip(
venv_path=venv_path,
task_pkg=task_pkg,
logger_name=logger_name,
)

# Update db
data.status = "fail"
data.info = f"Original error: {e}"
data.log = get_collection_log(venv_path)
state.data = data.sanitised_dict()
await db.merge(state)
await db.commit()
await db.close()
# collect
logger.debug("Task-collection status: collecting")
data.status = "collecting"
state.data = data.sanitised_dict()
await db.merge(state)
await db.commit()
tasks = await _insert_tasks(task_list=task_list, db=db)

# finalise
logger.debug("Task-collection status: finalising")
collection_path = get_collection_path(venv_path)
data.task_list = tasks
with collection_path.open("w") as f:
json.dump(data.sanitised_dict(), f)

# Update DB
data.status = "OK"
data.log = get_collection_log(venv_path)
state.data = data.sanitised_dict()
db.add(state)
await db.merge(state)
await db.commit()

# Write last logs to file
logger.debug("Task-collection status: OK")
logger.info("Background task collection completed successfully")
close_logger(logger)
await db.close()

except Exception as e:
# Write last logs to file
logger.debug("Task-collection status: fail")
logger.info(f"Background collection failed. Original error: {e}")
close_logger(logger)

# Update db
data.status = "fail"
data.info = f"Original error: {e}"
data.log = get_collection_log(venv_path)
state.data = data.sanitised_dict()
await db.merge(state)
await db.commit()
await db.close()

# Delete corrupted package dir
shell_rmtree(venv_path)
# Delete corrupted package dir
shell_rmtree(venv_path)


async def _insert_tasks(
Expand Down
23 changes: 17 additions & 6 deletions fractal_server/app/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import Session as DBSyncSession
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool

from ...config import get_settings
from ...logger import set_logger
Expand Down Expand Up @@ -52,18 +53,28 @@ def set_db(cls):
"the database cannot be guaranteed."
)

# Set some sqlite-specific options
if settings.DB_ENGINE == "sqlite":
engine_kwargs_async = dict(poolclass=StaticPool)
engine_kwargs_sync = dict(
poolclass=StaticPool,
connect_args={"check_same_thread": False},
)
else:
engine_kwargs_async = {}
engine_kwargs_sync = {}

cls._engine_async = create_async_engine(
settings.DATABASE_URL, echo=settings.DB_ECHO, future=True
settings.DATABASE_URL,
echo=settings.DB_ECHO,
future=True,
**engine_kwargs_async,
)
cls._engine_sync = create_engine(
settings.DATABASE_SYNC_URL,
echo=settings.DB_ECHO,
future=True,
connect_args=(
{"check_same_thread": False}
if settings.DB_ENGINE == "sqlite"
else {}
),
**engine_kwargs_sync,
)

cls._async_session_maker = sessionmaker(
Expand Down
7 changes: 3 additions & 4 deletions tests/fixtures_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,12 @@ def get_patched_settings(temp_path: Path):

settings.DB_ENGINE = DB_ENGINE
if DB_ENGINE == "sqlite":
settings.SQLITE_PATH = (
f"{temp_path.as_posix()}/_test.db?mode=memory&cache=shared"
)
settings.SQLITE_PATH = f"{temp_path.as_posix()}/_test.db"
elif DB_ENGINE == "postgres":
settings.DB_ENGINE = "postgres"
settings.POSTGRES_USER = "postgres"
settings.POSTGRES_PASSWORD = "postgres"
settings.POSTGRES_DB = "fractal"
settings.POSTGRES_DB = "fractal_test"
else:
raise ValueError

Expand Down Expand Up @@ -202,6 +200,7 @@ async def db_create_tables(override_settings):
engine = DB.engine_sync()
metadata = SQLModel.metadata
metadata.create_all(engine)

yield

metadata.drop_all(engine)
Expand Down
120 changes: 120 additions & 0 deletions tests/test_background_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""
See https://github.com/fractal-analytics-platform/fractal-server/issues/661
"""
from typing import Any
from typing import AsyncGenerator

import pytest
from asgi_lifespan import LifespanManager
from devtools import debug
from fastapi import BackgroundTasks
from fastapi import Depends
from fastapi import FastAPI
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession

from fractal_server.app.api import router_default
from fractal_server.app.db import DBSyncSession
from fractal_server.app.db import get_db
from fractal_server.app.db import get_sync_db
from fractal_server.app.models import State
from fractal_server.logger import set_logger

logger = set_logger(__name__)


# BackgroundTasks functions


async def bgtask_sync_db(state_id: int):
"""
This is a function to be executed as a background task, and it uses a
sync db session.
"""
new_db: DBSyncSession = next(get_sync_db())
state = new_db.get(State, state_id)
state.data = {"c": "d"}
new_db.merge(state)
new_db.commit()
new_db.close()


async def bgtask_async_db(state_id: int):
"""
This is a function to be executed as a background task, and it uses an
async db session.
"""
logger.critical("bgtask_async_db START")
async for new_db in get_db():
state = await new_db.get(State, state_id)
state.data = {"a": "b"}
await new_db.merge(state)
await new_db.commit()
await new_db.close()
logger.critical("bgtask_async_db END")


# New endpoints and client


@router_default.get("/test_async")
async def run_background_task_async(
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db),
):
"""Endpoint that calls bgtask_async_db in background."""
logger.critical("START run_background_task_async")
state = State()
db.add(state)
await db.commit()
debug(state)
state_id = state.id
await db.close()
logger.critical("END run_background_task_async")

background_tasks.add_task(bgtask_async_db, state_id)


@router_default.get("/test_sync")
async def run_background_task_sync(
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db),
):
"""Endpoint that calls bgtask_sync_db in background."""

logger.critical("START run_background_task_sync")
state = State()
db.add(state)
await db.commit()
debug(state)
state_id = state.id
await db.close()
logger.critical("END run_background_task_sync")

background_tasks.add_task(bgtask_sync_db, state_id)


@pytest.fixture
async def client_for_bgtasks(
app: FastAPI,
db: AsyncSession,
) -> AsyncGenerator[AsyncClient, Any]:
"""Client wich includes the two new endpoints."""

app.include_router(router_default, prefix="/test_bgtasks")
async with AsyncClient(
app=app, base_url="http://test"
) as client, LifespanManager(app):
yield client


async def test_async_db(db, client_for_bgtasks):
"""Call the run_background_task_async endpoint"""
res = await client_for_bgtasks.get("http://test_bgtasks/test_async")
debug(res)


async def test_sync_db(db, db_sync, client_for_bgtasks):
"""Call the run_background_task_sync endpoint"""
res = await client_for_bgtasks.get("http://test_bgtasks/test_sync")
debug(res)
2 changes: 1 addition & 1 deletion tests/test_startup_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _prepare_config_and_db(_tmp_path: Path):
[
"POSTGRES_USER=postgres",
"POSTGRES_PASSWORD=postgres",
"POSTGRES_DB=fractal",
"POSTGRES_DB=fractal_test",
]
)
elif DB_ENGINE == "sqlite":
Expand Down

0 comments on commit 011af83

Please sign in to comment.