Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce non-duplication constraints on task groups #1865

Merged
56 changes: 56 additions & 0 deletions fractal_server/app/routes/api/v2/_aux_functions_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,59 @@ async def _get_valid_user_group_id(
user_id=user_id, user_group_id=user_group_id, db=db
)
return user_group_id


async def _verify_non_duplication_user_constraint(
db: AsyncSession,
user_id: int,
pkg_name: str,
version: Optional[str],
):
stm = (
select(TaskGroupV2)
.where(TaskGroupV2.user_id == user_id)
.where(TaskGroupV2.pkg_name == pkg_name)
.where(TaskGroupV2.version == version) # FIXME test with None
)
res = await db.execute(stm)
duplicate = res.scalars().all()
if duplicate:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=(
"There is already a TaskGroupV2 with "
f"({pkg_name=}, {version=}, {user_id=})."
),
)


async def _verify_non_duplication_group_constraint(
db: AsyncSession,
user_group_id: int,
ychiucco marked this conversation as resolved.
Show resolved Hide resolved
pkg_name: str,
version: Optional[str],
):
if user_group_id is None:
ychiucco marked this conversation as resolved.
Show resolved Hide resolved
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=(
"`_verify_non_duplication_group_constraint` cannot be called "
f"with {user_group_id=}."
),
)
stm = (
select(TaskGroupV2)
.where(TaskGroupV2.user_group_id == user_group_id)
.where(TaskGroupV2.pkg_name == pkg_name)
.where(TaskGroupV2.version == version)
)
res = await db.execute(stm)
duplicate = res.scalars().all()
if duplicate:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=(
"There is already a TaskGroupV2 with "
f"({pkg_name=}, {version=}, {user_group_id=})."
),
)
17 changes: 15 additions & 2 deletions fractal_server/app/routes/api/v2/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from ._aux_functions_tasks import _get_task_full_access
from ._aux_functions_tasks import _get_task_read_access
from ._aux_functions_tasks import _get_valid_user_group_id
from ._aux_functions_tasks import _verify_non_duplication_group_constraint
from ._aux_functions_tasks import _verify_non_duplication_user_constraint
from fractal_server.app.db import AsyncSession
from fractal_server.app.db import get_async_db
from fractal_server.app.models import LinkUserGroup
Expand Down Expand Up @@ -221,14 +223,25 @@ async def create_task(
)
# Add task
db_task = TaskV2(**task.dict(), owner=owner, type=task_type)

pkg_name = db_task.name
await _verify_non_duplication_user_constraint(
db=db, pkg_name=pkg_name, user_id=user.id, version=db_task.version
)
if user_group_id is not None:
ychiucco marked this conversation as resolved.
Show resolved Hide resolved
await _verify_non_duplication_group_constraint(
db=db,
pkg_name=pkg_name,
user_group_id=user_group_id,
version=db_task.version,
)
db_task_group = TaskGroupV2(
user_id=user.id,
user_group_id=user_group_id,
active=True,
task_list=[db_task],
origin="other",
pkg_name=task.name,
version=db_task.version,
pkg_name=pkg_name,
)
db.add(db_task_group)
await db.commit()
Expand Down
9 changes: 8 additions & 1 deletion fractal_server/app/routes/api/v2/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from ._aux_functions_tasks import _get_task_group_full_access
from ._aux_functions_tasks import _get_task_group_read_access
from ._aux_functions_tasks import _verify_non_duplication_group_constraint
from fractal_server.app.db import AsyncSession
from fractal_server.app.db import get_async_db
from fractal_server.app.models import LinkUserGroup
Expand Down Expand Up @@ -124,7 +125,13 @@ async def patch_task_group(
user_id=user.id,
db=db,
)

