From d97be878d13ae89924e5c103d75c475c9b72127f Mon Sep 17 00:00:00 2001 From: Etienne Jodry Date: Thu, 6 Feb 2025 14:34:35 +0100 Subject: [PATCH] Introduce Schema component -> more efficient serialization. Improve field restriction and support 2 levels. --- src/biodm/components/__init__.py | 1 + .../components/controllers/admincontroller.py | 2 - .../components/controllers/controller.py | 3 +- .../controllers/resourcecontroller.py | 56 +++-- .../components/controllers/s3controller.py | 4 +- src/biodm/components/schema.py | 48 ++++ src/biodm/components/services/dbservice.py | 206 +++++++++++------- src/biodm/components/services/s3service.py | 4 +- src/biodm/components/table.py | 2 +- src/biodm/managers/dbmanager.py | 10 +- src/biodm/schemas/error.py | 4 +- src/biodm/schemas/group.py | 5 +- src/biodm/schemas/listgroup.py | 3 +- src/biodm/schemas/refresh.py | 4 +- src/biodm/schemas/upload.py | 4 +- src/biodm/schemas/user.py | 4 +- src/biodm/tables/group.py | 3 + src/biodm/tables/listgroup.py | 1 + src/biodm/tables/upload.py | 2 +- src/biodm/tables/upload_part.py | 3 +- src/biodm/tables/user.py | 3 +- src/biodm/utils/biodm.py | 1 + src/biodm/utils/security.py | 13 +- src/example/entities/schemas/dataset.py | 3 +- src/example/entities/schemas/file.py | 5 +- src/example/entities/schemas/project.py | 4 +- src/example/entities/tables/dataset.py | 8 +- src/example/entities/tables/file.py | 2 +- src/example/entities/tables/project.py | 2 +- src/example/entities/tables/visualization.py | 4 +- src/tests/integration/kc/test_keycloak.py | 2 +- src/tests/integration/kc/test_permissions.py | 9 +- src/tests/integration/s3/test_files.py | 2 +- src/tests/unit/conftest.py | 16 +- src/tests/unit/test_resource.py | 3 +- src/tests/unit/test_versioning.py | 2 +- 36 files changed, 298 insertions(+), 150 deletions(-) create mode 100644 src/biodm/components/schema.py diff --git a/src/biodm/components/__init__.py b/src/biodm/components/__init__.py index 69727d5..44fdf5b 100644 --- a/src/biodm/components/__init__.py +++ b/src/biodm/components/__init__.py @@ -1,4 +1,5 @@ # Explicit re-export for mypy strict. +from .schema import Schema as Schema from .table import Base as Base from .table import S3File as S3File from .table import Versioned as Versioned diff --git a/src/biodm/components/controllers/admincontroller.py b/src/biodm/components/controllers/admincontroller.py index b3b3dd7..7b07f58 100644 --- a/src/biodm/components/controllers/admincontroller.py +++ b/src/biodm/components/controllers/admincontroller.py @@ -1,5 +1,3 @@ -from marshmallow import Schema - from biodm.utils.security import admin_required, login_required from .resourcecontroller import ResourceController diff --git a/src/biodm/components/controllers/controller.py b/src/biodm/components/controllers/controller.py index 21cc3e2..a35595a 100644 --- a/src/biodm/components/controllers/controller.py +++ b/src/biodm/components/controllers/controller.py @@ -6,7 +6,7 @@ from io import BytesIO from typing import Any, Iterable, List, Dict, TYPE_CHECKING, Optional -from marshmallow.schema import Schema +# from marshmallow.schema import Schema from marshmallow.exceptions import ValidationError from sqlalchemy.exc import MissingGreenlet from starlette.requests import Request @@ -14,6 +14,7 @@ import starlette.routing as sr from biodm import config +from biodm.components import Schema from biodm.component import ApiComponent from biodm.exceptions import ( DataError, PayloadJSONDecodingError, AsyncDBError, SchemaError diff --git a/src/biodm/components/controllers/resourcecontroller.py b/src/biodm/components/controllers/resourcecontroller.py index f269d5e..84e564b 100644 --- a/src/biodm/components/controllers/resourcecontroller.py +++ b/src/biodm/components/controllers/resourcecontroller.py @@ -284,7 +284,7 @@ async def _extract_body(self, request: Request) -> bytes: :rtype: bytes """ body = await request.body() - if body in (b'{}', b'[]', b'[{}]'): + if not body or body in (b'{}', b'[]', b'[{}]'): raise PayloadEmptyError("No input data.") return body @@ -307,18 +307,41 @@ def _extract_fields( fields = query_params.pop('fields', None) fields = fields.split(',') if fields else None - if fields: # User input case, check and raise. - fields = set(fields) | self.table.pk + if fields: # User input case, check validity. + fields = set(fields) + nested = [] for field in fields: - if field not in self.schema.dump_fields.keys(): - raise DataError(f"Requested field {field} does not exists.") - self.svc.check_allowed_nested(fields, user_info=user_info) - - else: # Default case, gracefully populate allowed fields. - fields = [ - k for k,v in self.schema.dump_fields.items() - ] - fields = self.svc.takeout_unallowed_nested(fields, user_info=user_info) + chain = field.split('.') + if len(chain) > 2: + raise QueryError("Requested fields can be set only on two levels.") + if chain[0] not in self.schema.dump_fields: + raise QueryError( + f"Requested field {field} does not exists at {self.prefix}." + ) + if len(chain) > 1: + nschema = self.schema.dump_fields[chain[0]] + match nschema: + case List(): + nschema = nschema.inner.schema + case Nested(): + nschema = nschema.schema + if chain[1] not in nschema.dump_fields: + raise QueryError( + f"Requested field {field} does not exist " + f"for child resource at {self.prefix}." + ) + if chain[0] in self.table.relationships: + nested.append(chain[0]) + + self.svc.check_allowed_nested(nested, user_info=user_info) + + else: # Default case, pass down all allowed dump_fields + fields = self.svc.takeout_unallowed_nested( + self.schema.dump_fields.keys(), + user_info=user_info + ) + + fields = fields | self.table.pk # fetch pk in any case. return fields def _extract_query_params(self, queryparams: QueryParams) -> Dict[str, Any]: @@ -555,7 +578,7 @@ async def read_nested(self, request: Request) -> Response: f"Unknown collection {nested_attribute} of {self.table.__class__.__name__}" ) - # Serialization and field extraction done by target controller. + # Serialization and field extraction done by target controller. ctrl: ResourceController = ( target_rel .mapper @@ -720,12 +743,17 @@ async def filter(self, request: Request) -> Response: """ params = self._extract_query_params(request.query_params) fields = self._extract_fields(params, user_info=request.user) + + ser_fields = [] + for f in fields: + ser_fields.append(f.split('.')[0]) + count = bool(params.pop('count', 0)) result = await self.svc.filter( fields=fields, params=params, user_info=request.user, - serializer=partial(self.serialize, many=True, only=fields), + serializer=partial(self.serialize, many=True, only=ser_fields), ) # Prepare response object. diff --git a/src/biodm/components/controllers/s3controller.py b/src/biodm/components/controllers/s3controller.py index 5169943..91f0b0c 100644 --- a/src/biodm/components/controllers/s3controller.py +++ b/src/biodm/components/controllers/s3controller.py @@ -1,11 +1,11 @@ from typing import List, Type -from marshmallow import Schema, RAISE, ValidationError +from marshmallow import RAISE, ValidationError import starlette.routing as sr from starlette.requests import Request from starlette.responses import Response, PlainTextResponse -from biodm.components import S3File +from biodm.components import S3File, Schema from biodm.components.services import S3Service from biodm.components.table import Base from biodm.schemas import PartsEtagSchema diff --git a/src/biodm/components/schema.py b/src/biodm/components/schema.py new file mode 100644 index 0000000..141313d --- /dev/null +++ b/src/biodm/components/schema.py @@ -0,0 +1,48 @@ +import marshmallow as ma + +from marshmallow.utils import get_value as ma_get_value, missing +from sqlalchemy.orm import make_transient +from sqlalchemy.orm.exc import DetachedInstanceError + + +from biodm.utils.utils import to_it + +"""Below is a way to check if data is properly loaded before running serialization.""" + + +SKIP_VALUES = (None, [], {}, '', '[]', '{}',) + + +def gettattr_unbound(obj, key: int | str, default=missing): + try: + return ma_get_value(obj, key, default) + except DetachedInstanceError: + return default + + +class Schema(ma.Schema): + # def __init__(self, *args, **kwargs): + # super().__init__(*args, **kwargs) + # for field in self.fields.values(): + # field.get_value = partial(field.get_value, accessor=gettattr_unbound) + + @ma.pre_dump + def turn_to_transient(self, data, **kwargs): + """Avoids serialization fetchig extra data from the database on the fly.""" + for one in to_it(data): + make_transient(one) + return data + + @ma.post_dump + def remove_skip_values(self, data, **kwargs): + """Removes un-necessary empty values from resulting dict.""" + return { + key: value for key, value in data.items() + if value not in SKIP_VALUES + } + + # def get_attribute(self, obj, attr, default): + # try: + # return super().get_attribute(obj, attr, default) + # except DetachedInstanceError: + # return None diff --git a/src/biodm/components/services/dbservice.py b/src/biodm/components/services/dbservice.py index c356b87..db51e45 100644 --- a/src/biodm/components/services/dbservice.py +++ b/src/biodm/components/services/dbservice.py @@ -10,11 +10,9 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import ( - load_only, selectinload, joinedload, ONETOMANY, MANYTOONE, Relationship + load_only, selectinload, joinedload, ONETOMANY, MANYTOONE, Relationship, Load ) from sqlalchemy.sql import Delete, Select -from sqlalchemy.sql.selectable import Alias - from biodm import config from biodm.component import ApiService from biodm.components import Base @@ -378,23 +376,22 @@ def _svc_from_rel_name(self, key: str) -> DatabaseService: else: return rel.mapper.entity.svc - def check_allowed_nested(self, fields: List[str], user_info: UserInfo) -> None: + def check_allowed_nested(self, nested_fields: List[str], user_info: UserInfo) -> None: """Checks whether all user requested fields are allowed by static permissions. - :param fields: list of fields + :param nested_fields: list of relationship fields :type fields: List[str] :param user_info: user info :type user_info: UserInfo :raises UnauthorizedError: protected nested field required without sufficient authorization """ - nested, _ = partition(fields, lambda x: x in self.table.relationships) - for name in nested: - target_svc = self._svc_from_rel_name(name) + for field in nested_fields: + target_svc = self._svc_from_rel_name(field) if target_svc._login_required("read") and not user_info.is_authenticated: raise UnauthorizedError() if not self._group_required("read", user_info.groups): - raise UnauthorizedError(f"Insufficient group privileges to retrieve {name}.") + raise UnauthorizedError(f"Insufficient group privileges to retrieve {field}.") def takeout_unallowed_nested(self, fields: List[str], user_info: UserInfo) -> List[str]: """Take out fields not allowed by static permissions. @@ -406,6 +403,9 @@ def takeout_unallowed_nested(self, fields: List[str], user_info: UserInfo) -> Li :return: List of fields with unallowed ones taken out. :rtype: List[str] """ + if not user_info: + return fields + nested, fields = partition(fields, lambda x: x in self.table.relationships) def ncheck(name): @@ -568,6 +568,7 @@ def _restrict_select_on_fields( :return: statement restricted on field list :rtype: Select """ + # fields = self.takeout_unallowed_nested(fields, user_info=user_info) nested, fields = partition(fields, lambda x: x in self.table.relationships) # Exclude hybrid properties. _, fields = partition( @@ -579,49 +580,48 @@ def _restrict_select_on_fields( ) stmt = self._apply_read_permissions(user_info, stmt) - # Fields + nested_fields = {} + + if fields: + # Process fields argument, to distinguish between this level and lower levels: + nfields, fields = partition(fields, lambda x: len(x.split('.')) > 1) + for nf in nfields: + spl = nf.split('.') + nested_fields[spl[0]] = nested_fields.get(spl[0], []) + nested_fields[spl[0]].extend(spl[1::]) + + # Fields at self.table stmt = stmt.options( - load_only( - *[ - getattr(self.table, f) - for f in fields - ] - ), + Load(self.table).load_only(*[getattr(self.table, f) for f in fields]) ) if fields else stmt - for n in nested: - relationship = self.table.relationships[n] - target = self.table if isinstance(relationship.target, Alias) else relationship.mapper.entity - - if relationship.direction in (MANYTOONE, ONETOMANY): - if target == self.table: - stmt = stmt.options( - joinedload( - getattr(self.table, n) - # TODO: possible optimization, see if there is a way to infer innerjoin. - # innerjoin=relationship.direction is ONETOMANY -> Wrong - ) - ) - else: - rel_stmt = select(target) - rel_stmt = ( - target - .svc - ._apply_read_permissions(user_info, rel_stmt) - ) - - stmt = stmt.join_from( - self.table, - rel_stmt.subquery(), - isouter=True - ) - else: - # TODO: check permissions ? - # Possible edge cases in o2o relationships ?? - stmt = stmt.options( - selectinload( - getattr(self.table, n) - ) + for n in set(nested) | nested_fields.keys(): + rel = self.table.relationships[n] + target = rel.mapper.entity + + # Get relationship fields + stmt = stmt.options(joinedload(getattr(self.table, n))) + stmt = stmt.options( + Load(target).load_only(*[getattr(target, f) for f in nested_fields.get(n)]) + ) if nested_fields.get(n) else stmt + + # Filter based on permissions. + if rel.direction in (MANYTOONE, ONETOMANY): # TODO: Handle else ? -> MANYTOMANY + rel_stmt = select(target) + rel_stmt = ( + target + .svc + ._apply_read_permissions(user_info, rel_stmt) + ).subquery() + + stmt = stmt.join_from( + self.table, + rel_stmt, + onclause=unevalled_all([ + getattr(self.table, local.name) == getattr(rel_stmt.columns, remote.name) + for local, remote in rel.local_remote_pairs + ]), + isouter=True ) return stmt @@ -770,42 +770,38 @@ def _filter_parse_field_cond(self, stmt: Select, field: str, values: Sequence[st # Field equality conditions. stmt = stmt.where( unevalled_or([ - col == v# ctype(v) -> Already casted + col == v# ctype(v) -> Already casted by Controller for v in values ]) ) if values else stmt return stmt - async def filter( + def _filter_apply_parameters( self, - fields: List[str], - params: Dict[str, str], - count: bool = False, - stmt_only: bool = False, - user_info: UserInfo | None = None, - **kwargs - ) -> List[Base]: - """READ rows filted on query parameters.""" - # Get special parameters - offset = int(params.pop('start', 0)) - limit = int(params.pop('end', config.LIMIT)) - reverse = params.pop('reverse', None) # TODO: ? - - # start building statement. - stmt = select(self.table).distinct() - - # For lower level(s) propagation. - propagate = {"start": offset, "end": limit, "reverse": reverse} - nested_conditions = {} + stmt: Select, + params: Dict[str, Any], + nested_params: Dict[str, Any] + ) -> Select: + """Apply query parameters on statement. + :param stmt: Statement under construction + :type stmt: Select + :param params: Parsed query parameters + :type params: Dict[str, Any] + :param nested_params: Nested query parameters for children resources filtering. + :type nested_conditions: Dict[str, Any] + :raises EndpointError: _description_ + :return: Statement with all conditions applied + :rtype: Select + """ # Track on which fields to aggregate in case. aggregate_conditions = {'fields': [], 'conditions': []} for dskey, values in params.items(): attr = dskey.split(".") - # A bit redondant since fields get checked against schema in _extract_query_params. + # A bit redondant since fields get checked against schema by Controller. if attr[0] not in self.table.relationships.keys() + self.table.__table__.columns.keys(): raise EndpointError(f"Unknown field {attr[0]} of table {self.table.__name__}") @@ -822,27 +818,65 @@ async def filter( aggregate_conditions['fields'].append(attr[0]) stmt = self._filter_parse_num_op(stmt, attr[0], values) - else: # Nested filter case, prepare for recursive call below. - nested_attr = ".".join(attr[1::]) - nested_conditions[attr[0]] = nested_conditions.get(attr[0], {}) - nested_conditions[attr[0]][nested_attr] = values + else: # Nested filter case, prepare for recursive call. + nested_params[attr[0]] = nested_params.get(attr[0], {}) + nested_params[attr[0]][".".join(attr[1::])] = values # Handle aggregations after everything else. for cond in aggregate_conditions['conditions']: - stmt = self._filter_parse_num_op(stmt, **cond, aggregation=aggregate_conditions['fields']) + stmt = self._filter_parse_num_op( + stmt, **cond, aggregation=aggregate_conditions['fields'] + ) + + return stmt + + async def filter( + self, + fields: List[str], + params: Dict[str, str], + count: bool = False, + stmt_only: bool = False, + user_info: UserInfo | None = None, + **kwargs + ) -> List[Base]: + """READ rows filted on query parameters.""" + # Get special parameters + offset = int(params.pop('start', 0)) + limit = int(params.pop('end', config.LIMIT)) + reverse = params.pop('reverse', None) # TODO: ? + + # start building statement. + stmt = select(self.table).distinct() + + # For lower level(s) propagation. + propagate = {"start": offset, "end": limit, "reverse": reverse} + nested_params = {} + + # Apply parameters + stmt = self._filter_apply_parameters( + stmt, params=params, nested_params=nested_params + ) # Get the fields without conditions normally # Importantly, the joins in that method are outer -> Not filtering. - stmt = self._restrict_select_on_fields(stmt, fields - nested_conditions.keys(), user_info) + stmt = self._restrict_select_on_fields(stmt, fields, user_info) + # Prepare recursive call for nested filters, and do an (inner) left join -> Filtering. - for nf_key, nf_conditions in nested_conditions.items(): + # fields fetched below are only relevant for condtions. + for nf_key, nf_conditions in nested_params.items(): nf_svc = self._svc_from_rel_name(nf_key) nf_fields = nf_svc.table.pk | set(nf_conditions.keys()) nf_conditions.update(propagate) # Take in special parameters. nf_stmt = ( - await nf_svc.filter(nf_fields, nf_conditions, stmt_only=True, user_info=user_info) + await nf_svc.filter( + nf_fields, + nf_conditions, + stmt_only=True, + user_info=user_info + ) ).subquery() + stmt = stmt.join_from( self.table, nf_stmt, @@ -1016,10 +1050,6 @@ def patch(ins, mapping): # Insert main object. item = await self._insert(composite.item, **kwargs) - # Needed so that permissions are taken into account before writing. - if self.permission_relationships: - await session.flush() - # Populate nested objects into main object. for key, sub in composite.nested.items(): await getattr(item.awaitable_attrs, key) @@ -1065,6 +1095,16 @@ def patch(ins, mapping): raise NotImplementedError else: setattr(item, key, await svc._insert(delay, **kwargs)) + + # Write to DB and fetch most up to date relationships. + # -> That is necessary for entities containing permissions + # -> For others this is an opionionated choice that after a write all + #  items should return their most up to date data. + + # => This operation effectively does one I/O operation per inserted item + #  so it impacts performance. + await session.flush() + await session.refresh(item, rels.keys()) return item # pylint: disable=arguments-differ diff --git a/src/biodm/components/services/s3service.py b/src/biodm/components/services/s3service.py index dbdf853..6606c79 100644 --- a/src/biodm/components/services/s3service.py +++ b/src/biodm/components/services/s3service.py @@ -125,9 +125,11 @@ async def complete_multipart( session=session ) upload = await getattr(file.awaitable_attrs, 'upload') + upload_id = await getattr(upload.awaitable_attrs, 's3_uploadId') + complete = self.s3.complete_multipart_upload( object_name=await self.gen_key(file, session=session), - upload_id=upload.s3_uploadId, + upload_id=upload_id, parts=parts ) if ( diff --git a/src/biodm/components/table.py b/src/biodm/components/table.py index 77b9d39..14ac865 100644 --- a/src/biodm/components/table.py +++ b/src/biodm/components/table.py @@ -304,7 +304,7 @@ class S3File: @declared_attr def upload(cls) -> Mapped["Upload"]: - return relationship(backref="file", foreign_keys=[cls.upload_id]) + return relationship(backref="file", foreign_keys=[cls.upload_id], lazy="joined") dl_count = Column(Integer, nullable=False, server_default='0') diff --git a/src/biodm/managers/dbmanager.py b/src/biodm/managers/dbmanager.py index dc38f1b..706ea2e 100644 --- a/src/biodm/managers/dbmanager.py +++ b/src/biodm/managers/dbmanager.py @@ -116,10 +116,12 @@ async def wrapper(*args, **kwargs) -> Any | str | None: ) # Call and serialize result if requested. db_result = await db_exec(*args, session=session, **kwargs) - result = await session.run_sync( - lambda _, data: serializer(data), db_result - ) if serializer else db_result - return result + if serializer: + # Serialization is going to make instances transient => commit beforehand. + await session.commit() + # Run sync as running in async may cause greenlet_spawn errors. + return await session.run_sync(lambda _, data: serializer(data), db_result) + return db_result wrapper.__annotations__ = db_exec.__annotations__ wrapper.__name__ = db_exec.__name__ diff --git a/src/biodm/schemas/error.py b/src/biodm/schemas/error.py index 968b7c5..c15fe10 100644 --- a/src/biodm/schemas/error.py +++ b/src/biodm/schemas/error.py @@ -1,6 +1,8 @@ -from marshmallow import Schema +# from marshmallow import Schema from marshmallow.fields import String, Number +from biodm.components import Schema + class ErrorSchema(Schema): """Schema for errors returned by Api, mostly for apispec purposes.""" diff --git a/src/biodm/schemas/group.py b/src/biodm/schemas/group.py index 5551661..b3e23ca 100644 --- a/src/biodm/schemas/group.py +++ b/src/biodm/schemas/group.py @@ -1,6 +1,7 @@ -from marshmallow import Schema +# from marshmallow import Schema from marshmallow.fields import String, List, Nested +from biodm.components import Schema from .user import UserSchema class GroupSchema(Schema): @@ -9,4 +10,4 @@ class GroupSchema(Schema): users = List(Nested(lambda: UserSchema(load_only=['groups']))) children = List(Nested(lambda: GroupSchema(load_only=['users', 'children', 'parent']))) - parent = Nested('GroupSchema', dump_only=True) + parent = Nested(lambda: GroupSchema(load_only=['users', 'children', 'parent']), dump_only=True) diff --git a/src/biodm/schemas/listgroup.py b/src/biodm/schemas/listgroup.py index 5addba2..5abf83f 100644 --- a/src/biodm/schemas/listgroup.py +++ b/src/biodm/schemas/listgroup.py @@ -1,6 +1,7 @@ -from marshmallow import Schema +# from marshmallow import Schema from marshmallow.fields import List, Nested, Integer +from biodm.components import Schema from .group import GroupSchema class ListGroupSchema(Schema): diff --git a/src/biodm/schemas/refresh.py b/src/biodm/schemas/refresh.py index e5b79e8..0bea721 100644 --- a/src/biodm/schemas/refresh.py +++ b/src/biodm/schemas/refresh.py @@ -1,6 +1,8 @@ -from marshmallow import Schema +# from marshmallow import Schema from marshmallow.fields import String +from biodm.components import Schema + class RefreshSchema(Schema): """Schema for logout""" diff --git a/src/biodm/schemas/upload.py b/src/biodm/schemas/upload.py index 8eaf0dc..cee8cf9 100644 --- a/src/biodm/schemas/upload.py +++ b/src/biodm/schemas/upload.py @@ -1,8 +1,10 @@ -from marshmallow import Schema +# from marshmallow import Schema from marshmallow.fields import String, List, Nested, Integer +from biodm.components import Schema from biodm.utils.utils import check_hash + class PartsEtagSchema(Schema): PartNumber = Integer() ETag = String(validate=check_hash) diff --git a/src/biodm/schemas/user.py b/src/biodm/schemas/user.py index c4b679c..eca5913 100644 --- a/src/biodm/schemas/user.py +++ b/src/biodm/schemas/user.py @@ -1,6 +1,8 @@ -from marshmallow import Schema +# from marshmallow import Schema from marshmallow.fields import String, List, Nested +from biodm.components import Schema + class UserSchema(Schema): """Schema for Keycloak Users. id field is purposefully left out as we manage it internally.""" diff --git a/src/biodm/tables/group.py b/src/biodm/tables/group.py index bccfe2b..b384bf5 100644 --- a/src/biodm/tables/group.py +++ b/src/biodm/tables/group.py @@ -22,6 +22,7 @@ class Group(Base): users: Mapped[List["User"]] = relationship( secondary=asso_user_group, back_populates="groups", + lazy="joined" # init=False, ) @@ -80,6 +81,7 @@ def __repr__(self): foreign_keys=[Group_alias.path], uselist=False, viewonly=True, + # lazy="joined" ) @@ -89,4 +91,5 @@ def __repr__(self): foreign_keys=[Group_alias.parent_path], uselist=True, viewonly=True, + lazy="joined" ) diff --git a/src/biodm/tables/listgroup.py b/src/biodm/tables/listgroup.py index dc600d2..fd4c75d 100644 --- a/src/biodm/tables/listgroup.py +++ b/src/biodm/tables/listgroup.py @@ -16,4 +16,5 @@ class ListGroup(Base): groups: Mapped[List["Group"]] = relationship( secondary=asso_list_group, + lazy="joined" ) diff --git a/src/biodm/tables/upload.py b/src/biodm/tables/upload.py index 96b52cd..d043f77 100644 --- a/src/biodm/tables/upload.py +++ b/src/biodm/tables/upload.py @@ -11,4 +11,4 @@ class Upload(Base): id: Mapped[int] = mapped_column(primary_key=True) s3_uploadId: Mapped[str] = mapped_column(nullable=True) - parts: Mapped[List["UploadPart"]] = relationship(back_populates="upload") + parts: Mapped[List["UploadPart"]] = relationship(back_populates="upload", lazy="joined") diff --git a/src/biodm/tables/upload_part.py b/src/biodm/tables/upload_part.py index 1051057..66406d2 100644 --- a/src/biodm/tables/upload_part.py +++ b/src/biodm/tables/upload_part.py @@ -16,5 +16,6 @@ class UploadPart(Base): upload: Mapped["Upload"] = relationship( back_populates="parts", foreign_keys=[upload_id], - single_parent=True + single_parent=True, + lazy="joined", ) diff --git a/src/biodm/tables/user.py b/src/biodm/tables/user.py index 7b742c0..448b67b 100644 --- a/src/biodm/tables/user.py +++ b/src/biodm/tables/user.py @@ -24,5 +24,6 @@ class User(Base): groups: Mapped[List["Group"]] = relationship( secondary=asso_user_group, - back_populates="users" + back_populates="users", + lazy="joined" ) diff --git a/src/biodm/utils/biodm.py b/src/biodm/utils/biodm.py index ef7183e..afa0298 100644 --- a/src/biodm/utils/biodm.py +++ b/src/biodm/utils/biodm.py @@ -2,6 +2,7 @@ from sqlalchemy.orm.relationships import _RelationshipDeclared + if TYPE_CHECKING: from biodm.components import Base diff --git a/src/biodm/utils/security.py b/src/biodm/utils/security.py index 312e7a7..4e2411e 100644 --- a/src/biodm/utils/security.py +++ b/src/biodm/utils/security.py @@ -6,7 +6,7 @@ from inspect import getmembers, ismethod from typing import TYPE_CHECKING, List, Tuple, Set, ClassVar, Type, Any, Dict -from marshmallow import fields, Schema +from marshmallow import fields# Schema from starlette.authentication import BaseUser from starlette.requests import HTTPConnection from starlette.types import ASGIApp, Receive, Scope, Send @@ -258,7 +258,8 @@ class ASSO_PERM_{TABLE}_{FIELD}(Base): entity = relationship( Table, foreign_keys=[pk_1_Table, ..., pk_n_Table], - backref=backref(f'perm_{field}', uselist=False) + backref=backref(f'perm_{field}', uselist=False), + lazy="joined" ) :param app: Api object, used to declare service. @@ -290,7 +291,8 @@ class ASSO_PERM_{TABLE}_{FIELD}(Base): rel_name, uselist=False, # all\{refresh-expire} + delete-orphan: important. - cascade="save-update, merge, delete, expunge, delete-orphan" + cascade="save-update, merge, delete, expunge, delete-orphan", + lazy="joined" ), foreign_keys="[" + ",".join( [ @@ -324,7 +326,8 @@ class ASSO_PERM_{TABLE}_{FIELD}(Base): "ListGroup", cascade="save-update, merge, delete, delete-orphan", foreign_keys=[c], - single_parent=True + single_parent=True, + lazy="joined" ) } ) @@ -365,7 +368,7 @@ def _gen_perm_schema(table: Type['Base'], fkey: str, verbs: List[str]): # back reference - probably unwanted. # schema_columns['entity'] = fields.Nested(table.ctrl.schema) - + from biodm.components import Schema return type( f"AssoPerm{table.__name__.capitalize()}{fkey.capitalize()}Schema", (Schema,), diff --git a/src/example/entities/schemas/dataset.py b/src/example/entities/schemas/dataset.py index e5ad93c..43d4845 100644 --- a/src/example/entities/schemas/dataset.py +++ b/src/example/entities/schemas/dataset.py @@ -1,6 +1,7 @@ -from marshmallow import Schema #, validate, pre_load, ValidationError +# from marshmallow import Schema #, validate, pre_load, ValidationError from marshmallow.fields import String, Date, List, Nested, Integer, UUID +from biodm.components import Schema from biodm.schemas import UserSchema from .project import ProjectSchema diff --git a/src/example/entities/schemas/file.py b/src/example/entities/schemas/file.py index 99d83de..73a4a7c 100644 --- a/src/example/entities/schemas/file.py +++ b/src/example/entities/schemas/file.py @@ -1,7 +1,10 @@ -from marshmallow import Schema, validate +# from marshmallow import Schema, validate from marshmallow.fields import String, List, Nested, Integer, Bool from biodm.schemas import UploadSchema +from biodm.components import Schema + + class FileSchema(Schema): id = Integer() filename = String() diff --git a/src/example/entities/schemas/project.py b/src/example/entities/schemas/project.py index ef3e524..bd4f0fa 100644 --- a/src/example/entities/schemas/project.py +++ b/src/example/entities/schemas/project.py @@ -1,6 +1,8 @@ -from marshmallow import Schema +# from marshmallow import Schema from marshmallow.fields import String, List, Nested, Integer +from biodm.components import Schema + class ProjectSchema(Schema): id = Integer() diff --git a/src/example/entities/tables/dataset.py b/src/example/entities/tables/dataset.py index 8dd971a..9dad347 100644 --- a/src/example/entities/tables/dataset.py +++ b/src/example/entities/tables/dataset.py @@ -46,10 +46,10 @@ class Dataset(Versioned, Base): # # relationships # policy - cascade="save-update, merge" ? - contact: Mapped["User"] = relationship(foreign_keys=[contact_username]) - tags: Mapped[Set["Tag"]] = relationship(secondary=asso_dataset_tag, uselist=True) - project: Mapped["Project"] = relationship(back_populates="datasets") - files: Mapped[List["File"]] = relationship(back_populates="dataset") + contact: Mapped["User"] = relationship(foreign_keys=[contact_username], lazy="joined") + tags: Mapped[Set["Tag"]] = relationship(secondary=asso_dataset_tag, uselist=True, lazy="joined") + project: Mapped["Project"] = relationship(back_populates="datasets", lazy="joined") + files: Mapped[List["File"]] = relationship(back_populates="dataset", lazy="joined") # # permission_lv2: Mapped["Permission_lv2"] = relationship() diff --git a/src/example/entities/tables/file.py b/src/example/entities/tables/file.py index 8107945..f02b042 100644 --- a/src/example/entities/tables/file.py +++ b/src/example/entities/tables/file.py @@ -48,7 +48,7 @@ async def key(self) -> str: ) # relationships - dataset: Mapped["Dataset"] = relationship(back_populates="files", foreign_keys=[dataset_id, dataset_version]) + dataset: Mapped["Dataset"] = relationship(back_populates="files", foreign_keys=[dataset_id, dataset_version], lazy="joined") # # dataset: Mapped["Dataset"] = relationship(back_populates="files", foreign_keys=[dataset_id, dataset_version]) # # dataset: Mapped["Dataset"] = relationship('Dataset', primaryjoin="and_(Dataset.id == File.dataset_id, Dataset.version == File.dataset_version)") diff --git a/src/example/entities/tables/project.py b/src/example/entities/tables/project.py index 6a0efad..bbef56f 100644 --- a/src/example/entities/tables/project.py +++ b/src/example/entities/tables/project.py @@ -21,7 +21,7 @@ class Project(Base): # updated_at = Column(TIMESTAMP(timezone=True), server_default=text('now()')) # relationshipsNoneNone - datasets: Mapped[List["Dataset"]] = relationship(back_populates="project") + datasets: Mapped[List["Dataset"]] = relationship(back_populates="project", lazy="joined") # visualizations: Mapped[List["Visualization"]] = relationship(back_populates="project") # analyses: Mapped[List["Analysis"]] = relationship(back_populates="project") diff --git a/src/example/entities/tables/visualization.py b/src/example/entities/tables/visualization.py index b2aa485..63e0583 100644 --- a/src/example/entities/tables/visualization.py +++ b/src/example/entities/tables/visualization.py @@ -21,7 +21,7 @@ class Visualization(Base): # id_k8sinstance: Mapped[int] = mapped_column(ForeignKey("K8SINSTANCE.id")) # Relationships - user: Mapped["User"] = relationship(foreign_keys=[user_username]) + user: Mapped["User"] = relationship(foreign_keys=[user_username], lazy="joined") # project: Mapped["Project"] = relationship(back_populates="visualizations", lazy="select") - file: Mapped["File"] = relationship(foreign_keys=[file_id]) + file: Mapped["File"] = relationship(foreign_keys=[file_id], lazy="joined") # k8sinstance: Mapped["K8sInstance"] = relationship(foreign_keys=[id_k8sinstance], lazy="select") diff --git a/src/tests/integration/kc/test_keycloak.py b/src/tests/integration/kc/test_keycloak.py index a951763..9e81246 100644 --- a/src/tests/integration/kc/test_keycloak.py +++ b/src/tests/integration/kc/test_keycloak.py @@ -195,7 +195,7 @@ def test_create_groups_with_parent(srv_endpoint, utils, admin_header): assert len(json_parent) == 1 json_parent = json_parent[0] - assert json_parent['parent'] is None + assert 'parent' not in json_parent assert len(json_parent['children']) == 2 assert json_parent['children'][0]['path'] == child1['path'] assert json_parent['children'][1]['path'] == child2['path'] diff --git a/src/tests/integration/kc/test_permissions.py b/src/tests/integration/kc/test_permissions.py index ee0821c..df6f939 100644 --- a/src/tests/integration/kc/test_permissions.py +++ b/src/tests/integration/kc/test_permissions.py @@ -219,16 +219,19 @@ def test_read_dataset_no_read_perm(srv_endpoint): f'{srv_endpoint}/datasets', headers=headers1 ) - json_response1 = json.loads(response1.text) + assert response1.status_code == 200 + response2 = requests.get( f'{srv_endpoint}/datasets', headers=headers2 ) + assert response2.status_code == 200 + + json_response2 = json.loads(response2.text) + json_response1 = json.loads(response1.text) - assert response1.status_code == 200 - assert response2.status_code == 200 assert len(json_response1) == 1 assert str(json_response1[0]['name']) == str(dataset1['name']) assert str(json_response1[0]['project_id']) == str(dataset1['project_id']) diff --git a/src/tests/integration/s3/test_files.py b/src/tests/integration/s3/test_files.py index 1820899..2cbf165 100644 --- a/src/tests/integration/s3/test_files.py +++ b/src/tests/integration/s3/test_files.py @@ -107,7 +107,7 @@ def test_file_readiness(srv_endpoint): json_file = json.loads(response.text) assert json_file['ready'] == True - assert json_file['upload'] == None + assert 'upload' not in json_file @pytest.mark.dependency(name="test_file_upload") diff --git a/src/tests/unit/conftest.py b/src/tests/unit/conftest.py index cd3f3d8..06257c7 100644 --- a/src/tests/unit/conftest.py +++ b/src/tests/unit/conftest.py @@ -8,7 +8,7 @@ from starlette.testclient import TestClient from biodm.api import Api -from biodm.components import Base, Versioned +from biodm.components import Base, Versioned, Schema from biodm.components.controllers import ResourceController # SQLAlchemy @@ -41,8 +41,8 @@ class A(Base): y = sa.Column(sa.Integer, nullable=True) id_c: Mapped[Optional[int]] = mapped_column(sa.Integer, sa.ForeignKey("C.id"), nullable=True) - bs: Mapped[List["B"]] = relationship(secondary=asso_a_b, uselist=True, lazy="select") - c: Mapped["C"] = relationship(foreign_keys=[id_c], backref="ca", lazy="select") + bs: Mapped[List["B"]] = relationship(secondary=asso_a_b, uselist=True, lazy="joined") + c: Mapped["C"] = relationship(foreign_keys=[id_c], backref="ca", lazy="joined") class B(Versioned, Base): @@ -60,11 +60,11 @@ class D(Versioned, Base): info = sa.Column(sa.String, nullable=False) - cs: Mapped[List["C"]] = relationship(secondary=asso_c_d, uselist=True, lazy="select") + cs: Mapped[List["C"]] = relationship(secondary=asso_c_d, uselist=True, lazy="joined") # Schemas -class ASchema(ma.Schema): +class ASchema(Schema): id = ma.fields.Integer() x = ma.fields.Integer() y = ma.fields.Integer() @@ -74,21 +74,21 @@ class ASchema(ma.Schema): c = ma.fields.Nested("CSchema") -class BSchema(ma.Schema): +class BSchema(Schema): id = ma.fields.Integer() version = ma.fields.Integer() name = ma.fields.String() -class CSchema(ma.Schema): +class CSchema(Schema): id = ma.fields.Integer() data = ma.fields.String() ca = ma.fields.Nested("ASchema") -class DSchema(ma.Schema): +class DSchema(Schema): id = ma.fields.Integer() version = ma.fields.Integer() diff --git a/src/tests/unit/test_resource.py b/src/tests/unit/test_resource.py index ef0508c..6e6b71a 100644 --- a/src/tests/unit/test_resource.py +++ b/src/tests/unit/test_resource.py @@ -1,4 +1,5 @@ from copy import deepcopy +import time import pytest import json @@ -39,7 +40,6 @@ def test_create_composite_resource(client): oracle['id'] = 1 oracle['id_c'] = 1 oracle['c']['id'] = 1 - oracle['c']['ca'] = {} for i, x in enumerate(oracle['bs']): x['id'] = i+1 x['version'] = 1 @@ -273,7 +273,6 @@ def test_read_nested_collection(client): create = client.post('/as', content=json_bytes(item)) assert create.status_code == 201 - response = client.get('/as/1/bs') assert response.status_code == 200 diff --git a/src/tests/unit/test_versioning.py b/src/tests/unit/test_versioning.py index b09ebf5..e49a4c5 100644 --- a/src/tests/unit/test_versioning.py +++ b/src/tests/unit/test_versioning.py @@ -142,7 +142,7 @@ def test_update_nested_list_after_release_of_parent_resource(client): content=json_bytes(update_nested) ) oracle_nested = [update_nested['cs'][0]] - oracle_nested[0].update({'id': 3, 'ca': {}}) + oracle_nested[0].update({'id': 3}) assert update_response.status_code == 201 release_json = json.loads(update_response.text)