Skip to content

Commit

Permalink
typing: add middleware protocols
Browse files Browse the repository at this point in the history
  • Loading branch information
jkmnt committed Oct 26, 2024
1 parent 6d7a45b commit f471050
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 25 deletions.
144 changes: 144 additions & 0 deletions falcon/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,147 @@ def __call__(self, media: Any, content_type: Optional[str] = ...) -> bytes: ...
DeserializeSync = Callable[[bytes], Any]

Responder = Union[ResponderMethod, AsgiResponderMethod]


# Middleware
class MiddlewareWithProcessRequest(Protocol):
"""WSGI Middleware with request handler"""

def process_request(self, req: Request, resp: Response) -> None: ...


class MiddlewareWithProcessResource(Protocol):
"""WSGI Middleware with resource handler"""

def process_resource(
self,
req: Request,
resp: Response,
resource: object,
params: Dict[str, Any],
) -> None: ...


class MiddlewareWithProcessResponse(Protocol):
"""WSGI Middleware with response handler"""

def process_response(
self, req: Request, resp: Response, resource: object, req_succeeded: bool
) -> None: ...


class AsgiMiddlewareWithProcessStartup(Protocol):
"""ASGI middleware with startup handler"""

async def process_startup(
self, scope: Mapping[str, Any], event: Mapping[str, Any]
) -> None: ...


class AsgiMiddlewareWithProcessShutdown(Protocol):
"""ASGI middleware with shutdown handler"""

async def process_shutdown(
self, scope: Mapping[str, Any], event: Mapping[str, Any]
) -> None: ...


class AsgiMiddlewareWithProcessRequest(Protocol):
"""ASGI middleware with request handler"""

async def process_request(self, req: AsgiRequest, resp: AsgiResponse) -> None: ...


class AsgiMiddlewareWithProcessResource(Protocol):
"""ASGI middleware with resource handler"""

async def process_resource(
self,
req: AsgiRequest,
resp: AsgiResponse,
resource: object,
params: Mapping[str, Any],
) -> None: ...


class AsgiMiddlewareWithProcessResponse(Protocol):
"""ASGI middleware with response handler"""

async def process_response(
self,
req: AsgiRequest,
resp: AsgiResponse,
resource: object,
req_succeeded: bool,
) -> None: ...


class MiddlewareWithAsyncProcessRequestWs(Protocol):
"""ASGI middleware with WebSocket request handler"""

async def process_request_ws(self, req: AsgiRequest, ws: WebSocket) -> None: ...


class MiddlewareWithAsyncProcessResourceWs(Protocol):
"""ASGI middleware with WebSocket resource handler"""

async def process_resource_ws(
self,
req: AsgiRequest,
ws: WebSocket,
resource: object,
params: Mapping[str, Any],
) -> None: ...


class UniversalMiddlewareWithProcessRequest(MiddlewareWithProcessRequest, Protocol):
"""WSGI/ASGI middleware with request handler"""

async def process_request_async(
self, req: AsgiRequest, resp: AsgiResponse
) -> None: ...


class UniversalMiddlewareWithProcessResource(MiddlewareWithProcessResource, Protocol):
"""WSGI/ASGI middleware with resource handler"""

async def process_resource_async(
self,
req: AsgiRequest,
resp: AsgiResponse,
resource: object,
params: Mapping[str, Any],
) -> None: ...


class UniversalMiddlewareWithProcessResponse(MiddlewareWithProcessResponse, Protocol):
"""WSGI/ASGI middleware with response handler"""

async def process_response_async(
self,
req: AsgiRequest,
resp: AsgiResponse,
resource: object,
req_succeeded: bool,
) -> None: ...


# NOTE(jkmnt): This typing is far from perfect due to the Python typing limitations,
# but better than nothing. Middleware conforming to any protocol of the union
# will pass the type check. Other protocols violations are not checked.
Middleware = Union[
MiddlewareWithProcessRequest,
MiddlewareWithProcessResource,
MiddlewareWithProcessResponse,
]

