Skip to content

Commit

Permalink
Chores: add some error handling for operators, fix unit tests for new…
Browse files Browse the repository at this point in the history
… sqlite id generation
  • Loading branch information
Etienne Jodry authored and Etienne Jodry committed Jan 17, 2025
1 parent 22956e5 commit 1806cda
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 27 deletions.
2 changes: 0 additions & 2 deletions src/biodm/components/controllers/resourcecontroller.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@ def __init__(
# Inst schema, and set registry entry for apispec.
schema_cls = schema if schema else self._infer_schema()
self.__class__.schema = schema_cls(unknown=RAISE)
# TODO [prio-low]: Improve dynamic schema instanciation, at serializer generation time
# To handle nested cases. Implies storing dynamically generated schemas in a registry on the side.
register_runtime_schema(schema_cls, self.__class__.schema)
self._infuse_schema_in_apispec_docstrings()

Expand Down
42 changes: 31 additions & 11 deletions src/biodm/components/services/dbservice.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Database service: Translates requests data into SQLA statements and execute."""
from abc import ABCMeta
from typing import Callable, List, Sequence, Any, Dict, overload, Literal, Type, Set
from uuid import uuid4

from marshmallow.orderedset import OrderedSet
from sqlalchemy import select, delete, or_, func
from sqlalchemy import Column, select, delete, or_, func
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.hybrid import hybrid_property
Expand All @@ -28,7 +29,8 @@
from biodm.utils.utils import unevalled_all, unevalled_or, to_it, partition


SUPPORTED_NUM_OPERATORS = ("gt", "ge", "lt", "le", "min", "max", "min_v", "max_v")
NUM_OPERATORS = ("gt", "ge", "lt", "le", "min", "max", "min_v", "max_v")
AGG_OPERATORS = ("min_v", "max_v")


class DatabaseService(ApiService, metaclass=ABCMeta):
Expand Down Expand Up @@ -677,19 +679,37 @@ def _filter_parse_num_op(self, stmt: Select, field: str, operator: str) -> Selec
:return: Select statement with operator condition applied.
:rtype: Select
"""
col, ctype = self.table.colinfo(field)
try:
col, ctype = self.table.colinfo(field)
except KeyError:
raise EndpointError(
f"Unknown field {field} of table {self.table.__name__}"
)

def op_err(op: str, name: str, ftype: type):
raise EndpointError(
f"Invalid operator {op} on column {name} of type {ftype} "
f"in table {self.table.__name__}"
)

match operator.strip(')').split('('):
case [("gt" | "ge" | "lt" | "le") as op, arg]:
if ctype not in (int, float):
op_err(op, col.name, ctype)

op_fct: Callable = getattr(col, f"__{op}__")
return stmt.where(op_fct(ctype(arg)))

case [("min" | "max") as op, arg]:
if ctype == str:
op_err(op, col.name, ctype)

op_fct: Callable = getattr(func, op)
sub = select(op_fct(col))
return stmt.where(col == sub.scalar_subquery())

case [("min_v" | "max_v") as op, arg]:
if col.name is not 'version':
if col.name != 'version':
raise EndpointError("min_v and max_v are 'version' column exclusive filters.")

# TODO [prio-low]: This could be probably be improved to apply min, max on group-by
Expand All @@ -699,23 +719,24 @@ def _filter_parse_num_op(self, stmt: Select, field: str, operator: str) -> Selec

assert col in self.pk # invariant

label = "agg_col" + str(uuid4())[:4] # gen unique label
pk_no_col = [k for k in self.table.pk if k != col.name]
sub = select(
* [getattr(self.table, k) for k in pk_no_col]
+ [op_fct(col).label("max_col")]
+ [op_fct(col).label(label)]
)
sub = sub.group_by(*pk_no_col)
sub = sub.subquery()

return stmt.join(sub, onclause=unevalled_all([
getattr(self.table, k) == getattr(sub.c, k)
for k in pk_no_col
] + [getattr(self.table, col.name) == getattr(sub.c, "max_col")]))
] + [getattr(self.table, col.name) == getattr(sub.c, label)]))

case _:
raise EndpointError(
f"Expecting either 'field=v1,v2' pairs or integrer"
f" operators 'field.op([v])' op in {SUPPORTED_NUM_OPERATORS}")
f"Expecting either 'field=v1,v2' pairs or "
f" operators 'field.op([v])' op in {NUM_OPERATORS + AGG_OPERATORS}")

