From 8df1782b48f85d4093fed33fd0c83b2d948af115 Mon Sep 17 00:00:00 2001 From: Etienne Jodry Date: Wed, 29 Jan 2025 16:15:02 +0100 Subject: [PATCH] Improve querystring parsing, with casting of arguments. add responses error and handling. --- .../controllers/resourcecontroller.py | 120 +++++++++++++-- .../components/controllers/s3controller.py | 31 +++- src/biodm/components/services/dbservice.py | 141 ++++++++---------- src/biodm/error.py | 10 +- src/biodm/exceptions.py | 6 + src/biodm/managers/s3manager.py | 13 +- src/biodm/routing.py | 6 +- src/tests/unit/test_resource.py | 2 +- 8 files changed, 221 insertions(+), 108 deletions(-) diff --git a/src/biodm/components/controllers/resourcecontroller.py b/src/biodm/components/controllers/resourcecontroller.py index 5024c19..08c6148 100644 --- a/src/biodm/components/controllers/resourcecontroller.py +++ b/src/biodm/components/controllers/resourcecontroller.py @@ -7,6 +7,8 @@ from types import MethodType from typing import TYPE_CHECKING, Callable, List, Set, Any, Dict, Type, Self +from marshmallow import ValidationError +from marshmallow.fields import Field, List, Nested, Date, DateTime, Number from marshmallow.schema import RAISE from marshmallow.class_registry import get_class from marshmallow.exceptions import RegistryError @@ -22,9 +24,11 @@ KCGroupService, KCUserService ) +from biodm.components.services.dbservice import Operator, ValuedOperator, NUM_OPERATORS, AGG_OPERATORS from biodm.exceptions import ( DataError, EndpointError, + QueryError, ImplementionError, InvalidCollectionMethod, PayloadEmptyError, @@ -42,6 +46,9 @@ from marshmallow.schema import Schema +SPECIAL_QUERYPARAMETERS = {'fields', 'count', 'start', 'end', 'reverse'} + + def overload_docstring(f: Callable): # flake8: noqa: E501 pylint: disable=line-too-long """Decorator to allow for docstring overloading. @@ -289,10 +296,13 @@ def _extract_fields( """Extracts fields from request query parameters. Defaults to ``self.schema.dump_fields.keys()``. - :param request: incomming request - :type request: Request - :return: field list - :rtype: List[str] + :param query_params: query params + :type query_params: Dict[str, Any] + :param user_info: user info + :type user_info: UserInfo + :raises DataError: incorrect field name + :return: requested fields + :rtype: Set[str] """ fields = query_params.pop('fields', None) fields = fields.split(',') if fields else None @@ -311,6 +321,100 @@ def _extract_fields( fields = self.svc.takeout_unallowed_nested(fields, user_info=user_info) return fields + def _extract_query_params(self, params: Dict[str, Any]) -> Dict[str, Any]: + """Extracts query parameters, casting them to the proper type. + Uses a custom deserialization to treat comma separated values. + + :param query_params: query params + :type query_params: Dict[str, Any] + :raises DataError: incorrect parameter name or value + :return: requested fields + :rtype: Set[str] + """ + def deserialize_with_error(field: Field, value: str): + try: + return field.deserialize(value) + except ValidationError as ve: + raise QueryError(str(ve.messages)) + + def check_is_numeric(field: Field, op: str, dskey: str): + if not isinstance(field, (Number, Date, DateTime)): + raise QueryError( + f"Operator {op} in {dskey}, should be applied on a" + "numerical or date field." + ) + + # Reincorporate extra query + extra_query = params.pop('q', None) + if extra_query: + params.update(QueryParams(extra_query)) + + # Check parameter validity. + for dskey, csval in params.items(): + key = dskey.split('.') + # Handle specials -> no dot. + if key[0] in SPECIAL_QUERYPARAMETERS: + if len(key) > 1: + QueryError(f"Invalid query parameter: {dskey}") + continue + + # Fetch first field. + if key[0] not in self.schema.fields.keys(): + raise QueryError(f"Invalid query parameter: {dskey}") + field = self.schema.fields[key[0]] + + # Fetch rest of the chain. + for i, k in enumerate(key[1:]): + # Handle nested. + schema: Schema + match field: + case List(): + schema = field.inner.schema + case Nested(): + schema = field.schema + case Field(): # Chain should end when hitting a field. + if not i == len(key[1:])-1: + raise QueryError(f"Invalid query parameter: {dskey}") + break + + if k not in schema.fields.keys(): + raise QueryError(f"Invalid query parameter: {dskey}") + + # Maintain loop invariant. + field = schema.fields[k] + + if not csval: # Check operators. + match k.strip(')').split('('): # On last visited value. + case [("gt" | "ge" | "lt" | "le") as op, arg]: + check_is_numeric(field, op, dskey) + params[dskey] = ValuedOperator( + op=op, value=deserialize_with_error(field, arg) + ) + + case [("min" | "max" | "min_a" | "max_a" | "min_v" | "max_v") as op, arg]: + check_is_numeric(field, op, dskey) + if arg: + raise QueryError("[min|max][|_a|_v] Operators do not take a value") + params[dskey] = Operator(op=op) + + case _: + raise QueryError( + f"Invalid operator {k} on {key[0]} in table {self.table.__name__}" + ) + continue + + values = csval.split(',') + + # Deserialize value(s) + if len(values) == 1: + params[dskey] = deserialize_with_error(field, values[0]) + else: + params[dskey] = [ + deserialize_with_error(field, value) + for value in values + ] + return params + async def create(self, request: Request) -> Response: """CREATE. @@ -597,13 +701,9 @@ async def filter(self, request: Request) -> Response: description: Wrong use of filters. """ params = dict(request.query_params) - - extra_query = params.pop('q', None) - if extra_query: - params.update(QueryParams(extra_query)) - count = bool(params.pop('count', 0)) - fields = self._extract_fields(params, user_info=request.user) + params = self._extract_query_params(params) + count = bool(params.pop('count', 0)) result = await self.svc.filter( fields=fields, params=params, diff --git a/src/biodm/components/controllers/s3controller.py b/src/biodm/components/controllers/s3controller.py index 53ffcba..5169943 100644 --- a/src/biodm/components/controllers/s3controller.py +++ b/src/biodm/components/controllers/s3controller.py @@ -1,6 +1,6 @@ from typing import List, Type -from marshmallow import Schema, RAISE +from marshmallow import Schema, RAISE, ValidationError import starlette.routing as sr from starlette.requests import Request from starlette.responses import Response, PlainTextResponse @@ -9,7 +9,7 @@ from biodm.components.services import S3Service from biodm.components.table import Base from biodm.schemas import PartsEtagSchema -from biodm.exceptions import ImplementionError +from biodm.exceptions import DataError, ImplementionError from biodm.utils.security import UserInfo from biodm.utils.utils import json_response from biodm.routing import Route @@ -68,6 +68,12 @@ async def download(self, request: Request) -> Response: application/json: schema: type: string + 404: + description: Not found. + 409: + description: Download a file which has not been uploaded. + 500: + description: S3 Bucket issue. """ return PlainTextResponse( await self.svc.download( @@ -103,9 +109,20 @@ async def complete_multipart(self, request: Request): responses: 201: description: Completion confirmation 'Completed.' + 400: + description: Wrongly formatted completion notice. + 4O4: + description: Not found. + 500: + description: S3 Bucket issue. """ - await self.svc.complete_multipart( - pk_val=self._extract_pk_val(request), - parts=self.parts_etag_schema.loads(await request.body()) - ) - return json_response("Completed.", status_code=201) + flag = True + try: + parts = self.parts_etag_schema.loads(await request.body()) + await self.svc.complete_multipart( + pk_val=self._extract_pk_val(request), + parts=parts, + ) + return json_response("Completed.", status_code=201) + except ValidationError as ve: + raise DataError(str(ve.messages)) diff --git a/src/biodm/components/services/dbservice.py b/src/biodm/components/services/dbservice.py index d3108c6..9742189 100644 --- a/src/biodm/components/services/dbservice.py +++ b/src/biodm/components/services/dbservice.py @@ -1,10 +1,11 @@ """Database service: Translates requests data into SQLA statements and execute.""" from abc import ABCMeta +from datetime import datetime from typing import Callable, List, Sequence, Any, Dict, overload, Literal, Type, Set from uuid import uuid4 from marshmallow.orderedset import OrderedSet -from sqlalchemy import Column, Subquery, select, delete, or_, func +from sqlalchemy import Column, select, delete, or_, func from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.hybrid import hybrid_property @@ -29,8 +30,21 @@ from biodm.utils.utils import unevalled_all, unevalled_or, to_it, partition -NUM_OPERATORS = ("gt", "ge", "lt", "le", "min", "max") -AGG_OPERATORS = ("min_v", "max_v", "min_a", "max_a") +NUM_OPERATORS = ("gt", "ge", "lt", "le") +AGG_OPERATORS = ("min", "max", "min_v", "max_v", "min_a", "max_a") + + +from dataclasses import dataclass + + +@dataclass +class Operator: + op: str + + +@dataclass +class ValuedOperator(Operator): + value: Any class DatabaseService(ApiService, metaclass=ABCMeta): @@ -670,7 +684,7 @@ def _filter_parse_num_op( self, stmt: Select, field: str, - operator: str, + operator: Operator, aggregation: List[str] | None=None ) -> Select: """Applies numeric operator on a select statement. @@ -680,19 +694,12 @@ def _filter_parse_num_op( :param field: Field name to apply the operator on :type field: str :param operator: operator - :type operator: str + :type operator: Operator :raises EndpointError: Wrong operator :return: Select statement with operator condition applied. :rtype: Select """ - col, ctype = self.table.colinfo(field) - - def op_err(op: str, name: str, ftype: type): - """Raises generic error.""" - raise EndpointError( - f"Invalid operator {op} on column {name} of type {ftype} " - f"in table {self.table.__name__}" - ) + col = self.table.col(field) def agg_stmt(stmt: Select, op: str, col: Column, aggregate_on: List[str]): """Generates aggregation statement.""" @@ -710,42 +717,32 @@ def agg_stmt(stmt: Select, op: str, col: Column, aggregate_on: List[str]): for k in aggregate_on ] + [getattr(self.table, col.name) == getattr(sub.c, label)])) - match operator.strip(')').split('('): - case [("gt" | "ge" | "lt" | "le") as op, arg]: - if ctype not in (int, float): - op_err(op, col.name, ctype) - - op_fct: Callable = getattr(col, f"__{op}__") - return stmt.where(op_fct(ctype(arg))) + if operator.op in NUM_OPERATORS: + assert isinstance(operator, ValuedOperator) - case [("min" | "max") as op, arg]: - if ctype == str: - op_err(op, col.name, ctype) + op_fct: Callable = getattr(col, f"__{operator.op}__") + return stmt.where(op_fct(operator.value)) - op_fct: Callable = getattr(func, op) - sub = select(op_fct(col)) - return stmt.where(col == sub.scalar_subquery()) + elif operator.op in ("min", "max"): + op_fct: Callable = getattr(func, operator.op) + sub = select(op_fct(col)) + return stmt.where(col == sub.scalar_subquery()) - case [("min_v" | "max_v") as op, arg]: - if not self.table.is_versioned: - raise EndpointError("min_v and max_v are versioned table exclusive filters.") - - return agg_stmt(stmt, op, col, [k for k in self.table.pk if k != 'version']) - - case [("min_a" | "max_a") as op, arg]: - if not aggregation: - raise EndpointError( - "min_a and max_a must be used in conjunction with other filters." - ) + elif operator.op in ("min_v", "max_v"): + if not self.table.is_versioned: + raise EndpointError("min_v and max_v are versioned table exclusive filters.") - return agg_stmt(stmt, op, col, aggregation) + return agg_stmt(stmt, operator.op, col, [k for k in self.table.pk if k != 'version']) - case _: + elif operator.op in ("min_a", "max_a"): + if not aggregation: raise EndpointError( - f"Expecting either 'field=v1,v2' pairs or " - f" operators 'field.op([v])' op in {NUM_OPERATORS + AGG_OPERATORS}") + "min_a and max_a must be used in conjunction with other filters." + ) + + return agg_stmt(stmt, operator.op, col, aggregation) - def _filter_parse_field_cond(self, stmt: Select, field: str, values: List[str]) -> Select: + def _filter_parse_field_cond(self, stmt: Select, field: str, values: Sequence[str]) -> Select: """Applies field condition on a select statement. :param stmt: Statement under construction @@ -759,24 +756,21 @@ def _filter_parse_field_cond(self, stmt: Select, field: str, values: List[str]) :rtype: Select """ col, ctype = self.table.colinfo(field) - wildcards, values = partition(values, cond=lambda x: "*" in x) - if wildcards and ctype is not str: - raise EndpointError( - "Using wildcard symbol '*' in /search is only allowed for text fields." - ) - - # Wildcards. - stmt = stmt.where( - unevalled_or([ - col.like(str(w).replace("*", "%")) - for w in wildcards - ]) - ) if wildcards else stmt + if ctype is str: + wildcards, values = partition(values, cond=lambda x: "*" in x) + + # Wildcards. + stmt = stmt.where( + unevalled_or([ + col.like(str(w).replace("*", "%")) + for w in wildcards + ]) + ) if wildcards else stmt # Field equality conditions. stmt = stmt.where( unevalled_or([ - col == ctype(v) + col == v# ctype(v) -> Already casted for v in values ]) ) if values else stmt @@ -808,37 +802,34 @@ async def filter( # Track on which fields to aggregate in case. aggregate_conditions = {'fields': [], 'conditions': []} - for dskey, csval in params.items(): - attr, values = dskey.split("."), csval.split(",") + for dskey, values in params.items(): + attr = dskey.split(".") + # A bit redondant since fields get checked against schema in _extract_query_params. if attr[0] not in self.table.relationships.keys() + self.table.__table__.columns.keys(): - raise EndpointError( - f"Unknown field {attr[0]} of table {self.table.__name__}" - ) + raise EndpointError(f"Unknown field {attr[0]} of table {self.table.__name__}") if len(attr) == 1: # Field conditions. aggregate_conditions['fields'].append(attr[0]) - stmt = self._filter_parse_field_cond(stmt, attr[0], values) + stmt = self._filter_parse_field_cond(stmt, attr[0], to_it(values)) - elif len(attr) == 2 and not csval: # Operators. - field, operator = attr - match operator.strip(')').split('('): - case [("min_a" | "max_a") as op, _]: - aggregate_conditions['conditions'].append({'field': field, 'op': op}) - case _: - aggregate_conditions['fields'].append(field) - stmt = self._filter_parse_num_op(stmt, field, operator) + elif len(attr) == 2 and isinstance(values, Operator): # Operators. + if values.op in ("min_a", "max_a"): + aggregate_conditions['conditions'].append( + {'field': attr[0], 'operator': values} + ) + else: + aggregate_conditions['fields'].append(attr[0]) + stmt = self._filter_parse_num_op(stmt, attr[0], values) else: # Nested filter case, prepare for recursive call below. nested_attr = ".".join(attr[1::]) nested_conditions[attr[0]] = nested_conditions.get(attr[0], {}) - nested_conditions[attr[0]][nested_attr] = csval + nested_conditions[attr[0]][nested_attr] = values # Handle aggregations after everything else. for cond in aggregate_conditions['conditions']: - stmt = self._filter_parse_num_op( - stmt, cond['field'], f"{cond['op']}()", aggregate_conditions['fields'] - ) + stmt = self._filter_parse_num_op(stmt, **cond, aggregation=aggregate_conditions['fields']) # Get the fields without conditions normally # Importantly, the joins in that method are outer -> Not filtering. @@ -910,7 +901,7 @@ async def release( else: queried_version = pk_val[i] - stmt = self._filter_parse_num_op(stmt, field="version", operator="max_v()") + stmt = self._filter_parse_num_op(stmt, field="version", operator=Operator(op="max_v")) # Get item with all columns - covers many-to-one relationships. fields = set(self.table.__table__.columns.keys()) diff --git a/src/biodm/error.py b/src/biodm/error.py index b2b86f0..f66f2e4 100644 --- a/src/biodm/error.py +++ b/src/biodm/error.py @@ -8,6 +8,7 @@ FailedUpdate, FileUploadCompleteError, PayloadJSONDecodingError, + QueryError, RequestError, FailedDelete, FailedRead, @@ -19,7 +20,8 @@ FileNotUploadedError, FileTooLargeError, DataError, - ReleaseVersionError + ReleaseVersionError, + ManagerError ) @@ -52,9 +54,9 @@ async def onerror(_, exc): ) match exc: - case FileTooLargeError() | FileUploadCompleteError(): + case FileTooLargeError(): status = 400 - case DataError() | EndpointError() | PayloadJSONDecodingError(): + case DataError() | EndpointError() | QueryError() | PayloadJSONDecodingError(): status = 400 case FailedDelete() | FailedRead() | FailedUpdate(): status = 404 @@ -69,6 +71,8 @@ async def onerror(_, exc): status = 409 case PayloadEmptyError(): status = 204 + case ManagerError(): + status = 500 case TokenDecodingError(): status = 503 case UnauthorizedError(): diff --git a/src/biodm/exceptions.py b/src/biodm/exceptions.py index 7e4b186..aa56963 100644 --- a/src/biodm/exceptions.py +++ b/src/biodm/exceptions.py @@ -9,6 +9,8 @@ def __init__(self, detail: str) -> None: # origin: Exception # def __init__ +class ManagerError(RequestError): + """Holds errors raised by managers, converted in 5XX errors.""" class DBError(RuntimeError): """Raised when DB related errors are catched.""" @@ -77,6 +79,10 @@ class EndpointError(RequestError): """Raised when an endpoint is reached with wrong attributes, parameters and so on.""" +class QueryError(RequestError): + """Raised when an endpoint is reached with wrong query parameters.""" + + ## Routing class InvalidCollectionMethod(RequestError): """Raised when a unit method is accesed as a collection.""" diff --git a/src/biodm/managers/s3manager.py b/src/biodm/managers/s3manager.py index 5badd26..28d4282 100644 --- a/src/biodm/managers/s3manager.py +++ b/src/biodm/managers/s3manager.py @@ -2,11 +2,10 @@ from typing import TYPE_CHECKING, Any, Dict, List from boto3 import client -from botocore import response from botocore.exceptions import ClientError from starlette.datastructures import Secret -from biodm.exceptions import FailedCreate, FailedRead, FileUploadCompleteError +from biodm.exceptions import ManagerError from biodm.component import ApiManager if TYPE_CHECKING: @@ -66,7 +65,7 @@ def create_presigned_download_url(self, object_name: str) -> str: ) except ClientError as e: - raise FailedRead(str(e)) + raise ManagerError(str(e)) def create_multipart_upload(self, object_name: str) -> Dict[str, str]: """Create multipart upload @@ -86,7 +85,7 @@ def create_multipart_upload(self, object_name: str) -> Dict[str, str]: ) except ClientError as e: - raise FailedCreate(str(e)) + raise ManagerError(str(e)) def create_upload_part(self, object_name, upload_id, part_number): try: @@ -102,7 +101,7 @@ def create_upload_part(self, object_name, upload_id, part_number): ) except ClientError as e: - raise FailedCreate(str(e)) + raise ManagerError(str(e)) def complete_multipart_upload( self, @@ -131,7 +130,7 @@ def complete_multipart_upload( ) except ClientError as e: - raise FileUploadCompleteError(str(e.response['Error'])) + raise ManagerError(str(e.response['Error'])) def abort_multipart_upload(self, object_name: str, upload_id: str) -> Dict[str, str]: """Multipart upload termination notice @@ -152,4 +151,4 @@ def abort_multipart_upload(self, object_name: str, upload_id: str) -> Dict[str, ) except ClientError as e: - raise FailedCreate(str(e)) + raise ManagerError(str(e)) diff --git a/src/biodm/routing.py b/src/biodm/routing.py index e83753e..8bd835a 100644 --- a/src/biodm/routing.py +++ b/src/biodm/routing.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Sequence, Callable, Awaitable, Coroutine, Any +from typing import Sequence, Callable, Awaitable, Coroutine, Any import starlette.routing as sr from starlette.requests import Request @@ -11,10 +11,6 @@ from biodm.exceptions import UnauthorizedError -if TYPE_CHECKING: - from biodm.utils.security import UserInfo - - class RequireAuthMiddleware(BaseHTTPMiddleware): async def dispatch( self, diff --git a/src/tests/unit/test_resource.py b/src/tests/unit/test_resource.py index 640166d..ef0508c 100644 --- a/src/tests/unit/test_resource.py +++ b/src/tests/unit/test_resource.py @@ -217,7 +217,7 @@ def test_filter_wrong_op(client): client.get('/as?x.lt=2') -@pytest.mark.xfail(raises=exc.EndpointError) +@pytest.mark.xfail(raises=exc.QueryError) def test_filter_wrong_wildcard(client): item = {'x': 1, 'y': 2, 'bs': [{'name': 'bip'}, {'name': 'bap'},]}