Skip to content

Commit

Permalink
Fixes #90: Add flag to optionally disable fetching DB connection in h…
Browse files Browse the repository at this point in the history
…ttp and ws endpoints
  • Loading branch information
dolamroth committed Aug 19, 2024
1 parent 73bda08 commit 2865c64
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 9 deletions.
1 change: 1 addition & 0 deletions starlette_web/common/authorization/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
class BaseAuthenticationBackend:
openapi_spec = None
openapi_name = "Base"
requires_database = False

def __init__(self, request: HTTPConnection, scope: Scope):
self.request: HTTPConnection = request
Expand Down
22 changes: 22 additions & 0 deletions starlette_web/common/authorization/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def __call__(self, *args, **kwargs):
op1 = self.op1_class(*args, **kwargs)
return self.operator_class(op1)

@property
def requires_database(self):
return self.op1_class.requires_database


class OperandHolder(OperationHolderMixin):
def __init__(self, operator_class, op1_class, op2_class):
Expand All @@ -46,6 +50,10 @@ def __call__(self, *args, **kwargs):
op2 = self.op2_class(*args, **kwargs)
return self.operator_class(op1, op2)

@property
def requires_database(self):
return self.op1_class.requires_database or self.op2_class.requires_database


class AND:
def __init__(self, op1, op2):
Expand All @@ -57,6 +65,10 @@ async def has_permission(self, request: Request, scope: Scope):
await self.op2.has_permission(request, scope)
)

@property
def requires_database(self):
return self.op1.requires_database or self.op2.requires_database


class OR:
def __init__(self, op1, op2):
Expand All @@ -68,6 +80,10 @@ async def has_permission(self, request: Request, scope: Scope):
await self.op2.has_permission(request, scope)
)

@property
def requires_database(self):
return self.op1.requires_database or self.op2.requires_database


class NOT:
def __init__(self, op1):
Expand All @@ -76,12 +92,18 @@ def __init__(self, op1):
async def has_permission(self, request: Request, scope: Scope):
return not (await self.op1.has_permission(request, scope))

@property
def requires_database(self):
return self.op1.requires_database


class BasePermissionMetaclass(OperationHolderMixin, type):
pass


class BasePermission(metaclass=BasePermissionMetaclass):
requires_database = False

async def has_permission(self, request: Request, scope: Scope) -> bool:
raise NotImplementedError

Expand Down
27 changes: 23 additions & 4 deletions starlette_web/common/http/base_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from contextlib import AsyncExitStack
from typing import (
Type, Union, Iterable, ClassVar, Optional, Mapping, List, Awaitable, Dict,
)
Expand Down Expand Up @@ -51,6 +52,7 @@ class BaseHTTPEndpoint(HTTPEndpoint):
response_renderer: ClassVar[Type[BaseRenderer]] = import_string(
settings.DEFAULT_RESPONSE_RENDERER
)
requires_database: ClassVar[bool] = True

async def dispatch(self) -> None:
"""
Expand All @@ -65,18 +67,25 @@ async def dispatch(self) -> None:
handler: Awaitable[Response] = getattr(self, handler_name, self.method_not_allowed)

try:
async with self.app.session_maker() as session:
try:
_requires_database = self._requires_database()

async with AsyncExitStack() as db_stack:
if _requires_database:
session = await db_stack.enter_async_context(self.app.session_maker())
self.request.state.db_session = session
self.db_session = session

try:
await self._authenticate()
await self._check_permissions()

response: Response = await handler(self.request) # noqa
await session.commit()

if _requires_database:
await session.commit()
except Exception as err:
await session.rollback()
if _requires_database:
await session.rollback()
raise err

except (BaseApplicationError, WebargsHTTPException, HTTPException) as err:
Expand Down Expand Up @@ -180,3 +189,13 @@ def _response(
headers=headers,
background=background,
)

def _requires_database(self):
return (
self.requires_database
or self.auth_backend.requires_database
or any([
permission_class.requires_database
for permission_class in self.permission_classes
])
)
29 changes: 24 additions & 5 deletions starlette_web/common/ws/base_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from typing import ClassVar, Type, Any, Optional, List, Dict, Tuple
import sys
from contextlib import AsyncExitStack
from typing import ClassVar, Type, Any, Optional, List, Dict, Tuple

import anyio
from anyio._core._tasks import TaskGroup
Expand Down Expand Up @@ -44,8 +45,30 @@ async def dispatch(self) -> None:
async with anyio.create_task_group() as self.task_group:
await super().dispatch()

def _auth_requires_database(self):
return (
self.auth_backend.requires_database
or any([
permission_class.requires_database
for permission_class in self.permission_classes
])
)

async def _remove_auth_db_session(self, websocket: WebSocket):
del websocket.state.db_session

async def on_connect(self, websocket: WebSocket) -> None:
try:
auth_requires_db = self._auth_requires_database()
async with AsyncExitStack() as db_stack:
if auth_requires_db:
db_session = await db_stack.enter_async_context(self.app.session_maker())
websocket.state.db_session = db_session

# Explicitly clear db_session,
# so that user does not use it through lengthy websocket life-state
db_stack.push_async_callback(self._remove_auth_db_session, websocket)

async with self.app.session_maker() as db_session:
websocket.state.db_session = db_session
self.user = await self._authenticate(websocket)
Expand All @@ -56,10 +79,6 @@ async def on_connect(self, websocket: WebSocket) -> None:
# since accept()/close() have not been called
raise WebSocketDisconnect(code=1006, reason=str(exc)) from exc

# Explicitly clear db_session,
# so that user does not use it through lengthy websocket life-state
del websocket.state.db_session

if permitted:
await self.accept(websocket)
else:
Expand Down
1 change: 1 addition & 0 deletions starlette_web/contrib/auth/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class JWTAuthenticationBackend(BaseAuthenticationBackend):
keyword = "Bearer"
openapi_spec = {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"}
openapi_name = "JWTAuth"
requires_database = True

async def authenticate(self, **kwargs) -> User:
request = self.request
Expand Down

0 comments on commit 2865c64

Please sign in to comment.