def _filter_parse_field_cond(self, stmt: Select, field: str, values: List[str]) -> Select:
"""Applies field condition on a select statement.
Expand Down Expand Up @@ -854,16 +875,15 @@ async def release(
)
queried_version: int

# Slightly tweaked read version where we get max version column instead.
# Slightly tweaked read version where we get max agg version column instead.
stmt = select(self.table)
for i, col in enumerate(self.pk):
if col.name != 'version':
stmt = stmt.where(col == col.type.python_type(pk_val[i]))
else:
queried_version = pk_val[i]

sub = select(func.max(stmt.c.version)).scalar_subquery()
stmt = stmt.where(self.table.col('version') == sub)
stmt = self._filter_parse_num_op(stmt, field="version", operator="max_v()")

# Get item with all columns - covers many-to-one relationships.
fields = set(self.table.__table__.columns.keys())
Expand Down
2 changes: 1 addition & 1 deletion src/biodm/components/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def has_composite_pk_with_leading_id_sqlite(cls) -> bool:
return (
'sqlite' in str(config.DATABASE_URL) and
hasattr(cls, 'id') and
cls.pk.__len__ > 1
cls.pk.__len__() > 1
)

@staticmethod
Expand Down
17 changes: 11 additions & 6 deletions src/biodm/utils/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import TYPE_CHECKING, List, Tuple, Set, ClassVar, Type, Any, Dict

from marshmallow import fields, Schema
from starlette.authentication import BaseUser
from starlette.requests import HTTPConnection
from starlette.types import ASGIApp, Receive, Scope, Send
from sqlalchemy import ForeignKeyConstraint, Column, ForeignKey
Expand All @@ -25,10 +26,7 @@
from biodm.managers import KeycloakManager


# TODO: [prio: not urgent]
# possible improvement, would be to rewrite the following classes using
# starlette builtins from starlette.middleware.authentication.
class UserInfo(aobject):
class UserInfo(aobject, BaseUser):
"""Hold user info for a given request.
If the request contains an authentication header, self.info shall return User Info, else None
Expand Down Expand Up @@ -101,6 +99,9 @@ def keycloak_admin(self):
return self._keycloak_admin


# TODO: [prio: not urgent]
# possible improvement, would be to rewrite the following middleware using
# starlette builtins from starlette.middleware.authentication.
class AuthenticationMiddleware:
"""Handle token decoding for incoming requests, populate request object with result."""
def __init__(self, app: ASGIApp) -> None:
Expand Down Expand Up @@ -142,7 +143,11 @@ async def lr_wrapper(controller, request, *args, **kwargs):


def group_required(groups: List[str]):
"""Decorator for endpoints requiring authenticated user to be part of one of the list of paths."""
"""Decorator for endpoints requiring authenticated user to be part of one of the list of paths.
"""
if not groups:
raise ImplementionError("@group_required applied with empty group list.")

def _group_required(f):
if f.__name__ == "create":
@wraps(f)
Expand All @@ -154,7 +159,7 @@ async def gr_write_wrapper(controller, request, *args, **kwargs):

@wraps(f)
async def gr_wrapper(controller, request, *args, **kwargs):
if request.user.is_authenticated: # TODO: check empty group list edge case
if request.user.is_authenticated and request.user.groups:
if any((ug in groups for ug in request.user.groups)):
return await f(controller, request, *args, **kwargs)

Expand Down
6 changes: 5 additions & 1 deletion src/tests/unit/test_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ def test_create_composite_resource(client):
# May be in different orders.
oracle['bs'].sort(key=lambda x: x['id'])
json_response['bs'].sort(key=lambda x: x['id'])
assert oracle['bs'] == json_response['bs']

for bor, bres in zip(oracle['bs'],json_response['bs']):
bor['id'] == bres['id']
bor['version'] == bres['version']
bor['name'] == bres['name']


@pytest.mark.xfail(raises=exc.PayloadEmptyError)
Expand Down
13 changes: 7 additions & 6 deletions src/tests/unit/test_versioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def test_create_versioned_resource(client):

json_response = json.loads(response.text)

assert json_response['id'] == 1
assert json_response['version'] == 1
assert json_response['name'] == item['name']


def test_release_version(client):
Expand All @@ -24,22 +24,22 @@ def test_release_version(client):
response = client.post('/bs', content=json_bytes(item))

assert response.status_code == 201
item_res = json.loads(response.text)

update = {"name": "ver_updated"}
response = client.post('/bs/1_1/release', content=json_bytes(update))
response = client.post(f"/bs/{item_res['id']}_1/release", content=json_bytes(update))

assert response.status_code == 200
json_response = json.loads(response.text)

assert json_response['id'] == 1
assert json_response['version'] == 2
assert json_response['name'] == update['name']

response = client.get('/bs?id=1')
response = client.get(f"/bs?id={item_res['id']}")

assert response.status_code == 200
json_response = json.loads(response.text)

assert len(json_response) == 2
assert json_response[0]['version'] == 1
assert json_response[0]['name'] == item['name']
Expand All @@ -53,8 +53,9 @@ def test_no_update_version_resource_through_write(client):

response = client.post('/bs', content=json_bytes(item))
assert response.status_code == 201
item_res = json.loads(response.text)

update = {'id': '1', 'version': '1', 'name': '4321'}
update = {'id': item_res['id'], 'version': item_res['version'], 'name': '4321'}
response = client.post('/bs', content=json_bytes(update))
assert response.status_code == 409

Expand Down

0 comments on commit 1806cda

Please sign in to comment.