AsgiMiddleware = Union[
AsgiMiddlewareWithProcessRequest,
AsgiMiddlewareWithProcessResource,
AsgiMiddlewareWithProcessResponse,
AsgiMiddlewareWithProcessStartup,
AsgiMiddlewareWithProcessShutdown,
UniversalMiddlewareWithProcessRequest,
UniversalMiddlewareWithProcessResource,
UniversalMiddlewareWithProcessResponse,
]
29 changes: 16 additions & 13 deletions falcon/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from falcon._typing import ErrorHandler
from falcon._typing import ErrorSerializer
from falcon._typing import FindMethod
from falcon._typing import Middleware
from falcon._typing import ProcessResponseMethod
from falcon._typing import ResponderCallable
from falcon._typing import SinkCallable
Expand Down Expand Up @@ -286,7 +287,7 @@ def process_response(
_static_routes: List[
Tuple[routing.StaticRoute, routing.StaticRoute, Literal[False]]
]
_unprepared_middleware: List[object]
_unprepared_middleware: List[Middleware]

# Attributes
req_options: RequestOptions
Expand All @@ -305,7 +306,7 @@ def __init__(
media_type: str = constants.DEFAULT_MEDIA_TYPE,
request_type: Optional[Type[Request]] = None,
response_type: Optional[Type[Response]] = None,
middleware: Union[object, Iterable[object]] = None,
middleware: Optional[Union[Middleware, Iterable[Middleware]]] = None,
router: Optional[routing.CompiledRouter] = None,
independent_middleware: bool = True,
cors_enable: bool = False,
Expand All @@ -327,17 +328,17 @@ def __init__(
# NOTE(kgriffs): Check to see if middleware is an
# iterable, and if so, append the CORSMiddleware
# instance.
middleware = list(middleware) # type: ignore[arg-type]
middleware.append(cm) # type: ignore[arg-type]
middleware = list(cast(Iterable[Middleware], middleware))
middleware.append(cm)
except TypeError:
# NOTE(kgriffs): Assume the middleware kwarg references
# a single middleware component.
middleware = [middleware, cm]
middleware = [cast(Middleware, middleware), cm]

# set middleware
self._unprepared_middleware = []
self._independent_middleware = independent_middleware
self.add_middleware(middleware)
self.add_middleware(middleware or [])

self._router = router or routing.DefaultRouter()
self._router_search = self._router.find
Expand Down Expand Up @@ -524,7 +525,9 @@ def router_options(self) -> routing.CompiledRouterOptions:
"""
return self._router.options

def add_middleware(self, middleware: Union[object, Iterable[object]]) -> None:
def add_middleware(
self, middleware: Union[Middleware, Iterable[Middleware]]
) -> None:
"""Add one or more additional middleware components.
Arguments:
Expand All @@ -535,20 +538,20 @@ def add_middleware(self, middleware: Union[object, Iterable[object]]) -> None:
"""

# NOTE(kgriffs): Since this is called by the initializer, there is
# the chance that middleware may be None.
# the chance that middleware may be empty.
if middleware:
try:
middleware = list(middleware) # type: ignore[call-overload]
middleware = list(cast(Iterable[Middleware], middleware))
except TypeError:
# middleware is not iterable; assume it is just one bare component
middleware = [middleware]
middleware = [cast(Middleware, middleware)]

if (
self._cors_enable
and len(
[
mc
for mc in self._unprepared_middleware + middleware # type: ignore[operator]
for mc in self._unprepared_middleware + middleware
if isinstance(mc, CORSMiddleware)
]
)
Expand All @@ -559,7 +562,7 @@ def add_middleware(self, middleware: Union[object, Iterable[object]]) -> None:
'cors_enable (which already constructs one instance)'
)

self._unprepared_middleware += middleware # type: ignore[arg-type]
self._unprepared_middleware += middleware

# NOTE(kgriffs): Even if middleware is None or an empty list, we still
# need to make sure self._middleware is initialized if this is the
Expand Down Expand Up @@ -1012,7 +1015,7 @@ def my_serializer(
# ------------------------------------------------------------------------

def _prepare_middleware(
self, middleware: List[object], independent_middleware: bool = False
self, middleware: List[Middleware], independent_middleware: bool = False
) -> helpers.PreparedMiddlewareResult:
return helpers.prepare_middleware(
middleware=middleware, independent_middleware=independent_middleware
Expand Down
19 changes: 14 additions & 5 deletions falcon/app_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
from typing import IO, Iterable, List, Literal, Optional, overload, Tuple, Union

from falcon import util
from falcon._typing import AsgiMiddleware
from falcon._typing import AsgiProcessRequestMethod as APRequest
from falcon._typing import AsgiProcessRequestWsMethod
from falcon._typing import AsgiProcessResourceMethod as APResource
from falcon._typing import AsgiProcessResourceWsMethod
from falcon._typing import AsgiProcessResponseMethod as APResponse
from falcon._typing import Middleware
from falcon._typing import ProcessRequestMethod as PRequest
from falcon._typing import ProcessResourceMethod as PResource
from falcon._typing import ProcessResponseMethod as PResponse
Expand Down Expand Up @@ -62,24 +64,31 @@

@overload
def prepare_middleware(
middleware: Iterable, independent_middleware: bool = ..., asgi: Literal[False] = ...
middleware: Iterable[Middleware],
independent_middleware: bool = ...,
asgi: Literal[False] = ...,
) -> PreparedMiddlewareResult: ...


@overload
def prepare_middleware(
middleware: Iterable, independent_middleware: bool = ..., *, asgi: Literal[True]
middleware: Iterable[AsgiMiddleware],
independent_middleware: bool = ...,
*,
asgi: Literal[True],
) -> AsyncPreparedMiddlewareResult: ...


@overload
def prepare_middleware(
middleware: Iterable, independent_middleware: bool = ..., asgi: bool = ...
middleware: Union[Iterable[Middleware], Iterable[AsgiMiddleware]],
independent_middleware: bool = ...,
asgi: bool = ...,
) -> Union[PreparedMiddlewareResult, AsyncPreparedMiddlewareResult]: ...


def prepare_middleware(
middleware: Iterable[object],
middleware: Union[Iterable[Middleware], Iterable[AsgiMiddleware]],
independent_middleware: bool = False,
asgi: bool = False,
) -> Union[PreparedMiddlewareResult, AsyncPreparedMiddlewareResult]:
Expand Down Expand Up @@ -214,7 +223,7 @@ def prepare_middleware(


def prepare_middleware_ws(
middleware: Iterable[object],
middleware: Iterable[AsgiMiddleware],
) -> AsyncPreparedMiddlewareWsResult:
"""Check middleware interfaces and prepare WebSocket methods for request handling.
Expand Down
8 changes: 5 additions & 3 deletions falcon/asgi/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from falcon import routing
from falcon._typing import _UNSET
from falcon._typing import AsgiErrorHandler
from falcon._typing import AsgiMiddleware
from falcon._typing import AsgiReceive
from falcon._typing import AsgiResponderCallable
from falcon._typing import AsgiResponderWsCallable
Expand Down Expand Up @@ -356,6 +357,7 @@ async def process_resource_ws(
_middleware_ws: AsyncPreparedMiddlewareWsResult
_request_type: Type[Request]
_response_type: Type[Response]
_unprepared_middleware: List[AsgiMiddleware] # type: ignore[assignment]

ws_options: WebSocketOptions
"""A set of behavioral options related to WebSocket connections.
Expand All @@ -368,7 +370,7 @@ def __init__(
media_type: str = constants.DEFAULT_MEDIA_TYPE,
request_type: Optional[Type[Request]] = None,
response_type: Optional[Type[Response]] = None,
middleware: Union[object, Iterable[object]] = None,
middleware: Optional[Union[AsgiMiddleware, Iterable[AsgiMiddleware]]] = None,
router: Optional[routing.CompiledRouter] = None,
independent_middleware: bool = True,
cors_enable: bool = False,
Expand All @@ -378,7 +380,7 @@ def __init__(
media_type,
request_type or Request,
response_type or Response,
middleware,
middleware, # type: ignore[arg-type]
router,
independent_middleware,
cors_enable,
Expand Down Expand Up @@ -1163,7 +1165,7 @@ async def _handle_websocket(
raise

def _prepare_middleware( # type: ignore[override]
self, middleware: List[object], independent_middleware: bool = False
self, middleware: List[AsgiMiddleware], independent_middleware: bool = False
) -> AsyncPreparedMiddlewareResult:
self._middleware_ws = prepare_middleware_ws(middleware)

Expand Down
17 changes: 13 additions & 4 deletions falcon/middleware.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

from typing import Any, Iterable, Optional, Union
from typing import Iterable, Optional, Union

from ._typing import UniversalMiddlewareWithProcessResponse
from .asgi.request import Request as AsgiRequest
from .asgi.response import Response as AsgiResponse
from .request import Request
from .response import Response


class CORSMiddleware(object):
class CORSMiddleware(UniversalMiddlewareWithProcessResponse):
"""CORS Middleware.
This middleware provides a simple out-of-the box CORS policy, including handling
Expand Down Expand Up @@ -141,5 +144,11 @@ def process_response(
resp.set_header('Access-Control-Allow-Headers', allow_headers)
resp.set_header('Access-Control-Max-Age', '86400') # 24 hours

async def process_response_async(self, *args: Any) -> None:
self.process_response(*args)
async def process_response_async(
self,
req: AsgiRequest,
resp: AsgiResponse,
resource: object,
req_succeeded: bool,
) -> None:
self.process_response(req, resp, resource, req_succeeded)

0 comments on commit f471050

Please sign in to comment.