Skip to content

Commit

Permalink
Type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
Etienne Jodry authored and Etienne Jodry committed Apr 25, 2024
1 parent fc057a8 commit 4600c9a
Show file tree
Hide file tree
Showing 11 changed files with 127 additions and 100 deletions.
24 changes: 14 additions & 10 deletions src/biodm/api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from asyncio import wait_for, TimeoutError
import logging
import logging.config
from typing import List
from typing import List, Optional

from starlette.applications import Starlette
from starlette.middleware.base import BaseHTTPMiddleware, DispatchFunction, RequestResponseEndpoint
Expand Down Expand Up @@ -32,9 +32,7 @@ def __init__(self, app: ASGIApp, timeout: int=30) -> None:

async def dispatch(self, request, call_next):
try:
return await wait_for(
call_next(request),
timeout=self.timeout)
return await wait_for(call_next(request), timeout=self.timeout)
except TimeoutError:
return HTMLResponse("Request reached timeout.", status_code=504)

Expand All @@ -44,8 +42,8 @@ class HistoryMiddleware(BaseHTTPMiddleware):
def __init__(self, app: ASGIApp, server_host: str) -> None:
self.server_host = server_host
super().__init__(app, self.dispatch)
async def dispatch(self, request: Request, call_next) -> Response:

async def dispatch(self, request, call_next):
if auth_header(request):
app = History.svc.app
username, _, _ = await extract_and_decode_token(app.kc, request)
Expand All @@ -56,7 +54,7 @@ async def dispatch(self, request: Request, call_next) -> Response:
'method': request.method,
'content': body if body else ""
}
await History.svc.create(h, stmt_only=False, serializer=None)
await History.svc.create(h, stmt_only=False)
return await call_next(request)


Expand All @@ -71,7 +69,13 @@ class Api(Starlette):
"""
logger = logging.getLogger(__name__)

def __init__(self, config=None, controllers=[], routes=[], tables=None, schemas=None, *args, **kwargs):
def __init__(self,
config=None,
controllers: Optional[List[Controller]]=[],
routes: Optional[List[Route]]=[],
tables=None,
schemas=None,
*args, **kwargs):
self.tables = tables
self.schemas = schemas
self.config = config
Expand Down Expand Up @@ -102,8 +106,8 @@ def __init__(self, config=None, controllers=[], routes=[], tables=None, schemas=
because the services needs to access the app instance.
If more useful cases for this show up we might want to design a cleaner solution.
"""
History.svc = UnaryEntityService(app=self, table=History, pk=('timestamp', 'username_user'))
ListGroup.svc = CompositeEntityService(app=self, table=ListGroup, pk=('id',))
History.svc = UnaryEntityService(app=self, table=History)
ListGroup.svc = CompositeEntityService(app=self, table=ListGroup)

super(Api, self).__init__(routes=routes, *args, **kwargs)

Expand Down
7 changes: 4 additions & 3 deletions src/biodm/components/controllers/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import json
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any
from typing import Any, TYPE_CHECKING

from marshmallow.schema import Schema, EXCLUDE

from biodm.utils.utils import json_response

if TYPE_CHECKING:
from biodm.api import Api

class HttpMethod(Enum):
GET = "GET"
Expand All @@ -20,7 +21,7 @@ class HttpMethod(Enum):
class Controller(ABC):
@classmethod
def init(cls, app) -> None:
cls.app = app
cls.app: Api = app
return cls()

# Routes
Expand Down
2 changes: 1 addition & 1 deletion src/biodm/components/controllers/kccontroller.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
class KCController(ResourceController):
"""Controller for entities managed by keycloak (i.e. User/Group)."""
def _infer_svc(self) -> KCService:
match self.entity.lower():
match self.resource.lower():
case "user":
return KCUserService
case "group":
Expand Down
38 changes: 17 additions & 21 deletions src/biodm/components/controllers/resourcecontroller.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
from __future__ import annotations
from functools import partial
from typing import Any, Tuple
from typing import Any, Tuple, TYPE_CHECKING

from marshmallow.schema import Schema
from starlette.routing import Mount, Route
from biodm.components import Base
from biodm.components.services import (
DatabaseService,
UnaryEntityService,
CompositeEntityService,
)

from biodm.components.services import DatabaseService, UnaryEntityService, CompositeEntityService
from biodm.exceptions import InvalidCollectionMethod, EmptyPayloadException
from biodm.utils.utils import json_response

from .controller import HttpMethod, EntityController

if TYPE_CHECKING:
from biodm.components import Base
from marshmallow.schema import Schema


def overload_docstring(f):
"""Decorator to allow for docstring overloading.
Expand Down Expand Up @@ -46,26 +45,23 @@ def __init__(self,
entity: str=None,
table: Base=None,
schema: Schema=None):
self.entity = entity if entity else self._infer_entity_name()
self.resource = entity if entity else self._infer_entity_name()
self.table = table if table else self._infer_table()
self.pk: Tuple[str, ...] = tuple(
str(pk).split('.')[-1]
for pk in self.table.__table__.primary_key.columns
)
self.svc = self._infer_svc()(app=self.app, table=self.table, pk=self.pk)
self.pk = tuple(self.table.pk())
self.svc = self._infer_svc()(app=self.app, table=self.table)
self.schema = schema() if schema else self._infer_schema()

def _infer_entity_name(self) -> str:
"""Infer entity name from controller name."""
return self.__class__.__name__.split("Controller")[0]

@property
def prefix(self):
def prefix(self) -> str:
"""Computes route path prefix from entity name."""
return '/' + self.entity.lower() + 's'
return '/' + self.resource.lower() + 's'

@property
def qp_id(self):
def qp_id(self) -> str:
"""Put primary key in queryparam form."""
return "".join(["{" + k + "}_" for k in self.pk])[:-1]

Expand All @@ -79,16 +75,16 @@ def _infer_svc(self) -> DatabaseService:

def _infer_table(self) -> Base:
try:
return self.app.tables.__dict__[self.entity]
return self.app.tables.__dict__[self.resource]
except:
raise ValueError(
f"{self.__class__.__name__} could not find {self.entity} Table."
f"{self.__class__.__name__} could not find {self.resource} Table."
" Alternatively if you are following another naming convention "
"you should provide it as 'table' arg when creating a new controller"
)

def _infer_schema(self) -> Schema:
isn = f"{self.entity}Schema"
isn = f"{self.resource}Schema"
try:
return self.app.schemas.__dict__[isn]()
except:
Expand Down
62 changes: 30 additions & 32 deletions src/biodm/components/services/dbservice.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Any, overload, Tuple
from typing import List, Any, Tuple, TYPE_CHECKING
from contextlib import AsyncExitStack

from sqlalchemy import select, update, delete
Expand All @@ -12,6 +13,8 @@
from biodm.components import Base
from biodm.managers import DatabaseManager
from biodm.exceptions import FailedRead, FailedDelete, FailedUpdate
if TYPE_CHECKING:
from biodm.api import Api


SUPPORTED_INT_OPERATORS = ('gt', 'ge', 'lt', 'le')
Expand All @@ -21,7 +24,7 @@ class DatabaseService(ABC):
"""Root Service class: manages database transactions for entities."""
def __init__(self, app):
self.logger = app.logger
self.app = app
self.app: Api = app

@abstractmethod
async def create(self, data, stmt_only=False, **kwargs):
Expand Down Expand Up @@ -100,10 +103,10 @@ async def _delete(self, stmt: Delete, session: AsyncSession) -> None:

class UnaryEntityService(DatabaseService):
"""Generic Service class for non-composite entities."""
def __init__(self, app, table: Base, pk: Tuple[str, ...], *args, **kwargs):
def __init__(self, app, table: Base, *args, **kwargs):
# Entity info.
self._table = table
self.pk = tuple(table.col(key) for key in pk)
self.pk = tuple(table.col(name) for name in table.pk())
self.relationships = table.relationships()
# Enable entity - service - table linkage so everything is conveniently available.
table.svc = self
Expand All @@ -112,26 +115,13 @@ def __init__(self, app, table: Base, pk: Tuple[str, ...], *args, **kwargs):
super(UnaryEntityService, self).__init__(app=app, *args, **kwargs)

def __repr__(self) -> str:
return("{}({!r})".format(self.__class__.__name__, self._table.__name__))
return "{}({!r})".format(self.__class__.__name__, self._table.__name__)

@property
def table(self) -> Base:
return self._table

@property
def db(self):
return self.app.db

@property
def session(self) -> AsyncSession:
"""Important for when @in_session is applied on a service method."""
return self.db.session

@overload
async def create(self, data, stmt_only: bool=True, **kwargs) -> Insert:
"""..."""

async def create(self, data, stmt_only: bool=False, **kwargs) -> Base | List[Base]:
async def create(self, data, stmt_only: bool=False, **kwargs) -> Insert | Base | List[Base]:
"""CREATE one or many rows. data: schema validation result."""
stmt = insert(self.table)

Expand Down Expand Up @@ -182,7 +172,7 @@ async def create_update(self, pk_val, data: dict) -> Base:
item = self.table(**kw, **data)
return await self._merge(item)

def _parse_int_operators(self, attr):
def _parse_int_operators(self, attr) -> Tuple[str, str]:
input_op = attr.pop()
match input_op.strip(')').split('('):
case [('gt'| 'ge' | 'lt' | 'le') as op, arg]:
Expand Down Expand Up @@ -282,15 +272,15 @@ async def delete(self, pk_val, **kwargs) -> Any:

class CompositeEntityService(UnaryEntityService):
"""Special case for Composite Entities (i.e. containing nested entities attributes)."""
class CompositeInsert(object):
class CompositeInsert():
"""Class to manage composite entities insertions."""
def __init__(self, item: Insert, nested: dict={}, delayed: dict={}) -> None:
self.item = item
self.nested = nested
self.delayed = delayed

@DatabaseManager.in_session
async def _insert_composite(self, composite: CompositeInsert, session: AsyncSession) -> Base | None:
async def _insert_composite(self, composite: CompositeInsert, session: AsyncSession) -> Base:
# Insert all nested objects, and keep track.
for key, sub in composite.nested.items():
composite.nested[key] = await self._insert(sub, session)
Expand Down Expand Up @@ -319,21 +309,24 @@ async def _insert_composite(self, composite: CompositeInsert, session: AsyncSess
await session.commit()
return item

async def _insert(self, stmt: Insert | CompositeInsert, session: AsyncSession=None) -> Base | None:
async def _insert(self, stmt: Insert | CompositeInsert, session: AsyncSession) -> Base:
"""Redirect in case of composite insert. Mid-level: No need for in_session decorator."""
if isinstance(stmt, self.CompositeInsert):
return await self._insert_composite(stmt, session)
else:
return await super(CompositeEntityService, self)._insert(stmt, session)
return await super(CompositeEntityService, self)._insert(stmt, session)

async def _insert_many(self, stmt: Insert | List[CompositeInsert], session: AsyncSession=None) -> List[Base]:
async def _insert_many(self,
stmt: Insert | List[CompositeInsert],
session: AsyncSession) -> List[Base]:
"""Redirect in case of composite insert. Mid-level: No need for in_session decorator."""
if isinstance(stmt, Insert):
return await super(CompositeEntityService, self)._insert_many(stmt, session)
else:
return [await self._insert_composite(composite, session) for composite in stmt]
return [await self._insert_composite(composite, session) for composite in stmt]

async def _create_one(self, data: dict, stmt_only: bool=False, **kwargs) -> Base | CompositeInsert:
async def _create_one(self,
data: dict,
stmt_only: bool=False,
**kwargs) -> Base | CompositeInsert:
"""CREATE, accounting for nested entitites."""
nested = {}
delayed = {}
Expand Down Expand Up @@ -368,12 +361,16 @@ async def _create_one(self, data: dict, stmt_only: bool=False, **kwargs) -> Base
return composite if stmt_only else await self._insert_composite(composite, **kwargs)

# @DatabaseManager.in_session
async def _create_many(self, data: List[dict], stmt_only: bool=False, session: AsyncSession = None, **kwargs) -> List[Base] | List[CompositeInsert]:
async def _create_many(self,
data: List[dict],
stmt_only: bool=False,
session: AsyncSession = None,
**kwargs) -> List[Base] | List[CompositeInsert]:
"""Share session & top level stmt_only=True for list creation.
Issues a session.commit() after each insertion."""
async with AsyncExitStack() as stack:
session = session if session else (
await stack.enter_async_context(self.session()))
await stack.enter_async_context(self.app.db.session()))
composites = []
for one in data:
composites.append(
Expand All @@ -384,7 +381,8 @@ async def _create_many(self, data: List[dict], stmt_only: bool=False, session: A
await session.commit()
return composites

async def create(self, data: List[dict] | dict, **kwargs) -> Base | CompositeInsert | List[Base] | List[CompositeInsert]:
async def create(self, data: List[dict] | dict, **kwargs) -> (
Base | CompositeInsert | List[Base] | List[CompositeInsert]):
"""CREATE, Handle list and single case."""
f = self._create_many if isinstance(data, list) else self._create_one
return await f(data, **kwargs)
Expand Down
6 changes: 3 additions & 3 deletions src/biodm/components/services/kcservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ def kc(self):
return self.app.kc

@abstractmethod
async def _read_or_create(self, **kwargs):
async def _read_or_create(self, **kwargs) -> str:
"""Try to read from DB, create on keycloak side if not present. Return id."""
raise NotImplementedError


class KCGroupService(KCService):
async def _read_or_create(self, data: dict):
async def _read_or_create(self, data: dict) -> str:
try:
return (await self.read(data["name"])).id
except FailedRead:
Expand Down Expand Up @@ -53,7 +53,7 @@ async def delete(self, pk_val) -> Any:


class KCUserService(KCService):
async def _read_or_create(self, data, groups=[], group_ids=[]):
async def _read_or_create(self, data, groups=[], group_ids=[]) -> str:
try:
user = await self.read(data["username"])
for gid in group_ids:
Expand Down
Loading

0 comments on commit 4600c9a

Please sign in to comment.