if task_group_update.user_group_id is not None:
ychiucco marked this conversation as resolved.
Show resolved Hide resolved
await _verify_non_duplication_group_constraint(
db=db,
pkg_name=task_group.pkg_name,
version=task_group.version,
user_group_id=task_group_update.user_group_id,
)
for key, value in task_group_update.dict(exclude_unset=True).items():
if (key == "user_group_id") and (value is not None):
await _verify_user_belongs_to_group(
Expand Down
37 changes: 27 additions & 10 deletions tests/v2/03_api/test_api_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,17 @@ async def test_post_task(client, MockCurrentUser):

task = TaskCreateV2(
name="task_name",
source=f"{TASK_SOURCE}-parallel",
command_parallel="task_command_parallel",
)
res = await client.post(
f"{PREFIX}/", json=task.dict(exclude_unset=True)
)
# TaskGroupV2 with same (pkg_name, version, user_id)
assert res.status_code == 422

task = TaskCreateV2(
name="task_name2",
# Parallel
source=f"{TASK_SOURCE}-parallel",
command_parallel="task_command_parallel",
Expand All @@ -147,15 +158,14 @@ async def test_post_task(client, MockCurrentUser):
assert res.status_code == 201
assert res.json()["type"] == "parallel"
task = TaskCreateV2(
name="task_name",
name="task_name3",
# Non Parallel
source=f"{TASK_SOURCE}-non_parallel",
command_non_parallel="task_command_non_parallel",
)
res = await client.post(
f"{PREFIX}/", json=task.dict(exclude_unset=True)
)
debug(res.json())
assert res.status_code == 201
assert res.json()["type"] == "non_parallel"

Expand All @@ -181,8 +191,9 @@ async def test_post_task(client, MockCurrentUser):
# Case 1: (username, slurm_user) = (None, None)
user_kwargs = dict(username=None, is_verified=True)
user_settings_dict = dict(slurm_user=None)
payload = dict(name="task", command_parallel="cmd")
payload = dict(command_parallel="cmd")
payload["source"] = "source_x"
payload["name"] = "x"
async with MockCurrentUser(
user_kwargs=user_kwargs, user_settings_dict=user_settings_dict
):
Expand All @@ -205,6 +216,7 @@ async def test_post_task(client, MockCurrentUser):
user_kwargs = dict(username=None, is_verified=True)
user_settings_dict = dict(slurm_user=SLURM_USER)
payload["source"] = "source_z"
payload["name"] = "z"
async with MockCurrentUser(
user_kwargs=user_kwargs, user_settings_dict=user_settings_dict
):
Expand All @@ -214,6 +226,7 @@ async def test_post_task(client, MockCurrentUser):
user_kwargs = dict(username=USERNAME, is_verified=True)
user_settings_dict = dict(slurm_user=None)
payload["source"] = "source_xyz"
payload["name"] = "xyz"
async with MockCurrentUser(
user_kwargs=user_kwargs, user_settings_dict=user_settings_dict
):
Expand Down Expand Up @@ -282,19 +295,21 @@ async def test_post_task_user_group_id(
await db.commit()
await db.refresh(team1_group)

args = dict(name="a", command_non_parallel="cmd")
args = dict(command_non_parallel="cmd")

async with MockCurrentUser(user_kwargs=dict(is_verified=True)):

# No query parameter
res = await client.post(f"{PREFIX}/", json=dict(source="1", **args))
res = await client.post(
f"{PREFIX}/", json=dict(name="a", source="1", **args)
)
assert res.status_code == 201
taskgroup = await db.get(TaskGroupV2, res.json()["taskgroupv2_id"])
assert taskgroup.user_group_id == default_user_group.id

# Private task
res = await client.post(
f"{PREFIX}/?private=true", json=dict(source="2", **args)
f"{PREFIX}/?private=true", json=dict(name="b", source="2", **args)
)
assert res.status_code == 201
taskgroup = await db.get(TaskGroupV2, res.json()["taskgroupv2_id"])
Expand All @@ -303,7 +318,7 @@ async def test_post_task_user_group_id(
# Specific usergroup id / OK
res = await client.post(
f"{PREFIX}/?user_group_id={default_user_group.id}",
json=dict(source="3", **args),
json=dict(name="c", source="3", **args),
)
assert res.status_code == 201
taskgroup = await db.get(TaskGroupV2, res.json()["taskgroupv2_id"])
Expand All @@ -312,15 +327,15 @@ async def test_post_task_user_group_id(
# Specific usergroup id / not belonging
res = await client.post(
f"{PREFIX}/?user_group_id={team1_group.id}",
json=dict(source="4", **args),
json=dict(name="d", source="4", **args),
)
assert res.status_code == 403
debug(res.json())

# Conflicting query parameters
res = await client.post(
f"{PREFIX}/?private=true&user_group_id={default_user_group.id}",
json=dict(source="5", **args),
json=dict(name="e", source="5", **args),
)
assert res.status_code == 422
debug(res.json())
Expand All @@ -333,7 +348,9 @@ async def test_post_task_user_group_id(
),
"MONKEY",
)
res = await client.post(f"{PREFIX}/", json=dict(source="4", **args))
res = await client.post(
f"{PREFIX}/", json=dict(name="f", source="4", **args)
)
assert res.status_code == 404


Expand Down
12 changes: 11 additions & 1 deletion tests/v2/03_api/test_api_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ async def test_delete_task_group_fail(
assert res.status_code == 422


async def test_patch_task_group(client, MockCurrentUser, task_factory_v2):
async def test_patch_task_group(
client, MockCurrentUser, task_factory_v2, default_user_group
):
async with MockCurrentUser() as user1:

task = await task_factory_v2(user_id=user1.id, source="source")
Expand Down Expand Up @@ -157,6 +159,14 @@ async def test_patch_task_group(client, MockCurrentUser, task_factory_v2):
)
assert res.status_code == 404

# Already linked UserGroup

res = await client.patch(
f"{PREFIX}/{task.taskgroupv2_id}/",
json=dict(user_group_id=default_user_group.id),
)
assert res.status_code == 422

async with MockCurrentUser():

# Unauthorized
Expand Down
14 changes: 14 additions & 0 deletions tests/v2/03_api/test_unit_aux_functions_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from fractal_server.app.routes.api.v2._aux_functions_tasks import (
_get_task_read_access,
)
from fractal_server.app.routes.api.v2._aux_functions_tasks import (
_verify_non_duplication_group_constraint,
)
from fractal_server.app.security import FRACTAL_DEFAULT_GROUP_NAME


Expand Down Expand Up @@ -150,3 +153,14 @@ async def test_get_task_require_active(db, task_factory_v2):
await _get_task_read_access(
task_id=task.id, user_id=user.id, db=db, require_active=True
)


async def test_unit_verify_non_duplication_group_constraint(db):
ychiucco marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(HTTPException):
# fail because `user_group_id=None`
await _verify_non_duplication_group_constraint(
db=db,
user_group_id=None,
pkg_name="foo",
version=None,
)
Loading