Skip to content

Commit

Permalink
Fix mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick91 committed Dec 20, 2024
1 parent e4aa2fa commit 7ee63df
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 21 deletions.
3 changes: 1 addition & 2 deletions strawberry/experimental/pydantic/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import pydantic
from pydantic import BaseModel
from pydantic.version import VERSION as PYDANTIC_VERSION

from strawberry.experimental.pydantic.exceptions import UnsupportedTypeError

if TYPE_CHECKING:
Expand Down Expand Up @@ -134,7 +133,7 @@ def get_model_fields(self, model: Type[BaseModel]) -> Dict[str, CompatModelField
type_=field.annotation,
outer_type_=field.annotation,
default=field.default,
default_factory=field.default_factory, # type: ignore
default_factory=field.default_factory,
required=field.is_required(),
alias=field.alias,
# v2 doesn't have allow_none
Expand Down
19 changes: 12 additions & 7 deletions strawberry/http/async_base_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Mapping,
Optional,
Tuple,
Type,
Union,
cast,
overload,
Expand Down Expand Up @@ -117,14 +118,18 @@ class AsyncBaseHTTPView(
request_adapter_class: Callable[[Request], AsyncHTTPRequestAdapter]
websocket_adapter_class: Callable[
[
"AsyncBaseHTTPView[Request, Response, SubResponse, WebSocketRequest, WebSocketResponse, Context, RootValue]",
"AsyncBaseHTTPView[Any, Any, Any, Any, Any, Context, RootValue]",
WebSocketRequest,
WebSocketResponse,
],
AsyncWebSocketAdapter,
]
graphql_transport_ws_handler_class = BaseGraphQLTransportWSHandler
graphql_ws_handler_class = BaseGraphQLWSHandler
graphql_transport_ws_handler_class: Type[
BaseGraphQLTransportWSHandler[Context, RootValue]
] = BaseGraphQLTransportWSHandler[Context, RootValue]
graphql_ws_handler_class: Type[BaseGraphQLWSHandler[Context, RootValue]] = (
BaseGraphQLWSHandler[Context, RootValue]
)

@property
@abc.abstractmethod
Expand Down Expand Up @@ -285,8 +290,8 @@ async def run(
await self.graphql_transport_ws_handler_class(
view=self,
websocket=websocket,
context=context,
root_value=root_value,
context=context, # type: ignore
root_value=root_value, # type: ignore
schema=self.schema,
debug=self.debug,
connection_init_wait_timeout=self.connection_init_wait_timeout,
Expand All @@ -295,8 +300,8 @@ async def run(
await self.graphql_ws_handler_class(
view=self,
websocket=websocket,
context=context,
root_value=root_value,
context=context, # type: ignore
root_value=root_value, # type: ignore
schema=self.schema,
debug=self.debug,
keep_alive=self.keep_alive,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from contextlib import suppress
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Dict,
Generic,
List,
Optional,
cast,
Expand All @@ -20,6 +22,7 @@
NonTextMessageReceived,
WebSocketDisconnected,
)
from strawberry.http.typevars import Context, RootValue
from strawberry.subscriptions.protocols.graphql_transport_ws.types import (
CompleteMessage,
ConnectionInitMessage,
Expand All @@ -44,15 +47,15 @@
from strawberry.schema.subscribe import SubscriptionResult


class BaseGraphQLTransportWSHandler:
class BaseGraphQLTransportWSHandler(Generic[Context, RootValue]):
task_logger: logging.Logger = logging.getLogger("strawberry.ws.task")

def __init__(
self,
view: AsyncBaseHTTPView,
view: AsyncBaseHTTPView[Any, Any, Any, Any, Any, Context, RootValue],
websocket: AsyncWebSocketAdapter,
context: object,
root_value: object,
context: Context,
root_value: RootValue,
schema: BaseSchema,
debug: bool,
connection_init_wait_timeout: timedelta,
Expand All @@ -68,7 +71,7 @@ def __init__(
self.connection_init_received = False
self.connection_acknowledged = False
self.connection_timed_out = False
self.operations: Dict[str, Operation] = {}
self.operations: Dict[str, Operation[Context, RootValue]] = {}
self.completed_tasks: List[asyncio.Task] = []

async def handle(self) -> None:
Expand Down Expand Up @@ -184,6 +187,8 @@ async def handle_connection_init(self, message: ConnectionInitMessage) -> None:
elif hasattr(self.context, "connection_params"):
self.context.connection_params = payload

self.context = cast(Context, self.context)

try:
connection_ack_payload = await self.view.on_ws_connect(self.context)
except ConnectionRejectionError:
Expand Down Expand Up @@ -250,7 +255,7 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None:
operation.task = asyncio.create_task(self.run_operation(operation))
self.operations[message["id"]] = operation

async def run_operation(self, operation: Operation) -> None:
async def run_operation(self, operation: Operation[Context, RootValue]) -> None:
"""The operation task's top level method. Cleans-up and de-registers the operation once it is done."""
# TODO: Handle errors in this method using self.handle_task_exception()

Expand Down Expand Up @@ -334,7 +339,7 @@ async def reap_completed_tasks(self) -> None:
await task


class Operation:
class Operation(Generic[Context, RootValue]):
"""A class encapsulating a single operation with its id. Helps enforce protocol state transition."""

__slots__ = [
Expand All @@ -350,7 +355,7 @@ class Operation:

def __init__(
self,
handler: BaseGraphQLTransportWSHandler,
handler: BaseGraphQLTransportWSHandler[Context, RootValue],
id: str,
operation_type: OperationType,
query: str,
Expand Down
16 changes: 12 additions & 4 deletions strawberry/subscriptions/protocols/graphql_ws/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
from contextlib import suppress
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Dict,
Generic,
Optional,
cast,
)

from strawberry.exceptions import ConnectionRejectionError
from strawberry.http.exceptions import NonTextMessageReceived, WebSocketDisconnected
from strawberry.http.typevars import Context, RootValue
from strawberry.subscriptions.protocols.graphql_ws.types import (
ConnectionInitMessage,
ConnectionTerminateMessage,
Expand All @@ -29,13 +32,16 @@
from strawberry.schema import BaseSchema


class BaseGraphQLWSHandler:
class BaseGraphQLWSHandler(Generic[Context, RootValue]):
context: Context
root_value: RootValue

def __init__(
self,
view: AsyncBaseHTTPView,
view: AsyncBaseHTTPView[Any, Any, Any, Any, Any, Context, RootValue],
websocket: AsyncWebSocketAdapter,
context: object,
root_value: object,
context: Context,
root_value: RootValue,
schema: BaseSchema,
debug: bool,
keep_alive: bool,
Expand Down Expand Up @@ -100,6 +106,8 @@ async def handle_connection_init(self, message: ConnectionInitMessage) -> None:
elif hasattr(self.context, "connection_params"):
self.context.connection_params = payload

self.context = cast(Context, self.context)

try:
connection_ack_payload = await self.view.on_ws_connect(self.context)
except ConnectionRejectionError as e:
Expand Down

0 comments on commit 7ee63df

Please sign in to comment.