diff --git a/falcon/_typing.py b/falcon/_typing.py index d82a5bac5..a14a61f02 100644 --- a/falcon/_typing.py +++ b/falcon/_typing.py @@ -196,3 +196,147 @@ def __call__(self, media: Any, content_type: Optional[str] = ...) -> bytes: ... DeserializeSync = Callable[[bytes], Any] Responder = Union[ResponderMethod, AsgiResponderMethod] + + +# Middleware +class MiddlewareWithProcessRequest(Protocol): + """WSGI Middleware with request handler""" + + def process_request(self, req: Request, resp: Response) -> None: ... + + +class MiddlewareWithProcessResource(Protocol): + """WSGI Middleware with resource handler""" + + def process_resource( + self, + req: Request, + resp: Response, + resource: object, + params: Dict[str, Any], + ) -> None: ... + + +class MiddlewareWithProcessResponse(Protocol): + """WSGI Middleware with response handler""" + + def process_response( + self, req: Request, resp: Response, resource: object, req_succeeded: bool + ) -> None: ... + + +class AsgiMiddlewareWithProcessStartup(Protocol): + """ASGI middleware with startup handler""" + + async def process_startup( + self, scope: Mapping[str, Any], event: Mapping[str, Any] + ) -> None: ... + + +class AsgiMiddlewareWithProcessShutdown(Protocol): + """ASGI middleware with shutdown handler""" + + async def process_shutdown( + self, scope: Mapping[str, Any], event: Mapping[str, Any] + ) -> None: ... + + +class AsgiMiddlewareWithProcessRequest(Protocol): + """ASGI middleware with request handler""" + + async def process_request(self, req: AsgiRequest, resp: AsgiResponse) -> None: ... + + +class AsgiMiddlewareWithProcessResource(Protocol): + """ASGI middleware with resource handler""" + + async def process_resource( + self, + req: AsgiRequest, + resp: AsgiResponse, + resource: object, + params: Mapping[str, Any], + ) -> None: ... + + +class AsgiMiddlewareWithProcessResponse(Protocol): + """ASGI middleware with response handler""" + + async def process_response( + self, + req: AsgiRequest, + resp: AsgiResponse, + resource: object, + req_succeeded: bool, + ) -> None: ... + + +class MiddlewareWithAsyncProcessRequestWs(Protocol): + """ASGI middleware with WebSocket request handler""" + + async def process_request_ws(self, req: AsgiRequest, ws: WebSocket) -> None: ... + + +class MiddlewareWithAsyncProcessResourceWs(Protocol): + """ASGI middleware with WebSocket resource handler""" + + async def process_resource_ws( + self, + req: AsgiRequest, + ws: WebSocket, + resource: object, + params: Mapping[str, Any], + ) -> None: ... + + +class UniversalMiddlewareWithProcessRequest(MiddlewareWithProcessRequest, Protocol): + """WSGI/ASGI middleware with request handler""" + + async def process_request_async( + self, req: AsgiRequest, resp: AsgiResponse + ) -> None: ... + + +class UniversalMiddlewareWithProcessResource(MiddlewareWithProcessResource, Protocol): + """WSGI/ASGI middleware with resource handler""" + + async def process_resource_async( + self, + req: AsgiRequest, + resp: AsgiResponse, + resource: object, + params: Mapping[str, Any], + ) -> None: ... + + +class UniversalMiddlewareWithProcessResponse(MiddlewareWithProcessResponse, Protocol): + """WSGI/ASGI middleware with response handler""" + + async def process_response_async( + self, + req: AsgiRequest, + resp: AsgiResponse, + resource: object, + req_succeeded: bool, + ) -> None: ... + + +# NOTE(jkmnt): This typing is far from perfect due to the Python typing limitations, +# but better than nothing. Middleware conforming to any protocol of the union +# will pass the type check. Other protocols violations are not checked. +Middleware = Union[ + MiddlewareWithProcessRequest, + MiddlewareWithProcessResource, + MiddlewareWithProcessResponse, +] + +AsgiMiddleware = Union[ + AsgiMiddlewareWithProcessRequest, + AsgiMiddlewareWithProcessResource, + AsgiMiddlewareWithProcessResponse, + AsgiMiddlewareWithProcessStartup, + AsgiMiddlewareWithProcessShutdown, + UniversalMiddlewareWithProcessRequest, + UniversalMiddlewareWithProcessResource, + UniversalMiddlewareWithProcessResponse, +] diff --git a/falcon/app.py b/falcon/app.py index f74083693..a7752cffd 100644 --- a/falcon/app.py +++ b/falcon/app.py @@ -51,6 +51,7 @@ from falcon._typing import ErrorHandler from falcon._typing import ErrorSerializer from falcon._typing import FindMethod +from falcon._typing import Middleware from falcon._typing import ProcessResponseMethod from falcon._typing import ResponderCallable from falcon._typing import SinkCallable @@ -286,7 +287,7 @@ def process_response( _static_routes: List[ Tuple[routing.StaticRoute, routing.StaticRoute, Literal[False]] ] - _unprepared_middleware: List[object] + _unprepared_middleware: List[Middleware] # Attributes req_options: RequestOptions @@ -305,7 +306,7 @@ def __init__( media_type: str = constants.DEFAULT_MEDIA_TYPE, request_type: Optional[Type[Request]] = None, response_type: Optional[Type[Response]] = None, - middleware: Union[object, Iterable[object]] = None, + middleware: Optional[Union[Middleware, Iterable[Middleware]]] = None, router: Optional[routing.CompiledRouter] = None, independent_middleware: bool = True, cors_enable: bool = False, @@ -327,17 +328,17 @@ def __init__( # NOTE(kgriffs): Check to see if middleware is an # iterable, and if so, append the CORSMiddleware # instance. - middleware = list(middleware) # type: ignore[arg-type] - middleware.append(cm) # type: ignore[arg-type] + middleware = list(cast(Iterable[Middleware], middleware)) + middleware.append(cm) except TypeError: # NOTE(kgriffs): Assume the middleware kwarg references # a single middleware component. - middleware = [middleware, cm] + middleware = [cast(Middleware, middleware), cm] # set middleware self._unprepared_middleware = [] self._independent_middleware = independent_middleware - self.add_middleware(middleware) + self.add_middleware(middleware or []) self._router = router or routing.DefaultRouter() self._router_search = self._router.find @@ -524,7 +525,9 @@ def router_options(self) -> routing.CompiledRouterOptions: """ return self._router.options - def add_middleware(self, middleware: Union[object, Iterable[object]]) -> None: + def add_middleware( + self, middleware: Union[Middleware, Iterable[Middleware]] + ) -> None: """Add one or more additional middleware components. Arguments: @@ -535,20 +538,20 @@ def add_middleware(self, middleware: Union[object, Iterable[object]]) -> None: """ # NOTE(kgriffs): Since this is called by the initializer, there is - # the chance that middleware may be None. + # the chance that middleware may be empty. if middleware: try: - middleware = list(middleware) # type: ignore[call-overload] + middleware = list(cast(Iterable[Middleware], middleware)) except TypeError: # middleware is not iterable; assume it is just one bare component - middleware = [middleware] + middleware = [cast(Middleware, middleware)] if ( self._cors_enable and len( [ mc - for mc in self._unprepared_middleware + middleware # type: ignore[operator] + for mc in self._unprepared_middleware + middleware if isinstance(mc, CORSMiddleware) ] ) @@ -559,7 +562,7 @@ def add_middleware(self, middleware: Union[object, Iterable[object]]) -> None: 'cors_enable (which already constructs one instance)' ) - self._unprepared_middleware += middleware # type: ignore[arg-type] + self._unprepared_middleware += middleware # NOTE(kgriffs): Even if middleware is None or an empty list, we still # need to make sure self._middleware is initialized if this is the @@ -1012,7 +1015,7 @@ def my_serializer( # ------------------------------------------------------------------------ def _prepare_middleware( - self, middleware: List[object], independent_middleware: bool = False + self, middleware: List[Middleware], independent_middleware: bool = False ) -> helpers.PreparedMiddlewareResult: return helpers.prepare_middleware( middleware=middleware, independent_middleware=independent_middleware diff --git a/falcon/app_helpers.py b/falcon/app_helpers.py index 1248d280b..7fbbdae15 100644 --- a/falcon/app_helpers.py +++ b/falcon/app_helpers.py @@ -20,11 +20,13 @@ from typing import IO, Iterable, List, Literal, Optional, overload, Tuple, Union from falcon import util +from falcon._typing import AsgiMiddleware from falcon._typing import AsgiProcessRequestMethod as APRequest from falcon._typing import AsgiProcessRequestWsMethod from falcon._typing import AsgiProcessResourceMethod as APResource from falcon._typing import AsgiProcessResourceWsMethod from falcon._typing import AsgiProcessResponseMethod as APResponse +from falcon._typing import Middleware from falcon._typing import ProcessRequestMethod as PRequest from falcon._typing import ProcessResourceMethod as PResource from falcon._typing import ProcessResponseMethod as PResponse @@ -62,24 +64,31 @@ @overload def prepare_middleware( - middleware: Iterable, independent_middleware: bool = ..., asgi: Literal[False] = ... + middleware: Iterable[Middleware], + independent_middleware: bool = ..., + asgi: Literal[False] = ..., ) -> PreparedMiddlewareResult: ... @overload def prepare_middleware( - middleware: Iterable, independent_middleware: bool = ..., *, asgi: Literal[True] + middleware: Iterable[AsgiMiddleware], + independent_middleware: bool = ..., + *, + asgi: Literal[True], ) -> AsyncPreparedMiddlewareResult: ... @overload def prepare_middleware( - middleware: Iterable, independent_middleware: bool = ..., asgi: bool = ... + middleware: Union[Iterable[Middleware], Iterable[AsgiMiddleware]], + independent_middleware: bool = ..., + asgi: bool = ..., ) -> Union[PreparedMiddlewareResult, AsyncPreparedMiddlewareResult]: ... def prepare_middleware( - middleware: Iterable[object], + middleware: Union[Iterable[Middleware], Iterable[AsgiMiddleware]], independent_middleware: bool = False, asgi: bool = False, ) -> Union[PreparedMiddlewareResult, AsyncPreparedMiddlewareResult]: @@ -214,7 +223,7 @@ def prepare_middleware( def prepare_middleware_ws( - middleware: Iterable[object], + middleware: Iterable[AsgiMiddleware], ) -> AsyncPreparedMiddlewareWsResult: """Check middleware interfaces and prepare WebSocket methods for request handling. diff --git a/falcon/asgi/app.py b/falcon/asgi/app.py index f3b637802..253694e3f 100644 --- a/falcon/asgi/app.py +++ b/falcon/asgi/app.py @@ -42,6 +42,7 @@ from falcon import routing from falcon._typing import _UNSET from falcon._typing import AsgiErrorHandler +from falcon._typing import AsgiMiddleware from falcon._typing import AsgiReceive from falcon._typing import AsgiResponderCallable from falcon._typing import AsgiResponderWsCallable @@ -356,6 +357,7 @@ async def process_resource_ws( _middleware_ws: AsyncPreparedMiddlewareWsResult _request_type: Type[Request] _response_type: Type[Response] + _unprepared_middleware: List[AsgiMiddleware] # type: ignore[assignment] ws_options: WebSocketOptions """A set of behavioral options related to WebSocket connections. @@ -368,7 +370,7 @@ def __init__( media_type: str = constants.DEFAULT_MEDIA_TYPE, request_type: Optional[Type[Request]] = None, response_type: Optional[Type[Response]] = None, - middleware: Union[object, Iterable[object]] = None, + middleware: Optional[Union[AsgiMiddleware, Iterable[AsgiMiddleware]]] = None, router: Optional[routing.CompiledRouter] = None, independent_middleware: bool = True, cors_enable: bool = False, @@ -378,7 +380,7 @@ def __init__( media_type, request_type or Request, response_type or Response, - middleware, + middleware, # type: ignore[arg-type] router, independent_middleware, cors_enable, @@ -1163,7 +1165,7 @@ async def _handle_websocket( raise def _prepare_middleware( # type: ignore[override] - self, middleware: List[object], independent_middleware: bool = False + self, middleware: List[AsgiMiddleware], independent_middleware: bool = False ) -> AsyncPreparedMiddlewareResult: self._middleware_ws = prepare_middleware_ws(middleware) diff --git a/falcon/middleware.py b/falcon/middleware.py index d457a44b8..e2b81ad33 100644 --- a/falcon/middleware.py +++ b/falcon/middleware.py @@ -1,12 +1,15 @@ from __future__ import annotations -from typing import Any, Iterable, Optional, Union +from typing import Iterable, Optional, Union +from ._typing import UniversalMiddlewareWithProcessResponse +from .asgi.request import Request as AsgiRequest +from .asgi.response import Response as AsgiResponse from .request import Request from .response import Response -class CORSMiddleware(object): +class CORSMiddleware(UniversalMiddlewareWithProcessResponse): """CORS Middleware. This middleware provides a simple out-of-the box CORS policy, including handling @@ -141,5 +144,11 @@ def process_response( resp.set_header('Access-Control-Allow-Headers', allow_headers) resp.set_header('Access-Control-Max-Age', '86400') # 24 hours - async def process_response_async(self, *args: Any) -> None: - self.process_response(*args) + async def process_response_async( + self, + req: AsgiRequest, + resp: AsgiResponse, + resource: object, + req_succeeded: bool, + ) -> None: + self.process_response(req, resp, resource, req_succeeded)