From c800cf84522d03759043c0488f260eab1077bdf2 Mon Sep 17 00:00:00 2001 From: Giovanni Barillari Date: Wed, 8 Jan 2025 18:39:12 +0100 Subject: [PATCH] Add missing overrides in protos' dynamic handlers --- emmett/asgi/handlers.py | 52 +++++++++++++++++++++++++++++++++--- emmett/rsgi/handlers.py | 58 +++++++++++++++++++++++++++++++++++------ 2 files changed, 99 insertions(+), 11 deletions(-) diff --git a/emmett/asgi/handlers.py b/emmett/asgi/handlers.py index 0be20684..419ac673 100644 --- a/emmett/asgi/handlers.py +++ b/emmett/asgi/handlers.py @@ -15,12 +15,12 @@ from importlib import resources from typing import Awaitable, Callable -from emmett_core.http.response import HTTPBytesResponse, HTTPResponse -from emmett_core.protocols.asgi.handlers import HTTPHandler as _HTTPHandler, WSHandler as _WSHandler +from emmett_core.http.response import HTTPBytesResponse, HTTPResponse, HTTPStringResponse +from emmett_core.protocols.asgi.handlers import HTTPHandler as _HTTPHandler, RequestCancelled, WSHandler as _WSHandler from emmett_core.protocols.asgi.typing import Receive, Scope, Send from emmett_core.utils import cachedprop -from ..ctx import current +from ..ctx import RequestContext, WSContext, current from ..debug import debug_handler, smart_traceback from ..libs.contenttype import contenttype from ..wrappers.response import Response @@ -70,7 +70,53 @@ async def _debug_handler(self) -> str: current.response.headers._data["content-type"] = "text/html; charset=utf-8" return debug_handler(smart_traceback(self.app)) + async def dynamic_handler(self, scope: Scope, receive: Receive, send: Send) -> HTTPResponse: + request = Request( + scope, + receive, + send, + max_content_length=self.app.config.request_max_content_length, + max_multipart_size=self.app.config.request_multipart_max_size, + body_timeout=self.app.config.request_body_timeout, + ) + response = Response() + ctx = RequestContext(self.app, request, response) + ctx_token = current._init_(ctx) + try: + http = await self.router.dispatch(request, response) + except HTTPResponse as http_exception: + http = http_exception + #: render error with handlers if in app + error_handler = self.app.error_handlers.get(http.status_code) + if error_handler: + http = HTTPStringResponse( + http.status_code, await error_handler(), headers=response.headers, cookies=response.cookies + ) + except RequestCancelled: + raise + except Exception: + self.app.log.exception("Application exception:") + http = HTTPStringResponse(500, await self.error_handler(), headers=response.headers) + finally: + current._close_(ctx_token) + return http + + async def _exception_handler(self) -> str: + current.response.headers._data["content-type"] = "text/plain" + return "Internal error" + class WSHandler(_WSHandler): __slots__ = [] wrapper_cls = Websocket + + async def dynamic_handler(self, scope: Scope, send: Send): + ctx = WSContext(self.app, Websocket(scope, scope["emt.input"].get, send)) + ctx_token = current._init_(ctx) + try: + await self.router.dispatch(ctx.websocket) + finally: + if not scope.get("emt._flow_cancel", False) and ctx.websocket._accepted: + await send({"type": "websocket.close", "code": 1000}) + scope["emt._ws_closed"] = True + current._close_(ctx_token) diff --git a/emmett/rsgi/handlers.py b/emmett/rsgi/handlers.py index de7373ec..83175873 100644 --- a/emmett/rsgi/handlers.py +++ b/emmett/rsgi/handlers.py @@ -11,18 +11,15 @@ from __future__ import annotations +import asyncio import os from typing import Awaitable, Callable -from emmett_core.http.response import HTTPResponse -from emmett_core.protocols.rsgi.handlers import HTTPHandler as _HTTPHandler, WSHandler as _WSHandler +from emmett_core.http.response import HTTPResponse, HTTPStringResponse +from emmett_core.protocols.rsgi.handlers import HTTPHandler as _HTTPHandler, WSHandler as _WSHandler, WSTransport from emmett_core.utils import cachedprop -from granian.rsgi import ( - HTTPProtocol, - Scope, -) -from ..ctx import current +from ..ctx import RequestContext, WSContext, current from ..debug import debug_handler, smart_traceback from ..wrappers.response import Response from .wrappers import Request, Websocket @@ -37,7 +34,7 @@ class HTTPHandler(_HTTPHandler): def error_handler(self) -> Callable[[], Awaitable[str]]: return self._debug_handler if self.app.debug else self.exception_handler - def _static_handler(self, scope: Scope, protocol: HTTPProtocol, path: str) -> Awaitable[HTTPResponse]: + def _static_handler(self, scope, protocol, path: str) -> Awaitable[HTTPResponse]: #: handle internal assets if path.startswith("/__emmett__"): file_name = path[12:] @@ -57,6 +54,51 @@ async def _debug_handler(self) -> str: current.response.headers._data["content-type"] = "text/html; charset=utf-8" return debug_handler(smart_traceback(self.app)) + async def dynamic_handler(self, scope, protocol, path: str) -> HTTPResponse: + request = Request( + scope, + path, + protocol, + max_content_length=self.app.config.request_max_content_length, + max_multipart_size=self.app.config.request_multipart_max_size, + body_timeout=self.app.config.request_body_timeout, + ) + response = Response() + ctx = RequestContext(self.app, request, response) + ctx_token = current._init_(ctx) + try: + http = await self.router.dispatch(request, response) + except HTTPResponse as http_exception: + http = http_exception + #: render error with handlers if in app + error_handler = self.app.error_handlers.get(http.status_code) + if error_handler: + http = HTTPStringResponse( + http.status_code, await error_handler(), headers=response.headers, cookies=response.cookies + ) + except Exception: + self.app.log.exception("Application exception:") + http = HTTPStringResponse(500, await self.error_handler(), headers=response.headers) + finally: + current._close_(ctx_token) + return http + class WSHandler(_WSHandler): wrapper_cls = Websocket + + async def dynamic_handler(self, scope, transport: WSTransport, path: str): + ctx = WSContext(self.app, Websocket(scope, path, transport)) + ctx_token = current._init_(ctx) + try: + await self.router.dispatch(ctx.websocket) + except HTTPResponse as http: + transport.status = http.status_code + except asyncio.CancelledError: + if not transport.interrupted: + self.app.log.exception("Application exception:") + except Exception: + transport.status = 500 + self.app.log.exception("Application exception:") + finally: + current._close_(ctx_token)