diff --git a/strawberry/experimental/pydantic/_compat.py b/strawberry/experimental/pydantic/_compat.py index 7a5f776a4d..dc3c2fe08f 100644 --- a/strawberry/experimental/pydantic/_compat.py +++ b/strawberry/experimental/pydantic/_compat.py @@ -8,7 +8,6 @@ import pydantic from pydantic import BaseModel from pydantic.version import VERSION as PYDANTIC_VERSION - from strawberry.experimental.pydantic.exceptions import UnsupportedTypeError if TYPE_CHECKING: @@ -134,7 +133,7 @@ def get_model_fields(self, model: Type[BaseModel]) -> Dict[str, CompatModelField type_=field.annotation, outer_type_=field.annotation, default=field.default, - default_factory=field.default_factory, # type: ignore + default_factory=field.default_factory, required=field.is_required(), alias=field.alias, # v2 doesn't have allow_none diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index e3f5f469ff..d10b688003 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -13,6 +13,7 @@ Mapping, Optional, Tuple, + Type, Union, cast, overload, @@ -117,14 +118,18 @@ class AsyncBaseHTTPView( request_adapter_class: Callable[[Request], AsyncHTTPRequestAdapter] websocket_adapter_class: Callable[ [ - "AsyncBaseHTTPView[Request, Response, SubResponse, WebSocketRequest, WebSocketResponse, Context, RootValue]", + "AsyncBaseHTTPView[Any, Any, Any, Any, Any, Context, RootValue]", WebSocketRequest, WebSocketResponse, ], AsyncWebSocketAdapter, ] - graphql_transport_ws_handler_class = BaseGraphQLTransportWSHandler - graphql_ws_handler_class = BaseGraphQLWSHandler + graphql_transport_ws_handler_class: Type[ + BaseGraphQLTransportWSHandler[Context, RootValue] + ] = BaseGraphQLTransportWSHandler[Context, RootValue] + graphql_ws_handler_class: Type[BaseGraphQLWSHandler[Context, RootValue]] = ( + BaseGraphQLWSHandler[Context, RootValue] + ) @property @abc.abstractmethod @@ -285,8 +290,8 @@ async def run( await self.graphql_transport_ws_handler_class( view=self, websocket=websocket, - context=context, - root_value=root_value, + context=context, # type: ignore + root_value=root_value, # type: ignore schema=self.schema, debug=self.debug, connection_init_wait_timeout=self.connection_init_wait_timeout, @@ -295,8 +300,8 @@ async def run( await self.graphql_ws_handler_class( view=self, websocket=websocket, - context=context, - root_value=root_value, + context=context, # type: ignore + root_value=root_value, # type: ignore schema=self.schema, debug=self.debug, keep_alive=self.keep_alive, diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index 647ab7ab3c..6c29f9faeb 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -5,8 +5,10 @@ from contextlib import suppress from typing import ( TYPE_CHECKING, + Any, Awaitable, Dict, + Generic, List, Optional, cast, @@ -20,6 +22,7 @@ NonTextMessageReceived, WebSocketDisconnected, ) +from strawberry.http.typevars import Context, RootValue from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( CompleteMessage, ConnectionInitMessage, @@ -44,15 +47,15 @@ from strawberry.schema.subscribe import SubscriptionResult -class BaseGraphQLTransportWSHandler: +class BaseGraphQLTransportWSHandler(Generic[Context, RootValue]): task_logger: logging.Logger = logging.getLogger("strawberry.ws.task") def __init__( self, - view: AsyncBaseHTTPView, + view: AsyncBaseHTTPView[Any, Any, Any, Any, Any, Context, RootValue], websocket: AsyncWebSocketAdapter, - context: object, - root_value: object, + context: Context, + root_value: RootValue, schema: BaseSchema, debug: bool, connection_init_wait_timeout: timedelta, @@ -68,7 +71,7 @@ def __init__( self.connection_init_received = False self.connection_acknowledged = False self.connection_timed_out = False - self.operations: Dict[str, Operation] = {} + self.operations: Dict[str, Operation[Context, RootValue]] = {} self.completed_tasks: List[asyncio.Task] = [] async def handle(self) -> None: @@ -184,6 +187,8 @@ async def handle_connection_init(self, message: ConnectionInitMessage) -> None: elif hasattr(self.context, "connection_params"): self.context.connection_params = payload + self.context = cast(Context, self.context) + try: connection_ack_payload = await self.view.on_ws_connect(self.context) except ConnectionRejectionError: @@ -250,7 +255,7 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None: operation.task = asyncio.create_task(self.run_operation(operation)) self.operations[message["id"]] = operation - async def run_operation(self, operation: Operation) -> None: + async def run_operation(self, operation: Operation[Context, RootValue]) -> None: """The operation task's top level method. Cleans-up and de-registers the operation once it is done.""" # TODO: Handle errors in this method using self.handle_task_exception() @@ -334,7 +339,7 @@ async def reap_completed_tasks(self) -> None: await task -class Operation: +class Operation(Generic[Context, RootValue]): """A class encapsulating a single operation with its id. Helps enforce protocol state transition.""" __slots__ = [ @@ -350,7 +355,7 @@ class Operation: def __init__( self, - handler: BaseGraphQLTransportWSHandler, + handler: BaseGraphQLTransportWSHandler[Context, RootValue], id: str, operation_type: OperationType, query: str, diff --git a/strawberry/subscriptions/protocols/graphql_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_ws/handlers.py index 95536bc4fd..6f2dcb929d 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_ws/handlers.py @@ -4,14 +4,17 @@ from contextlib import suppress from typing import ( TYPE_CHECKING, + Any, AsyncGenerator, Dict, + Generic, Optional, cast, ) from strawberry.exceptions import ConnectionRejectionError from strawberry.http.exceptions import NonTextMessageReceived, WebSocketDisconnected +from strawberry.http.typevars import Context, RootValue from strawberry.subscriptions.protocols.graphql_ws.types import ( ConnectionInitMessage, ConnectionTerminateMessage, @@ -29,13 +32,16 @@ from strawberry.schema import BaseSchema -class BaseGraphQLWSHandler: +class BaseGraphQLWSHandler(Generic[Context, RootValue]): + context: Context + root_value: RootValue + def __init__( self, - view: AsyncBaseHTTPView, + view: AsyncBaseHTTPView[Any, Any, Any, Any, Any, Context, RootValue], websocket: AsyncWebSocketAdapter, - context: object, - root_value: object, + context: Context, + root_value: RootValue, schema: BaseSchema, debug: bool, keep_alive: bool, @@ -100,6 +106,8 @@ async def handle_connection_init(self, message: ConnectionInitMessage) -> None: elif hasattr(self.context, "connection_params"): self.context.connection_params = payload + self.context = cast(Context, self.context) + try: connection_ack_payload = await self.view.on_ws_connect(self.context) except ConnectionRejectionError as e: