diff --git a/docs/developer_manual/advanced_use.rst b/docs/developer_manual/advanced_use.rst index d855c5f..8937483 100644 --- a/docs/developer_manual/advanced_use.rst +++ b/docs/developer_manual/advanced_use.rst @@ -183,6 +183,8 @@ a ``ResourceController``. .. code-block:: python :caption: s3controller.py + from biodm.routing import Route, PublicRoute + class S3Controller(ResourceController): """Controller for entities involving file management leveraging an S3Service.""" def _infer_svc(self) -> Type[S3Service]: @@ -200,7 +202,7 @@ a ``ResourceController``. prefix = f'{self.prefix}/{self.qp_id}/' file_routes = [ Route(f'{prefix}download', self.download, methods=[HttpMethod.GET]), - Route(f'{prefix}post_success', self.post_success, methods=[HttpMethod.GET]), + PublicRoute(f'{prefix}post_success', self.post_success, methods=[HttpMethod.GET]), ... ] # Set an extra attribute for later. @@ -420,3 +422,20 @@ A lot of that code has to do with retrieving async SQLAlchemy objects attributes # Generate a new form. await self.gen_upload_form(file, session=session) return file + +.. _dev-routing: + +Routing and Auth +---------------- + +As shown in the ``S3Controller`` example above, ``BioDM`` provides two +``Routes`` class: ``PublicRoute`` and ``Route``. + +In case you are defining your own routes you should use those ones instead of +starlette's ``Route``. + +Ultimately, this allows to use the config parameter ``REQUIRE_AUTH`` which when set to ``True`` +will require authentication on all endpoints routed +with simple ``Routes`` while leaving endpoints marked with ``PublicRoute`` public. +This distinction can be important as in the example above, s3 bucket is **not** authenticated +when sending us a successful notice of file upload. diff --git a/docs/developer_manual/demo.rst b/docs/developer_manual/demo.rst index c12e6db..5b78f13 100644 --- a/docs/developer_manual/demo.rst +++ b/docs/developer_manual/demo.rst @@ -21,7 +21,7 @@ we will go over the following minimal example. # Tables class Dataset(bd.components.Versioned, bd.components.Base): - id = Column(Integer, primary_key=True, autoincrement=not 'sqlite' in config.DATABASE_URL) + id = Column(Integer, primary_key=True, autoincrement=not 'sqlite' in str(config.DATABASE_URL)) name : sao.Mapped[str] = sa.Column(sa.String(50), nullable=False) description : sao.Mapped[str] = sa.Column(sa.String(500), nullable=False) username_owner: sao.Mapped[int] = sa.Column(sa.ForeignKey("USER.username"), nullable=False) diff --git a/docs/developer_manual/permissions.rst b/docs/developer_manual/permissions.rst index 982ff3f..ce7412e 100644 --- a/docs/developer_manual/permissions.rst +++ b/docs/developer_manual/permissions.rst @@ -17,6 +17,17 @@ be provided in a ``.env`` file at the same level as your ``demo.py`` script. KC_CLIENT_ID= KC_CLIENT_SECRET= + +Server level: REQUIRE_AUTH +-------------------------- + +Setting ``REQUIRE_AUTH=True`` config argument, will make all routes, except the ones explicitely +marked public (such as ``/login`` and ``/[resources/|]schemas)`` require authentication. + + +See more at :ref:`dev-routing` + + Coarse: Static rule on a Controller endpoint --------------------------------------------- diff --git a/docs/user_manual.rst b/docs/user_manual.rst index 5554837..2f4791c 100644 --- a/docs/user_manual.rst +++ b/docs/user_manual.rst @@ -143,9 +143,11 @@ This triggers creation of a new row with a version increment. .. note:: - ``PUT /release`` is the way of updating versioned resources. - The endpoint ``PUT /`` (a.k.a ``update``) will not be available for such resources, and - any attempt at updating by reference through ``POST /`` will raise an error. + ``POST /release`` is the way of updating versioned resources. + The endpoint ``PUT /`` (a.k.a ``update``) is available, however it is meant to be used + in order to update nested objects and collections of that resource. Thus, + any attempt at updating a versioned resource through either ``PUT /`` or ``POST /`` + shall raise an error. **E.g.** @@ -178,17 +180,16 @@ and followed by: * Use ``nested.field=val`` to select on a nested attribute field * Use ``*`` in a string attribute for wildcards -* operators ``field.op([value])`` +* numeric operators ``field.op([value])`` * ``[lt, le, gt, ge]`` are supported with a value. * ``[min, max]`` are supported without a value -**e.g.** - .. note:: - When querying with ``curl``, don't forget to escape ``&`` symbol or enclose the whole url in quotes, else your scripting language may intepret it as several commands. + When querying with ``curl``, don't forget to escape ``&`` symbol or enclose the whole url + in quotes, else your scripting language may intepret it as several commands. Query a nested collection @@ -368,6 +369,6 @@ otherwise. - Passing a top level group will allow all descending children group for that verb/resource tuple. - - Permissions are taken into account if and only if keyclaok functionalities are enabled. + - Permissions are taken into account if and only if keycloak functionalities are enabled. - Without keycloak, no token exchange -> No way of getting back protected data. diff --git a/src/biodm/api.py b/src/biodm/api.py index 93bdaad..8a10e01 100644 --- a/src/biodm/api.py +++ b/src/biodm/api.py @@ -55,17 +55,10 @@ def __init__(self, app: ASGIApp, server_host: str) -> None: super().__init__(app, self.dispatch) async def dispatch(self, request: Request, call_next: Callable) -> Any: - if request.state.user_info.info: - user_id = request.state.user_info.info[0] - user_groups = request.state.user_info.info[1] - else: - user_id = "anon" - user_groups = ['no_groups'] - endpoint = str(request.url).rsplit(self.server_host, maxsplit=1)[-1] body = await request.body() entry = { - 'user_username': user_id, + 'user_username': request.user.display_name, 'endpoint': endpoint, 'method': request.method, 'content': str(body) if body else "" @@ -84,7 +77,8 @@ async def dispatch(self, request: Request, call_next: Callable) -> Any: # Log timestamp = datetime.now().strftime("%I:%M%p on %B %d, %Y") History.svc.app.logger.info( - f'{timestamp}\t{user_id}\t{",".join(user_groups)}\t' + f'{timestamp}\t' + f'{request.user.display_name}\t{",".join(request.user.groups)}\t' f'{endpoint}\t-\t{request.method}' ) @@ -177,7 +171,7 @@ def __init__( # Middlewares -> Stack goes in reverse order. self.add_middleware(HistoryMiddleware, server_host=config.SERVER_HOST) self.add_middleware(AuthenticationMiddleware) - if self.scope is Scope.PROD: + if Scope.DEBUG not in self.scope: self.add_middleware(TimeoutMiddleware, timeout=config.SERVER_TIMEOUT) # CORS last (i.e. first). self.add_middleware( diff --git a/src/biodm/basics/rootcontroller.py b/src/biodm/basics/rootcontroller.py index 823a28a..03f776b 100644 --- a/src/biodm/basics/rootcontroller.py +++ b/src/biodm/basics/rootcontroller.py @@ -3,12 +3,14 @@ from starlette.requests import Request from starlette.responses import Response, PlainTextResponse -from starlette.routing import Route +# from starlette.routing import Route from biodm import config from biodm.components.controllers import Controller from biodm.utils.security import admin_required, login_required from biodm.utils.utils import json_response +from biodm.routing import Route, PublicRoute + from biodm import tables as bt @@ -16,11 +18,11 @@ class RootController(Controller): """Bundles Routes located at the root of the app i.e. '/'.""" def routes(self, **_): return [ - Route("/live", endpoint=self.live), - Route("/login", endpoint=self.login), - Route("/syn_ack", endpoint=self.syn_ack), + PublicRoute("/live", endpoint=self.live), + PublicRoute("/login", endpoint=self.login), + PublicRoute("/syn_ack", endpoint=self.syn_ack), + PublicRoute("/schema", endpoint=self.openapi_schema), Route("/authenticated", endpoint=self.authenticated), - Route("/schema", endpoint=self.openapi_schema), ] + ( [Route("/kc_sync", endpoint=self.keycloak_sync)] if hasattr(self.app, 'kc') else [] @@ -106,8 +108,8 @@ async def authenticated(self, request: Request) -> Response: description: Unauthorized. """ - assert request.state.user_info.info - user_id, groups = request.state.user_info.info + assert request.user.info + user_id, groups = request.user.info return PlainTextResponse(f"{user_id}, {groups}\n") @admin_required diff --git a/src/biodm/components/controllers/controller.py b/src/biodm/components/controllers/controller.py index 17c8a5c..e8416ea 100644 --- a/src/biodm/components/controllers/controller.py +++ b/src/biodm/components/controllers/controller.py @@ -6,7 +6,6 @@ from io import BytesIO from typing import Any, Iterable, List, Dict, TYPE_CHECKING, Optional -from marshmallow import RAISE from marshmallow.schema import Schema from marshmallow.exceptions import ValidationError from sqlalchemy.exc import MissingGreenlet @@ -17,7 +16,7 @@ from biodm import config from biodm.component import ApiComponent from biodm.exceptions import ( - PayloadJSONDecodingError, AsyncDBError, SchemaError + DataError, PayloadJSONDecodingError, AsyncDBError, SchemaError ) from biodm.utils.utils import json_response @@ -104,6 +103,9 @@ def validate( json_data = json.loads(data) # Accepts **kwargs in case support needed. return cls.schema.load(json_data, many=many, partial=partial) + except ValidationError as ve: + raise DataError(str(ve.messages)) + except json.JSONDecodeError as e: raise PayloadJSONDecodingError(cls.__name__) from e diff --git a/src/biodm/components/controllers/resourcecontroller.py b/src/biodm/components/controllers/resourcecontroller.py index 3275950..13eae1c 100644 --- a/src/biodm/components/controllers/resourcecontroller.py +++ b/src/biodm/components/controllers/resourcecontroller.py @@ -9,7 +9,7 @@ from marshmallow.schema import RAISE from marshmallow.class_registry import get_class from marshmallow.exceptions import RegistryError -from starlette.routing import Mount, Route, BaseRoute +from starlette.routing import Mount, BaseRoute from starlette.requests import Request from starlette.responses import Response @@ -32,6 +32,7 @@ from biodm.utils.utils import json_response from biodm.utils.apispec import register_runtime_schema, process_apispec_docstrings from biodm.components import Base +from biodm.routing import Route, PublicRoute from .controller import HttpMethod, EntityController if TYPE_CHECKING: @@ -199,7 +200,7 @@ def routes(self, **_) -> List[Mount | Route] | List[Mount] | List[BaseRoute]: Route(f"{self.prefix}", self.create, methods=[HttpMethod.POST]), Route(f"{self.prefix}", self.filter, methods=[HttpMethod.GET]), Mount(self.prefix, routes=[ - Route('/schema', self.openapi_schema, methods=[HttpMethod.GET]), + PublicRoute('/schema', self.openapi_schema, methods=[HttpMethod.GET]), Route(f'/{self.qp_id}', self.read, methods=[HttpMethod.GET]), Route(f'/{self.qp_id}/{{attribute}}', self.read, methods=[HttpMethod.GET]), Route(f'/{self.qp_id}', self.delete, methods=[HttpMethod.DELETE]), @@ -321,7 +322,7 @@ async def create(self, request: Request) -> Response: created = await self.svc.write( data=validated_data, stmt_only=False, - user_info=request.state.user_info, + user_info=request.user, serializer=partial(self.serialize, many=isinstance(validated_data, list)) ) return json_response(data=created, status_code=201) @@ -378,14 +379,14 @@ async def read(self, request: Request) -> Response: fields = ctrl._extract_fields( dict(request.query_params), - user_info=request.state.user_info + user_info=request.user ) return json_response( data=await self.svc.read( pk_val=self._extract_pk_val(request), fields=fields, nested_attribute=nested_attribute, - user_info=request.state.user_info, + user_info=request.user, serializer=partial(ctrl.serialize, many=many, only=fields), ), status_code=200, @@ -436,7 +437,7 @@ async def update(self, request: Request) -> Response: data=await self.svc.write( data=validated_data, stmt_only=False, - user_info=request.state.user_info, + user_info=request.user, serializer=partial(self.serialize, many=isinstance(validated_data, list)), ), status_code=201, @@ -464,7 +465,7 @@ async def delete(self, request: Request) -> Response: """ await self.svc.delete( pk_val=self._extract_pk_val(request), - user_info=request.state.user_info, + user_info=request.user, ) return json_response("Deleted.", status_code=200) @@ -496,12 +497,12 @@ async def filter(self, request: Request) -> Response: schema: Schema """ params = dict(request.query_params) - fields = self._extract_fields(params, user_info=request.state.user_info) + fields = self._extract_fields(params, user_info=request.user) return json_response( await self.svc.filter( fields=fields, params=params, - user_info=request.state.user_info, + user_info=request.user, serializer=partial(self.serialize, many=True, only=fields), ), status_code=200, @@ -542,14 +543,14 @@ async def release(self, request: Request) -> Response: fields = self._extract_fields( dict(request.query_params), - user_info=request.state.user_info + user_info=request.user ) return json_response( await self.svc.release( pk_val=self._extract_pk_val(request), update=validated_data, - user_info=request.state.user_info, + user_info=request.user, serializer=partial(self.serialize, many=False, only=fields), ), status_code=200 ) diff --git a/src/biodm/components/controllers/s3controller.py b/src/biodm/components/controllers/s3controller.py index 19d28dc..ab053a5 100644 --- a/src/biodm/components/controllers/s3controller.py +++ b/src/biodm/components/controllers/s3controller.py @@ -3,7 +3,7 @@ from typing import List, Type from marshmallow import Schema, RAISE -from starlette.routing import Route, Mount, BaseRoute +from starlette.routing import Mount, BaseRoute from starlette.requests import Request from starlette.responses import RedirectResponse @@ -14,6 +14,7 @@ from biodm.exceptions import ImplementionError from biodm.utils.security import UserInfo from biodm.utils.utils import json_response +from biodm.routing import PublicRoute, Route from .controller import HttpMethod from .resourcecontroller import ResourceController @@ -50,8 +51,8 @@ def routes(self, **_) -> List[Mount | Route] | List[Mount] | List[BaseRoute]: prefix = f'{self.prefix}/{self.qp_id}/' file_routes = [ Route(f'{prefix}download', self.download, methods=[HttpMethod.GET]), - Route(f'{prefix}post_success', self.post_success, methods=[HttpMethod.GET]), Route(f'{prefix}complete_multipart', self.complete_multipart, methods=[HttpMethod.PUT]), + PublicRoute(f'{prefix}post_success', self.post_success, methods=[HttpMethod.GET]), ] self.post_upload_callback = Path(file_routes[1].path) diff --git a/src/biodm/components/services/dbservice.py b/src/biodm/components/services/dbservice.py index 8fb773c..30b3e84 100644 --- a/src/biodm/components/services/dbservice.py +++ b/src/biodm/components/services/dbservice.py @@ -190,16 +190,14 @@ async def _check_permissions( if not user_info: return - if self._login_required(verb) and not user_info.info: - raise UnauthorizedError("Authentication required.") + if self._login_required(verb) and not user_info.is_authenticated: + raise UnauthorizedError() # Special admin case. if user_info.is_admin: return - groups = user_info.info[1] if user_info.info else [] - - if not self._group_required(verb, groups): + if not self._group_required(verb, user_info.groups): raise UnauthorizedError("Insufficient group privileges for this operation.") perms = self._get_permissions(verb) @@ -252,7 +250,7 @@ async def _check_permissions( # Empty perm list: public. continue - if not self._group_path_matching(set(g.path for g in allowed.groups), set(groups)): + if not self._group_path_matching(set(g.path for g in allowed.groups), set(user_info.groups)): raise UnauthorizedError(f"No {verb} access.") def _apply_read_permissions( @@ -287,15 +285,13 @@ def _apply_read_permissions( if user_info.is_admin: return stmt - groups = user_info.info[1] if user_info.info else [] - # Build nested query to filter permitted results. for permission in perms: lgverb = permission['table'].__table__.c[f'id_{verb}'] # public. perm_stmt = select(permission['table']).where(lgverb == None) - if groups: + if user_info.groups: protected = ( select(permission['table']) .join( @@ -307,7 +303,7 @@ def _apply_read_permissions( .where( or_(*[ # Group path matching. Group.path.like(upper_level + '%') - for upper_level in groups + for upper_level in user_info.groups ]), ) ) @@ -375,12 +371,10 @@ def check_allowed_nested(self, fields: List[str], user_info: UserInfo) -> None: nested, _ = partition(fields, lambda x: x in self.table.relationships) for name in nested: target_svc = self._svc_from_rel_name(name) - if target_svc._login_required("read") and not user_info.info: - raise UnauthorizedError("Authentication required.") - - groups = user_info.info[1] if user_info.info else [] + if target_svc._login_required("read") and not user_info.is_authenticated: + raise UnauthorizedError() - if not self._group_required("read", groups): + if not self._group_required("read", user_info.groups): raise UnauthorizedError(f"Insufficient group privileges to retrieve {name}.") def takeout_unallowed_nested(self, fields: List[str], user_info: UserInfo) -> List[str]: @@ -397,12 +391,10 @@ def takeout_unallowed_nested(self, fields: List[str], user_info: UserInfo) -> Li def ncheck(name): target_svc = self._svc_from_rel_name(name) - if target_svc._login_required("read") and not user_info.info: + if target_svc._login_required("read") and not user_info.is_authenticated: return False - groups = user_info.info[1] if user_info.info else [] - - if not self._group_required("read", groups): + if not self._group_required("read", user_info.groups): return False return True @@ -455,9 +447,9 @@ def gen_upsert_holder( # submitter_username special col elif missing_data == {'submitter_username'} and self.table.has_submitter_username: - if not user_info or not user_info.info: - raise UnauthorizedError("Requires authentication.") - data['submitter_username'] = user_info.info[0] + if not user_info or not user_info.is_authenticated: + raise UnauthorizedError() + data['submitter_username'] = user_info.display_name else: raise DataError(f"{self.table.__name__} missing the following: {missing_data}.") @@ -514,7 +506,7 @@ async def write( """ # SQLite support for composite primary keys, with leading id. if ( - 'sqlite' in config.DATABASE_URL and + 'sqlite' in str(config.DATABASE_URL) and hasattr(self.table, 'id') and len(list(self.table.pk)) > 1 ): @@ -603,7 +595,7 @@ def _restrict_select_on_fields( .svc ._apply_read_permissions(user_info, rel_stmt) ) - # stmt = stmt.join(rel_stmt.subquery(), isouter=True) + stmt = stmt.join_from( self.table, rel_stmt.subquery(), @@ -659,12 +651,10 @@ async def read_nested( # Special cases for nested, as endpoint protection is not enough. target_svc = self._svc_from_rel_name(attribute) - if target_svc._login_required("read") and not user_info.info: - raise UnauthorizedError("Authentication required.") - - groups = user_info.info[1] if user_info.info else [] + if target_svc._login_required("read") and not user_info.is_authenticated: + raise UnauthorizedError() - if not target_svc._group_required("read", groups): + if not target_svc._group_required("read", user_info.groups): raise UnauthorizedError("Insufficient group privileges for this operation.") # Dynamic permissions are covered by read. diff --git a/src/biodm/components/table.py b/src/biodm/components/table.py index dfa2fa6..db4a2bd 100644 --- a/src/biodm/components/table.py +++ b/src/biodm/components/table.py @@ -86,7 +86,7 @@ def is_autoincrement(cls, name: str) -> bool: - https://groups.google.com/g/sqlalchemy/c/o5YQNH5UUko """ # Enforced by DatabaseService.populate_ids_sqlite - if name == 'id' and 'sqlite' in config.DATABASE_URL: + if name == 'id' and 'sqlite' in str(config.DATABASE_URL): return True if cls.__table__.columns[name] is cls.__table__.autoincrement_column: diff --git a/src/biodm/config.py b/src/biodm/config.py index f94ef0b..c733e44 100644 --- a/src/biodm/config.py +++ b/src/biodm/config.py @@ -1,10 +1,15 @@ from starlette.config import Config +from databases import DatabaseURL try: config = Config('.env') except FileNotFoundError: config = Config() +# TODO: [prio medium - before release] +# Change credentials to Secret type +# Avoids leaking them in stacktraces + # Server. API_NAME = config("API_NAME", cast=str, default="biodm_instance") API_VERSION = config("API_VERSION", cast=str, default="0.1.0") @@ -14,6 +19,7 @@ SERVER_PORT = config("SERVER_PORT", cast=int, default=8000) SECRET_KEY = config("SECRET_KEY", cast=str, default="r4nD0m_p455") SERVER_TIMEOUT = config("SERVER_TIMEOUT", cast=int, default=30) +REQUIRE_AUTH = config("REQUIRE_AUTH", cast=bool, default=False) # Responses. INDENT = config('INDENT', cast=int, default=2) diff --git a/src/biodm/error.py b/src/biodm/error.py index 166d10c..383e3bb 100644 --- a/src/biodm/error.py +++ b/src/biodm/error.py @@ -1,8 +1,6 @@ import json from http import HTTPStatus -from marshmallow.exceptions import ValidationError - from biodm.utils.utils import json_response from .exceptions import ( EndpointError, @@ -52,7 +50,7 @@ async def onerror(_, exc): ) match exc: - case ValidationError() | FileTooLargeError(): + case FileTooLargeError(): status = 400 case DataError() | EndpointError() | PayloadJSONDecodingError(): status = 400 diff --git a/src/biodm/exceptions.py b/src/biodm/exceptions.py index 5c7a9fe..cc41c9d 100644 --- a/src/biodm/exceptions.py +++ b/src/biodm/exceptions.py @@ -86,6 +86,8 @@ class PartialIndex(RequestError): class UnauthorizedError(RequestError): """Raised when a request on a group restricted route is sent by an unauthorized user.""" + def __init__(self, detail: str="Authentication required.") -> None: + super().__init__(detail) class ManifestError(RequestError): diff --git a/src/biodm/managers/dbmanager.py b/src/biodm/managers/dbmanager.py index 7f76770..dc38f1b 100644 --- a/src/biodm/managers/dbmanager.py +++ b/src/biodm/managers/dbmanager.py @@ -2,6 +2,7 @@ from contextlib import asynccontextmanager, AsyncExitStack from typing import AsyncGenerator, TYPE_CHECKING, Callable, Any +from databases import DatabaseURL from sqlalchemy import event from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker @@ -19,14 +20,14 @@ class DatabaseManager(ApiManager): """Manages DB side query execution.""" def __init__(self, app: Api) -> None: super().__init__(app=app) - self.database_url: str = self.async_database_url(config.DATABASE_URL) + self._database_url: DatabaseURL = self.async_database_url(config.DATABASE_URL) try: self.engine = create_async_engine( - self.database_url, + str(self._database_url), echo=Scope.DEBUG in app.scope, ) - if "sqlite" in self.database_url: + if "sqlite" in str(self._database_url): event.listens_for(self.engine.sync_engine, "connect")(self.sqlite_declare_strrev) self.async_session = async_sessionmaker( @@ -42,21 +43,23 @@ def endpoint(self): return f"{self.engine.url.host}:{self.engine.url.port}" @staticmethod - def async_database_url(url) -> str: + def async_database_url(url: DatabaseURL) -> str: """Adds a matching async driver to a database url.""" + url = str(url) match url.split("://"): case ["postgresql", _]: - return url.replace( # type: ignore [unreachable] + url = url.replace( # type: ignore [unreachable] "postgresql://", "postgresql+asyncpg://" ) case ["sqlite", _]: - return url.replace( # type: ignore [unreachable] + url = url.replace( # type: ignore [unreachable] "sqlite://", "sqlite+aiosqlite://" ) case _: raise DBError( "Only ['postgresql', 'sqlite'] backends are supported at the moment." ) + return DatabaseURL(url) @asynccontextmanager async def session(self) -> AsyncGenerator[AsyncSession, None]: diff --git a/src/biodm/routing.py b/src/biodm/routing.py new file mode 100644 index 0000000..e83753e --- /dev/null +++ b/src/biodm/routing.py @@ -0,0 +1,58 @@ +from typing import TYPE_CHECKING, Sequence, Callable, Awaitable, Coroutine, Any +import starlette.routing as sr + +from starlette.requests import Request +from starlette.responses import Response +from starlette.middleware import Middleware +from starlette.middleware.base import BaseHTTPMiddleware + + +from biodm import config +from biodm.exceptions import UnauthorizedError + + +if TYPE_CHECKING: + from biodm.utils.security import UserInfo + + +class RequireAuthMiddleware(BaseHTTPMiddleware): + async def dispatch( + self, + request: Request, + call_next: Callable[[Request], Awaitable[Response]] + ) -> Coroutine[Any, Any, Response]: + if not request.user.is_authenticated: + raise UnauthorizedError() + return await call_next(request) + + +class PublicRoute(sr.Route): + """A route explicitely marked public. + So it is not checked for authentication even when server is run + in REQUIRE_AUTH mode.""" + + +class Route(sr.Route): + """Adds a middleware ensure user is authenticated when running server + in REQUIRE_AUTH mode.""" + def __init__( + self, + path: str, + endpoint: Callable[..., Any], + *, + methods: list[str] | None = None, + name: str | None = None, + include_in_schema: bool = True, + middleware: Sequence[Middleware] | None = None + ) -> None: + if config.REQUIRE_AUTH: + middleware = middleware or [] + middleware.append(Middleware(RequireAuthMiddleware)) + super().__init__( + path, + endpoint, + methods=methods, + name=name, + include_in_schema=include_in_schema, + middleware=middleware + ) diff --git a/src/biodm/tables/group.py b/src/biodm/tables/group.py index c67c0a5..b3359c7 100644 --- a/src/biodm/tables/group.py +++ b/src/biodm/tables/group.py @@ -41,7 +41,7 @@ def parent_path(self) -> str: @classmethod def _parent_path(cls) -> SQLColumnExpression[str]: sep = literal('__') - if "postgresql" in config.DATABASE_URL: + if "postgresql" in str(config.DATABASE_URL): return func.substring( cls.path, 0, @@ -52,7 +52,7 @@ def _parent_path(cls) -> SQLColumnExpression[str]: ) ) ) - if "sqlite" in config.DATABASE_URL: + if "sqlite" in str(config.DATABASE_URL): #  sqlite doesn't have reverse #  -> strrev declared in dbmanager #  postgres.position -> sqlite.instr diff --git a/src/biodm/utils/security.py b/src/biodm/utils/security.py index adb7b07..93fa7f0 100644 --- a/src/biodm/utils/security.py +++ b/src/biodm/utils/security.py @@ -4,12 +4,11 @@ from dataclasses import field as dc_field from functools import wraps from inspect import getmembers, ismethod -from typing import TYPE_CHECKING, List, Tuple, Callable, Awaitable, Set, ClassVar, Type, Any, Dict +from typing import TYPE_CHECKING, List, Tuple, Set, ClassVar, Type, Any, Dict from marshmallow import fields, Schema -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.requests import Request -from starlette.responses import Response +from starlette.requests import HTTPConnection +from starlette.types import ASGIApp, Receive, Scope, Send from sqlalchemy import ForeignKeyConstraint, Column, ForeignKey from sqlalchemy.orm import ( relationship, Relationship, backref, ONETOMANY, mapped_column, MappedColumn @@ -24,6 +23,9 @@ from biodm.managers import KeycloakManager +# TODO: [prio: not urgent] +# possible improvement, would be to rewrite the following classes using +# starlette builtins from starlette.middleware.authentication. class UserInfo(aobject): """Hold user info for a given request. @@ -32,8 +34,8 @@ class UserInfo(aobject): kc: 'KeycloakManager' _info: Tuple[str, List, List] | None = None - async def __init__(self, request: Request) -> None: # type: ignore [misc] - self.token = self.auth_header(request) + async def __init__(self, conn: HTTPConnection) -> None: # type: ignore [misc] + self.token = self.auth_header(conn) if self.token: self._info = await self.decode_token(self.token) @@ -42,10 +44,18 @@ def info(self) -> Tuple[str, List, List] | None: """info getter. Returns user_info if the request is authenticated, else None.""" return self._info + @property + def display_name(self): + return self._info[0] if self._info else "anon" + + @property + def groups(self): + return self._info[1] if self._info else ["no_groups"] + @staticmethod - def auth_header(request) -> str | None: + def auth_header(conn: HTTPConnection) -> str | None: """Check and return token from headers if present else returns None.""" - header = request.headers.get("Authorization") + header = conn.headers.get("Authorization") if not header: return None return (header.split("Bearer")[-1] if "Bearer" in header else header).strip() @@ -64,6 +74,10 @@ async def decode_token( ] or ['no_groups'] return username, groups + @property + def is_authenticated(self): + return bool(self._info) + @property def is_admin(self): """token bearer is admin flag""" @@ -72,16 +86,19 @@ def is_admin(self): return 'admin' in self._info[1] -# pylint: disable=too-few-public-methods -class AuthenticationMiddleware(BaseHTTPMiddleware): +class AuthenticationMiddleware: """Handle token decoding for incoming requests, populate request object with result.""" - async def dispatch( - self, - request: Request, - call_next: Callable[[Request], Awaitable[Response]] - ) -> Response: - request.state.user_info = await UserInfo(request) - return await call_next(request) + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] not in ["http", "websocket"]: + await self.app(scope, receive, send) + return + + conn = HTTPConnection(scope) + scope["user"] = await UserInfo(conn) + await self.app(scope, receive, send) def login_required(f): @@ -98,9 +115,9 @@ async def lr_write_wrapper(controller, request, *args, **kwargs): # Else hardcheck here is enough. @wraps(f) async def lr_wrapper(controller, request, *args, **kwargs): - if request.state.user_info.info: + if request.user.info: return await f(controller, request, *args, **kwargs) - raise UnauthorizedError("Authentication required.") + raise UnauthorizedError() # Read is protected on its endpoint and is handled specifically for nested cases in codebase. if f.__name__ == "read": @@ -122,8 +139,8 @@ async def gr_write_wrapper(controller, request, *args, **kwargs): @wraps(f) async def gr_wrapper(controller, request, *args, **kwargs): - if request.state.user_info.info: - _, user_groups = request.state.user_info.info + if request.user.info: + _, user_groups = request.user.info if any((ug in groups for ug in user_groups)): return f(controller, request, *args, **kwargs) diff --git a/src/biodm/utils/sqla.py b/src/biodm/utils/sqla.py index e75e2e5..4bf3b97 100644 --- a/src/biodm/utils/sqla.py +++ b/src/biodm/utils/sqla.py @@ -18,10 +18,10 @@ def _backend_specific_insert() -> Callable[[_DMLTableArgument], Insert]: MariaDB/InnoDB have similar constructs, in case we want to support more backends the to_stmt method from the UpsertStmtValuesHolder class below should be tweaked as well. """ - if 'postgresql' in config.DATABASE_URL.lower(): + if 'postgresql' in str(config.DATABASE_URL).lower(): return postgresql.insert - if 'sqlite' in config.DATABASE_URL.lower(): + if 'sqlite' in str(config.DATABASE_URL).lower(): return sqlite.insert raise # Should not happen. Here to suppress linters. diff --git a/src/example/.env b/src/example/.env index 61c310c..8925c1d 100644 --- a/src/example/.env +++ b/src/example/.env @@ -1,4 +1,6 @@ API_NAME="DWARF_BIODM_PoC" +# REQUIRE_AUTH=True + # PG_USER="postgres" # PG_PASS="pass" # PG_HOST="postgres.local:5432" diff --git a/src/example/entities/controllers/file.py b/src/example/entities/controllers/file.py index a8f4846..6c11652 100644 --- a/src/example/entities/controllers/file.py +++ b/src/example/entities/controllers/file.py @@ -4,10 +4,11 @@ from biodm.components.controllers import S3Controller, HttpMethod from biodm.exceptions import UnauthorizedError from biodm.utils.security import UserInfo +from biodm.routing import Route from starlette.requests import Request from starlette.responses import Response, PlainTextResponse #, RedirectResponse -from starlette.routing import BaseRoute, Mount, Route +from starlette.routing import BaseRoute, Mount from entities import tables @@ -44,13 +45,11 @@ async def visualize(self, request: Request) -> Response: vis_data = {'file_id': int(request.path_params.get('id'))} - user_info = await UserInfo(request) + if not request.user.is_authenticated: + raise UnauthorizedError() - if not user_info.info: - raise UnauthorizedError("Visualizing requires authentication.") + vis_data["user_username"] = request.user.display_name - vis_data["user_username"] = user_info.info[0] - - vis = await vis_svc.write(data=vis_data, stmt_only=False, user_info=user_info) + vis = await vis_svc.write(data=vis_data, stmt_only=False, user_info=request.user) return PlainTextResponse(f"http://{config.K8_HOST}/{vis.name}/") diff --git a/src/example/entities/tables/dataset.py b/src/example/entities/tables/dataset.py index f4edbc1..8dd971a 100644 --- a/src/example/entities/tables/dataset.py +++ b/src/example/entities/tables/dataset.py @@ -15,7 +15,7 @@ class Dataset(Versioned, Base): - id = Column(Integer, primary_key=True, autoincrement=not 'sqlite' in config.DATABASE_URL) + id = Column(Integer, primary_key=True, autoincrement=not 'sqlite' in str(config.DATABASE_URL)) # data fields name: Mapped[str] = mapped_column(String(50), nullable=False) description: Mapped[str] = mapped_column(Text, nullable=True) diff --git a/src/requirements/common.txt b/src/requirements/common.txt index d7afb58..56bd7f1 100644 --- a/src/requirements/common.txt +++ b/src/requirements/common.txt @@ -3,6 +3,7 @@ apispec==6.6.1 asyncpg==0.29.0 boto3==1.34.65 botocore==1.34.65 +databases==0.9.0 marshmallow==3.20.2 python-keycloak==3.9.1 SQLAlchemy==2.0.30 diff --git a/src/tests/unit/test_resource.py b/src/tests/unit/test_resource.py index 7feba00..a5c1a0d 100644 --- a/src/tests/unit/test_resource.py +++ b/src/tests/unit/test_resource.py @@ -234,7 +234,7 @@ def test_update_unary_resource(client): cr_response = client.post('/cs', content=json_bytes(item)) item_id = json.loads(cr_response.text)['id'] - up_response = client.put(f'/cs/{item_id}', data=json_bytes({'data': 'modified'})) + up_response = client.put(f'/cs/{item_id}', content=json_bytes({'data': 'modified'})) json_response = json.loads(up_response.text) assert up_response.status_code == 201 @@ -248,7 +248,7 @@ def test_update_composite_resource(client): item_id = json.loads(cr_response.text)['id'] c_oracle = {'data': 'bop'} - up_response = client.put(f'/as/{item_id}', data=json_bytes( + up_response = client.put(f'/as/{item_id}', content=json_bytes( { 'x': 3, 'c': c_oracle