Skip to content

Commit

Permalink
Improve querystring parsing, with casting of arguments. add responses…
Browse files Browse the repository at this point in the history
… error and handling.
  • Loading branch information
Etienne Jodry authored and Etienne Jodry committed Jan 29, 2025
1 parent 87017d5 commit 8df1782
Show file tree
Hide file tree
Showing 8 changed files with 221 additions and 108 deletions.
120 changes: 110 additions & 10 deletions src/biodm/components/controllers/resourcecontroller.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from types import MethodType
from typing import TYPE_CHECKING, Callable, List, Set, Any, Dict, Type, Self

from marshmallow import ValidationError
from marshmallow.fields import Field, List, Nested, Date, DateTime, Number
from marshmallow.schema import RAISE
from marshmallow.class_registry import get_class
from marshmallow.exceptions import RegistryError
Expand All @@ -22,9 +24,11 @@
KCGroupService,
KCUserService
)
from biodm.components.services.dbservice import Operator, ValuedOperator, NUM_OPERATORS, AGG_OPERATORS
from biodm.exceptions import (
DataError,
EndpointError,
QueryError,
ImplementionError,
InvalidCollectionMethod,
PayloadEmptyError,
Expand All @@ -42,6 +46,9 @@
from marshmallow.schema import Schema


SPECIAL_QUERYPARAMETERS = {'fields', 'count', 'start', 'end', 'reverse'}


def overload_docstring(f: Callable): # flake8: noqa: E501 pylint: disable=line-too-long
"""Decorator to allow for docstring overloading.
Expand Down Expand Up @@ -289,10 +296,13 @@ def _extract_fields(
"""Extracts fields from request query parameters.
Defaults to ``self.schema.dump_fields.keys()``.
:param request: incomming request
:type request: Request
:return: field list
:rtype: List[str]
:param query_params: query params
:type query_params: Dict[str, Any]
:param user_info: user info
:type user_info: UserInfo
:raises DataError: incorrect field name
:return: requested fields
:rtype: Set[str]
"""
fields = query_params.pop('fields', None)
fields = fields.split(',') if fields else None
Expand All @@ -311,6 +321,100 @@ def _extract_fields(
fields = self.svc.takeout_unallowed_nested(fields, user_info=user_info)
return fields

def _extract_query_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""Extracts query parameters, casting them to the proper type.
Uses a custom deserialization to treat comma separated values.
:param query_params: query params
:type query_params: Dict[str, Any]
:raises DataError: incorrect parameter name or value
:return: requested fields
:rtype: Set[str]
"""
def deserialize_with_error(field: Field, value: str):
try:
return field.deserialize(value)
except ValidationError as ve:
raise QueryError(str(ve.messages))

def check_is_numeric(field: Field, op: str, dskey: str):
if not isinstance(field, (Number, Date, DateTime)):
raise QueryError(
f"Operator {op} in {dskey}, should be applied on a"
"numerical or date field."
)

# Reincorporate extra query
extra_query = params.pop('q', None)
if extra_query:
params.update(QueryParams(extra_query))

# Check parameter validity.
for dskey, csval in params.items():
key = dskey.split('.')
# Handle specials -> no dot.
if key[0] in SPECIAL_QUERYPARAMETERS:
if len(key) > 1:
QueryError(f"Invalid query parameter: {dskey}")
continue

# Fetch first field.
if key[0] not in self.schema.fields.keys():
raise QueryError(f"Invalid query parameter: {dskey}")
field = self.schema.fields[key[0]]

# Fetch rest of the chain.
for i, k in enumerate(key[1:]):
# Handle nested.
schema: Schema
match field:
case List():
schema = field.inner.schema
case Nested():
schema = field.schema
case Field(): # Chain should end when hitting a field.
if not i == len(key[1:])-1:
raise QueryError(f"Invalid query parameter: {dskey}")
break

if k not in schema.fields.keys():
raise QueryError(f"Invalid query parameter: {dskey}")

# Maintain loop invariant.
field = schema.fields[k]

if not csval: # Check operators.
match k.strip(')').split('('): # On last visited value.
case [("gt" | "ge" | "lt" | "le") as op, arg]:
check_is_numeric(field, op, dskey)
params[dskey] = ValuedOperator(
op=op, value=deserialize_with_error(field, arg)
)

case [("min" | "max" | "min_a" | "max_a" | "min_v" | "max_v") as op, arg]:
check_is_numeric(field, op, dskey)
if arg:
raise QueryError("[min|max][|_a|_v] Operators do not take a value")
params[dskey] = Operator(op=op)

case _:
raise QueryError(
f"Invalid operator {k} on {key[0]} in table {self.table.__name__}"
)
continue

values = csval.split(',')

# Deserialize value(s)
if len(values) == 1:
params[dskey] = deserialize_with_error(field, values[0])
else:
params[dskey] = [
deserialize_with_error(field, value)
for value in values
]
return params

async def create(self, request: Request) -> Response:
"""CREATE.
Expand Down Expand Up @@ -597,13 +701,9 @@ async def filter(self, request: Request) -> Response:
description: Wrong use of filters.
"""
params = dict(request.query_params)

extra_query = params.pop('q', None)
if extra_query:
params.update(QueryParams(extra_query))
count = bool(params.pop('count', 0))

fields = self._extract_fields(params, user_info=request.user)
params = self._extract_query_params(params)
count = bool(params.pop('count', 0))
result = await self.svc.filter(
fields=fields,
params=params,
Expand Down
31 changes: 24 additions & 7 deletions src/biodm/components/controllers/s3controller.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Type

from marshmallow import Schema, RAISE
from marshmallow import Schema, RAISE, ValidationError
import starlette.routing as sr
from starlette.requests import Request
from starlette.responses import Response, PlainTextResponse
Expand All @@ -9,7 +9,7 @@
from biodm.components.services import S3Service
from biodm.components.table import Base
from biodm.schemas import PartsEtagSchema
from biodm.exceptions import ImplementionError
from biodm.exceptions import DataError, ImplementionError
from biodm.utils.security import UserInfo
from biodm.utils.utils import json_response
from biodm.routing import Route
Expand Down Expand Up @@ -68,6 +68,12 @@ async def download(self, request: Request) -> Response:
application/json:
schema:
type: string
404:
description: Not found.
409:
description: Download a file which has not been uploaded.
500:
description: S3 Bucket issue.
"""
return PlainTextResponse(
await self.svc.download(
Expand Down Expand Up @@ -103,9 +109,20 @@ async def complete_multipart(self, request: Request):
responses:
201:
description: Completion confirmation 'Completed.'
400:
description: Wrongly formatted completion notice.
4O4:
description: Not found.
500:
description: S3 Bucket issue.
"""
await self.svc.complete_multipart(
pk_val=self._extract_pk_val(request),
parts=self.parts_etag_schema.loads(await request.body())
)
return json_response("Completed.", status_code=201)
flag = True
try:
parts = self.parts_etag_schema.loads(await request.body())
await self.svc.complete_multipart(
pk_val=self._extract_pk_val(request),
parts=parts,
)
return json_response("Completed.", status_code=201)
except ValidationError as ve:
raise DataError(str(ve.messages))
Loading

0 comments on commit 8df1782

Please sign in to comment.