Skip to content

Commit

Permalink
Merge pull request #1867 from fractal-analytics-platform/1866-link-co…
Browse files Browse the repository at this point in the history
…llectionstatev2-to-taskgroupv2-with-taskgroupv2_id-optional-foreign-key

Add foreign key `CollectionStateV2.taskgroupv2_id`
  • Loading branch information
tcompa authored Oct 9, 2024
2 parents 015ae2c + b2cb85d commit a24a3f4
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ into pre-release sections below.

* API:
* Enforce non-duplication constraints on `TaskGroupV2` (\#1865).
* Add cascade operations to `DELETE /api/v2/task-group/{task_group_id}/` and to `DELETE /admin/v2/task-group/{task_group_id}/` (\#1867).
* Database:
* Add `taskgroupv2_id` foreign key to `CollectionStateV2` (\#1867).

# 2.7.0a3

Expand Down
1 change: 1 addition & 0 deletions fractal_server/app/models/v2/collection_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
class CollectionStateV2(SQLModel, table=True):

id: Optional[int] = Field(default=None, primary_key=True)
taskgroupv2_id: Optional[int] = Field(foreign_key="taskgroupv2.id")
data: dict[str, Any] = Field(sa_column=Column(JSON), default={})
timestamp: datetime = Field(
default_factory=get_timestamp,
Expand Down
21 changes: 21 additions & 0 deletions fractal_server/app/routes/admin/v2/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from fractal_server.app.db import AsyncSession
from fractal_server.app.db import get_async_db
from fractal_server.app.models import UserOAuth
from fractal_server.app.models.v2 import CollectionStateV2
from fractal_server.app.models.v2 import TaskGroupV2
from fractal_server.app.models.v2 import WorkflowTaskV2
from fractal_server.app.routes.auth import current_active_superuser
Expand All @@ -20,9 +21,12 @@
)
from fractal_server.app.schemas.v2 import TaskGroupReadV2
from fractal_server.app.schemas.v2 import TaskGroupUpdateV2
from fractal_server.logger import set_logger

router = APIRouter()

logger = set_logger(__name__)


@router.get("/{task_group_id}/", response_model=TaskGroupReadV2)
async def query_task_group(
Expand Down Expand Up @@ -128,6 +132,23 @@ async def delete_task_group(
detail=f"TaskV2 {workflow_tasks[0].task_id} is still in use",
)

# Cascade operations: set foreign-keys to null for CollectionStateV2 which
# are in relationship with the current TaskGroupV2
logger.debug("Start of cascade operations on CollectionStateV2.")
stm = select(CollectionStateV2).where(
CollectionStateV2.taskgroupv2_id == task_group_id
)
res = await db.execute(stm)
collection_states = res.scalars().all()
for collection_state in collection_states:
logger.debug(
f"Setting CollectionStateV2[{collection_state.id}].taskgroupv2_id "
"to None."
)
collection_state.taskgroupv2_id = None
db.add(collection_state)
logger.debug("End of cascade operations on CollectionStateV2.")

await db.delete(task_group)
await db.commit()

Expand Down
18 changes: 18 additions & 0 deletions fractal_server/app/routes/api/v2/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from fractal_server.app.db import get_async_db
from fractal_server.app.models import LinkUserGroup
from fractal_server.app.models import UserOAuth
from fractal_server.app.models.v2 import CollectionStateV2
from fractal_server.app.models.v2 import TaskGroupV2
from fractal_server.app.models.v2 import WorkflowTaskV2
from fractal_server.app.routes.auth import current_active_user
Expand Down Expand Up @@ -104,6 +105,23 @@ async def delete_task_group(
detail=f"TaskV2 {workflow_tasks[0].task_id} is still in use",
)

# Cascade operations: set foreign-keys to null for CollectionStateV2 which
# are in relationship with the current TaskGroupV2
logger.debug("Start of cascade operations on CollectionStateV2.")
stm = select(CollectionStateV2).where(
CollectionStateV2.taskgroupv2_id == task_group_id
)
res = await db.execute(stm)
collection_states = res.scalars().all()
for collection_state in collection_states:
logger.debug(
f"Setting CollectionStateV2[{collection_state.id}].taskgroupv2_id "
"to None."
)
collection_state.taskgroupv2_id = None
db.add(collection_state)
logger.debug("End of cascade operations on CollectionStateV2.")

await db.delete(task_group)
await db.commit()

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""CollectionStateV2.taskgroupv2_id
Revision ID: d82ee0dc1e48
Revises: 742b74e1cc6e
Create Date: 2024-10-09 14:13:59.288582
"""
import sqlalchemy as sa
from alembic import op


# revision identifiers, used by Alembic.
revision = "d82ee0dc1e48"
down_revision = "742b74e1cc6e"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("collectionstatev2", schema=None) as batch_op:
batch_op.add_column(
sa.Column("taskgroupv2_id", sa.Integer(), nullable=True)
)
batch_op.create_foreign_key(
batch_op.f("fk_collectionstatev2_taskgroupv2_id_taskgroupv2"),
"taskgroupv2",
["taskgroupv2_id"],
["id"],
)

# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("collectionstatev2", schema=None) as batch_op:
batch_op.drop_constraint(
batch_op.f("fk_collectionstatev2_taskgroupv2_id_taskgroupv2"),
type_="foreignkey",
)
batch_op.drop_column("taskgroupv2_id")

# ### end Alembic commands ###
36 changes: 36 additions & 0 deletions tests/v2/02_models/test_tasks_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from fractal_server.app.models import UserGroup
from fractal_server.app.models import UserOAuth
from fractal_server.app.models.v2 import CollectionStateV2
from fractal_server.app.models.v2 import TaskGroupV2
from fractal_server.app.models.v2 import TaskV2
from fractal_server.config import get_settings
Expand Down Expand Up @@ -113,3 +114,38 @@ async def test_task_group_v2(db):
assert task_group is None
assert task1 is None
assert task3 is None


async def test_collection_state(db):

user = UserOAuth(email="[email protected]", hashed_password="1234")
db.add(user)
await db.commit()
await db.refresh(user)

task_group = TaskGroupV2(
user_id=user.id,
origin="wheel-file",
pkg_name="package-name",
)
db.add(task_group)
await db.commit()
await db.refresh(task_group)

state = CollectionStateV2(taskgroupv2_id=task_group.id)
db.add(state)
await db.commit()

assert state.taskgroupv2_id == task_group.id

await db.delete(task_group)

settings = Inject(get_settings)
if settings.DB_ENGINE == "sqlite":
await db.commit()
await db.refresh(state)
assert state.taskgroupv2_id is not None
else:
with pytest.raises(IntegrityError):
await db.commit()
await db.rollback()
10 changes: 10 additions & 0 deletions tests/v2/03_api/test_api_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from fractal_server.app.models import TaskGroupV2
from fractal_server.app.models import UserGroup
from fractal_server.app.models.v2 import CollectionStateV2
from fractal_server.app.routes.api.v2._aux_functions import (
_workflow_insert_task,
)
Expand Down Expand Up @@ -812,6 +813,12 @@ async def test_task_group_admin(
task = await task_factory_v2(user_id=user.id, source="source")
await workflowtask_factory_v2(workflow_id=workflow.id, task_id=task.id)

state = CollectionStateV2(taskgroupv2_id=task_group_1["id"])
db.add(state)
await db.commit()
await db.refresh(state)
assert state.taskgroupv2_id == task_group_1["id"]

async with MockCurrentUser(user_kwargs={"is_superuser": True}):
res = await client.delete(f"{PREFIX}/task-group/{task_group_1['id']}/")
assert res.status_code == 204
Expand All @@ -825,3 +832,6 @@ async def test_task_group_admin(
f"{PREFIX}/task-group/{task.taskgroupv2_id}/"
)
assert res.status_code == 422

await db.refresh(state)
assert state.taskgroupv2_id is None
12 changes: 11 additions & 1 deletion tests/v2/03_api/test_api_task_group.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from fractal_server.app.models import LinkUserGroup
from fractal_server.app.models import UserGroup
from fractal_server.app.models.v2 import CollectionStateV2

PREFIX = "/api/v2/task-group"

Expand Down Expand Up @@ -93,10 +94,16 @@ async def test_get_task_group_list(
assert len(res.json()) == 2


async def test_delete_task_group(client, MockCurrentUser, task_factory_v2):
async def test_delete_task_group(client, MockCurrentUser, task_factory_v2, db):
async with MockCurrentUser() as user1:
task = await task_factory_v2(user_id=user1.id, source="source")

state = CollectionStateV2(taskgroupv2_id=task.taskgroupv2_id)
db.add(state)
await db.commit()
await db.refresh(state)
assert state.taskgroupv2_id == task.taskgroupv2_id

async with MockCurrentUser():
res = await client.delete(f"{PREFIX}/{task.taskgroupv2_id}/")
assert res.status_code == 403
Expand All @@ -107,6 +114,9 @@ async def test_delete_task_group(client, MockCurrentUser, task_factory_v2):
res = await client.delete(f"{PREFIX}/{task.taskgroupv2_id}/")
assert res.status_code == 404

await db.refresh(state)
assert state.taskgroupv2_id is None


async def test_delete_task_group_fail(
project_factory_v2,
Expand Down

0 comments on commit a24a3f4

Please sign in to comment.