From 51eb183a317722480bc08cfc1032bfdad048e253 Mon Sep 17 00:00:00 2001 From: Alex Auritt Date: Wed, 17 May 2023 14:07:30 -0700 Subject: [PATCH 001/119] Ann003 (#2767) * Support ruff ann-003 checks * Delete RELEASE.md --------- Co-authored-by: Patrick Arminio --- federation-compatibility/schema.py | 10 +++--- pyproject.toml | 1 - strawberry/channels/handlers/http_handler.py | 2 +- strawberry/channels/handlers/ws_handler.py | 2 +- strawberry/channels/testing.py | 2 +- .../experimental/pydantic/conversion_types.py | 4 +-- .../experimental/pydantic/object_type.py | 2 +- strawberry/extensions/field_extension.py | 6 ++-- strawberry/permission.py | 2 +- strawberry/schema/schema_converter.py | 8 +++-- tests/aiohttp/app.py | 4 ++- tests/asgi/app.py | 4 +-- tests/experimental/pydantic/test_basic.py | 2 +- tests/fastapi/app.py | 2 +- tests/fields/test_permissions.py | 4 +-- tests/http/clients/aiohttp.py | 2 +- tests/http/clients/asgi.py | 2 +- tests/http/clients/chalice.py | 4 +-- tests/http/clients/channels.py | 2 +- tests/http/clients/django.py | 4 +-- tests/http/clients/flask.py | 6 ++-- tests/http/clients/sanic.py | 2 +- tests/schema/extensions/test_extensions.py | 2 +- .../extensions/test_field_extensions.py | 36 +++++++++++++------ tests/schema/test_permission.py | 32 +++++++++++------ tests/starlite/app.py | 4 ++- tests/starlite/schema.py | 4 +-- 27 files changed, 92 insertions(+), 63 deletions(-) diff --git a/federation-compatibility/schema.py b/federation-compatibility/schema.py index fec327eae6..8c08e1d47d 100644 --- a/federation-compatibility/schema.py +++ b/federation-compatibility/schema.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Any, List, Optional import strawberry from strawberry.schema_directive import Location @@ -135,7 +135,7 @@ def average_products_created_per_year(self) -> Optional[int]: return None @classmethod - def resolve_reference(cls, **data) -> Optional["User"]: + def resolve_reference(cls, **data: Any) -> Optional["User"]: if email := data.get("email"): years_of_employment = data.get("yearsOfEmployment") @@ -183,7 +183,7 @@ def from_data(cls, data: dict) -> "ProductResearch": ) @classmethod - def resolve_reference(cls, **data) -> Optional["ProductResearch"]: + def resolve_reference(cls, **data: Any) -> Optional["ProductResearch"]: study = data.get("study") if not study: @@ -211,7 +211,7 @@ class DeprecatedProduct: created_by: Optional[User] @classmethod - def resolve_reference(cls, **data) -> Optional["DeprecatedProduct"]: + def resolve_reference(cls, **data: Any) -> Optional["DeprecatedProduct"]: if deprecated_product["sku"] == data.get("sku") and deprecated_product[ "package" ] == data.get("package"): @@ -271,7 +271,7 @@ def from_data(cls, data: dict) -> "Product": ) @classmethod - def resolve_reference(cls, **data) -> Optional["Product"]: + def resolve_reference(cls, **data: Any) -> Optional["Product"]: if "id" in data: return get_product_by_id(id=data["id"]) diff --git a/pyproject.toml b/pyproject.toml index 5418cb30bc..5ad175f66d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -194,7 +194,6 @@ ignore = [ "D", "ANN101", # missing annotation for self? # definitely enable these, maybe not in tests - "ANN003", "ANN102", "ANN202", "ANN204", diff --git a/strawberry/channels/handlers/http_handler.py b/strawberry/channels/handlers/http_handler.py index 505be39a59..26e4a5bf02 100644 --- a/strawberry/channels/handlers/http_handler.py +++ b/strawberry/channels/handlers/http_handler.py @@ -73,7 +73,7 @@ def __init__( graphiql: bool = True, allow_queries_via_get: bool = True, subscriptions_enabled: bool = True, - **kwargs, + **kwargs: Any, ): self.schema = schema self.graphiql = graphiql diff --git a/strawberry/channels/handlers/ws_handler.py b/strawberry/channels/handlers/ws_handler.py index 84b8cf26cc..d4fae0d22c 100644 --- a/strawberry/channels/handlers/ws_handler.py +++ b/strawberry/channels/handlers/ws_handler.py @@ -108,7 +108,7 @@ async def receive(self, *args: str, **kwargs: Any) -> None: except ValueError as e: await self._handler.handle_invalid_message(str(e)) - async def receive_json(self, content: Any, **kwargs) -> None: + async def receive_json(self, content: Any, **kwargs: Any) -> None: await self._handler.handle_message(content) async def disconnect(self, code: int) -> None: diff --git a/strawberry/channels/testing.py b/strawberry/channels/testing.py index 8fc890a39a..40ff31671f 100644 --- a/strawberry/channels/testing.py +++ b/strawberry/channels/testing.py @@ -66,7 +66,7 @@ def __init__( path: str, headers: Optional[List[Tuple[bytes, bytes]]] = None, protocol: str = GRAPHQL_TRANSPORT_WS_PROTOCOL, - **kwargs, + **kwargs: Any, ): """ diff --git a/strawberry/experimental/pydantic/conversion_types.py b/strawberry/experimental/pydantic/conversion_types.py index aca9cdccd9..69ca6ae054 100644 --- a/strawberry/experimental/pydantic/conversion_types.py +++ b/strawberry/experimental/pydantic/conversion_types.py @@ -16,7 +16,7 @@ class StrawberryTypeFromPydantic(Protocol[PydanticModel]): """This class does not exist in runtime. It only makes the methods below visible for IDEs""" - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): ... @staticmethod @@ -25,7 +25,7 @@ def from_pydantic( ) -> StrawberryTypeFromPydantic[PydanticModel]: ... - def to_pydantic(self, **kwargs) -> PydanticModel: + def to_pydantic(self, **kwargs: Any) -> PydanticModel: ... @property diff --git a/strawberry/experimental/pydantic/object_type.py b/strawberry/experimental/pydantic/object_type.py index a469d24fac..5af20a331f 100644 --- a/strawberry/experimental/pydantic/object_type.py +++ b/strawberry/experimental/pydantic/object_type.py @@ -270,7 +270,7 @@ def from_pydantic_default( ret._original_model = instance return ret - def to_pydantic_default(self: Any, **kwargs) -> PydanticModel: + def to_pydantic_default(self: Any, **kwargs: Any) -> PydanticModel: instance_kwargs = { f.name: convert_strawberry_class_to_pydantic_model( getattr(self, f.name) diff --git a/strawberry/extensions/field_extension.py b/strawberry/extensions/field_extension.py index b0635d1822..b9156754ca 100644 --- a/strawberry/extensions/field_extension.py +++ b/strawberry/extensions/field_extension.py @@ -18,14 +18,14 @@ def apply(self, field: StrawberryField) -> None: # pragma: no cover pass def resolve( - self, next_: SyncExtensionResolver, source: Any, info: Info, **kwargs + self, next_: SyncExtensionResolver, source: Any, info: Info, **kwargs: Any ) -> Any: # pragma: no cover raise NotImplementedError( "Sync Resolve is not supported for this Field Extension" ) async def resolve_async( - self, next_: AsyncExtensionResolver, source: Any, info: Info, **kwargs + self, next_: AsyncExtensionResolver, source: Any, info: Info, **kwargs: Any ) -> Any: # pragma: no cover raise NotImplementedError( "Async Resolve is not supported for this Field Extension" @@ -45,7 +45,7 @@ class SyncToAsyncExtension(FieldExtension): Applied automatically""" async def resolve_async( - self, next_: AsyncExtensionResolver, source: Any, info: Info, **kwargs + self, next_: AsyncExtensionResolver, source: Any, info: Info, **kwargs: Any ) -> Any: return next_(source, info, **kwargs) diff --git a/strawberry/permission.py b/strawberry/permission.py index a933fe6d78..ccebfb3bc6 100644 --- a/strawberry/permission.py +++ b/strawberry/permission.py @@ -14,7 +14,7 @@ class BasePermission: message: Optional[str] = None def has_permission( - self, source: Any, info: Info, **kwargs + self, source: Any, info: Info, **kwargs: Any ) -> Union[bool, Awaitable[bool]]: raise NotImplementedError( "Permission classes should override has_permission method" diff --git a/strawberry/schema/schema_converter.py b/strawberry/schema/schema_converter.py index fdebb8745a..fa7a6357a6 100644 --- a/strawberry/schema/schema_converter.py +++ b/strawberry/schema/schema_converter.py @@ -551,7 +551,7 @@ def wrap_field_extensions() -> Callable[..., Any]: def extension_resolver( _source: Any, info: Info, - **kwargs, + **kwargs: Any, ): # parse field arguments into Strawberry input types and convert # field names to Python equivalents @@ -590,7 +590,7 @@ def wrapped_get_result(_source: Any, info: Info, **kwargs: Any): _get_result_with_extensions = wrap_field_extensions() - def _resolver(_source: Any, info: GraphQLResolveInfo, **kwargs): + def _resolver(_source: Any, info: GraphQLResolveInfo, **kwargs: Any): strawberry_info = _strawberry_info_from_graphql(info) _check_permissions(_source, strawberry_info, kwargs) @@ -600,7 +600,9 @@ def _resolver(_source: Any, info: GraphQLResolveInfo, **kwargs): **kwargs, ) - async def _async_resolver(_source: Any, info: GraphQLResolveInfo, **kwargs): + async def _async_resolver( + _source: Any, info: GraphQLResolveInfo, **kwargs: Any + ): strawberry_info = _strawberry_info_from_graphql(info) await _check_permissions_async(_source, strawberry_info, kwargs) diff --git a/tests/aiohttp/app.py b/tests/aiohttp/app.py index a2b628fea7..4703f822b8 100644 --- a/tests/aiohttp/app.py +++ b/tests/aiohttp/app.py @@ -1,3 +1,5 @@ +from typing import Any + from aiohttp import web from strawberry.aiohttp.handlers import GraphQLTransportWSHandler, GraphQLWSHandler from strawberry.aiohttp.views import GraphQLView @@ -30,7 +32,7 @@ async def get_root_value(self, request: web.Request) -> Query: return Query() -def create_app(**kwargs) -> web.Application: +def create_app(**kwargs: Any) -> web.Application: app = web.Application() app.router.add_route("*", "/graphql", MyGraphQLView(schema=schema, **kwargs)) diff --git a/tests/asgi/app.py b/tests/asgi/app.py index dde84a03a8..a179a3210f 100644 --- a/tests/asgi/app.py +++ b/tests/asgi/app.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Union +from typing import Any, Dict, Optional, Union from starlette.requests import Request from starlette.responses import Response @@ -20,5 +20,5 @@ async def get_context( return {"request": request, "response": response, "custom_value": "Hi"} -def create_app(**kwargs) -> GraphQL: +def create_app(**kwargs: Any) -> GraphQL: return GraphQL(schema, **kwargs) diff --git a/tests/experimental/pydantic/test_basic.py b/tests/experimental/pydantic/test_basic.py index 7f2622dd82..0f31d3036d 100644 --- a/tests/experimental/pydantic/test_basic.py +++ b/tests/experimental/pydantic/test_basic.py @@ -738,7 +738,7 @@ class IsAuthenticated(strawberry.BasePermission): message = "User is not authenticated" def has_permission( - self, source: Any, info: strawberry.types.Info, **kwargs + self, source: Any, info: strawberry.types.Info, **kwargs: Any ) -> bool: return False diff --git a/tests/fastapi/app.py b/tests/fastapi/app.py index f8dfb3e150..0ae009eebc 100644 --- a/tests/fastapi/app.py +++ b/tests/fastapi/app.py @@ -28,7 +28,7 @@ async def get_root_value( return request or ws -def create_app(schema=schema, **kwargs) -> FastAPI: +def create_app(schema=schema, **kwargs: Any) -> FastAPI: app = FastAPI() graphql_app = GraphQLRouter( diff --git a/tests/fields/test_permissions.py b/tests/fields/test_permissions.py index e6b9bd6ac4..dac1487976 100644 --- a/tests/fields/test_permissions.py +++ b/tests/fields/test_permissions.py @@ -9,7 +9,7 @@ def test_permission_classes_basic_fields(): class IsAuthenticated(BasePermission): message = "User is not authenticated" - def has_permission(self, source: Any, info: Info, **kwargs) -> bool: + def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool: return False @strawberry.type @@ -30,7 +30,7 @@ def test_permission_classes(): class IsAuthenticated(BasePermission): message = "User is not authenticated" - def has_permission(self, source: Any, info: Info, **kwargs) -> bool: + def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool: return False @strawberry.type diff --git a/tests/http/clients/aiohttp.py b/tests/http/clients/aiohttp.py index 2dc4e56f9d..454b7b26b5 100644 --- a/tests/http/clients/aiohttp.py +++ b/tests/http/clients/aiohttp.py @@ -102,7 +102,7 @@ async def _graphql_request( variables: Optional[Dict[str, object]] = None, files: Optional[Dict[str, BytesIO]] = None, headers: Optional[Dict[str, str]] = None, - **kwargs, + **kwargs: Any, ) -> Response: async with TestClient(TestServer(self.app)) as client: body = self._build_body( diff --git a/tests/http/clients/asgi.py b/tests/http/clients/asgi.py index 2ac832532e..8207cd4a05 100644 --- a/tests/http/clients/asgi.py +++ b/tests/http/clients/asgi.py @@ -94,7 +94,7 @@ async def _graphql_request( variables: Optional[Dict[str, object]] = None, files: Optional[Dict[str, BytesIO]] = None, headers: Optional[Dict[str, str]] = None, - **kwargs, + **kwargs: Any, ) -> Response: body = self._build_body( query=query, variables=variables, files=files, method=method diff --git a/tests/http/clients/chalice.py b/tests/http/clients/chalice.py index 8802bd39cd..93b6df6ed7 100644 --- a/tests/http/clients/chalice.py +++ b/tests/http/clients/chalice.py @@ -3,7 +3,7 @@ import urllib.parse from io import BytesIO from json import dumps -from typing import Dict, Optional, Union +from typing import Any, Dict, Optional, Union from typing_extensions import Literal from chalice.app import Chalice @@ -70,7 +70,7 @@ async def _graphql_request( variables: Optional[Dict[str, object]] = None, files: Optional[Dict[str, BytesIO]] = None, headers: Optional[Dict[str, str]] = None, - **kwargs, + **kwargs: Any, ) -> Response: body = self._build_body( query=query, variables=variables, files=files, method=method diff --git a/tests/http/clients/channels.py b/tests/http/clients/channels.py index 9e0e66d32d..54925dbde1 100644 --- a/tests/http/clients/channels.py +++ b/tests/http/clients/channels.py @@ -59,7 +59,7 @@ async def _graphql_request( variables: Optional[Dict[str, object]] = None, files: Optional[Dict[str, BytesIO]] = None, headers: Optional[Dict[str, str]] = None, - **kwargs, + **kwargs: Any, ) -> Response: raise NotImplementedError diff --git a/tests/http/clients/django.py b/tests/http/clients/django.py index bcc1448091..57e84091cd 100644 --- a/tests/http/clients/django.py +++ b/tests/http/clients/django.py @@ -2,7 +2,7 @@ from io import BytesIO from json import dumps -from typing import Dict, Optional, Union +from typing import Any, Dict, Optional, Union from typing_extensions import Literal from django.core.exceptions import BadRequest, SuspiciousOperation @@ -96,7 +96,7 @@ async def _graphql_request( variables: Optional[Dict[str, object]] = None, files: Optional[Dict[str, BytesIO]] = None, headers: Optional[Dict[str, str]] = None, - **kwargs, + **kwargs: Any, ) -> Response: headers = self._get_headers(method=method, headers=headers, files=files) additional_arguments = {**kwargs, **headers} diff --git a/tests/http/clients/flask.py b/tests/http/clients/flask.py index 55af81e45a..68050b95a0 100644 --- a/tests/http/clients/flask.py +++ b/tests/http/clients/flask.py @@ -82,7 +82,7 @@ async def _graphql_request( variables: Optional[Dict[str, object]] = None, files: Optional[Dict[str, BytesIO]] = None, headers: Optional[Dict[str, str]] = None, - **kwargs, + **kwargs: Any, ) -> Response: body = self._build_body( query=query, variables=variables, files=files, method=method @@ -112,7 +112,7 @@ def _do_request( url: str, method: Literal["get", "post", "patch", "put", "delete"], headers: Optional[Dict[str, str]] = None, - **kwargs, + **kwargs: Any, ): with self.app.test_client() as client: response = getattr(client, method)(url, headers=headers, **kwargs) @@ -128,7 +128,7 @@ async def request( url: str, method: Literal["get", "post", "patch", "put", "delete"], headers: Optional[Dict[str, str]] = None, - **kwargs, + **kwargs: Any, ) -> Response: loop = asyncio.get_running_loop() ctx = contextvars.copy_context() diff --git a/tests/http/clients/sanic.py b/tests/http/clients/sanic.py index edb08d7604..f4f7ff005c 100644 --- a/tests/http/clients/sanic.py +++ b/tests/http/clients/sanic.py @@ -72,7 +72,7 @@ async def _graphql_request( variables: Optional[Dict[str, object]] = None, files: Optional[Dict[str, BytesIO]] = None, headers: Optional[Dict[str, str]] = None, - **kwargs, + **kwargs: Any, ) -> Response: body = self._build_body( query=query, variables=variables, files=files, method=method diff --git a/tests/schema/extensions/test_extensions.py b/tests/schema/extensions/test_extensions.py index bfbcf0700e..23619d708c 100644 --- a/tests/schema/extensions/test_extensions.py +++ b/tests/schema/extensions/test_extensions.py @@ -192,7 +192,7 @@ class DefaultSchemaQuery: class ExampleExtension(SchemaExtension): - def __init_subclass__(cls, **kwargs): + def __init_subclass__(cls, **kwargs: Any): super().__init_subclass__(**kwargs) cls.called_hooks = set() diff --git a/tests/schema/extensions/test_field_extensions.py b/tests/schema/extensions/test_field_extensions.py index 6c484516f6..c5a3e61256 100644 --- a/tests/schema/extensions/test_field_extensions.py +++ b/tests/schema/extensions/test_field_extensions.py @@ -13,20 +13,24 @@ class UpperCaseExtension(FieldExtension): - def resolve(self, next_: Callable[..., Any], source: Any, info: Info, **kwargs): + def resolve( + self, next_: Callable[..., Any], source: Any, info: Info, **kwargs: Any + ): result = next_(source, info, **kwargs) return str(result).upper() class LowerCaseExtension(FieldExtension): - def resolve(self, next_: Callable[..., Any], source: Any, info: Info, **kwargs): + def resolve( + self, next_: Callable[..., Any], source: Any, info: Info, **kwargs: Any + ): result = next_(source, info, **kwargs) return str(result).lower() class AsyncUpperCaseExtension(FieldExtension): async def resolve_async( - self, next_: AsyncExtensionResolver, source: Any, info: Info, **kwargs + self, next_: AsyncExtensionResolver, source: Any, info: Info, **kwargs: Any ): result = await next_(source, info, **kwargs) return str(result).upper() @@ -34,12 +38,12 @@ async def resolve_async( class IdentityExtension(FieldExtension): def resolve( - self, next_: SyncExtensionResolver, source: Any, info: Info, **kwargs + self, next_: SyncExtensionResolver, source: Any, info: Info, **kwargs: Any ) -> Any: return next_(source, info, **kwargs) async def resolve_async( - self, next_: AsyncExtensionResolver, source: Any, info: Info, **kwargs + self, next_: AsyncExtensionResolver, source: Any, info: Info, **kwargs: Any ) -> Any: return await next_(source, info, **kwargs) @@ -167,13 +171,15 @@ def string(self) -> str: def test_fail_on_missing_async_extensions(): class LowerCaseExtension(FieldExtension): - def resolve(self, next_: Callable[..., Any], source: Any, info: Info, **kwargs): + def resolve( + self, next_: Callable[..., Any], source: Any, info: Info, **kwargs: Any + ): result = next_(source, info, **kwargs) return str(result).lower() class UpperCaseExtension(FieldExtension): async def resolve_async( - self, next_: Callable[..., Any], source: Any, info: Info, **kwargs + self, next_: Callable[..., Any], source: Any, info: Info, **kwargs: Any ): result = await next_(source, info, **kwargs) return str(result).upper() @@ -196,12 +202,16 @@ async def string(self) -> str: def test_extension_order_respected(): class LowerCaseExtension(FieldExtension): - def resolve(self, next_: Callable[..., Any], source: Any, info: Info, **kwargs): + def resolve( + self, next_: Callable[..., Any], source: Any, info: Info, **kwargs: Any + ): result = next_(source, info, **kwargs) return str(result).lower() class UpperCaseExtension(FieldExtension): - def resolve(self, next_: Callable[..., Any], source: Any, info: Info, **kwargs): + def resolve( + self, next_: Callable[..., Any], source: Any, info: Info, **kwargs: Any + ): result = next_(source, info, **kwargs) return str(result).upper() @@ -232,7 +242,9 @@ class StringInput: field_kwargs = {} class CustomExtension(FieldExtension): - def resolve(self, next_: Callable[..., Any], source: Any, info: Info, **kwargs): + def resolve( + self, next_: Callable[..., Any], source: Any, info: Info, **kwargs: Any + ): nonlocal field_kwargs field_kwargs = kwargs result = next_(source, info, **kwargs) @@ -259,7 +271,9 @@ def string(self, some_input: StringInput) -> str: def test_extension_mutate_arguments(): class CustomExtension(FieldExtension): - def resolve(self, next_: Callable[..., Any], source: Any, info: Info, **kwargs): + def resolve( + self, next_: Callable[..., Any], source: Any, info: Info, **kwargs: Any + ): kwargs["some_input"] += 10 result = next_(source, info, **kwargs) return result diff --git a/tests/schema/test_permission.py b/tests/schema/test_permission.py index 071813366b..fe8a23ac07 100644 --- a/tests/schema/test_permission.py +++ b/tests/schema/test_permission.py @@ -32,7 +32,9 @@ def test_raises_graphql_error_when_permission_is_denied(): class IsAuthenticated(BasePermission): message = "User is not authenticated" - def has_permission(self, source: typing.Any, info: Info, **kwargs) -> bool: + def has_permission( + self, source: typing.Any, info: Info, **kwargs: typing.Any + ) -> bool: return False @strawberry.type @@ -54,7 +56,9 @@ async def test_raises_permission_error_for_subscription(): class IsAdmin(BasePermission): message = "You are not authorized" - def has_permission(self, source: typing.Any, info: Info, **kwargs) -> bool: + def has_permission( + self, source: typing.Any, info: Info, **kwargs: typing.Any + ) -> bool: return False @strawberry.type @@ -81,7 +85,7 @@ async def test_sync_permissions_work_with_async_resolvers(): class IsAuthorized(BasePermission): message = "User is not authorized" - def has_permission(self, source, info, **kwargs) -> bool: + def has_permission(self, source, info, **kwargs: typing.Any) -> bool: return info.context["user"] == "Patrick" @strawberry.type @@ -110,7 +114,9 @@ def test_can_use_source_when_testing_permission(): class CanSeeEmail(BasePermission): message = "Cannot see email for this user" - def has_permission(self, source: typing.Any, info: Info, **kwargs) -> bool: + def has_permission( + self, source: typing.Any, info: Info, **kwargs: typing.Any + ) -> bool: return source.name.lower() == "patrick" @strawberry.type @@ -144,7 +150,9 @@ def test_can_use_args_when_testing_permission(): class CanSeeEmail(BasePermission): message = "Cannot see email for this user" - def has_permission(self, source: typing.Any, info: Info, **kwargs) -> bool: + def has_permission( + self, source: typing.Any, info: Info, **kwargs: typing.Any + ) -> bool: return kwargs.get("secure", False) @strawberry.type @@ -178,7 +186,9 @@ def test_can_use_on_simple_fields(): class CanSeeEmail(BasePermission): message = "Cannot see email for this user" - def has_permission(self, source: typing.Any, info: Info, **kwargs) -> bool: + def has_permission( + self, source: typing.Any, info: Info, **kwargs: typing.Any + ) -> bool: return source.name.lower() == "patrick" @strawberry.type @@ -210,7 +220,7 @@ async def test_dataclass_field_with_async_permission_class(): class CanSeeEmail(BasePermission): message = "Cannot see email for this user" - async def has_permission(self, source, info, **kwargs) -> bool: + async def has_permission(self, source, info, **kwargs: typing.Any) -> bool: return source.name.lower() == "patrick" @strawberry.type @@ -240,7 +250,7 @@ async def test_async_resolver_with_async_permission_class(): class IsAuthorized(BasePermission): message = "User is not authorized" - async def has_permission(self, source, info, **kwargs) -> bool: + async def has_permission(self, source, info, **kwargs: typing.Any) -> bool: return info.context["user"] == "Patrick" @strawberry.type @@ -270,7 +280,7 @@ async def test_sync_resolver_with_async_permission_class(): class IsAuthorized(BasePermission): message = "User is not authorized" - async def has_permission(self, source, info, **kwargs) -> bool: + async def has_permission(self, source, info, **kwargs: typing.Any) -> bool: return info.context["user"] == "Patrick" @strawberry.type @@ -300,13 +310,13 @@ async def test_mixed_sync_and_async_permission_classes(): class IsAuthorizedAsync(BasePermission): message = "User is not authorized (async)" - async def has_permission(self, source, info, **kwargs) -> bool: + async def has_permission(self, source, info, **kwargs: typing.Any) -> bool: return info.context.get("passAsync", False) class IsAuthorizedSync(BasePermission): message = "User is not authorized (sync)" - def has_permission(self, source, info, **kwargs) -> bool: + def has_permission(self, source, info, **kwargs: typing.Any) -> bool: return info.context.get("passSync", False) @strawberry.type diff --git a/tests/starlite/app.py b/tests/starlite/app.py index 924821dfb1..8e7afe8429 100644 --- a/tests/starlite/app.py +++ b/tests/starlite/app.py @@ -1,3 +1,5 @@ +from typing import Any + from starlite import Provide, Request, Starlite from strawberry.starlite import make_graphql_controller from tests.starlite.schema import schema @@ -18,7 +20,7 @@ async def get_context(app_dependency: str, request: Request = None): } -def create_app(schema=schema, **kwargs): +def create_app(schema=schema, **kwargs: Any): GraphQLController = make_graphql_controller( schema, path="/graphql", diff --git a/tests/starlite/schema.py b/tests/starlite/schema.py index 86ac8bf400..15da28e9dd 100644 --- a/tests/starlite/schema.py +++ b/tests/starlite/schema.py @@ -1,7 +1,7 @@ import asyncio import typing from enum import Enum -from typing import Optional +from typing import Any, Optional from graphql import GraphQLError @@ -15,7 +15,7 @@ class AlwaysFailPermission(BasePermission): message = "You are not authorized" - def has_permission(self, source: typing.Any, info: Info, **kwargs) -> bool: + def has_permission(self, source: Any, info: Info, **kwargs: typing.Any) -> bool: return False From aee5ec94b31bfa5e0e2be895f086670c2b3ac95b Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Thu, 18 May 2023 16:02:28 -0300 Subject: [PATCH 002/119] fix: fix optional scalars using the "or" notation on python 3.10 (#2774) --- RELEASE.md | 18 ++++++++++++++++++ strawberry/custom_scalar.py | 4 ++-- tests/test_forward_references.py | 14 ++++++++++++-- tests/utils/test_typing_forward_refs.py | 2 ++ 4 files changed, 34 insertions(+), 4 deletions(-) create mode 100644 RELEASE.md diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..1afc844e4a --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,18 @@ +Release type: patch + +This release fixes an issue with optional scalars using the `or` +notation with forward references on python 3.10. + +The following code would previously raise `TypeError` on python 3.10: + +```python +from __future__ import annotations + +import strawberry +from strawberry.scalars import JSON + + +@strawberry.type +class SomeType: + an_optional_json: JSON | None +``` diff --git a/strawberry/custom_scalar.py b/strawberry/custom_scalar.py index 89fd0f4706..6b7d2ad121 100644 --- a/strawberry/custom_scalar.py +++ b/strawberry/custom_scalar.py @@ -17,7 +17,7 @@ ) from strawberry.exceptions import InvalidUnionTypeError -from strawberry.type import StrawberryOptional, StrawberryType +from strawberry.type import StrawberryType from .utils.str_converters import to_camel_case @@ -76,7 +76,7 @@ def __call__(self, *args: str, **kwargs: Any): def __or__(self, other: Union[StrawberryType, type]) -> StrawberryType: if other is None: # Return the correct notation when using `StrawberryUnion | None`. - return StrawberryOptional(of_type=self) + return Optional[self] # Raise an error in any other case. # There is Work in progress to deal with more merging cases, see: diff --git a/tests/test_forward_references.py b/tests/test_forward_references.py index dd0ed764c8..7a102a857f 100644 --- a/tests/test_forward_references.py +++ b/tests/test_forward_references.py @@ -10,6 +10,7 @@ import strawberry from strawberry.printer import print_schema +from strawberry.scalars import JSON from strawberry.type import StrawberryList, StrawberryOptional @@ -23,16 +24,25 @@ class Query: @strawberry.type class MyType: id: strawberry.ID + scalar: JSON + optional_scalar: JSON | None + + expected_representation = ''' + """ + The `JSON` scalar type represents JSON values as specified by [ECMA-404](http://www.ecma-international.org/publications/files/ECMA-ST/ECMA-404.pdf). + """ + scalar JSON @specifiedBy(url: "http://www.ecma-international.org/publications/files/ECMA-ST/ECMA-404.pdf") - expected_representation = """ type MyType { id: ID! + scalar: JSON! + optionalScalar: JSON } type Query { myself: MyType! } - """ + ''' schema = strawberry.Schema(Query) diff --git a/tests/utils/test_typing_forward_refs.py b/tests/utils/test_typing_forward_refs.py index 807ef95cb4..ff36b74bde 100644 --- a/tests/utils/test_typing_forward_refs.py +++ b/tests/utils/test_typing_forward_refs.py @@ -5,6 +5,7 @@ import pytest +from strawberry.scalars import JSON from strawberry.utils.typing import eval_type @@ -26,6 +27,7 @@ class Foo: eval_type(ForwardRef("List[Foo | str] | None | int"), globals(), locals()) == Union[List[Union[Foo, str]], int, None] ) + assert eval_type(ForwardRef("JSON | None"), globals(), locals()) == Optional[JSON] @pytest.mark.skipif( From efc684a0611d927bb29b856c32171ea8e15a52d2 Mon Sep 17 00:00:00 2001 From: Botberry Date: Thu, 18 May 2023 19:03:34 +0000 Subject: [PATCH 003/119] =?UTF-8?q?Release=20=F0=9F=8D=93=200.177.2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 23 +++++++++++++++++++++++ RELEASE.md | 18 ------------------ pyproject.toml | 2 +- 3 files changed, 24 insertions(+), 19 deletions(-) delete mode 100644 RELEASE.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 67b572004b..bae367f008 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,29 @@ CHANGELOG ========= +0.177.2 - 2023-05-18 +-------------------- + +This release fixes an issue with optional scalars using the `or` +notation with forward references on python 3.10. + +The following code would previously raise `TypeError` on python 3.10: + +```python +from __future__ import annotations + +import strawberry +from strawberry.scalars import JSON + + +@strawberry.type +class SomeType: + an_optional_json: JSON | None +``` + +Contributed by [Thiago Bellini Ribeiro](https://github.com/bellini666) via [PR #2774](https://github.com/strawberry-graphql/strawberry/pull/2774/) + + 0.177.1 - 2023-05-09 -------------------- diff --git a/RELEASE.md b/RELEASE.md deleted file mode 100644 index 1afc844e4a..0000000000 --- a/RELEASE.md +++ /dev/null @@ -1,18 +0,0 @@ -Release type: patch - -This release fixes an issue with optional scalars using the `or` -notation with forward references on python 3.10. - -The following code would previously raise `TypeError` on python 3.10: - -```python -from __future__ import annotations - -import strawberry -from strawberry.scalars import JSON - - -@strawberry.type -class SomeType: - an_optional_json: JSON | None -``` diff --git a/pyproject.toml b/pyproject.toml index 5ad175f66d..625d5e8c25 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "strawberry-graphql" packages = [ { include = "strawberry" } ] -version = "0.177.1" +version = "0.177.2" description = "A library for creating GraphQL APIs" authors = ["Patrick Arminio "] license = "MIT" From f0ddaa2b097ccba86992bbf32d69563f002d7467 Mon Sep 17 00:00:00 2001 From: Ronald Williams Date: Fri, 19 May 2023 10:45:25 -0500 Subject: [PATCH 004/119] Allow additional tags in DatadogTracingExtension (#2773) Co-authored-by: Patrick Arminio --- .github/workflows/mypy.yml | 1 - RELEASE.md | 23 ++++++++ docs/extensions/datadog.md | 32 ++++++++++ strawberry/extensions/__init__.py | 3 +- strawberry/extensions/base_extension.py | 8 +++ strawberry/extensions/tracing/datadog.py | 59 +++++++++++++++---- .../extensions/tracing/opentelemetry.py | 41 ++++++------- tests/schema/extensions/test_datadog.py | 37 ++++++++++++ 8 files changed, 168 insertions(+), 36 deletions(-) create mode 100644 RELEASE.md diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index a8dcb8c4bd..d3476cca0a 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -24,7 +24,6 @@ jobs: - uses: actions/checkout@v2 - run: pip install poetry - - run: poetry config experimental.new-installer false - name: "Python dependencies cache" id: cache-poetry-dependencies diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..0fd18e716c --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,23 @@ +Release type: patch + +This release adds a method on the DatadogTracingExtension class called `create_span` that can be overridden to create a custom span or add additional tags to the span. + +```python +from ddtrace import Span + +from strawberry.extensions import LifecycleStep +from strawberry.extensions.tracing import DatadogTracingExtension + + +class DataDogExtension(DatadogTracingExtension): + def create_span( + self, + lifecycle_step: LifecycleStep, + name: str, + **kwargs, + ) -> Span: + span = super().create_span(lifecycle_step, name, **kwargs) + if lifecycle_step == LifeCycleStep.OPERATION: + span.set_tag("graphql.query", self.execution_context.query) + return span +``` diff --git a/docs/extensions/datadog.md b/docs/extensions/datadog.md index e083bb8ee0..3c77c52edd 100644 --- a/docs/extensions/datadog.md +++ b/docs/extensions/datadog.md @@ -50,3 +50,35 @@ schema = strawberry.Schema( ``` + +## API reference: + +_No arguments_ + +## Extending the extension + +### Overriding the `create_span` method + +You can customize any of the spans or add tags to them by overriding the `create_span` method. + +Example: + +```python +from ddtrace import Span + +from strawberry.extensions import LifecycleStep +from strawberry.extensions.tracing import DatadogTracingExtension + + +class DataDogExtension(DatadogTracingExtension): + def create_span( + self, + lifecycle_step: LifecycleStep, + name: str, + **kwargs, + ) -> Span: + span = super().create_span(lifecycle_step, name, **kwargs) + if lifecycle_step == LifeCycleStep.OPERATION: + span.set_tag("graphql.query", self.execution_context.query) + return span +``` diff --git a/strawberry/extensions/__init__.py b/strawberry/extensions/__init__.py index 2ff75c9a4f..3877626bf5 100644 --- a/strawberry/extensions/__init__.py +++ b/strawberry/extensions/__init__.py @@ -1,7 +1,7 @@ import warnings from .add_validation_rules import AddValidationRules -from .base_extension import SchemaExtension +from .base_extension import LifecycleStep, SchemaExtension from .disable_validation import DisableValidation from .field_extension import FieldExtension from .mask_errors import MaskErrors @@ -30,6 +30,7 @@ def __getattr__(name: str): __all__ = [ "FieldExtension", "SchemaExtension", + "LifecycleStep", "AddValidationRules", "DisableValidation", "ParserCache", diff --git a/strawberry/extensions/base_extension.py b/strawberry/extensions/base_extension.py index 2de89b870d..f10b7db797 100644 --- a/strawberry/extensions/base_extension.py +++ b/strawberry/extensions/base_extension.py @@ -1,5 +1,6 @@ from __future__ import annotations +from enum import Enum from typing import TYPE_CHECKING, Any, Callable, Dict, Set from strawberry.utils.await_maybe import AsyncIteratorOrIterator, AwaitableOrValue @@ -10,6 +11,13 @@ from strawberry.types import ExecutionContext +class LifecycleStep(Enum): + OPERATION = "operation" + VALIDATION = "validation" + PARSE = "parse" + RESOLVE = "resolve" + + class SchemaExtension: execution_context: ExecutionContext diff --git a/strawberry/extensions/tracing/datadog.py b/strawberry/extensions/tracing/datadog.py index 0e3cbef423..11bb71b3da 100644 --- a/strawberry/extensions/tracing/datadog.py +++ b/strawberry/extensions/tracing/datadog.py @@ -4,9 +4,9 @@ from inspect import isawaitable from typing import TYPE_CHECKING, Any, Callable, Generator, Iterator, Optional -from ddtrace import tracer +from ddtrace import Span, tracer -from strawberry.extensions import SchemaExtension +from strawberry.extensions import LifecycleStep, SchemaExtension from strawberry.extensions.tracing.utils import should_skip_tracing from strawberry.utils.cached_property import cached_property @@ -36,6 +36,31 @@ def _resource_name(self): return query_hash + def create_span( + self, + lifecycle_step: LifecycleStep, + name: str, + **kwargs: Any, + ) -> Span: + """ + Create a span with the given name and kwargs. + You can override this if you want to add more tags to the span. + + Example: + + class CustomExtension(DatadogTracingExtension): + def create_span(self, lifecycle_step, name, **kwargs): + span = super().create_span(lifecycle_step, name, **kwargs) + if lifecycle_step == LifeCycleStep.OPERATION: + span.set_tag("graphql.query", self.execution_context.query) + return span + """ + return tracer.trace( + name, + span_type="graphql", + **kwargs, + ) + def hash_query(self, query: str) -> str: return hashlib.md5(query.encode("utf-8")).hexdigest() @@ -45,34 +70,38 @@ def on_operation(self) -> Iterator[None]: f"{self._operation_name}" if self._operation_name else "Anonymous Query" ) - self.request_span = tracer.trace( + self.request_span = self.create_span( + LifecycleStep.OPERATION, span_name, resource=self._resource_name, - span_type="graphql", service="strawberry", ) self.request_span.set_tag("graphql.operation_name", self._operation_name) - operation_type = "query" - assert self.execution_context.query + operation_type = "query" if self.execution_context.query.strip().startswith("mutation"): operation_type = "mutation" - if self.execution_context.query.strip().startswith("subscription"): + elif self.execution_context.query.strip().startswith("subscription"): operation_type = "subscription" - self.request_span.set_tag("graphql.operation_type", operation_type) yield self.request_span.finish() def on_validate(self) -> Generator[None, None, None]: - self.validation_span = tracer.trace("Validation", span_type="graphql") + self.validation_span = self.create_span( + lifecycle_step=LifecycleStep.VALIDATION, + name="Validation", + ) yield self.validation_span.finish() def on_parse(self) -> Generator[None, None, None]: - self.parsing_span = tracer.trace("Parsing", span_type="graphql") + self.parsing_span = self.create_span( + lifecycle_step=LifecycleStep.PARSE, + name="Parsing", + ) yield self.parsing_span.finish() @@ -94,7 +123,10 @@ async def resolve( field_path = f"{info.parent_type}.{info.field_name}" - with tracer.trace(f"Resolving: {field_path}", span_type="graphql") as span: + with self.create_span( + lifecycle_step=LifecycleStep.RESOLVE, + name=f"Resolving: {field_path}", + ) as span: span.set_tag("graphql.field_name", info.field_name) span.set_tag("graphql.parent_type", info.parent_type.name) span.set_tag("graphql.field_path", field_path) @@ -122,7 +154,10 @@ def resolve( field_path = f"{info.parent_type}.{info.field_name}" - with tracer.trace(f"Resolving: {field_path}", span_type="graphql") as span: + with self.create_span( + lifecycle_step=LifecycleStep.RESOLVE, + name=f"Resolving: {field_path}", + ) as span: span.set_tag("graphql.field_name", info.field_name) span.set_tag("graphql.parent_type", info.parent_type.name) span.set_tag("graphql.field_path", field_path) diff --git a/strawberry/extensions/tracing/opentelemetry.py b/strawberry/extensions/tracing/opentelemetry.py index a724863454..8051d23bf9 100644 --- a/strawberry/extensions/tracing/opentelemetry.py +++ b/strawberry/extensions/tracing/opentelemetry.py @@ -1,6 +1,5 @@ from __future__ import annotations -import enum from copy import deepcopy from inspect import isawaitable from typing import ( @@ -19,7 +18,7 @@ from opentelemetry import trace from opentelemetry.trace import SpanKind -from strawberry.extensions import SchemaExtension +from strawberry.extensions import LifecycleStep, SchemaExtension from strawberry.extensions.utils import get_path_from_info from .utils import should_skip_tracing @@ -36,15 +35,9 @@ ArgFilter = Callable[[Dict[str, Any], "GraphQLResolveInfo"], Dict[str, Any]] -class RequestStage(enum.Enum): - REQUEST = enum.auto() - PARSING = enum.auto() - VALIDATION = enum.auto() - - class OpenTelemetryExtension(SchemaExtension): _arg_filter: Optional[ArgFilter] - _span_holder: Dict[RequestStage, Span] = dict() + _span_holder: Dict[LifecycleStep, Span] = dict() _tracer: Tracer def __init__( @@ -66,13 +59,13 @@ def on_operation(self) -> Generator[None, None, None]: else "GraphQL Query" ) - self._span_holder[RequestStage.REQUEST] = self._tracer.start_span( + self._span_holder[LifecycleStep.OPERATION] = self._tracer.start_span( span_name, kind=SpanKind.SERVER ) - self._span_holder[RequestStage.REQUEST].set_attribute("component", "graphql") + self._span_holder[LifecycleStep.OPERATION].set_attribute("component", "graphql") if self.execution_context.query: - self._span_holder[RequestStage.REQUEST].set_attribute( + self._span_holder[LifecycleStep.OPERATION].set_attribute( "query", self.execution_context.query ) @@ -84,26 +77,26 @@ def on_operation(self) -> Generator[None, None, None]: # useful name in our trace. if not self._operation_name and self.execution_context.operation_name: span_name = f"GraphQL Query: {self.execution_context.operation_name}" - self._span_holder[RequestStage.REQUEST].update_name(span_name) - self._span_holder[RequestStage.REQUEST].end() + self._span_holder[LifecycleStep.OPERATION].update_name(span_name) + self._span_holder[LifecycleStep.OPERATION].end() def on_validate(self) -> Generator[None, None, None]: - ctx = trace.set_span_in_context(self._span_holder[RequestStage.REQUEST]) - self._span_holder[RequestStage.VALIDATION] = self._tracer.start_span( + ctx = trace.set_span_in_context(self._span_holder[LifecycleStep.OPERATION]) + self._span_holder[LifecycleStep.VALIDATION] = self._tracer.start_span( "GraphQL Validation", context=ctx, ) yield - self._span_holder[RequestStage.VALIDATION].end() + self._span_holder[LifecycleStep.VALIDATION].end() def on_parse(self) -> Generator[None, None, None]: - ctx = trace.set_span_in_context(self._span_holder[RequestStage.REQUEST]) - self._span_holder[RequestStage.PARSING] = self._tracer.start_span( + ctx = trace.set_span_in_context(self._span_holder[LifecycleStep.OPERATION]) + self._span_holder[LifecycleStep.PARSE] = self._tracer.start_span( "GraphQL Parsing", context=ctx ) yield - self._span_holder[RequestStage.PARSING].end() + self._span_holder[LifecycleStep.PARSE].end() def filter_resolver_args( self, args: Dict[str, Any], info: GraphQLResolveInfo @@ -178,7 +171,9 @@ async def resolve( with self._tracer.start_as_current_span( f"GraphQL Resolving: {info.field_name}", - context=trace.set_span_in_context(self._span_holder[RequestStage.REQUEST]), + context=trace.set_span_in_context( + self._span_holder[LifecycleStep.OPERATION] + ), ) as span: self.add_tags(span, info, kwargs) result = _next(root, info, *args, **kwargs) @@ -205,7 +200,9 @@ def resolve( with self._tracer.start_as_current_span( f"GraphQL Resolving: {info.field_name}", - context=trace.set_span_in_context(self._span_holder[RequestStage.REQUEST]), + context=trace.set_span_in_context( + self._span_holder[LifecycleStep.OPERATION] + ), ) as span: self.add_tags(span, info, kwargs) result = _next(root, info, *args, **kwargs) diff --git a/tests/schema/extensions/test_datadog.py b/tests/schema/extensions/test_datadog.py index ab881d34e2..936151b2f8 100644 --- a/tests/schema/extensions/test_datadog.py +++ b/tests/schema/extensions/test_datadog.py @@ -252,3 +252,40 @@ def test_uses_operation_type_sync(datadog_extension_sync): schema.execute_sync(query, operation_name="MyMutation") mock.tracer.trace().set_tag.assert_any_call("graphql.operation_type", "mutation") + + +@pytest.mark.asyncio +async def test_create_span_override(datadog_extension): + from strawberry.extensions.tracing.datadog import LifecycleStep + + extension, mock = datadog_extension + + class CustomExtension(extension): + def create_span( + self, + lifecycle_step: LifecycleStep, + name: str, + **kwargs, # noqa: ANN003 + ): + span = super().create_span(lifecycle_step, name, **kwargs) + if lifecycle_step == LifecycleStep.OPERATION: + span.set_tag("graphql.query", self.execution_context.query) + return span + + schema = strawberry.Schema( + query=Query, + mutation=Mutation, + extensions=[CustomExtension], + ) + + query = """ + query { + personAsync { + name + } + } + """ + + await schema.execute(query) + + mock.tracer.trace().set_tag.assert_any_call("graphql.query", query) From c443d9c0d3893c4b8e38934631ec09b3a34877fd Mon Sep 17 00:00:00 2001 From: Botberry Date: Fri, 19 May 2023 15:46:26 +0000 Subject: [PATCH 005/119] =?UTF-8?q?Release=20=F0=9F=8D=93=200.177.3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 28 ++++++++++++++++++++++++++++ RELEASE.md | 23 ----------------------- pyproject.toml | 2 +- 3 files changed, 29 insertions(+), 24 deletions(-) delete mode 100644 RELEASE.md diff --git a/CHANGELOG.md b/CHANGELOG.md index bae367f008..a1fe98fe9a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,34 @@ CHANGELOG ========= +0.177.3 - 2023-05-19 +-------------------- + +This release adds a method on the DatadogTracingExtension class called `create_span` that can be overridden to create a custom span or add additional tags to the span. + +```python +from ddtrace import Span + +from strawberry.extensions import LifecycleStep +from strawberry.extensions.tracing import DatadogTracingExtension + + +class DataDogExtension(DatadogTracingExtension): + def create_span( + self, + lifecycle_step: LifecycleStep, + name: str, + **kwargs, + ) -> Span: + span = super().create_span(lifecycle_step, name, **kwargs) + if lifecycle_step == LifeCycleStep.OPERATION: + span.set_tag("graphql.query", self.execution_context.query) + return span +``` + +Contributed by [Ronald Williams](https://github.com/ronaldnwilliams) via [PR #2773](https://github.com/strawberry-graphql/strawberry/pull/2773/) + + 0.177.2 - 2023-05-18 -------------------- diff --git a/RELEASE.md b/RELEASE.md deleted file mode 100644 index 0fd18e716c..0000000000 --- a/RELEASE.md +++ /dev/null @@ -1,23 +0,0 @@ -Release type: patch - -This release adds a method on the DatadogTracingExtension class called `create_span` that can be overridden to create a custom span or add additional tags to the span. - -```python -from ddtrace import Span - -from strawberry.extensions import LifecycleStep -from strawberry.extensions.tracing import DatadogTracingExtension - - -class DataDogExtension(DatadogTracingExtension): - def create_span( - self, - lifecycle_step: LifecycleStep, - name: str, - **kwargs, - ) -> Span: - span = super().create_span(lifecycle_step, name, **kwargs) - if lifecycle_step == LifeCycleStep.OPERATION: - span.set_tag("graphql.query", self.execution_context.query) - return span -``` diff --git a/pyproject.toml b/pyproject.toml index 625d5e8c25..8153596a07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "strawberry-graphql" packages = [ { include = "strawberry" } ] -version = "0.177.2" +version = "0.177.3" description = "A library for creating GraphQL APIs" authors = ["Patrick Arminio "] license = "MIT" From c2d50c0cd9a94518478c9ff3eb25be041338adb3 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Mon, 22 May 2023 14:46:38 +0100 Subject: [PATCH 006/119] Extend query depth limiter allowing for more detailed rules (#2505) * add FieldAttributesRule class for specifying name, args, and keys * add tests to cover new class * fix failing type check no field_args * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add RELEASE.md file * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix inconspicuous bug with multiple rules * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove field_keys, it was not working as intended * update RELEASE.md * fix missed field_keys references * remove redundant if None condition * add FieldArgumentsRuleType to appease mypy * rename FieldAttributesRule -> FieldRule, FieldNameRuleType -> FieldAttributeRuleType, specify that FieldArgumentsRuleType is a dict of FieldAttributeRuleType * update docs to describe code changes * add back breaking line * add comment on blacken-docs issue * update RELEASE.md * arbitrary change to rerun tests * introduce should_ignore and deprecate ignore options * update RELEASE.md * update docs * appease ruff * Update release notes --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Patrick Arminio --- RELEASE.md | 66 ++++ docs/extensions/query-depth-limiter.md | 127 +++++++- strawberry/extensions/__init__.py | 3 +- strawberry/extensions/query_depth_limiter.py | 233 +++++++++++++-- .../extensions/test_query_depth_limiter.py | 234 +++++++++++++-- .../test_query_depth_limiter_deprecated.py | 282 ++++++++++++++++++ 6 files changed, 900 insertions(+), 45 deletions(-) create mode 100644 RELEASE.md create mode 100644 tests/schema/extensions/test_query_depth_limiter_deprecated.py diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..8cc251582d --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,66 @@ +Release type: minor + +This release introduces the new `should_ignore` argument to the `QueryDepthLimiter` extension that provides +a more general and more verbose way of specifying the rules by which a query's depth should be limited. + +The `should_ignore` argument should be a function that accepts a single argument of type `IgnoreContext`. +The `IgnoreContext` class has the following attributes: +- `field_name` of type `str`: the name of the field to be compared against +- `field_args` of type `strawberry.extensions.query_depth_limiter.FieldArgumentsType`: the arguments of the field to be compared against +- `query` of type `graphql.language.Node`: the query string +- `context` of type `graphql.validation.ValidationContext`: the context passed to the query +and returns `True` if the field should be ignored and `False` otherwise. +This argument is injected, regardless of name, by the `QueryDepthLimiter` class and should not be passed by the user. + +Instead, the user should write business logic to determine whether a field should be ignored or not by +the attributes of the `IgnoreContext` class. + +For example, the following query: +```python +""" + query { + matt: user(name: "matt") { + email + } + andy: user(name: "andy") { + email + address { + city + } + pets { + name + owner { + name + } + } + } + } +""" +``` +can have its depth limited by the following `should_ignore`: +```python +from strawberry.extensions import IgnoreContext + + +def should_ignore(ignore: IgnoreContext): + return ignore.field_args.get("name") == "matt" + + +query_depth_limiter = QueryDepthLimiter(should_ignore=should_ignore) +``` +so that it *effectively* becomes: +```python +""" + query { + andy: user(name: "andy") { + email + pets { + name + owner { + name + } + } + } + } +""" +``` diff --git a/docs/extensions/query-depth-limiter.md b/docs/extensions/query-depth-limiter.md index a159cc43d6..6d327ca5eb 100644 --- a/docs/extensions/query-depth-limiter.md +++ b/docs/extensions/query-depth-limiter.md @@ -25,7 +25,7 @@ schema = strawberry.Schema( ## API reference: ```python -class QueryDepthLimiter(max_depth, ignore=None, callback=None): +class QueryDepthLimiter(max_depth, ignore=None, callback=None, should_ignore=None): ... ``` @@ -39,12 +39,135 @@ Stops recursive depth checking based on a field name. Either a string or regexp to match the name, or a function that returns a boolean. +This variable has been deprecated in favour of the `should_ignore` argument +as documented below. + #### `callback: Optional[Callable[[Dict[str, int]], None]` Called each time validation runs. Receives a dictionary which is a map of the depths for each operation. -## More examples: +#### `should_ignore: Optional[Callable[[IgnoreContext], bool]]` + +Called at each field to determine whether the field should be ignored or not. +Must be implemented by the user and returns `True` if the field should be ignored +and `False` otherwise. + +The `IgnoreContext` class has the following attributes: + +- `field_name` of type `str`: the name of the field to be compared against +- `field_args` of type `strawberry.extensions.query_depth_limiter.FieldArgumentsType`: the arguments of the field to be compared against +- `query` of type `graphql.language.Node`: the query string +- `context` of type `graphql.validation.ValidationContext`: the context passed to the query + +This argument is injected, regardless of name, by the `QueryDepthLimiter` class and should not be passed by the user. + +Instead, the user should write business logic to determine whether a field should be ignored or not by +the attributes of the `IgnoreContext` class. + +## Example with field_name: + +```python +import strawberry +from strawberry.extensions import QueryDepthLimiter + + +def should_ignore(ignore: IgnoreContext): + return ignore.field_name == "user" + + +schema = strawberry.Schema( + Query, + extensions=[ + QueryDepthLimiter(max_depth=2, should_ignore=should_ignore), + ], +) + +# This query fails +schema.execute( + """ + query TooDeep { + book { + author { + publishedBooks { + title + } + } + } + } +""" +) + +# This query succeeds because the `user` field is ignored +schema.execute( + """ + query NotTooDeep { + user { + favouriteBooks { + author { + publishedBooks { + title + } + } + } + } + } +""" +) +``` + +## Example with field_args: + +```python +import strawberry +from strawberry.extensions import QueryDepthLimiter + + +def should_ignore(ignore: IgnoreContext): + return ignore.field_args.get("name") == "matt" + + +schema = strawberry.Schema( + Query, + extensions=[ + QueryDepthLimiter(max_depth=2, should_ignore=should_ignore), + ], +) + +# This query fails +schema.execute( + """ + query TooDeep { + book { + author { + publishedBooks { + title + } + } + } + } +""" +) + +# This query succeeds because the `user` field is ignored +schema.execute( + """ + query NotTooDeep { + user(name:"matt") { + favouriteBooks { + author { + publishedBooks { + title + } + } + } + } + } +""" +) +``` + +## More examples for deprecated `ignore` argument:
Ignoring fields diff --git a/strawberry/extensions/__init__.py b/strawberry/extensions/__init__.py index 3877626bf5..04024eabf9 100644 --- a/strawberry/extensions/__init__.py +++ b/strawberry/extensions/__init__.py @@ -8,7 +8,7 @@ from .max_aliases import MaxAliasesLimiter from .max_tokens import MaxTokensLimiter from .parser_cache import ParserCache -from .query_depth_limiter import QueryDepthLimiter +from .query_depth_limiter import IgnoreContext, QueryDepthLimiter from .validation_cache import ValidationCache @@ -35,6 +35,7 @@ def __getattr__(name: str): "DisableValidation", "ParserCache", "QueryDepthLimiter", + "IgnoreContext", "ValidationCache", "MaskErrors", "MaxAliasesLimiter", diff --git a/strawberry/extensions/query_depth_limiter.py b/strawberry/extensions/query_depth_limiter.py index c217950910..8838279dc2 100644 --- a/strawberry/extensions/query_depth_limiter.py +++ b/strawberry/extensions/query_depth_limiter.py @@ -28,27 +28,57 @@ from __future__ import annotations import re -from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Type, Union +import warnings +from dataclasses import dataclass +from typing import ( + Callable, + Dict, + Iterable, + List, + Optional, + Type, + Union, +) from graphql import GraphQLError from graphql.language import ( + BooleanValueNode, + DefinitionNode, FieldNode, + FloatValueNode, FragmentDefinitionNode, FragmentSpreadNode, InlineFragmentNode, + IntValueNode, + ListValueNode, + Node, + ObjectValueNode, OperationDefinitionNode, + StringValueNode, + ValueNode, ) -from graphql.validation import ValidationRule +from graphql.validation import ValidationContext, ValidationRule from strawberry.extensions import AddValidationRules from strawberry.extensions.utils import is_introspection_key -if TYPE_CHECKING: - from graphql.language import DefinitionNode, Node - from graphql.validation import ValidationContext +IgnoreType = Union[Callable[[str], bool], re.Pattern, str] + +FieldArgumentType = Union[ + bool, int, float, str, List["FieldArgumentType"], Dict[str, "FieldArgumentType"] +] +FieldArgumentsType = Dict[str, FieldArgumentType] -IgnoreType = Union[Callable[[str], bool], re.Pattern, str] +@dataclass +class IgnoreContext: + field_name: str + field_args: FieldArgumentsType + node: Node + context: ValidationContext + + +ShouldIgnoreType = Callable[[IgnoreContext], bool] class QueryDepthLimiter(AddValidationRules): @@ -71,28 +101,47 @@ class QueryDepthLimiter(AddValidationRules): `max_depth: int` The maximum allowed depth for any operation in a GraphQL document. - `ignore: Optional[List[IgnoreType]]` + `ignore: Optional[List[IgnoreType]] DEPRECATED` Stops recursive depth checking based on a field name. Either a string or regexp to match the name, or a function that returns a boolean. `callback: Optional[Callable[[Dict[str, int]], None]` Called each time validation runs. Receives an Object which is a map of the depths for each operation. + `should_ignore: Optional[ShouldIgnoreType]` + Stops recursive depth checking based on a field name and arguments. + A function that returns a boolean and conforms to the ShouldIgnoreType + function signature. """ def __init__( self, max_depth: int, - ignore: Optional[List[IgnoreType]] = None, + ignore: Optional[List[IgnoreType]] = None, # DEPRECATED callback: Optional[Callable[[Dict[str, int]], None]] = None, + should_ignore: Optional[ShouldIgnoreType] = None, ): - validator = create_validator(max_depth, ignore, callback) + if should_ignore is not None: + if not callable(should_ignore): + raise TypeError( + "The `should_ignore` argument to " + "`QueryDepthLimiter` must be a callable." + ) + validator = create_validator(max_depth, should_ignore, callback) + else: + warnings.warn( + "The `ignore` argument to `QueryDepthLimiter` is deprecated. " + "Please use `should_ignore` instead.", + DeprecationWarning, + stacklevel=1, + ) + validator = create_validator_deprecated(max_depth, ignore, callback) super().__init__([validator]) def create_validator( max_depth: int, - ignore: Optional[List[IgnoreType]] = None, + should_ignore: Optional[ShouldIgnoreType], callback: Optional[Callable[[Dict[str, int]], None]] = None, ) -> Type[ValidationRule]: class DepthLimitValidator(ValidationRule): @@ -104,15 +153,15 @@ def __init__(self, validation_context: ValidationContext): queries = get_queries_and_mutations(definitions) query_depths = {} - for name in queries: - query_depths[name] = determine_depth( - node=queries[name], + for query in queries: + query_depths[query] = determine_depth( + node=queries[query], fragments=fragments, depth_so_far=0, max_depth=max_depth, context=validation_context, - operation_name=name, - ignore=ignore, + operation_name=query, + should_ignore=should_ignore, ) if callable(callback): @@ -148,7 +197,155 @@ def get_queries_and_mutations( return operations +def get_field_name( + node: FieldNode, +) -> str: + return node.alias.value if node.alias else node.name.value + + +def resolve_field_value( + value: ValueNode, +) -> FieldArgumentType: + if isinstance(value, StringValueNode): + return value.value + elif isinstance(value, IntValueNode): + return int(value.value) + elif isinstance(value, FloatValueNode): + return float(value.value) + elif isinstance(value, BooleanValueNode): + return value.value + elif isinstance(value, ListValueNode): + return [resolve_field_value(v) for v in value.values] + elif isinstance(value, ObjectValueNode): + return {v.name.value: resolve_field_value(v.value) for v in value.fields} + else: + return {} + + +def get_field_arguments( + node: FieldNode, +) -> FieldArgumentsType: + args_dict: FieldArgumentsType = {} + for arg in node.arguments: + args_dict[arg.name.value] = resolve_field_value(arg.value) + return args_dict + + def determine_depth( + node: Node, + fragments: Dict[str, FragmentDefinitionNode], + depth_so_far: int, + max_depth: int, + context: ValidationContext, + operation_name: str, + should_ignore: Optional[ShouldIgnoreType], +) -> int: + if depth_so_far > max_depth: + context.report_error( + GraphQLError( + f"'{operation_name}' exceeds maximum operation depth of {max_depth}", + [node], + ) + ) + return depth_so_far + + if isinstance(node, FieldNode): + # by default, ignore the introspection fields which begin + # with double underscores + should_ignore_field = is_introspection_key(node.name.value) or ( + should_ignore( + IgnoreContext( + get_field_name(node), + get_field_arguments(node), + node, + context, + ) + ) + if should_ignore is not None + else False + ) + + if should_ignore_field or not node.selection_set: + return 0 + + return 1 + max( + map( + lambda selection: determine_depth( + node=selection, + fragments=fragments, + depth_so_far=depth_so_far + 1, + max_depth=max_depth, + context=context, + operation_name=operation_name, + should_ignore=should_ignore, + ), + node.selection_set.selections, + ) + ) + elif isinstance(node, FragmentSpreadNode): + return determine_depth( + node=fragments[node.name.value], + fragments=fragments, + depth_so_far=depth_so_far, + max_depth=max_depth, + context=context, + operation_name=operation_name, + should_ignore=should_ignore, + ) + elif isinstance( + node, (InlineFragmentNode, FragmentDefinitionNode, OperationDefinitionNode) + ): + return max( + map( + lambda selection: determine_depth( + node=selection, + fragments=fragments, + depth_so_far=depth_so_far, + max_depth=max_depth, + context=context, + operation_name=operation_name, + should_ignore=should_ignore, + ), + node.selection_set.selections, + ) + ) + else: + raise TypeError(f"Depth crawler cannot handle: {node.kind}") # pragma: no cover + + +def create_validator_deprecated( + max_depth: int, + ignore: Optional[List[IgnoreType]] = None, + callback: Optional[Callable[[Dict[str, int]], None]] = None, +) -> Type[ValidationRule]: + class DepthLimitValidator(ValidationRule): + def __init__(self, validation_context: ValidationContext): + document = validation_context.document + definitions = document.definitions + + fragments = get_fragments(definitions) + queries = get_queries_and_mutations(definitions) + query_depths = {} + + for name in queries: + query_depths[name] = determine_depth_deprecated( + node=queries[name], + fragments=fragments, + depth_so_far=0, + max_depth=max_depth, + context=validation_context, + operation_name=name, + ignore=ignore, + ) + + if callable(callback): + callback(query_depths) + super().__init__(validation_context) + + return DepthLimitValidator + + +def determine_depth_deprecated( node: Node, fragments: Dict[str, FragmentDefinitionNode], depth_so_far: int, @@ -178,7 +375,7 @@ def determine_depth( return 1 + max( map( - lambda selection: determine_depth( + lambda selection: determine_depth_deprecated( node=selection, fragments=fragments, depth_so_far=depth_so_far + 1, @@ -191,7 +388,7 @@ def determine_depth( ) ) elif isinstance(node, FragmentSpreadNode): - return determine_depth( + return determine_depth_deprecated( node=fragments[node.name.value], fragments=fragments, depth_so_far=depth_so_far, @@ -205,7 +402,7 @@ def determine_depth( ): return max( map( - lambda selection: determine_depth( + lambda selection: determine_depth_deprecated( node=selection, fragments=fragments, depth_so_far=depth_so_far, diff --git a/tests/schema/extensions/test_query_depth_limiter.py b/tests/schema/extensions/test_query_depth_limiter.py index 683eb6e87d..22ed5be6c0 100644 --- a/tests/schema/extensions/test_query_depth_limiter.py +++ b/tests/schema/extensions/test_query_depth_limiter.py @@ -1,4 +1,3 @@ -import re from typing import Dict, List, Optional, Tuple, Union import pytest @@ -12,7 +11,11 @@ import strawberry from strawberry.extensions import QueryDepthLimiter -from strawberry.extensions.query_depth_limiter import create_validator +from strawberry.extensions.query_depth_limiter import ( + IgnoreContext, + ShouldIgnoreType, + create_validator, +) @strawberry.interface @@ -47,10 +50,30 @@ class Human: pets: List[Pet] +@strawberry.input +class Biography: + name: str + owner_name: str + + @strawberry.type class Query: @strawberry.field - def user(self, name: Optional[str]) -> Human: + def user( + self, + name: Optional[str], + id: Optional[int], + age: Optional[float], + is_cool: Optional[bool], + ) -> Human: + pass + + @strawberry.field + def users(self, names: Optional[List[str]]) -> List[Human]: + pass + + @strawberry.field + def cat(bio: Biography) -> Cat: pass version: str @@ -63,7 +86,7 @@ def user(self, name: Optional[str]) -> Human: def run_query( - query: str, max_depth: int, ignore=None + query: str, max_depth: int, should_ignore: ShouldIgnoreType = None ) -> Tuple[List[GraphQLError], Union[Dict[str, int], None]]: document = parse(query) @@ -73,7 +96,7 @@ def callback(query_depths): nonlocal result result = query_depths - validation_rule = create_validator(max_depth, ignore, callback) + validation_rule = create_validator(max_depth, should_ignore, callback) errors = validate( schema._schema, @@ -221,7 +244,22 @@ def test_should_catch_query_thats_too_deep(): assert errors[0].message == "'anonymous' exceeds maximum operation depth of 4" -def test_should_ignore_field(): +def test_should_raise_invalid_ignore(): + query = """ + query read1 { + user { address { city } } + } + """ + with pytest.raises( + TypeError, + match="The `should_ignore` argument to `QueryDepthLimiter` must be a callable.", + ): + strawberry.Schema( + Query, extensions=[QueryDepthLimiter(max_depth=10, should_ignore=True)] + ) + + +def test_should_ignore_field_by_name(): query = """ query read1 { user { address { city } } @@ -233,33 +271,175 @@ def test_should_ignore_field(): } """ - errors, result = run_query( - query, - 10, - ignore=[ - "user1", - re.compile("user2"), - lambda field_name: field_name == "user3", - ], - ) + def should_ignore(ignore: IgnoreContext) -> bool: + return ( + ignore.field_name == "user1" + or ignore.field_name == "user2" + or ignore.field_name == "user3" + ) + + errors, result = run_query(query, 10, should_ignore=should_ignore) expected = {"read1": 2, "read2": 0} assert not errors assert result == expected -def test_should_raise_invalid_ignore(): +def test_should_ignore_field_by_str_argument(): query = """ query read1 { - user { address { city } } + user(name:"matt") { address { city } } + } + query read2 { + user1 { address { city } } + user2 { address { city } } + user3 { address { city } } } """ - with pytest.raises(TypeError, match="Invalid ignore option:"): - run_query( - query, - 10, - ignore=[True], - ) + + def should_ignore(ignore: IgnoreContext) -> bool: + return ignore.field_args.get("name") == "matt" + + errors, result = run_query(query, 10, should_ignore=should_ignore) + + expected = {"read1": 0, "read2": 2} + assert not errors + assert result == expected + + +def test_should_ignore_field_by_int_argument(): + query = """ + query read1 { + user(id:1) { address { city } } + } + query read2 { + user1 { address { city } } + user2 { address { city } } + user3 { address { city } } + } + """ + + def should_ignore(ignore: IgnoreContext) -> bool: + return ignore.field_args.get("id") == 1 + + errors, result = run_query(query, 10, should_ignore=should_ignore) + + expected = {"read1": 0, "read2": 2} + assert not errors + assert result == expected + + +def test_should_ignore_field_by_float_argument(): + query = """ + query read1 { + user(age:10.5) { address { city } } + } + query read2 { + user1 { address { city } } + user2 { address { city } } + user3 { address { city } } + } + """ + + def should_ignore(ignore: IgnoreContext) -> bool: + return ignore.field_args.get("age") == 10.5 + + errors, result = run_query(query, 10, should_ignore=should_ignore) + + expected = {"read1": 0, "read2": 2} + assert not errors + assert result == expected + + +def test_should_ignore_field_by_bool_argument(): + query = """ + query read1 { + user(isCool:false) { address { city } } + } + query read2 { + user1 { address { city } } + user2 { address { city } } + user3 { address { city } } + } + """ + + def should_ignore(ignore: IgnoreContext) -> bool: + return ignore.field_args.get("isCool") is False + + errors, result = run_query(query, 10, should_ignore=should_ignore) + + expected = {"read1": 0, "read2": 2} + assert not errors + assert result == expected + + +def test_should_ignore_field_by_name_and_str_argument(): + query = """ + query read1 { + user(name:"matt") { address { city } } + } + query read2 { + user1 { address { city } } + user2 { address { city } } + user3 { address { city } } + } + """ + + def should_ignore(ignore: IgnoreContext) -> bool: + return ignore.field_args.get("name") == "matt" + + errors, result = run_query(query, 10, should_ignore=should_ignore) + + expected = {"read1": 0, "read2": 2} + assert not errors + assert result == expected + + +def test_should_ignore_field_by_list_argument(): + query = """ + query read1 { + users(names:["matt","andy"]) { address { city } } + } + query read2 { + user1 { address { city } } + user2 { address { city } } + user3 { address { city } } + } + """ + + def should_ignore(ignore: IgnoreContext) -> bool: + return "matt" in ignore.field_args.get("names", []) + + errors, result = run_query(query, 10, should_ignore=should_ignore) + + expected = {"read1": 0, "read2": 2} + assert not errors + assert result == expected + + +def test_should_ignore_field_by_object_argument(): + query = """ + query read1 { + cat(bio:{ + name:"Momo", + ownerName:"Tommy" + }) { name } + } + query read2 { + user1 { address { city } } + user2 { address { city } } + user3 { address { city } } + } + """ + + def should_ignore(ignore: IgnoreContext) -> bool: + return ignore.field_args.get("bio", {}).get("name") == "Momo" + + errors, result = run_query(query, 10, should_ignore=should_ignore) + + expected = {"read1": 0, "read2": 2} + assert not errors + assert result == expected def test_should_work_as_extension(): @@ -279,7 +459,13 @@ def test_should_work_as_extension(): } } """ - schema = strawberry.Schema(Query, extensions=[QueryDepthLimiter(max_depth=4)]) + + def should_ignore(ignore: IgnoreContext) -> bool: + return False + + schema = strawberry.Schema( + Query, extensions=[QueryDepthLimiter(max_depth=4, should_ignore=should_ignore)] + ) result = schema.execute_sync(query) diff --git a/tests/schema/extensions/test_query_depth_limiter_deprecated.py b/tests/schema/extensions/test_query_depth_limiter_deprecated.py new file mode 100644 index 0000000000..aa528f41b8 --- /dev/null +++ b/tests/schema/extensions/test_query_depth_limiter_deprecated.py @@ -0,0 +1,282 @@ +import re +from typing import List, Optional + +import pytest +from graphql import get_introspection_query, parse, specified_rules, validate + +import strawberry +from strawberry.extensions import QueryDepthLimiter +from strawberry.extensions.query_depth_limiter import create_validator_deprecated + + +@strawberry.interface +class Pet: + name: str + owner: "Human" + + +@strawberry.type +class Cat(Pet): + pass + + +@strawberry.type +class Dog(Pet): + pass + + +@strawberry.type +class Address: + street: str + number: int + city: str + country: str + + +@strawberry.type +class Human: + name: str + email: str + address: Address + pets: List[Pet] + + +@strawberry.type +class Query: + @strawberry.field + def user(self, name: Optional[str]) -> Human: + pass + + version: str + user1: Human + user2: Human + user3: Human + + +schema = strawberry.Schema(Query) + + +def run_query(query: str, max_depth: int, ignore=None): + document = parse(query) + + result = None + + def callback(query_depths): + nonlocal result + result = query_depths + + validation_rule = create_validator_deprecated(max_depth, ignore, callback) + + errors = validate( + schema._schema, + document, + rules=(*specified_rules, validation_rule), + ) + + return errors, result + + +def test_should_count_depth_without_fragment(): + query = """ + query read0 { + version + } + query read1 { + version + user { + name + } + } + query read2 { + matt: user(name: "matt") { + email + } + andy: user(name: "andy") { + email + address { + city + } + } + } + query read3 { + matt: user(name: "matt") { + email + } + andy: user(name: "andy") { + email + address { + city + } + pets { + name + owner { + name + } + } + } + } + """ + + expected = {"read0": 0, "read1": 1, "read2": 2, "read3": 3} + + errors, result = run_query(query, 10) + assert not errors + assert result == expected + + +def test_should_count_with_fragments(): + query = """ + query read0 { + ... on Query { + version + } + } + query read1 { + version + user { + ... on Human { + name + } + } + } + fragment humanInfo on Human { + email + } + fragment petInfo on Pet { + name + owner { + name + } + } + query read2 { + matt: user(name: "matt") { + ...humanInfo + } + andy: user(name: "andy") { + ...humanInfo + address { + city + } + } + } + query read3 { + matt: user(name: "matt") { + ...humanInfo + } + andy: user(name: "andy") { + ... on Human { + email + } + address { + city + } + pets { + ...petInfo + } + } + } + """ + + expected = {"read0": 0, "read1": 1, "read2": 2, "read3": 3} + + errors, result = run_query(query, 10) + assert not errors + assert result == expected + + +def test_should_ignore_the_introspection_query(): + errors, result = run_query(get_introspection_query(), 10) + assert not errors + assert result == {"IntrospectionQuery": 0} + + +def test_should_catch_query_thats_too_deep(): + query = """{ + user { + pets { + owner { + pets { + owner { + pets { + name + } + } + } + } + } + } + } + """ + errors, result = run_query(query, 4) + + assert len(errors) == 1 + assert errors[0].message == "'anonymous' exceeds maximum operation depth of 4" + + +def test_should_ignore_field_simple(): + query = """ + query read1 { + user { address { city } } + } + query read2 { + user1 { address { city } } + user2 { address { city } } + user3 { address { city } } + } + """ + + errors, result = run_query( + query, + 10, + ignore=[ + "user1", + re.compile("user2"), + lambda field_name: field_name == "user3", + ], + ) + + expected = {"read1": 2, "read2": 0} + assert not errors + assert result == expected + + +def test_should_raise_invalid_ignore(): + query = """ + query read1 { + user { address { city } } + } + """ + with pytest.raises(TypeError, match="Invalid ignore option:"): + run_query( + query, + 10, + ignore=[True], + ) + + +def test_should_work_as_extension(): + query = """{ + user { + pets { + owner { + pets { + owner { + pets { + name + } + } + } + } + } + } + } + """ + with pytest.deprecated_call(): + schema = strawberry.Schema(Query, extensions=[QueryDepthLimiter(max_depth=4)]) + + result = schema.execute_sync(query) + + assert len(result.errors) == 1 + assert ( + result.errors[0].message == "'anonymous' exceeds maximum operation depth of 4" + ) From 4519c399d3407ae321c4abfbbd4d4cbc7978816a Mon Sep 17 00:00:00 2001 From: Botberry Date: Mon, 22 May 2023 13:47:49 +0000 Subject: [PATCH 007/119] =?UTF-8?q?Release=20=F0=9F=8D=93=200.178.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 71 ++++++++++++++++++++++++++++++++++++++++++++++++++ RELEASE.md | 66 ---------------------------------------------- pyproject.toml | 2 +- 3 files changed, 72 insertions(+), 67 deletions(-) delete mode 100644 RELEASE.md diff --git a/CHANGELOG.md b/CHANGELOG.md index a1fe98fe9a..45917973a6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,77 @@ CHANGELOG ========= +0.178.0 - 2023-05-22 +-------------------- + +This release introduces the new `should_ignore` argument to the `QueryDepthLimiter` extension that provides +a more general and more verbose way of specifying the rules by which a query's depth should be limited. + +The `should_ignore` argument should be a function that accepts a single argument of type `IgnoreContext`. +The `IgnoreContext` class has the following attributes: +- `field_name` of type `str`: the name of the field to be compared against +- `field_args` of type `strawberry.extensions.query_depth_limiter.FieldArgumentsType`: the arguments of the field to be compared against +- `query` of type `graphql.language.Node`: the query string +- `context` of type `graphql.validation.ValidationContext`: the context passed to the query +and returns `True` if the field should be ignored and `False` otherwise. +This argument is injected, regardless of name, by the `QueryDepthLimiter` class and should not be passed by the user. + +Instead, the user should write business logic to determine whether a field should be ignored or not by +the attributes of the `IgnoreContext` class. + +For example, the following query: +```python +""" + query { + matt: user(name: "matt") { + email + } + andy: user(name: "andy") { + email + address { + city + } + pets { + name + owner { + name + } + } + } + } +""" +``` +can have its depth limited by the following `should_ignore`: +```python +from strawberry.extensions import IgnoreContext + + +def should_ignore(ignore: IgnoreContext): + return ignore.field_args.get("name") == "matt" + + +query_depth_limiter = QueryDepthLimiter(should_ignore=should_ignore) +``` +so that it *effectively* becomes: +```python +""" + query { + andy: user(name: "andy") { + email + pets { + name + owner { + name + } + } + } + } +""" +``` + +Contributed by [Tommy Smith](https://github.com/tsmith023) via [PR #2505](https://github.com/strawberry-graphql/strawberry/pull/2505/) + + 0.177.3 - 2023-05-19 -------------------- diff --git a/RELEASE.md b/RELEASE.md deleted file mode 100644 index 8cc251582d..0000000000 --- a/RELEASE.md +++ /dev/null @@ -1,66 +0,0 @@ -Release type: minor - -This release introduces the new `should_ignore` argument to the `QueryDepthLimiter` extension that provides -a more general and more verbose way of specifying the rules by which a query's depth should be limited. - -The `should_ignore` argument should be a function that accepts a single argument of type `IgnoreContext`. -The `IgnoreContext` class has the following attributes: -- `field_name` of type `str`: the name of the field to be compared against -- `field_args` of type `strawberry.extensions.query_depth_limiter.FieldArgumentsType`: the arguments of the field to be compared against -- `query` of type `graphql.language.Node`: the query string -- `context` of type `graphql.validation.ValidationContext`: the context passed to the query -and returns `True` if the field should be ignored and `False` otherwise. -This argument is injected, regardless of name, by the `QueryDepthLimiter` class and should not be passed by the user. - -Instead, the user should write business logic to determine whether a field should be ignored or not by -the attributes of the `IgnoreContext` class. - -For example, the following query: -```python -""" - query { - matt: user(name: "matt") { - email - } - andy: user(name: "andy") { - email - address { - city - } - pets { - name - owner { - name - } - } - } - } -""" -``` -can have its depth limited by the following `should_ignore`: -```python -from strawberry.extensions import IgnoreContext - - -def should_ignore(ignore: IgnoreContext): - return ignore.field_args.get("name") == "matt" - - -query_depth_limiter = QueryDepthLimiter(should_ignore=should_ignore) -``` -so that it *effectively* becomes: -```python -""" - query { - andy: user(name: "andy") { - email - pets { - name - owner { - name - } - } - } - } -""" -``` diff --git a/pyproject.toml b/pyproject.toml index 8153596a07..c1631e48a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "strawberry-graphql" packages = [ { include = "strawberry" } ] -version = "0.177.3" +version = "0.178.0" description = "A library for creating GraphQL APIs" authors = ["Patrick Arminio "] license = "MIT" From 141e03149676c411ba8b4c1231bd858e43a4cbb8 Mon Sep 17 00:00:00 2001 From: Nick Butlin Date: Tue, 30 May 2023 21:56:49 +0100 Subject: [PATCH 008/119] 2772 match annotations in pydantic schema (#2782) * test: add tests for updated default behaviour * fix: 2772 correct type annotation for pydantic * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * chore: add release.md * Update RELEASE.md * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update RELEASE.md --------- Co-authored-by: Nick Butlin <1270349-nickbutlin@users.noreply.gitlab.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: James Chua <30519287+thejaminator@users.noreply.github.com> Co-authored-by: Patrick Arminio --- RELEASE.md | 19 +++++++ .../experimental/pydantic/object_type.py | 6 +- .../pydantic/schema/test_basic.py | 55 +++++++++++++++++++ .../pydantic/schema/test_defaults.py | 2 + tests/experimental/pydantic/test_basic.py | 49 +++++++++++++++++ 5 files changed, 126 insertions(+), 5 deletions(-) create mode 100644 RELEASE.md diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..a89e71f685 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,19 @@ +Release type: patch + +This release fixes a bug in experimental.pydantic whereby `Optional` type annotations weren't exactly aligned between strawberry type and pydantic model. + +Previously this would have caused the series field to be non-nullable in graphql. +```python +from typing import Optional +from pydantic import BaseModel, Field +import strawberry + + +class VehicleModel(BaseModel): + series: Optional[str] = Field(default="") + + +@strawberry.experimental.pydantic.type(model=VehicleModel, all_fields=True) +class VehicleModelType: + pass +``` diff --git a/strawberry/experimental/pydantic/object_type.py b/strawberry/experimental/pydantic/object_type.py index 5af20a331f..217547dc11 100644 --- a/strawberry/experimental/pydantic/object_type.py +++ b/strawberry/experimental/pydantic/object_type.py @@ -43,11 +43,7 @@ def get_type_for_field(field: ModelField, is_input: bool): # noqa: ANN201 outer_type = field.outer_type_ replaced_type = replace_types_recursively(outer_type, is_input) - - default_defined: bool = ( - field.default_factory is not None or field.default is not None - ) - should_add_optional: bool = not (field.required or default_defined) + should_add_optional: bool = field.allow_none if should_add_optional: return Optional[replaced_type] else: diff --git a/tests/experimental/pydantic/schema/test_basic.py b/tests/experimental/pydantic/schema/test_basic.py index 4c006a21e6..83eed46b6c 100644 --- a/tests/experimental/pydantic/schema/test_basic.py +++ b/tests/experimental/pydantic/schema/test_basic.py @@ -465,3 +465,58 @@ def user(self) -> UserType: assert not result.errors assert result.data["user"]["interfaceField"]["baseField"] == "abc" assert result.data["user"]["interfaceField"]["fieldB"] == 10 + + +def test_basic_type_with_optional_and_default(): + class UserModel(pydantic.BaseModel): + age: int + password: Optional[str] = pydantic.Field(default="ABC") + + @strawberry.experimental.pydantic.type(UserModel, all_fields=True) + class User: + pass + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> User: + return User(age=1) + + schema = strawberry.Schema(query=Query) + + expected_schema = """ + type Query { + user: User! + } + + type User { + age: Int! + password: String + } + """ + + assert str(schema) == textwrap.dedent(expected_schema).strip() + + query = "{ user { age password } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["age"] == 1 + assert result.data["user"]["password"] == "ABC" + + @strawberry.type + class QueryNone: + @strawberry.field + def user(self) -> User: + return User(age=1, password=None) + + schema = strawberry.Schema(query=QueryNone) + + query = "{ user { age password } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["age"] == 1 + assert result.data["user"]["password"] is None diff --git a/tests/experimental/pydantic/schema/test_defaults.py b/tests/experimental/pydantic/schema/test_defaults.py index dc7dd3935c..0917e85ec9 100644 --- a/tests/experimental/pydantic/schema/test_defaults.py +++ b/tests/experimental/pydantic/schema/test_defaults.py @@ -10,6 +10,7 @@ def test_field_type_default(): class User(pydantic.BaseModel): name: str = "James" + nickname: Optional[str] = "Jim" @strawberry.experimental.pydantic.type(User, all_fields=True) class PydanticUser: @@ -35,6 +36,7 @@ def b(self) -> StrawberryUser: expected = """ type PydanticUser { name: String! + nickname: String } type Query { diff --git a/tests/experimental/pydantic/test_basic.py b/tests/experimental/pydantic/test_basic.py index 0f31d3036d..8d8474cd76 100644 --- a/tests/experimental/pydantic/test_basic.py +++ b/tests/experimental/pydantic/test_basic.py @@ -343,6 +343,55 @@ class UserType4: assert UserType4().to_pydantic().friend is None +def test_optional_and_default(): + class UserModel(pydantic.BaseModel): + age: int + name: str = pydantic.Field("Michael", description="The user name") + password: Optional[str] = pydantic.Field(default="ABC") + passwordtwo: Optional[str] = None + some_list: Optional[List[str]] = pydantic.Field(default_factory=list) + check: Optional[bool] = False + + @strawberry.experimental.pydantic.type(UserModel, all_fields=True) + class User: + pass + + definition: TypeDefinition = User._type_definition + assert definition.name == "User" + + [ + age_field, + name_field, + password_field, + passwordtwo_field, + some_list_field, + check_field, + ] = definition.fields + + assert age_field.python_name == "age" + assert age_field.type is int + + assert name_field.python_name == "name" + assert name_field.type is str + + assert password_field.python_name == "password" + assert isinstance(password_field.type, StrawberryOptional) + assert password_field.type.of_type is str + + assert passwordtwo_field.python_name == "passwordtwo" + assert isinstance(passwordtwo_field.type, StrawberryOptional) + assert passwordtwo_field.type.of_type is str + + assert some_list_field.python_name == "some_list" + assert isinstance(some_list_field.type, StrawberryOptional) + assert isinstance(some_list_field.type.of_type, StrawberryList) + assert some_list_field.type.of_type.of_type is str + + assert check_field.python_name == "check" + assert isinstance(check_field.type, StrawberryOptional) + assert check_field.type.of_type is bool + + def test_type_with_fields_mutable_default(): empty_list = [] From 6cf8238c9cb11eda69a912e493c6a477b1d2a6be Mon Sep 17 00:00:00 2001 From: Botberry Date: Tue, 30 May 2023 20:58:04 +0000 Subject: [PATCH 009/119] =?UTF-8?q?Release=20=F0=9F=8D=93=200.178.1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 24 ++++++++++++++++++++++++ RELEASE.md | 19 ------------------- pyproject.toml | 2 +- 3 files changed, 25 insertions(+), 20 deletions(-) delete mode 100644 RELEASE.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 45917973a6..d7128fa7de 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,30 @@ CHANGELOG ========= +0.178.1 - 2023-05-30 +-------------------- + +This release fixes a bug in experimental.pydantic whereby `Optional` type annotations weren't exactly aligned between strawberry type and pydantic model. + +Previously this would have caused the series field to be non-nullable in graphql. +```python +from typing import Optional +from pydantic import BaseModel, Field +import strawberry + + +class VehicleModel(BaseModel): + series: Optional[str] = Field(default="") + + +@strawberry.experimental.pydantic.type(model=VehicleModel, all_fields=True) +class VehicleModelType: + pass +``` + +Contributed by [Nick Butlin](https://github.com/nicholasbutlin) via [PR #2782](https://github.com/strawberry-graphql/strawberry/pull/2782/) + + 0.178.0 - 2023-05-22 -------------------- diff --git a/RELEASE.md b/RELEASE.md deleted file mode 100644 index a89e71f685..0000000000 --- a/RELEASE.md +++ /dev/null @@ -1,19 +0,0 @@ -Release type: patch - -This release fixes a bug in experimental.pydantic whereby `Optional` type annotations weren't exactly aligned between strawberry type and pydantic model. - -Previously this would have caused the series field to be non-nullable in graphql. -```python -from typing import Optional -from pydantic import BaseModel, Field -import strawberry - - -class VehicleModel(BaseModel): - series: Optional[str] = Field(default="") - - -@strawberry.experimental.pydantic.type(model=VehicleModel, all_fields=True) -class VehicleModelType: - pass -``` diff --git a/pyproject.toml b/pyproject.toml index c1631e48a8..a3b0e91518 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "strawberry-graphql" packages = [ { include = "strawberry" } ] -version = "0.178.0" +version = "0.178.1" description = "A library for creating GraphQL APIs" authors = ["Patrick Arminio "] license = "MIT" From 8842778470b669716414d7d59c566ab103566097 Mon Sep 17 00:00:00 2001 From: Matt Gilson Date: Wed, 31 May 2023 15:47:14 -0400 Subject: [PATCH 010/119] Get mutations to work with codegen (#2795) * Add regression tests for mutation codegen. * Fix codegen mutation bug. --------- Co-authored-by: Matt Gilson --- RELEASE.md | 3 +++ strawberry/codegen/query_codegen.py | 4 +++- tests/codegen/conftest.py | 13 ++++++++++++- tests/codegen/queries/mutation.graphql | 5 +++++ tests/codegen/snapshots/python/mutation.py | 8 ++++++++ tests/codegen/snapshots/typescript/mutation.ts | 11 +++++++++++ 6 files changed, 42 insertions(+), 2 deletions(-) create mode 100644 RELEASE.md create mode 100644 tests/codegen/queries/mutation.graphql create mode 100644 tests/codegen/snapshots/python/mutation.py create mode 100644 tests/codegen/snapshots/typescript/mutation.ts diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..74a3aa732c --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,3 @@ +Release type: patch + +Prevent AssertionError when using `strawberry codegen` on a query file that contains a mutation. diff --git a/strawberry/codegen/query_codegen.py b/strawberry/codegen/query_codegen.py index 23970782ec..c2b825f989 100644 --- a/strawberry/codegen/query_codegen.py +++ b/strawberry/codegen/query_codegen.py @@ -255,7 +255,9 @@ def _convert_directives( def _convert_operation( self, operation_definition: OperationDefinitionNode ) -> GraphQLOperation: - query_type = self.schema.get_type_by_name("Query") + query_type = self.schema.get_type_by_name( + operation_definition.operation.value.title() + ) assert isinstance(query_type, TypeDefinition) assert operation_definition.name is not None diff --git a/tests/codegen/conftest.py b/tests/codegen/conftest.py index 3196d63f2e..bf1021e60a 100644 --- a/tests/codegen/conftest.py +++ b/tests/codegen/conftest.py @@ -46,6 +46,10 @@ class Node: class BlogPost(Node): title: str + def __init__(self, id: str, title: str) -> None: + self.id = id + self.title = title + @strawberry.type class Image(Node): @@ -97,6 +101,13 @@ def with_inputs(self, id: Optional[strawberry.ID], input: ExampleInput) -> bool: return True +@strawberry.type +class Mutation: + @strawberry.mutation + def add_book(self, name: str) -> BlogPost: + return BlogPost(id="c6f1c3ce-5249-4570-9182-c2836b836d14", name=name) + + @pytest.fixture def schema() -> strawberry.Schema: - return strawberry.Schema(query=Query, types=[BlogPost, Image]) + return strawberry.Schema(query=Query, mutation=Mutation, types=[BlogPost, Image]) diff --git a/tests/codegen/queries/mutation.graphql b/tests/codegen/queries/mutation.graphql new file mode 100644 index 0000000000..701c20f3a0 --- /dev/null +++ b/tests/codegen/queries/mutation.graphql @@ -0,0 +1,5 @@ +mutation addBook($input: String!) { + addBook(input: $input) { + id + } +} diff --git a/tests/codegen/snapshots/python/mutation.py b/tests/codegen/snapshots/python/mutation.py new file mode 100644 index 0000000000..e2b1d3562f --- /dev/null +++ b/tests/codegen/snapshots/python/mutation.py @@ -0,0 +1,8 @@ +class addBookResultAddBook: + id: str + +class addBookResult: + add_book: addBookResultAddBook + +class addBookVariables: + input: str diff --git a/tests/codegen/snapshots/typescript/mutation.ts b/tests/codegen/snapshots/typescript/mutation.ts new file mode 100644 index 0000000000..773bbd235e --- /dev/null +++ b/tests/codegen/snapshots/typescript/mutation.ts @@ -0,0 +1,11 @@ +type addBookResultAddBook = { + id: string +} + +type addBookResult = { + add_book: addBookResultAddBook +} + +type addBookVariables = { + input: string +} From 31c37cb80b342856988400aa82daeea6d0bfe48b Mon Sep 17 00:00:00 2001 From: Botberry Date: Wed, 31 May 2023 19:48:13 +0000 Subject: [PATCH 011/119] =?UTF-8?q?Release=20=F0=9F=8D=93=200.178.2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 8 ++++++++ RELEASE.md | 3 --- pyproject.toml | 2 +- 3 files changed, 9 insertions(+), 4 deletions(-) delete mode 100644 RELEASE.md diff --git a/CHANGELOG.md b/CHANGELOG.md index d7128fa7de..7cdbe99064 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,14 @@ CHANGELOG ========= +0.178.2 - 2023-05-31 +-------------------- + +Prevent AssertionError when using `strawberry codegen` on a query file that contains a mutation. + +Contributed by [Matt Gilson](https://github.com/mgilson) via [PR #2795](https://github.com/strawberry-graphql/strawberry/pull/2795/) + + 0.178.1 - 2023-05-30 -------------------- diff --git a/RELEASE.md b/RELEASE.md deleted file mode 100644 index 74a3aa732c..0000000000 --- a/RELEASE.md +++ /dev/null @@ -1,3 +0,0 @@ -Release type: patch - -Prevent AssertionError when using `strawberry codegen` on a query file that contains a mutation. diff --git a/pyproject.toml b/pyproject.toml index a3b0e91518..bfc8e5b088 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "strawberry-graphql" packages = [ { include = "strawberry" } ] -version = "0.178.1" +version = "0.178.2" description = "A library for creating GraphQL APIs" authors = ["Patrick Arminio "] license = "MIT" From 6b499f02d6d4b3d94d5cc40e5e4621846042e440 Mon Sep 17 00:00:00 2001 From: Matt Gilson Date: Wed, 31 May 2023 15:57:55 -0400 Subject: [PATCH 012/119] Do not choke on __typename in codegen. (#2797) * Do not choke on __typename in codegen. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update RELEASE.md --------- Co-authored-by: Matt Gilson Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Patrick Arminio --- RELEASE.md | 4 +++ strawberry/codegen/plugins/python.py | 6 ++++- strawberry/codegen/query_codegen.py | 2 ++ .../queries/union_with_typename.graphql | 19 ++++++++++++++ .../snapshots/python/union_with_typename.py | 21 ++++++++++++++++ .../typescript/union_with_typename.ts | 25 +++++++++++++++++++ 6 files changed, 76 insertions(+), 1 deletion(-) create mode 100644 RELEASE.md create mode 100644 tests/codegen/queries/union_with_typename.graphql create mode 100644 tests/codegen/snapshots/python/union_with_typename.py create mode 100644 tests/codegen/snapshots/typescript/union_with_typename.ts diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..1bf8656e7a --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,4 @@ +Release type: patch + +In this release codegen no longer chokes on queries that have a `__typename` in them. +Python generated types will not have the `__typename` included in the fields. diff --git a/strawberry/codegen/plugins/python.py b/strawberry/codegen/plugins/python.py index 701debe3fe..58a261fbab 100644 --- a/strawberry/codegen/plugins/python.py +++ b/strawberry/codegen/plugins/python.py @@ -110,7 +110,11 @@ def _print_enum_value(self, value: str) -> str: return f'{value} = "{value}"' def _print_object_type(self, type_: GraphQLObjectType) -> str: - fields = "\n".join(self._print_field(field) for field in type_.fields) + fields = "\n".join( + self._print_field(field) + for field in type_.fields + if field.name != "__typename" + ) return "\n".join( [ diff --git a/strawberry/codegen/query_codegen.py b/strawberry/codegen/query_codegen.py index c2b825f989..f84207cd43 100644 --- a/strawberry/codegen/query_codegen.py +++ b/strawberry/codegen/query_codegen.py @@ -418,6 +418,8 @@ def _collect_type_from_variable( def _field_from_selection( self, selection: FieldNode, parent_type: TypeDefinition ) -> GraphQLField: + if selection.name.value == "__typename": + return GraphQLField("__typename", None, GraphQLScalar("String", None)) field = self.schema.get_field_for_type(selection.name.value, parent_type.name) assert field diff --git a/tests/codegen/queries/union_with_typename.graphql b/tests/codegen/queries/union_with_typename.graphql new file mode 100644 index 0000000000..b89bf55245 --- /dev/null +++ b/tests/codegen/queries/union_with_typename.graphql @@ -0,0 +1,19 @@ +query OperationName { + __typename + union { + ... on Animal { + age + } + ... on Person { + name + } + } + optionalUnion { + ... on Animal { + age + } + ... on Person { + name + } + } +} diff --git a/tests/codegen/snapshots/python/union_with_typename.py b/tests/codegen/snapshots/python/union_with_typename.py new file mode 100644 index 0000000000..4c92cf5184 --- /dev/null +++ b/tests/codegen/snapshots/python/union_with_typename.py @@ -0,0 +1,21 @@ +from typing import Optional, Union + +class OperationNameResultUnionAnimal: + age: int + +class OperationNameResultUnionPerson: + name: str + +OperationNameResultUnion = Union[OperationNameResultUnionAnimal, OperationNameResultUnionPerson] + +class OperationNameResultOptionalUnionAnimal: + age: int + +class OperationNameResultOptionalUnionPerson: + name: str + +OperationNameResultOptionalUnion = Union[OperationNameResultOptionalUnionAnimal, OperationNameResultOptionalUnionPerson] + +class OperationNameResult: + union: OperationNameResultUnion + optional_union: Optional[OperationNameResultOptionalUnion] diff --git a/tests/codegen/snapshots/typescript/union_with_typename.ts b/tests/codegen/snapshots/typescript/union_with_typename.ts new file mode 100644 index 0000000000..aef8a9a93f --- /dev/null +++ b/tests/codegen/snapshots/typescript/union_with_typename.ts @@ -0,0 +1,25 @@ +type OperationNameResultUnionAnimal = { + age: number +} + +type OperationNameResultUnionPerson = { + name: string +} + +type OperationNameResultUnion = OperationNameResultUnionAnimal | OperationNameResultUnionPerson + +type OperationNameResultOptionalUnionAnimal = { + age: number +} + +type OperationNameResultOptionalUnionPerson = { + name: string +} + +type OperationNameResultOptionalUnion = OperationNameResultOptionalUnionAnimal | OperationNameResultOptionalUnionPerson + +type OperationNameResult = { + __typename: string + union: OperationNameResultUnion + optional_union: OperationNameResultOptionalUnion | undefined +} From b4c53b9408ee0e5890a62012ae6faf1a2da8f468 Mon Sep 17 00:00:00 2001 From: Botberry Date: Wed, 31 May 2023 19:59:06 +0000 Subject: [PATCH 013/119] =?UTF-8?q?Release=20=F0=9F=8D=93=200.178.3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 9 +++++++++ RELEASE.md | 4 ---- pyproject.toml | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) delete mode 100644 RELEASE.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 7cdbe99064..2a175e7dda 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,15 @@ CHANGELOG ========= +0.178.3 - 2023-05-31 +-------------------- + +In this release codegen no longer chokes on queries that have a `__typename` in them. +Python generated types will not have the `__typename` included in the fields. + +Contributed by [Matt Gilson](https://github.com/mgilson) via [PR #2797](https://github.com/strawberry-graphql/strawberry/pull/2797/) + + 0.178.2 - 2023-05-31 -------------------- diff --git a/RELEASE.md b/RELEASE.md deleted file mode 100644 index 1bf8656e7a..0000000000 --- a/RELEASE.md +++ /dev/null @@ -1,4 +0,0 @@ -Release type: patch - -In this release codegen no longer chokes on queries that have a `__typename` in them. -Python generated types will not have the `__typename` included in the fields. diff --git a/pyproject.toml b/pyproject.toml index bfc8e5b088..0e87bd6185 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "strawberry-graphql" packages = [ { include = "strawberry" } ] -version = "0.178.2" +version = "0.178.3" description = "A library for creating GraphQL APIs" authors = ["Patrick Arminio "] license = "MIT" From 79fdbd347924f1bcfaa405bd7e85066bdaf72579 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Wed, 31 May 2023 20:08:48 +0000 Subject: [PATCH 014/119] make graphql_ws test less timing-sensitive. (#2785) * make graphql_ws test less timing-sensitive. * Mark a timing dependent test as flaky --- tests/websockets/test_graphql_transport_ws.py | 3 +++ tests/websockets/test_graphql_ws.py | 16 ++++++++++------ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 78a3f3a0e3..17b9ca0af9 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -787,6 +787,9 @@ async def test_rejects_connection_params_not_unset(ws_raw: WebSocketClient): ws.assert_reason("Invalid connection init payload") +# timings can sometimes fail currently. Until this test is rewritten when +# generator based subscriptions are implemented, mark it as flaky +@pytest.mark.flaky async def test_subsciption_cancel_finalization_delay(ws: WebSocketClient): # Test that when we cancel a subscription, the websocket isn't blocked # while some complex finalization takes place. diff --git a/tests/websockets/test_graphql_ws.py b/tests/websockets/test_graphql_ws.py index 80cf07dced..6a07fb6ac4 100644 --- a/tests/websockets/test_graphql_ws.py +++ b/tests/websockets/test_graphql_ws.py @@ -119,13 +119,17 @@ async def test_sends_keep_alive(aiohttp_app_client: HttpClient): response = await ws.receive_json() assert response["type"] == GQL_CONNECTION_ACK - response = await ws.receive_json() - assert response["type"] == GQL_CONNECTION_KEEP_ALIVE - - response = await ws.receive_json() - assert response["type"] == GQL_CONNECTION_KEEP_ALIVE + # we can't be sure how many keep-alives exactly we + # get but they should be more than one. + keepalive_count = 0 + while True: + response = await ws.receive_json() + if response["type"] == GQL_CONNECTION_KEEP_ALIVE: + keepalive_count += 1 + else: + break + assert keepalive_count >= 1 - response = await ws.receive_json() assert response["type"] == GQL_DATA assert response["id"] == "demo" assert response["payload"]["data"] == {"echo": "Hi"} From 5bf44b345e06601b96cd5d4dd0013ebd7e5f445d Mon Sep 17 00:00:00 2001 From: Jonathan Kim Date: Wed, 31 May 2023 21:47:10 +0100 Subject: [PATCH 015/119] Allow passing metadata to Strawberry arguments (#2755) * Allow passing metadata to Strawberry arguments This is groundwork to allow implementation of a validation extensions that uses field and argument metadata to define validators. Example: ```python import strawberry @strawberry.type class Query: @strawberry.field def hello( self, info, input: Annotated[str, strawberry.argument(metadata={"test": "foo"})], ) -> str: argument_definition = info.get_argument_definition("input") assert argument_definition.metadata["test"] == "foo" return f"Hi {input}" ``` Diff-Id: 7439a * Add release notes Diff-Id: e1987 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- RELEASE.md | 23 ++++++++++ strawberry/arguments.py | 8 ++++ .../extensions/test_field_extensions.py | 46 ++++++++++++++++++- tests/schema/test_arguments.py | 33 +++++++++++++ 4 files changed, 109 insertions(+), 1 deletion(-) create mode 100644 RELEASE.md diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..a5759eb7ec --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,23 @@ +Release type: minor + +This PR allows passing metadata to Strawberry arguments. + +Example: + +```python +import strawberry + + +@strawberry.type +class Query: + @strawberry.field + def hello( + self, + info, + input: Annotated[str, strawberry.argument(metadata={"test": "foo"})], + ) -> str: + argument_definition = info.get_argument_definition("input") + assert argument_definition.metadata["test"] == "foo" + + return f"Hi {input}" +``` diff --git a/strawberry/arguments.py b/strawberry/arguments.py index e4526b739c..160174df66 100644 --- a/strawberry/arguments.py +++ b/strawberry/arguments.py @@ -46,6 +46,7 @@ class StrawberryArgumentAnnotation: name: Optional[str] deprecation_reason: Optional[str] directives: Iterable[object] + metadata: Mapping[Any, Any] def __init__( self, @@ -53,11 +54,13 @@ def __init__( name: Optional[str] = None, deprecation_reason: Optional[str] = None, directives: Iterable[object] = (), + metadata: Optional[Mapping[Any, Any]] = None, ): self.description = description self.name = name self.deprecation_reason = deprecation_reason self.directives = directives + self.metadata = metadata or {} class StrawberryArgument: @@ -71,6 +74,7 @@ def __init__( default: object = _deprecated_UNSET, deprecation_reason: Optional[str] = None, directives: Iterable[object] = (), + metadata: Optional[Mapping[Any, Any]] = None, ) -> None: self.python_name = python_name self.graphql_name = graphql_name @@ -80,6 +84,7 @@ def __init__( self.type_annotation = type_annotation self.deprecation_reason = deprecation_reason self.directives = directives + self.metadata = metadata or {} # TODO: Consider moving this logic to a function self.default = ( @@ -121,6 +126,7 @@ def _parse_annotated(self): self.graphql_name = arg.name self.deprecation_reason = arg.deprecation_reason self.directives = arg.directives + self.metadata = arg.metadata if isinstance(arg, StrawberryLazyReference): self.type_annotation = StrawberryAnnotation( @@ -224,12 +230,14 @@ def argument( name: Optional[str] = None, deprecation_reason: Optional[str] = None, directives: Iterable[object] = (), + metadata: Optional[Mapping[Any, Any]] = None, ) -> StrawberryArgumentAnnotation: return StrawberryArgumentAnnotation( description=description, name=name, deprecation_reason=deprecation_reason, directives=directives, + metadata=metadata, ) diff --git a/tests/schema/extensions/test_field_extensions.py b/tests/schema/extensions/test_field_extensions.py index c5a3e61256..e16ae28f68 100644 --- a/tests/schema/extensions/test_field_extensions.py +++ b/tests/schema/extensions/test_field_extensions.py @@ -1,5 +1,6 @@ import re -from typing import Any, Callable +from typing import Any, Callable, Optional +from typing_extensions import Annotated import pytest @@ -290,3 +291,46 @@ def string(self, some_input: int) -> str: result = schema.execute_sync(query) assert result.data, result.errors assert result.data["string"] == "This is a test!! 13" + + +def test_extension_access_argument_metadata(): + field_kwargs = {} + argument_metadata = {} + + class CustomExtension(FieldExtension): + def resolve(self, next_: Callable[..., Any], source: Any, info: Info, **kwargs): + nonlocal field_kwargs + field_kwargs = kwargs + + for key in kwargs: + argument_def = info.get_argument_definition(key) + assert argument_def is not None + argument_metadata[key] = argument_def.metadata + + result = next_(source, info, **kwargs) + return result + + @strawberry.type + class Query: + @strawberry.field(extensions=[CustomExtension()]) + def string( + self, + some_input: Annotated[str, strawberry.argument(metadata={"test": "foo"})], + another_input: Optional[str] = None, + ) -> str: + return f"This is a test!! {some_input}" + + schema = strawberry.Schema(query=Query) + query = 'query { string(someInput: "foo") }' + + result = schema.execute_sync(query) + assert result.data, result.errors + assert result.data["string"] == "This is a test!! foo" + + assert isinstance(field_kwargs["some_input"], str) + assert argument_metadata == { + "some_input": { + "test": "foo", + }, + "another_input": {}, + } diff --git a/tests/schema/test_arguments.py b/tests/schema/test_arguments.py index 984b4c6d8d..0497c8ddda 100644 --- a/tests/schema/test_arguments.py +++ b/tests/schema/test_arguments.py @@ -166,3 +166,36 @@ def hello(self, input: TestInput) -> str: ) assert not result.errors assert result.data == {"hello": "Hi there"} + + +def test_setting_metadata_on_argument(): + field_definition = None + + @strawberry.type + class Query: + @strawberry.field + def hello( + self, + info, + input: Annotated[str, strawberry.argument(metadata={"test": "foo"})], + ) -> str: + nonlocal field_definition + field_definition = info._field + return f"Hi {input}" + + schema = strawberry.Schema(query=Query) + + result = schema.execute_sync( + """ + query { + hello(input: "there") + } + """ + ) + assert not result.errors + assert result.data == {"hello": "Hi there"} + + assert field_definition + assert field_definition.arguments[0].metadata == { + "test": "foo", + } From 0c6cb6540faac4c432068674ab74f90b1235b5cb Mon Sep 17 00:00:00 2001 From: Botberry Date: Wed, 31 May 2023 20:48:21 +0000 Subject: [PATCH 016/119] =?UTF-8?q?Release=20=F0=9F=8D=93=200.179.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 28 ++++++++++++++++++++++++++++ RELEASE.md | 23 ----------------------- pyproject.toml | 2 +- 3 files changed, 29 insertions(+), 24 deletions(-) delete mode 100644 RELEASE.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 2a175e7dda..507739ad14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,34 @@ CHANGELOG ========= +0.179.0 - 2023-05-31 +-------------------- + +This PR allows passing metadata to Strawberry arguments. + +Example: + +```python +import strawberry + + +@strawberry.type +class Query: + @strawberry.field + def hello( + self, + info, + input: Annotated[str, strawberry.argument(metadata={"test": "foo"})], + ) -> str: + argument_definition = info.get_argument_definition("input") + assert argument_definition.metadata["test"] == "foo" + + return f"Hi {input}" +``` + +Contributed by [Jonathan Kim](https://github.com/jkimbo) via [PR #2755](https://github.com/strawberry-graphql/strawberry/pull/2755/) + + 0.178.3 - 2023-05-31 -------------------- diff --git a/RELEASE.md b/RELEASE.md deleted file mode 100644 index a5759eb7ec..0000000000 --- a/RELEASE.md +++ /dev/null @@ -1,23 +0,0 @@ -Release type: minor - -This PR allows passing metadata to Strawberry arguments. - -Example: - -```python -import strawberry - - -@strawberry.type -class Query: - @strawberry.field - def hello( - self, - info, - input: Annotated[str, strawberry.argument(metadata={"test": "foo"})], - ) -> str: - argument_definition = info.get_argument_definition("input") - assert argument_definition.metadata["test"] == "foo" - - return f"Hi {input}" -``` diff --git a/pyproject.toml b/pyproject.toml index 0e87bd6185..b2e8f16f2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "strawberry-graphql" packages = [ { include = "strawberry" } ] -version = "0.178.3" +version = "0.179.0" description = "A library for creating GraphQL APIs" authors = ["Patrick Arminio "] license = "MIT" From a46f9cb210921662727dcd7933b998855fcd1229 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20Dr=C3=B6ge?= Date: Wed, 31 May 2023 23:19:51 +0200 Subject: [PATCH 017/119] Update Channels integration (#2775) * Wip channels view integration * Wip channels integration * Allow to query with variables in the test channels client * Channels: Set correct status code when creating a response * Channels: Fix post() of test client * Channels: Use byte encoded headers and body * Channels: Handle json in post requests correctly when using the test client * Channels: Fix context and root value tests * Channels: Fix setting and sending headers * Channels: Fix test for custom process_result method * Channels: Fix graphiql tests * Channels: Add support for file upload and fix file upload tests * Channels: Fix websocket handlers * Channels: Fix type annotations * Channels: Replace Any annotations with correct types * Channels: Rename TestGraphQLHTTPConsumer to DebuggableGraphQLHTTPConsumer * Channels: Remove unused exceptions * Channels: Simplify response headers handling * Channels: Migrate SyncGraphQLHTTPConsumer * Channels: Cleanup tests and fix flakyness of SyncChannelsHttpClient * Channels: Fix types * Channels: Remove unused StrawberryChannelsContext * Channels: Simplify create_response * Channels: Remove redundant definitions in SyncGraphQLHTTPConsumer * Channels: Set request.consumer correctly and expose the consumer in info.context["ws"] (for ws connections) * Channels: Update file upload documentation * Channels: Use same context in SyncGraphQLHTTPConsumer and GraphQLHTTPConsumer * Channels: Update Channels integration documentation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Channels: Ignore coverage of get_root_value from base class * Channels: Expose ChannelsConsumer, GraphQLHTTPConsumer, ChannelsWSConsumer in strawberry.channels * Channels: Update documentation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Channels: Cleanup documentation * Channels: Cleanup test client * Update breaking changes --------- Co-authored-by: Patrick Arminio Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- RELEASE.md | 28 ++ docs/breaking-changes.md | 1 + docs/breaking-changes/0.189.0.md | 21 + docs/guides/file-upload.md | 3 +- docs/integrations/channels.md | 158 +++++++- strawberry/channels/__init__.py | 13 +- strawberry/channels/context.py | 22 -- strawberry/channels/handlers/base.py | 21 - strawberry/channels/handlers/http_handler.py | 389 +++++++++++-------- strawberry/channels/handlers/ws_handler.py | 15 +- tests/channels/test_http_handler.py | 296 -------------- tests/http/clients/__init__.py | 3 +- tests/http/clients/channels.py | 187 ++++++++- tests/http/conftest.py | 9 + tests/http/test_upload.py | 10 +- tests/views/schema.py | 7 +- 16 files changed, 635 insertions(+), 548 deletions(-) create mode 100644 RELEASE.md create mode 100644 docs/breaking-changes/0.189.0.md delete mode 100644 strawberry/channels/context.py delete mode 100644 tests/channels/test_http_handler.py diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..fd679d7867 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,28 @@ +Release type: minor + +This release updates the Django Channels integration so that it uses the same base +classes used by all other integrations. + +**New features:** + +The Django Channels integration supports two new features: + +* Setting headers in a response +* File uploads via `multipart/form-data` POST requests + +**Breaking changes:** + +This release contains a breaking change for the Channels integration. The context +object is now a `dict` and it contains different keys depending on the connection +protocol: + +1. HTTP: `request` and `response`. The `request` object contains the full + request (including the body). Previously, `request` was the `GraphQLHTTPConsumer` + instance of the current connection. The consumer is now available via + `request.consumer`. +2. WebSockets: `request`, `ws` and `response`. `request` and `ws` are the same + `GraphQLWSConsumer` instance of the current connection. + +If you want to use a dataclass for the context object (like in previous releases), +you can still use them by overriding the `get_context` methods. See the Channels +integration documentation for an example. diff --git a/docs/breaking-changes.md b/docs/breaking-changes.md index 8f25b83dda..87579fe6b0 100644 --- a/docs/breaking-changes.md +++ b/docs/breaking-changes.md @@ -4,6 +4,7 @@ title: List of breaking changes # List of breaking changes +- [Version 0.180.0 - 31 May 2023](./breaking-changes/0.180.0.md) - [Version 0.169.0 - 5 April 2023](./breaking-changes/0.169.0.md) - [Version 0.159.0 - 22 February 2023](./breaking-changes/0.159.0.md) - [Version 0.146.0 - 5 December 2022](./breaking-changes/0.146.0.md) diff --git a/docs/breaking-changes/0.189.0.md b/docs/breaking-changes/0.189.0.md new file mode 100644 index 0000000000..59af4f8aae --- /dev/null +++ b/docs/breaking-changes/0.189.0.md @@ -0,0 +1,21 @@ +--- +title: 0.180.0 Breaking cahnges +--- + +# v0.180.0 introduces a breaking change for the Django Channels HTTP integration + +The context object is now a `dict`. This means that you should access the context +value using the `["key"]` syntax instead of the `.key` syntax. + +For the HTTP integration, there is also no `ws` key anymore and `request` is a custom +request object containing the full request instead of a `GraphQLHTTPConsumer` instance. +If you need to access the `GraphQLHTTPConsumer` instance in a HTTP connection, you can +access it via `info.context["request"].consumer`. + +For the WebSockets integration, the context keys did not change, e.g. the values for +`info.context["ws"]`, `info.context["request"]` and `info.context["connection_params"]` +are the same as before. + +If you still want to use the `.key` syntax, you can override `get_context()` +to return a custom dataclass there. See the Channels integration documentation +for an example. diff --git a/docs/guides/file-upload.md b/docs/guides/file-upload.md index ed042d981c..20abe15699 100644 --- a/docs/guides/file-upload.md +++ b/docs/guides/file-upload.md @@ -15,6 +15,7 @@ The type passed at runtime depends on the integration: | ----------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------- | | [AIOHTTP](/docs/integrations/aiohttp) | [`io.BytesIO`](https://docs.python.org/3/library/io.html#io.BytesIO) | | [ASGI](/docs/integrations/asgi) | [`starlette.datastructures.UploadFile`](https://www.starlette.io/requests/#request-files) | +| [Channels](/docs/integrations/channels) | [`django.core.files.uploadedfile.UploadedFile`](https://docs.djangoproject.com/en/3.2/ref/files/uploads/#django.core.files.uploadedfile.UploadedFile) | | [Django](/docs/integrations/django) | [`django.core.files.uploadedfile.UploadedFile`](https://docs.djangoproject.com/en/3.2/ref/files/uploads/#django.core.files.uploadedfile.UploadedFile) | | [FastAPI](/docs/integrations/fastapi) | [`fastapi.UploadFile`](https://fastapi.tiangolo.com/tutorial/request-files/#file-parameters-with-uploadfile) | | [Flask](/docs/integrations/flask) | [`werkzeug.datastructures.FileStorage`](https://werkzeug.palletsprojects.com/en/2.0.x/datastructures/#werkzeug.datastructures.FileStorage) | @@ -66,7 +67,7 @@ class Mutation: return contents ``` -## Sanic / Flask / Django / AIOHTTP +## Sanic / Flask / Django / Channels / AIOHTTP Example: diff --git a/docs/integrations/channels.md b/docs/integrations/channels.md index cf401a4beb..ab3aa1cf49 100644 --- a/docs/integrations/channels.md +++ b/docs/integrations/channels.md @@ -56,7 +56,7 @@ pip install 'strawberry-graphql[channels]' _The following example will pick up where the Channels tutorials left off._ -By the end of This tutorial, You will have a graphql chat subscription that will +By the end of this tutorial, You will have a graphql chat subscription that will be able to talk with the channels chat consumer from the tutorial. ### Types setup @@ -116,7 +116,7 @@ class Subscription: user: str, ) -> AsyncGenerator[ChatRoomMessage, None]: """Join and subscribe to message sent to the given rooms.""" - ws = info.context.ws + ws = info.context["ws"] channel_layer = ws.channel_layer room_ids = [f"chat_{room.room_name}" for room in rooms] @@ -145,7 +145,7 @@ class Subscription: ) ``` -Explanation: `Info.context.ws` or `Info.context.request` is a pointer to the +Explanation: `Info.context["ws"]` or `Info.context["request"]` is a pointer to the [`ChannelsConsumer`](#channelsconsumer) instance. Here we have first sent a message to all the channel_layer groups (specified in the subscription argument `rooms`) that we have joined the chat. @@ -191,7 +191,7 @@ class Mutation: room: ChatRoom, message: str, ) -> None: - ws = info.context.ws + ws = info.context["ws"] channel_layer = ws.channel_layer await channel_layer.group_send( @@ -443,6 +443,147 @@ def test_send_message_via_channels_chat_joinChatRooms_recieves(self): --- +The HTTP and WebSockets protocol are handled by different base classes. HTTP uses +`GraphQLHTTPConsumer` and WebSockets uses `GraphQLWSConsumer`. Both of them can +be extended: + +## GraphQLHTTPConsumer (HTTP) + +### Options + +`GraphQLHTTPConsumer` supports the same options as all other integrations: + +- `schema`: mandatory, the schema created by `strawberry.Schema`. +- `graphiql`: optional, defaults to `True`, whether to enable the GraphiQL + interface. +- `allow_queries_via_get`: optional, defaults to `True`, whether to enable + queries via `GET` requests +- `subscriptions_enabled`: optional boolean paramenter enabling subscriptions in + the GraphiQL interface, defaults to `True` + +### Extending the consumer + +We allow to extend `GraphQLHTTPConsumer`, by overriding the following methods: + +- `async def get_context(self, request: ChannelsRequest, response: TemporalResponse) -> Context` +- `async def get_root_value(self, request: ChannelsRequest) -> Optional[RootValue]` +- `async def process_result(self, request: Request, result: ExecutionResult) -> GraphQLHTTPResponse:`. + +### Context + +The default context returned by `get_context()` is a `dict` that includes the following keys by default: + +- `request`: A `ChannelsRequest` object with the following fields and methods: + - `consumer`: The `GraphQLHTTPConsumer` instance for this connection + - `body`: The request body + - `headers`: A dict containing the headers of the request + - `method`: The HTTP method of the request + - `content_type`: The content type of the request +- `response` A `TemporalResponse` object, that can be used to influence the HTTP response: + - `status_code`: The status code of the response, if there are no execution errors (defaults to `200`) + - `headers`: Any additional headers that should be send with the response + +## GraphQLWSConsumer (WebSockets / Subscriptions) + +### Options + +- `schema`: mandatory, the schema created by `strawberry.Schema`. +- `debug`: optional, defaults to `False`, whether to enable debug mode. +- `keep_alive`: optional, defaults to `False`, whether to enable keep alive mode + for websockets. +- `keep_alive_interval`: optional, defaults to `1`, the interval in seconds for + keep alive messages. + +### Extending the consumer + +We allow to extend `GraphQLWSConsumer`, by overriding the following methods: + +- `async def get_context(self, request: ChannelsConsumer, connection_params: Any) -> Context` +- `async def get_root_value(self, request: ChannelsConsumer) -> Optional[RootValue]` + +### Context + +The default context returned by `get_context()` is a `dict` and it includes the following keys by default: + +- `request`: The `GraphQLWSConsumer` instance of the current connection. It can be used to access the connection + scope, e.g. `info.context["ws"].headers` allows access to any headers. +- `ws`: The same as `request` +- `connection_params`: Any `connection_params`, see [Authenticating Subscriptions](/docs/general/subscriptions#authenticating-subscriptions) + +## Example for defining a custom context + +Here is an example for extending the base classes to offer a different context object in your resolvers. +For the HTTP integration, you can also have properties to access the current user and the +session. Both properties depend on the `AuthMiddlewareStack` wrapper. + +```python +from django.contrib.auth.models import AnonymousUser + +from strawberry.channels import ChannelsConsumer, ChannelsRequest +from strawberry.channels import GraphQLHTTPConsumer as BaseGraphQLHTTPConsumer +from strawberry.channels import GraphQLWSConsumer as BaseGraphQLWSConsumer +from strawberry.http.temporal_response import TemporalResponse + + +@dataclass +class ChannelsContext: + request: ChannelsRequest + response: TemporalResponse + + @property + def user(self): + # Depends on Channels' AuthMiddlewareStack + if "user" in self.request.consumer.scope: + return self.request.consumer.scope["user"] + + return AnonymousUser() + + @property + def session(self): + # Depends on Channels' SessionMiddleware / AuthMiddlewareStack + if "session" in self.request.consumer.scope: + return self.request.consumer.scope["session"] + + return None + + +@dataclass +class ChannelsWSContext: + request: ChannelsConsumer + connection_params: Optional[Dict[str, Any]] = None + + @property + def ws(self) -> ChannelsConsumer: + return self.request + + +class GraphQLHTTPConsumer(BaseGraphQLHTTPConsumer): + @override + async def get_context( + self, request: ChannelsRequest, response: TemporalResponse + ) -> ChannelsContext: + return ChannelsContext( + request=request, + response=response, + ) + + +class GraphQLWSConsumer(BaseGraphQLWSConsumer): + @override + async def get_context( + self, request: ChannelsConsumer, connection_params: Any + ) -> ChannelsWSContext: + return ChannelsWSContext( + request=request, + connection_params=connection_params, + ) +``` + +You can import and use the extended `GraphQLHTTPConsumer` and `GraphQLWSConsumer` classes in your +`myproject.asgi.py` file as shown before. + +--- + ## API ### GraphQLProtocolTypeRouter @@ -472,17 +613,12 @@ everything else to the Django application. ### ChannelsConsumer -Strawberries extended -[`AsyncConsumer`](https://channels.readthedocs.io/en/stable/topics/consumers.html#consumers). +Strawberries extended [`AsyncConsumer`](https://channels.readthedocs.io/en/stable/topics/consumers.html#consumers). -#### \*\*Every graphql session will have an instance of this class inside - -`info.ws` which is actually the `info.context.request`.\*\* +Every graphql session will have an instance of this class inside `info.context["ws"]` (WebSockets) or `info.context["request"].consumer` (HTTP). #### properties -- `ws.headers: dict` returns a map of the headers from `scope['headers']`. - ```python async def channel_listen( self, diff --git a/strawberry/channels/__init__.py b/strawberry/channels/__init__.py index bad0e48144..455513babb 100644 --- a/strawberry/channels/__init__.py +++ b/strawberry/channels/__init__.py @@ -1,15 +1,22 @@ -from .context import StrawberryChannelsContext +from .handlers.base import ChannelsConsumer, ChannelsWSConsumer from .handlers.graphql_transport_ws_handler import GraphQLTransportWSHandler from .handlers.graphql_ws_handler import GraphQLWSHandler -from .handlers.http_handler import GraphQLHTTPConsumer +from .handlers.http_handler import ( + ChannelsRequest, + GraphQLHTTPConsumer, + SyncGraphQLHTTPConsumer, +) from .handlers.ws_handler import GraphQLWSConsumer from .router import GraphQLProtocolTypeRouter __all__ = [ + "ChannelsConsumer", + "ChannelsRequest", + "ChannelsWSConsumer", "GraphQLProtocolTypeRouter", "GraphQLWSHandler", "GraphQLTransportWSHandler", "GraphQLHTTPConsumer", "GraphQLWSConsumer", - "StrawberryChannelsContext", + "SyncGraphQLHTTPConsumer", ] diff --git a/strawberry/channels/context.py b/strawberry/channels/context.py deleted file mode 100644 index 188f8b5647..0000000000 --- a/strawberry/channels/context.py +++ /dev/null @@ -1,22 +0,0 @@ -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, Optional - -if TYPE_CHECKING: - from strawberry.channels.handlers.base import ChannelsConsumer - - -@dataclass -class StrawberryChannelsContext: - """ - A Channels context for GraphQL - """ - - request: "ChannelsConsumer" - connection_params: Optional[Dict[str, Any]] = None - - @property - def ws(self) -> "ChannelsConsumer": - return self.request - - def __getitem__(self, item: str) -> Any: - return getattr(self, item) diff --git a/strawberry/channels/handlers/base.py b/strawberry/channels/handlers/base.py index b21eaeb12b..457b1d8d36 100644 --- a/strawberry/channels/handlers/base.py +++ b/strawberry/channels/handlers/base.py @@ -7,7 +7,6 @@ Awaitable, Callable, DefaultDict, - Dict, List, Optional, Sequence, @@ -17,7 +16,6 @@ from channels.consumer import AsyncConsumer from channels.generic.websocket import AsyncJsonWebsocketConsumer -from strawberry.channels.context import StrawberryChannelsContext class ChannelsMessage(TypedDict, total=False): @@ -75,25 +73,6 @@ def __init__(self, *args: str, **kwargs: Any): ) super().__init__(*args, **kwargs) - @property - def headers(self) -> Dict[str, str]: - return { - header_name.decode().lower(): header_value.decode() - for header_name, header_value in self.scope["headers"] - } - - async def get_root_value(self, request: Optional["ChannelsConsumer"] = None) -> Any: - return None - - async def get_context( - self, - request: Optional["ChannelsConsumer"] = None, - connection_params: Optional[Dict[str, Any]] = None, - ) -> StrawberryChannelsContext: - return StrawberryChannelsContext( - request=request or self, connection_params=connection_params - ) - async def dispatch(self, message: ChannelsMessage) -> None: # AsyncConsumer will try to get a function for message["type"] to handle # for both http/websocket types and also for layers communication. diff --git a/strawberry/channels/handlers/http_handler.py b/strawberry/channels/handlers/http_handler.py index 26e4a5bf02..b0b7477aa3 100644 --- a/strawberry/channels/handlers/http_handler.py +++ b/strawberry/channels/handlers/http_handler.py @@ -6,67 +6,143 @@ import dataclasses import json -from typing import TYPE_CHECKING, Any, Optional +from io import BytesIO +from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union from urllib.parse import parse_qs +from django.conf import settings +from django.core.files import uploadhandler +from django.http.multipartparser import MultiPartParser + from channels.db import database_sync_to_async from channels.generic.http import AsyncHttpConsumer -from strawberry.channels.context import StrawberryChannelsContext -from strawberry.exceptions import MissingQueryError -from strawberry.http import ( - parse_query_params, - parse_request_data, - process_result, -) -from strawberry.schema.exceptions import InvalidOperationTypeError -from strawberry.types.graphql import OperationType +from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncHTTPRequestAdapter +from strawberry.http.exceptions import HTTPException +from strawberry.http.sync_base_view import SyncBaseHTTPView, SyncHTTPRequestAdapter +from strawberry.http.temporal_response import TemporalResponse +from strawberry.http.types import FormData +from strawberry.http.typevars import Context, RootValue +from strawberry.unset import UNSET +from strawberry.utils.cached_property import cached_property from strawberry.utils.graphiql import get_graphiql_html from .base import ChannelsConsumer if TYPE_CHECKING: - from strawberry.http import GraphQLHTTPResponse, GraphQLRequestData + from strawberry.http import GraphQLHTTPResponse + from strawberry.http.types import HTTPMethod, QueryParams from strawberry.schema import BaseSchema - from strawberry.types import ExecutionResult -class MethodNotAllowed(Exception): - ... +@dataclasses.dataclass +class ChannelsResponse: + content: bytes + status: int = 200 + content_type: str = "application/json" + headers: Dict[bytes, bytes] = dataclasses.field(default_factory=dict) -class ExecutionError(Exception): - ... +@dataclasses.dataclass +class ChannelsRequest: + consumer: ChannelsConsumer + body: bytes + + @property + def query_params(self) -> QueryParams: + query_params_str = self.consumer.scope["query_string"].decode() + + query_params = {} + for key, value in parse_qs(query_params_str, keep_blank_values=True).items(): + # Only one argument per key is expected here + query_params[key] = value[0] + + return query_params + + @property + def headers(self) -> Mapping[str, str]: + return { + header_name.decode().lower(): header_value.decode() + for header_name, header_value in self.consumer.scope["headers"] + } + + @property + def method(self) -> HTTPMethod: + return self.consumer.scope["method"].upper() + + @property + def content_type(self) -> Optional[str]: + return self.headers.get("content-type", None) + + @cached_property + def form_data(self) -> FormData: + upload_handlers = [ + uploadhandler.load_handler(handler) + for handler in settings.FILE_UPLOAD_HANDLERS + ] + + parser = MultiPartParser( + { + "CONTENT_TYPE": self.headers.get("content-type"), + "CONTENT_LENGTH": self.headers.get("content-length", "0"), + }, + BytesIO(self.body), + upload_handlers, + ) + querydict, files = parser.parse() -@dataclasses.dataclass -class Result: - response: bytes - status: int = 200 - content_type: str = "application/json" + form = { + "operations": querydict.get("operations", "{}"), + "map": querydict.get("map", "{}"), + } + return FormData(files=files, form=form) -class GraphQLHTTPConsumer(ChannelsConsumer, AsyncHttpConsumer): - """A consumer to provide a view for GraphQL over HTTP. - To use this, place it in your ProtocolTypeRouter for your channels project: +class BaseChannelsRequestAdapter: + def __init__(self, request: ChannelsRequest): + self.request = request - ``` - from strawberry.channels import GraphQLHttpRouter - from channels.routing import ProtocolTypeRouter - from django.core.asgi import get_asgi_application + @property + def query_params(self) -> QueryParams: + return self.request.query_params + + @property + def method(self) -> HTTPMethod: + return self.request.method + + @property + def headers(self) -> Mapping[str, str]: + return self.request.headers + + @property + def content_type(self) -> Optional[str]: + return self.request.content_type + + +class ChannelsRequestAdapter(BaseChannelsRequestAdapter, AsyncHTTPRequestAdapter): + async def get_body(self) -> bytes: + return self.request.body + + async def get_form_data(self) -> FormData: + return self.request.form_data - application = ProtocolTypeRouter({ - "http": URLRouter([ - re_path("^graphql", GraphQLHTTPRouter(schema=schema)), - re_path("^", get_asgi_application()), - ]), - "websocket": URLRouter([ - re_path("^ws/graphql", GraphQLWebSocketRouter(schema=schema)), - ]), - }) - ``` - """ +class SyncChannelsRequestAdapter(BaseChannelsRequestAdapter, SyncHTTPRequestAdapter): + @property + def body(self) -> bytes: + return self.request.body + + @property + def post_data(self) -> Mapping[str, Union[str, bytes]]: + return self.request.form_data["form"] + + @property + def files(self) -> Mapping[str, Any]: + return self.request.form_data["files"] + + +class BaseGraphQLHTTPConsumer(ChannelsConsumer, AsyncHttpConsumer): def __init__( self, schema: BaseSchema, @@ -81,156 +157,127 @@ def __init__( self.subscriptions_enabled = subscriptions_enabled super().__init__(**kwargs) - async def handle(self, body: bytes) -> None: - try: - if self.scope["method"] == "GET": - result = await self.get(body) - elif self.scope["method"] == "POST": - result = await self.post(body) - else: - raise MethodNotAllowed() - except MethodNotAllowed: - await self.send_response( - 405, - b"Method not allowed", - headers=[(b"Allow", b"GET, POST")], - ) - except InvalidOperationTypeError as e: - error_str = e.as_http_error_reason(self.scope["method"]) - await self.send_response( - 406, - error_str.encode(), - ) - except ExecutionError as e: - await self.send_response( - 500, - str(e).encode(), - ) - else: - await self.send_response( - result.status, - result.response, - headers=[(b"Content-Type", result.content_type.encode())], - ) - - async def get(self, body: bytes) -> Result: - if self.should_render_graphiql(): - return await self.render_graphiql(body) - elif self.scope.get("query_string"): - params = parse_query_params( - { - k: v[0] - for k, v in parse_qs(self.scope["query_string"].decode()).items() - } - ) - - try: - result = await self.execute(parse_request_data(params)) - except MissingQueryError as e: - raise ExecutionError("No GraphQL query found in the request") from e - - return Result(response=json.dumps(result).encode()) - else: - raise MethodNotAllowed() - - async def post(self, body: bytes) -> Result: - request_data = await self.parse_body(body) + def render_graphiql(self, request: ChannelsRequest) -> ChannelsResponse: + html = get_graphiql_html(self.subscriptions_enabled) + return ChannelsResponse(content=html.encode(), content_type="text/html") + + def create_response( + self, response_data: GraphQLHTTPResponse, sub_response: TemporalResponse + ) -> ChannelsResponse: + return ChannelsResponse( + content=json.dumps(response_data).encode(), + status=sub_response.status_code, + headers={k.encode(): v.encode() for k, v in sub_response.headers.items()}, + ) + async def handle(self, body: bytes) -> None: + request = ChannelsRequest(consumer=self, body=body) try: - result = await self.execute(request_data) - except MissingQueryError as e: - raise ExecutionError("No GraphQL query found in the request") from e - - return Result(response=json.dumps(result).encode()) - - async def parse_body(self, body: bytes) -> GraphQLRequestData: - if self.headers.get("content-type", "").startswith("multipart/form-data"): - return await self.parse_multipart_body(body) + response: ChannelsResponse = await self.run(request) - try: - data = json.loads(body) - except json.JSONDecodeError as e: - raise ExecutionError("Unable to parse request body as JSON") from e - - return parse_request_data(data) - - async def parse_multipart_body(self, body: bytes) -> GraphQLRequestData: - raise ExecutionError("Unable to parse the multipart body") - - async def execute(self, request_data: GraphQLRequestData) -> GraphQLHTTPResponse: - context = await self.get_context() - root_value = await self.get_root_value() - - method = self.scope["method"] - allowed_operation_types = OperationType.from_http(method) - if not self.allow_queries_via_get and method == "GET": - allowed_operation_types = allowed_operation_types - {OperationType.QUERY} - - result = await self.schema.execute( - query=request_data.query, - root_value=root_value, - variable_values=request_data.variables, - context_value=context, - operation_name=request_data.operation_name, - allowed_operation_types=allowed_operation_types, - ) - return await self.process_result(result) + if b"Content-Type" not in response.headers: + response.headers[b"Content-Type"] = response.content_type.encode() - async def process_result(self, result: ExecutionResult) -> GraphQLHTTPResponse: - return process_result(result) + await self.send_response( + response.status, + response.content, + headers=response.headers, + ) + except HTTPException as e: + await self.send_response(e.status_code, e.reason.encode()) + + +class GraphQLHTTPConsumer( + BaseGraphQLHTTPConsumer, + AsyncBaseHTTPView[ + ChannelsRequest, + ChannelsResponse, + TemporalResponse, + Context, + RootValue, + ], +): + """A consumer to provide a view for GraphQL over HTTP. - async def render_graphiql(self, body: bytes) -> Result: - html = get_graphiql_html(self.subscriptions_enabled) - return Result(response=html.encode(), content_type="text/html") + To use this, place it in your ProtocolTypeRouter for your channels project: - def should_render_graphiql(self) -> bool: - accept_list = self.headers.get("accept", "").split(",") - return self.graphiql and any( - accepted in accept_list for accepted in ["text/html", "*/*"] - ) + ``` + from strawberry.channels import GraphQLHttpRouter + from channels.routing import ProtocolTypeRouter + from django.core.asgi import get_asgi_application + application = ProtocolTypeRouter({ + "http": URLRouter([ + re_path("^graphql", GraphQLHTTPRouter(schema=schema)), + re_path("^", get_asgi_application()), + ]), + "websocket": URLRouter([ + re_path("^ws/graphql", GraphQLWebSocketRouter(schema=schema)), + ]), + }) + ``` + """ -class SyncGraphQLHTTPConsumer(GraphQLHTTPConsumer): + allow_queries_via_get: bool = True + request_adapter_class = ChannelsRequestAdapter + + async def get_root_value(self, request: ChannelsRequest) -> Optional[RootValue]: + return None # pragma: no cover + + async def get_context( + self, request: ChannelsRequest, response: TemporalResponse + ) -> Context: + return { + "request": request, + "response": response, + } # type: ignore + + async def get_sub_response(self, request: ChannelsRequest) -> TemporalResponse: + return TemporalResponse() + + +class SyncGraphQLHTTPConsumer( + BaseGraphQLHTTPConsumer, + SyncBaseHTTPView[ + ChannelsRequest, + ChannelsResponse, + TemporalResponse, + Context, + RootValue, + ], +): """Synchronous version of the HTTPConsumer. This is the same as `GraphQLHTTPConsumer`, but it can be used with - synchronous schemas (i.e. the schema's resolvers are espected to be + synchronous schemas (i.e. the schema's resolvers are expected to be synchronous and not asynchronous). """ - def get_root_value(self, request: Optional[ChannelsConsumer] = None) -> Any: - return None + allow_queries_via_get: bool = True + request_adapter_class = SyncChannelsRequestAdapter - def get_context( # type: ignore[override] - self, - request: Optional[ChannelsConsumer] = None, - ) -> StrawberryChannelsContext: - return StrawberryChannelsContext(request=request or self) + def get_root_value(self, request: ChannelsRequest) -> Optional[RootValue]: + return None # pragma: no cover + + def get_context( + self, request: ChannelsRequest, response: TemporalResponse + ) -> Context: + return { + "request": request, + "response": response, + } # type: ignore - def process_result( # type:ignore [override] - self, result: ExecutionResult - ) -> GraphQLHTTPResponse: - return process_result(result) + def get_sub_response(self, request: ChannelsRequest) -> TemporalResponse: + return TemporalResponse() # Sync channels is actually async, but it uses database_sync_to_async to call # handlers in a threadpool. Check SyncConsumer's documentation for more info: # https://github.com/django/channels/blob/main/channels/consumer.py#L104 @database_sync_to_async - def execute(self, request_data: GraphQLRequestData) -> GraphQLHTTPResponse: - context = self.get_context(self) - root_value = self.get_root_value(self) - - method = self.scope["method"] - allowed_operation_types = OperationType.from_http(method) - if not self.allow_queries_via_get and method == "GET": - allowed_operation_types = allowed_operation_types - {OperationType.QUERY} - - result = self.schema.execute_sync( - query=request_data.query, - root_value=root_value, - variable_values=request_data.variables, - context_value=context, - operation_name=request_data.operation_name, - allowed_operation_types=allowed_operation_types, - ) - return self.process_result(result) + def run( + self, + request: ChannelsRequest, + context: Optional[Context] = UNSET, + root_value: Optional[RootValue] = UNSET, + ) -> ChannelsResponse: + return super().run(request, context, root_value) diff --git a/strawberry/channels/handlers/ws_handler.py b/strawberry/channels/handlers/ws_handler.py index d4fae0d22c..1adf21e5ea 100644 --- a/strawberry/channels/handlers/ws_handler.py +++ b/strawberry/channels/handlers/ws_handler.py @@ -5,11 +5,12 @@ from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL -from .base import ChannelsWSConsumer +from .base import ChannelsConsumer, ChannelsWSConsumer from .graphql_transport_ws_handler import GraphQLTransportWSHandler from .graphql_ws_handler import GraphQLWSHandler if TYPE_CHECKING: + from strawberry.http.typevars import Context, RootValue from strawberry.schema import BaseSchema @@ -113,3 +114,15 @@ async def receive_json(self, content: Any, **kwargs: Any) -> None: async def disconnect(self, code: int) -> None: await self._handler.handle_disconnect(code) + + async def get_root_value(self, request: ChannelsConsumer) -> Optional[RootValue]: + return None + + async def get_context( + self, request: ChannelsConsumer, connection_params: Any + ) -> Context: + return { + "request": request, + "connection_params": connection_params, + "ws": request, + } # type: ignore diff --git a/tests/channels/test_http_handler.py b/tests/channels/test_http_handler.py deleted file mode 100644 index fee37b878f..0000000000 --- a/tests/channels/test_http_handler.py +++ /dev/null @@ -1,296 +0,0 @@ -import json -from typing import Any, Dict, Optional - -import pytest - -from channels.testing import HttpCommunicator -from strawberry.channels import GraphQLHTTPConsumer -from strawberry.channels.handlers.http_handler import SyncGraphQLHTTPConsumer -from tests.views.schema import schema - -pytestmark = pytest.mark.xfail( - reason=( - "Some of these tests seems to crash due to usage of database_sync_to_async" - ), -) - - -def generate_body(query: str, variables: Optional[Dict[str, Any]] = None) -> bytes: - body: Dict[str, Any] = {"query": query} - if variables is not None: - body["variables"] = variables - - return json.dumps(body).encode() - - -def generate_get_path( - path, query: str, variables: Optional[Dict[str, Any]] = None -) -> str: - body: Dict[str, Any] = {"query": query} - if variables is not None: - body["variables"] = json.dumps(variables) - - parts = [f"{k}={v}" for k, v in body.items()] - return f"{path}?{'&'.join(parts)}" - - -def assert_response( - response: Dict[str, Any], expected: Any, errors: Optional[Any] = None -): - assert response["status"] == 200 - body = json.loads(response["body"]) - assert "errors" not in body - assert body["data"] == expected - - -@pytest.mark.parametrize("consumer", [GraphQLHTTPConsumer, SyncGraphQLHTTPConsumer]) -async def test_graphiql_view(consumer): - client = HttpCommunicator( - consumer.as_asgi(schema=schema), - "GET", - "/graphql", - headers=[(b"accept", b"text/html")], - ) - response = await client.get_response() - assert response["headers"] == [(b"Content-Type", b"text/html")] - assert response["status"] == 200 - assert b"GraphiQL" in response["body"] - - -@pytest.mark.parametrize("consumer", [GraphQLHTTPConsumer, SyncGraphQLHTTPConsumer]) -async def test_graphiql_view_disabled(consumer): - client = HttpCommunicator( - consumer.as_asgi(schema=schema, graphiql=False), - "GET", - "/graphql", - headers=[(b"accept", b"text/html")], - ) - response = await client.get_response() - assert response == { - "headers": [(b"Allow", b"GET, POST")], - "status": 405, - "body": b"Method not allowed", - } - - -@pytest.mark.parametrize("consumer", [GraphQLHTTPConsumer, SyncGraphQLHTTPConsumer]) -async def test_graphiql_view_not_allowed(consumer): - client = HttpCommunicator( - consumer.as_asgi(schema=schema), - "GET", - "/graphql", - ) - response = await client.get_response() - assert response == { - "headers": [(b"Allow", b"GET, POST")], - "status": 405, - "body": b"Method not allowed", - } - - -@pytest.mark.parametrize("consumer", [GraphQLHTTPConsumer, SyncGraphQLHTTPConsumer]) -@pytest.mark.parametrize("method", ["DELETE", "HEAD", "PUT", "PATCH"]) -async def test_disabled_methods(consumer, method: str): - client = HttpCommunicator( - consumer.as_asgi(schema=schema), - method, - "/graphql", - headers=[(b"accept", b"text/html")], - ) - response = await client.get_response() - assert response == { - "headers": [(b"Allow", b"GET, POST")], - "status": 405, - "body": b"Method not allowed", - } - - -@pytest.mark.parametrize("consumer", [GraphQLHTTPConsumer, SyncGraphQLHTTPConsumer]) -async def test_fails_on_multipart_body(consumer): - client = HttpCommunicator( - consumer.as_asgi(schema=schema), - "POST", - "/graphql", - body=generate_body("{ hello }"), - headers=[(b"content-type", b"multipart/form-data")], - ) - response = await client.get_response() - assert response == { - "status": 500, - "headers": [], - "body": b"Unable to parse the multipart body", - } - - -@pytest.mark.parametrize("consumer", [GraphQLHTTPConsumer, SyncGraphQLHTTPConsumer]) -@pytest.mark.parametrize("body", [b"{}", b'{"foo": "bar"}']) -async def test_fails_on_missing_query(consumer, body: bytes): - client = HttpCommunicator( - consumer.as_asgi(schema=schema), - "POST", - "/graphql", - body=body, - ) - response = await client.get_response() - assert response == { - "status": 500, - "headers": [], - "body": b"No GraphQL query found in the request", - } - - -@pytest.mark.parametrize("consumer", [GraphQLHTTPConsumer, SyncGraphQLHTTPConsumer]) -@pytest.mark.parametrize("body", [b"", b"definitely-not-json-string"]) -async def test_fails_on_invalid_query(consumer, body: bytes): - client = HttpCommunicator( - consumer.as_asgi(schema=schema), - "POST", - "/graphql", - body=body, - ) - response = await client.get_response() - assert response == { - "status": 500, - "headers": [], - "body": b"Unable to parse request body as JSON", - } - - -@pytest.mark.parametrize("consumer", [GraphQLHTTPConsumer, SyncGraphQLHTTPConsumer]) -async def test_graphql_post_query_fails_using_params(consumer): - client = HttpCommunicator( - consumer.as_asgi(schema=schema), - "GET", - "/graphql?foo=bar", - ) - response = await client.get_response() - assert response == { - "status": 500, - "headers": [], - "body": b"No GraphQL query found in the request", - } - - -# FIXME: All the tests bellow runs fine if running tests in this file only, -# but fail for Sync when running the whole testsuite, unless using. -# @pytest.mark.django_db. Probably because of the `database_sync_to_async`? - - -@pytest.mark.django_db -@pytest.mark.parametrize("consumer", [GraphQLHTTPConsumer, SyncGraphQLHTTPConsumer]) -async def test_graphql_query(consumer): - client = HttpCommunicator( - consumer.as_asgi(schema=schema), - "POST", - "/graphql", - body=generate_body("{ hello }"), - ) - assert_response( - await client.get_response(), - {"hello": "Hello world"}, - ) - - -@pytest.mark.django_db -@pytest.mark.parametrize("consumer", [GraphQLHTTPConsumer, SyncGraphQLHTTPConsumer]) -async def test_graphql_can_pass_variables(consumer): - client = HttpCommunicator( - consumer.as_asgi(schema=schema), - "POST", - "/graphql", - body=generate_body( - "query Hello($name: String!) { hello(name: $name) }", - variables={"name": "James"}, - ), - ) - assert_response( - await client.get_response(), - {"hello": "Hello James"}, - ) - - -@pytest.mark.django_db -@pytest.mark.parametrize("consumer", [GraphQLHTTPConsumer, SyncGraphQLHTTPConsumer]) -async def test_graphql_get_query_using_params(consumer): - client = HttpCommunicator( - consumer.as_asgi(schema=schema), - "GET", - generate_get_path("/graphql", "{ hello }"), - ) - assert_response( - await client.get_response(), - {"hello": "Hello world"}, - ) - - -@pytest.mark.django_db -@pytest.mark.parametrize("consumer", [GraphQLHTTPConsumer, SyncGraphQLHTTPConsumer]) -async def test_graphql_can_pass_variables_using_params(consumer): - client = HttpCommunicator( - consumer.as_asgi(schema=schema), - "GET", - generate_get_path( - "/graphql", - "query Hello($name: String!) { hello(name: $name) }", - variables={"name": "James"}, - ), - ) - assert_response( - await client.get_response(), - {"hello": "Hello James"}, - ) - - -@pytest.mark.django_db -@pytest.mark.parametrize("consumer", [GraphQLHTTPConsumer, SyncGraphQLHTTPConsumer]) -async def test_returns_errors_and_data(consumer): - client = HttpCommunicator( - consumer.as_asgi(schema=schema), - "POST", - "/graphql", - body=generate_body("{ hello, alwaysFail }"), - ) - response = await client.get_response() - assert response["status"] == 200 - - body = json.loads(response["body"]) - assert body["data"] == {"alwaysFail": None, "hello": "Hello world"} - assert body["errors"] == [ - { - "locations": [{"column": 10, "line": 1}], - "message": "You are not authorized", - "path": ["alwaysFail"], - } - ] - - -@pytest.mark.django_db -@pytest.mark.parametrize("consumer", [GraphQLHTTPConsumer, SyncGraphQLHTTPConsumer]) -async def test_graphql_get_does_not_allow_mutation(consumer): - client = HttpCommunicator( - consumer.as_asgi(schema=schema), - "GET", - generate_get_path("/graphql", "mutation { hello }"), - ) - response = await client.get_response() - assert response == { - "status": 406, - "headers": [], - "body": b"mutations are not allowed when using GET", - } - - -@pytest.mark.django_db -@pytest.mark.parametrize("consumer", [GraphQLHTTPConsumer, SyncGraphQLHTTPConsumer]) -async def test_graphql_get_not_allowed(consumer): - client = HttpCommunicator( - consumer.as_asgi(schema=schema, allow_queries_via_get=False), - "GET", - generate_get_path("/graphql", "query { hello }"), - ) - response = await client.get_response() - assert response == { - "status": 406, - "headers": [], - "body": b"queries are not allowed when using GET", - } diff --git a/tests/http/clients/__init__.py b/tests/http/clients/__init__.py index 3f0cd8669e..cf808da33f 100644 --- a/tests/http/clients/__init__.py +++ b/tests/http/clients/__init__.py @@ -10,7 +10,7 @@ from .async_flask import AsyncFlaskHttpClient from .base import HttpClient, WebSocketClient from .chalice import ChaliceHttpClient -from .channels import ChannelsHttpClient +from .channels import ChannelsHttpClient, SyncChannelsHttpClient from .django import DjangoHttpClient from .fastapi import FastAPIHttpClient from .flask import FlaskHttpClient @@ -34,5 +34,6 @@ "HttpClient", "SanicHttpClient", "StarliteHttpClient", + "SyncChannelsHttpClient", "WebSocketClient", ] diff --git a/tests/http/clients/channels.py b/tests/http/clients/channels.py index 54925dbde1..7a554c7298 100644 --- a/tests/http/clients/channels.py +++ b/tests/http/clients/channels.py @@ -1,14 +1,23 @@ from __future__ import annotations import contextlib -import json +import json as json_module from io import BytesIO from typing import Any, AsyncGenerator, Dict, List, Optional from typing_extensions import Literal -from channels.testing import WebsocketCommunicator -from strawberry.channels import GraphQLWSConsumer -from tests.views.schema import schema +from urllib3 import encode_multipart_formdata + +from channels.testing import HttpCommunicator, WebsocketCommunicator +from strawberry.channels import ( + GraphQLHTTPConsumer, + GraphQLWSConsumer, + SyncGraphQLHTTPConsumer, +) +from strawberry.channels.handlers.base import ChannelsConsumer +from strawberry.http import GraphQLHTTPResponse +from strawberry.http.typevars import Context, RootValue +from tests.views.schema import Query, schema from ..context import get_context from .base import ( @@ -21,18 +30,99 @@ ) +def generate_get_path( + path, query: str, variables: Optional[Dict[str, Any]] = None +) -> str: + body: Dict[str, Any] = {"query": query} + if variables is not None: + body["variables"] = json_module.dumps(variables) + + parts = [f"{k}={v}" for k, v in body.items()] + return f"{path}?{'&'.join(parts)}" + + +def create_multipart_request_body( + body: Dict[str, object], files: Dict[str, BytesIO] +) -> tuple[list[tuple[str, str]], bytes]: + fields = { + "operations": body["operations"], + "map": body["map"], + } + + for filename, data in files.items(): + fields[filename] = (filename, data.read().decode(), "text/plain") + + request_body, content_type_header = encode_multipart_formdata(fields) + + headers = [ + ("Content-Type", content_type_header), + ("Content-Length", f"{len(request_body)}"), + ] + + return headers, request_body + + class DebuggableGraphQLTransportWSConsumer(GraphQLWSConsumer): async def get_context(self, *args: str, **kwargs: Any) -> object: context = await super().get_context(*args, **kwargs) - context.tasks = self._handler.tasks - context.connectionInitTimeoutTask = getattr( + context["ws"] = self._handler._ws + context["tasks"] = self._handler.tasks + context["connectionInitTimeoutTask"] = getattr( self._handler, "connection_init_timeout_task", None ) for key, val in get_context({}).items(): - setattr(context, key, val) + context[key] = val return context +class DebuggableGraphQLHTTPConsumer(GraphQLHTTPConsumer): + result_override: ResultOverrideFunction = None + + def __init__(self, *args: Any, **kwargs: Any): + self.result_override = kwargs.pop("result_override") + super().__init__(*args, **kwargs) + + async def get_root_value(self, request: ChannelsConsumer) -> Optional[RootValue]: + return Query() + + async def get_context(self, request: ChannelsConsumer, response: Any) -> Context: + context = await super().get_context(request, response) + + return get_context(context) + + async def process_result( + self, request: ChannelsConsumer, result: Any + ) -> GraphQLHTTPResponse: + if self.result_override: + return self.result_override(result) + + return await super().process_result(request, result) + + +class DebuggableSyncGraphQLHTTPConsumer(SyncGraphQLHTTPConsumer): + result_override: ResultOverrideFunction = None + + def __init__(self, *args: Any, **kwargs: Any): + self.result_override = kwargs.pop("result_override") + super().__init__(*args, **kwargs) + + def get_root_value(self, request: ChannelsConsumer) -> Optional[RootValue]: + return Query() + + def get_context(self, request: ChannelsConsumer, response: Any) -> Context: + context = super().get_context(request, response) + + return get_context(context) + + def process_result( + self, request: ChannelsConsumer, result: Any + ) -> GraphQLHTTPResponse: + if self.result_override: + return self.result_override(result) + + return super().process_result(request, result) + + class ChannelsHttpClient(HttpClient): """ A client to test websockets over channels @@ -44,13 +134,22 @@ def __init__( allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, ): - self.app = DebuggableGraphQLTransportWSConsumer.as_asgi( + self.ws_app = DebuggableGraphQLTransportWSConsumer.as_asgi( schema=schema, keep_alive=False, ) + self.http_app = DebuggableGraphQLHTTPConsumer.as_asgi( + schema=schema, + graphiql=graphiql, + allow_queries_via_get=allow_queries_via_get, + result_override=result_override, + ) + def create_app(self, **kwargs: Any) -> None: - self.app = DebuggableGraphQLTransportWSConsumer.as_asgi(schema=schema, **kwargs) + self.ws_app = DebuggableGraphQLTransportWSConsumer.as_asgi( + schema=schema, **kwargs + ) async def _graphql_request( self, @@ -61,22 +160,60 @@ async def _graphql_request( headers: Optional[Dict[str, str]] = None, **kwargs: Any, ) -> Response: - raise NotImplementedError + body = self._build_body( + query=query, variables=variables, files=files, method=method + ) + + headers = self._get_headers(method=method, headers=headers, files=files) + + if method == "post": + if files: + new_headers, body = create_multipart_request_body(body, files) + for k, v in new_headers: + headers[k] = v + else: + body = json_module.dumps(body).encode() + endpoint_url = "/graphql" + else: + body = b"" + endpoint_url = generate_get_path("/graphql", query, variables) + + return await self.request( + url=endpoint_url, method=method, body=body, headers=headers + ) async def request( self, url: str, method: Literal["get", "post", "patch", "put", "delete"], + body: bytes = b"", headers: Optional[Dict[str, str]] = None, ) -> Response: - raise NotImplementedError + # HttpCommunicator expects tuples of bytestrings + if headers: + headers = [(k.encode(), v.encode()) for k, v in headers.items()] + + communicator = HttpCommunicator( + self.http_app, + method.upper(), + url, + body=body, + headers=headers, + ) + response = await communicator.get_response() + + return Response( + status_code=response["status"], + data=response["body"], + headers={k.decode(): v.decode() for k, v in response["headers"]}, + ) async def get( self, url: str, headers: Optional[Dict[str, str]] = None, ) -> Response: - raise NotImplementedError + return await self.request(url, "get", headers=headers) async def post( self, @@ -85,7 +222,12 @@ async def post( json: Optional[JSON] = None, headers: Optional[Dict[str, str]] = None, ) -> Response: - raise NotImplementedError + body = b"" + if data is not None: + body = data + elif json is not None: + body = json_module.dumps(json).encode() + return await self.request(url, "post", body=body, headers=headers) @contextlib.asynccontextmanager async def ws_connect( @@ -94,7 +236,7 @@ async def ws_connect( *, protocols: List[str], ) -> AsyncGenerator[WebSocketClient, None]: - client = WebsocketCommunicator(self.app, url, subprotocols=protocols) + client = WebsocketCommunicator(self.ws_app, url, subprotocols=protocols) res = await client.connect() assert res == (True, protocols[0]) @@ -104,6 +246,21 @@ async def ws_connect( await client.disconnect() +class SyncChannelsHttpClient(ChannelsHttpClient): + def __init__( + self, + graphiql: bool = True, + allow_queries_via_get: bool = True, + result_override: ResultOverrideFunction = None, + ): + self.http_app = DebuggableSyncGraphQLHTTPConsumer.as_asgi( + schema=schema, + graphiql=graphiql, + allow_queries_via_get=allow_queries_via_get, + result_override=result_override, + ) + + class ChannelsWebSocketClient(WebSocketClient): def __init__(self, client: WebsocketCommunicator): self.ws = client @@ -135,7 +292,7 @@ async def receive_json(self, timeout: Optional[float] = None) -> Any: m = await self.ws.receive_output(timeout=timeout) assert m["type"] == "websocket.send" assert "text" in m - return json.loads(m["text"]) + return json_module.loads(m["text"]) async def close(self) -> None: await self.ws.disconnect() diff --git a/tests/http/conftest.py b/tests/http/conftest.py index a39232a6f6..ae98ad73ab 100644 --- a/tests/http/conftest.py +++ b/tests/http/conftest.py @@ -9,12 +9,14 @@ AsyncDjangoHttpClient, AsyncFlaskHttpClient, ChaliceHttpClient, + ChannelsHttpClient, DjangoHttpClient, FastAPIHttpClient, FlaskHttpClient, HttpClient, SanicHttpClient, StarliteHttpClient, + SyncChannelsHttpClient, ) @@ -29,6 +31,13 @@ pytest.param(FastAPIHttpClient, marks=pytest.mark.fastapi), pytest.param(FlaskHttpClient, marks=pytest.mark.flask), pytest.param(SanicHttpClient, marks=pytest.mark.sanic), + pytest.param(ChannelsHttpClient, marks=pytest.mark.channels), + pytest.param( + # SyncChannelsHttpClient uses @database_sync_to_async and therefore + # needs pytest.mark.django_db + SyncChannelsHttpClient, + marks=[pytest.mark.channels, pytest.mark.django_db], + ), pytest.param( StarliteHttpClient, marks=[ diff --git a/tests/http/test_upload.py b/tests/http/test_upload.py index 7887999a80..8bc9532925 100644 --- a/tests/http/test_upload.py +++ b/tests/http/test_upload.py @@ -175,7 +175,10 @@ async def test_extra_form_data_fields_are_ignored(http_client: HttpClient): response = await http_client.post( url="/graphql", data=buffer.value, - headers={"content-type": writer.content_type}, + headers={ + "content-type": writer.content_type, + "content-length": f"{len(buffer.value)}", + }, ) assert response.status_code == 200 @@ -213,7 +216,10 @@ async def test_sending_invalid_json_body(http_client: HttpClient): response = await http_client.post( "/graphql", data=buffer.value, - headers={"content-type": writer.content_type}, + headers={ + "content-type": writer.content_type, + "content-length": f"{len(buffer.value)}", + }, ) assert response.status_code == 400 diff --git a/tests/views/schema.py b/tests/views/schema.py index a9e44c25a0..9b1cefab87 100644 --- a/tests/views/schema.py +++ b/tests/views/schema.py @@ -6,7 +6,6 @@ from graphql import GraphQLError import strawberry -from strawberry.channels.context import StrawberryChannelsContext from strawberry.extensions import SchemaExtension from strawberry.file_uploads import Upload from strawberry.permission import BasePermission @@ -205,13 +204,13 @@ async def debug(self, info: Info[Any, Any]) -> AsyncGenerator[DebugInfo, None]: @strawberry.subscription async def listener( self, - info: Info[StrawberryChannelsContext, Any], + info: Info[Any, Any], timeout: Optional[float] = None, group: Optional[str] = None, ) -> AsyncGenerator[str, None]: - yield info.context.request.channel_name + yield info.context["request"].channel_name - async for message in info.context.request.channel_listen( + async for message in info.context["request"].channel_listen( type="test.message", timeout=timeout, groups=[group] if group is not None else [], From c9a1f20c8fd3b398f9bec970778a433ced1af810 Mon Sep 17 00:00:00 2001 From: Botberry Date: Wed, 31 May 2023 21:20:48 +0000 Subject: [PATCH 018/119] =?UTF-8?q?Release=20=F0=9F=8D=93=200.180.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 33 +++++++++++++++++++++++++++++++++ RELEASE.md | 28 ---------------------------- pyproject.toml | 2 +- 3 files changed, 34 insertions(+), 29 deletions(-) delete mode 100644 RELEASE.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 507739ad14..e9436584e1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,39 @@ CHANGELOG ========= +0.180.0 - 2023-05-31 +-------------------- + +This release updates the Django Channels integration so that it uses the same base +classes used by all other integrations. + +**New features:** + +The Django Channels integration supports two new features: + +* Setting headers in a response +* File uploads via `multipart/form-data` POST requests + +**Breaking changes:** + +This release contains a breaking change for the Channels integration. The context +object is now a `dict` and it contains different keys depending on the connection +protocol: + +1. HTTP: `request` and `response`. The `request` object contains the full + request (including the body). Previously, `request` was the `GraphQLHTTPConsumer` + instance of the current connection. The consumer is now available via + `request.consumer`. +2. WebSockets: `request`, `ws` and `response`. `request` and `ws` are the same + `GraphQLWSConsumer` instance of the current connection. + +If you want to use a dataclass for the context object (like in previous releases), +you can still use them by overriding the `get_context` methods. See the Channels +integration documentation for an example. + +Contributed by [Christian Dröge](https://github.com/cdroege) via [PR #2775](https://github.com/strawberry-graphql/strawberry/pull/2775/) + + 0.179.0 - 2023-05-31 -------------------- diff --git a/RELEASE.md b/RELEASE.md deleted file mode 100644 index fd679d7867..0000000000 --- a/RELEASE.md +++ /dev/null @@ -1,28 +0,0 @@ -Release type: minor - -This release updates the Django Channels integration so that it uses the same base -classes used by all other integrations. - -**New features:** - -The Django Channels integration supports two new features: - -* Setting headers in a response -* File uploads via `multipart/form-data` POST requests - -**Breaking changes:** - -This release contains a breaking change for the Channels integration. The context -object is now a `dict` and it contains different keys depending on the connection -protocol: - -1. HTTP: `request` and `response`. The `request` object contains the full - request (including the body). Previously, `request` was the `GraphQLHTTPConsumer` - instance of the current connection. The consumer is now available via - `request.consumer`. -2. WebSockets: `request`, `ws` and `response`. `request` and `ws` are the same - `GraphQLWSConsumer` instance of the current connection. - -If you want to use a dataclass for the context object (like in previous releases), -you can still use them by overriding the `get_context` methods. See the Channels -integration documentation for an example. diff --git a/pyproject.toml b/pyproject.toml index b2e8f16f2c..9a0dac76b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "strawberry-graphql" packages = [ { include = "strawberry" } ] -version = "0.179.0" +version = "0.180.0" description = "A library for creating GraphQL APIs" authors = ["Patrick Arminio "] license = "MIT" From 99010b4da50b6899fa9a3c621dbc5af3e9317ed0 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Wed, 31 May 2023 23:24:23 +0200 Subject: [PATCH 019/119] Lint --- tests/schema/extensions/test_field_extensions.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/schema/extensions/test_field_extensions.py b/tests/schema/extensions/test_field_extensions.py index e16ae28f68..1b494c4fff 100644 --- a/tests/schema/extensions/test_field_extensions.py +++ b/tests/schema/extensions/test_field_extensions.py @@ -298,7 +298,9 @@ def test_extension_access_argument_metadata(): argument_metadata = {} class CustomExtension(FieldExtension): - def resolve(self, next_: Callable[..., Any], source: Any, info: Info, **kwargs): + def resolve( + self, next_: Callable[..., Any], source: Any, info: Info, **kwargs: Any + ): nonlocal field_kwargs field_kwargs = kwargs From 3ce7feaf9256c49bcf0cf3c23e8d0f2008aff623 Mon Sep 17 00:00:00 2001 From: Jaime Coello de Portugal Date: Thu, 1 Jun 2023 11:25:30 +0200 Subject: [PATCH 020/119] Make StrawberryAnnotation hashable. (#2790) * Made StrawberryAnnotation hashable. * Added RELEASE.md * Added test and fixed more hash functions. * Use pytest parametrization --------- Co-authored-by: Patrick Arminio --- RELEASE.md | 3 ++ strawberry/annotation.py | 3 ++ strawberry/type.py | 3 ++ strawberry/union.py | 3 +- tests/types/test_annotation.py | 50 ++++++++++++++++++++++++++++++++++ 5 files changed, 60 insertions(+), 2 deletions(-) create mode 100644 RELEASE.md create mode 100644 tests/types/test_annotation.py diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..30f0dd2f2d --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,3 @@ +Release type: patch + +Make StrawberryAnnotation hashable, to make it compatible to newer versions of dacite. diff --git a/strawberry/annotation.py b/strawberry/annotation.py index 17fe0fed37..1298d8c29d 100644 --- a/strawberry/annotation.py +++ b/strawberry/annotation.py @@ -61,6 +61,9 @@ def __eq__(self, other: object) -> bool: return self.resolve() == other.resolve() + def __hash__(self) -> int: + return hash(self.resolve()) + @staticmethod def from_annotation( annotation: object, namespace: Optional[Dict] = None diff --git a/strawberry/type.py b/strawberry/type.py index e546aa5f7b..c2f50273de 100644 --- a/strawberry/type.py +++ b/strawberry/type.py @@ -148,3 +148,6 @@ def __eq__(self, other: object) -> bool: return self.type_var == other return super().__eq__(other) + + def __hash__(self): + return hash(self.type_var) diff --git a/strawberry/union.py b/strawberry/union.py index e339f9166f..5de5590657 100644 --- a/strawberry/union.py +++ b/strawberry/union.py @@ -69,8 +69,7 @@ def __eq__(self, other: object) -> bool: return super().__eq__(other) def __hash__(self) -> int: - # TODO: Is this a bad idea? __eq__ objects are supposed to have the same hash - return id(self) + return hash((self.graphql_name, self.type_annotations, self.description)) def __or__(self, other: Union[StrawberryType, type]) -> StrawberryType: if other is None: diff --git a/tests/types/test_annotation.py b/tests/types/test_annotation.py new file mode 100644 index 0000000000..1bdca8e319 --- /dev/null +++ b/tests/types/test_annotation.py @@ -0,0 +1,50 @@ +import itertools +from enum import Enum +from typing import Optional, TypeVar, Union + +import pytest + +import strawberry +from strawberry.annotation import StrawberryAnnotation +from strawberry.unset import UnsetType + + +class Bleh: + pass + + +@strawberry.enum +class NumaNuma(Enum): + MA = "ma" + I = "i" # noqa: E741 + A = "a" + HI = "hi" + + +T = TypeVar("T") + +types = [ + int, + str, + None, + Optional[str], + UnsetType, + Union[int, str], + "int", + T, + Bleh, + NumaNuma, +] + + +@pytest.mark.parametrize( + ("type1", "type2"), itertools.combinations_with_replacement(types, 2) +) +def test_annotation_hash(type1: Union[object, str], type2: Union[object, str]): + annotation1 = StrawberryAnnotation(type1) + annotation2 = StrawberryAnnotation(type2) + assert ( + hash(annotation1) == hash(annotation2) + if annotation1 == annotation2 + else hash(annotation1) != hash(annotation2) + ), "Equal type must imply equal hash" From 1151f25ba669f70365553c70038b62216ce8b99c Mon Sep 17 00:00:00 2001 From: Botberry Date: Thu, 1 Jun 2023 09:26:48 +0000 Subject: [PATCH 021/119] =?UTF-8?q?Release=20=F0=9F=8D=93=200.180.1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 8 ++++++++ RELEASE.md | 3 --- pyproject.toml | 2 +- 3 files changed, 9 insertions(+), 4 deletions(-) delete mode 100644 RELEASE.md diff --git a/CHANGELOG.md b/CHANGELOG.md index e9436584e1..685f7ee033 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,14 @@ CHANGELOG ========= +0.180.1 - 2023-06-01 +-------------------- + +Make StrawberryAnnotation hashable, to make it compatible to newer versions of dacite. + +Contributed by [Jaime Coello de Portugal](https://github.com/jaimecp89) via [PR #2790](https://github.com/strawberry-graphql/strawberry/pull/2790/) + + 0.180.0 - 2023-05-31 -------------------- diff --git a/RELEASE.md b/RELEASE.md deleted file mode 100644 index 30f0dd2f2d..0000000000 --- a/RELEASE.md +++ /dev/null @@ -1,3 +0,0 @@ -Release type: patch - -Make StrawberryAnnotation hashable, to make it compatible to newer versions of dacite. diff --git a/pyproject.toml b/pyproject.toml index 9a0dac76b2..11fd827834 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "strawberry-graphql" packages = [ { include = "strawberry" } ] -version = "0.180.0" +version = "0.180.1" description = "A library for creating GraphQL APIs" authors = ["Patrick Arminio "] license = "MIT" From ca481a7d9a87b0b3ad642cdc9a4c6921fb9bda71 Mon Sep 17 00:00:00 2001 From: Matt Gilson Date: Fri, 2 Jun 2023 02:52:20 -0400 Subject: [PATCH 022/119] Allow the usage of fragment directives in codegen. (#2802) * Allow the usage of fragment directives in codegen. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Matt Gilson Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- RELEASE.md | 39 +++++++++++++ strawberry/codegen/plugins/print_operation.py | 39 +++++++++++++ strawberry/codegen/query_codegen.py | 58 ++++++++++++++++++- strawberry/codegen/types.py | 29 +++++++++- tests/codegen/queries/fragment.graphql | 18 ++++++ .../codegen/queries/mutation-fragment.graphql | 9 +++ tests/codegen/snapshots/python/fragment.py | 18 ++++++ .../snapshots/python/mutation-fragment.py | 8 +++ .../codegen/snapshots/typescript/fragment.ts | 16 +++++ .../snapshots/typescript/mutation-fragment.ts | 11 ++++ 10 files changed, 239 insertions(+), 6 deletions(-) create mode 100644 RELEASE.md create mode 100644 tests/codegen/queries/fragment.graphql create mode 100644 tests/codegen/queries/mutation-fragment.graphql create mode 100644 tests/codegen/snapshots/python/fragment.py create mode 100644 tests/codegen/snapshots/python/mutation-fragment.py create mode 100644 tests/codegen/snapshots/typescript/fragment.ts create mode 100644 tests/codegen/snapshots/typescript/mutation-fragment.ts diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..963dd5b8a4 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,39 @@ +Release type: patch + +In this release codegen no longer chokes on queries that use a fragment. + +There is one significant limitation at the present. When a fragment is included via the spread operator in an object, it must be the only field present. Attempts to include more fields will result in a ``ValueError``. + +However, there are some real benefits. When a fragment is included in multiple places in the query, only a single class will be made to represent that fragment: + +``` +fragment Point on Bar { + id + x + y +} + +query GetPoints { + circlePoints { + ...Point + } + squarePoints { + ...Point + } +} +``` + +Might generate the following types + +```py +class Point: + id: str + x: float + y: float + +class GetPointsResult: + circle_points: List[Point] + square_points: List[Point] +``` + +The previous behavior would generate duplicate classes for for the `GetPointsCirclePoints` and `GetPointsSquarePoints` even though they are really identical classes. diff --git a/strawberry/codegen/plugins/print_operation.py b/strawberry/codegen/plugins/print_operation.py index d4387ba171..fbe795f858 100644 --- a/strawberry/codegen/plugins/print_operation.py +++ b/strawberry/codegen/plugins/print_operation.py @@ -7,11 +7,15 @@ from strawberry.codegen.types import ( GraphQLBoolValue, GraphQLEnumValue, + GraphQLField, GraphQLFieldSelection, + GraphQLFragmentSpread, + GraphQLFragmentType, GraphQLInlineFragment, GraphQLIntValue, GraphQLList, GraphQLListValue, + GraphQLObjectType, GraphQLOptional, GraphQLStringValue, GraphQLVariableReference, @@ -32,8 +36,15 @@ class PrintOperationPlugin(QueryCodegenPlugin): def generate_code( self, types: List[GraphQLType], operation: GraphQLOperation ) -> List[CodegenFile]: + code_lines = [] + for t in types: + if not isinstance(t, GraphQLFragmentType): + continue + code_lines.append(self._print_fragment(t)) + code = "\n".join( [ + *code_lines, ( f"{operation.kind} {operation.name}" f"{self._print_operation_variables(operation)}" @@ -45,6 +56,28 @@ def generate_code( ) return [CodegenFile("query.graphql", code)] + def _print_fragment_field(self, field: GraphQLField, indent: str = "") -> str: + code_lines = [] + if isinstance(field.type, GraphQLObjectType): + code_lines.append(f"{indent}{field.name} {{") + for subfield in field.type.fields: + code_lines.append( + self._print_fragment_field(subfield, indent=indent + " ") + ) + code_lines.append(f"{indent}}}") + else: + code_lines.append(f"{indent}{field.name}") + return "\n".join(code_lines) + + def _print_fragment(self, fragment: GraphQLFragmentType) -> str: + code_lines = [] + code_lines.append(f"fragment {fragment.name} on {fragment.on} {{") + for field in fragment.fields: + code_lines.append(self._print_fragment_field(field, indent=" ")) + code_lines.append("}") + code_lines.append("") + return "\n".join(code_lines) + def _print_operation_variables(self, operation: GraphQLOperation) -> str: if not operation.variables: return "" @@ -143,6 +176,9 @@ def _print_inline_fragment(self, fragment: GraphQLInlineFragment) -> str: ] ) + def _print_fragment_spread(self, fragment: GraphQLFragmentSpread) -> str: + return f"...{fragment.name}" + def _print_selection(self, selection: GraphQLSelection) -> str: if isinstance(selection, GraphQLFieldSelection): return self._print_field_selection(selection) @@ -150,6 +186,9 @@ def _print_selection(self, selection: GraphQLSelection) -> str: if isinstance(selection, GraphQLInlineFragment): return self._print_inline_fragment(selection) + if isinstance(selection, GraphQLFragmentSpread): + return self._print_fragment_spread(selection) + raise ValueError(f"Unsupported selection: {selection}") # pragma: no cover def _print_selections(self, selections: List[GraphQLSelection]) -> str: diff --git a/strawberry/codegen/query_codegen.py b/strawberry/codegen/query_codegen.py index f84207cd43..aa380b7db2 100644 --- a/strawberry/codegen/query_codegen.py +++ b/strawberry/codegen/query_codegen.py @@ -18,6 +18,8 @@ BooleanValueNode, EnumValueNode, FieldNode, + FragmentDefinitionNode, + FragmentSpreadNode, InlineFragmentNode, IntValueNode, ListTypeNode, @@ -51,6 +53,8 @@ GraphQLEnumValue, GraphQLField, GraphQLFieldSelection, + GraphQLFragmentSpread, + GraphQLFragmentType, GraphQLInlineFragment, GraphQLIntValue, GraphQLList, @@ -169,6 +173,10 @@ def run(self, query: str) -> CodegenResult: if operation.name is None: raise NoOperationNameProvidedError() + # Look for any free-floating fragments and create types out of them + # These types can then be referenced and included later via the + # fragment spread operator. + self._populate_fragment_types(ast) self.operation = self._convert_operation(operation) result = self.generate_code() @@ -182,6 +190,26 @@ def _collect_type(self, type_: GraphQLType) -> None: self.types.append(type_) + def _populate_fragment_types(self, ast: DocumentNode) -> None: + fragment_definitions = ( + definition + for definition in ast.definitions + if isinstance(definition, FragmentDefinitionNode) + ) + for fd in fragment_definitions: + query_type = self.schema.get_type_by_name("Query") + assert isinstance(query_type, TypeDefinition) + self._collect_types( + # The FragmentDefinitionNode has a non-Optional `SelectionSetNode` but the Protocol + # wants an `Optional[SelectionSetNode]` so this doesn't quite conform. + cast(HasSelectionSet, fd), + parent_type=query_type, + class_name=fd.name.value, + graph_ql_object_type_factory=lambda name: GraphQLFragmentType( + name, on=fd.type_condition.name.value + ), + ) + def _convert_selection(self, selection: SelectionNode) -> GraphQLSelection: if isinstance(selection, FieldNode): return GraphQLFieldSelection( @@ -198,6 +226,9 @@ def _convert_selection(self, selection: SelectionNode) -> GraphQLSelection: self._convert_selection_set(selection.selection_set), ) + if isinstance(selection, FragmentSpreadNode): + return GraphQLFragmentSpread(selection.name.value) + raise ValueError(f"Unsupported type: {type(selection)}") # pragma: no cover def _convert_selection_set( @@ -525,6 +556,9 @@ def _collect_types( selection: HasSelectionSet, parent_type: TypeDefinition, class_name: str, + graph_ql_object_type_factory: Callable[ + [str], GraphQLObjectType + ] = GraphQLObjectType, ) -> GraphQLType: assert selection.selection_set is not None selection_set = selection.selection_set @@ -537,14 +571,32 @@ def _collect_types( selection, parent_type, class_name ) - current_type = GraphQLObjectType(class_name, []) + current_type = graph_ql_object_type_factory(class_name) + fields: List[Union[GraphQLFragmentSpread, GraphQLField]] = [] for sub_selection in selection_set.selections: + if isinstance(sub_selection, FragmentSpreadNode): + fields.append(GraphQLFragmentSpread(sub_selection.name.value)) + continue assert isinstance(sub_selection, FieldNode) - field = self._get_field(sub_selection, class_name, parent_type) - current_type.fields.append(field) + fields.append(field) + + if any(isinstance(f, GraphQLFragmentSpread) for f in fields): + if len(fields) > 1: + raise ValueError( + "Queries with Fragments cannot currently include separate fields." + ) + spread_field = fields[0] + assert isinstance(spread_field, GraphQLFragmentSpread) + return next( + t + for t in self.types + if isinstance(t, GraphQLObjectType) and t.name == spread_field.name + ) + + current_type.fields = cast(List[GraphQLField], fields) self._collect_type(current_type) diff --git a/strawberry/codegen/types.py b/strawberry/codegen/types.py index 5e28daf247..3c9ddeec06 100644 --- a/strawberry/codegen/types.py +++ b/strawberry/codegen/types.py @@ -1,6 +1,6 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, List, Optional, Type, Union if TYPE_CHECKING: @@ -31,10 +31,31 @@ class GraphQLField: type: GraphQLType +@dataclass +class GraphQLFragmentSpread: + name: str + + @dataclass class GraphQLObjectType: name: str - fields: List[GraphQLField] + fields: List[GraphQLField] = field(default_factory=list) + + +# Subtype of GraphQLObjectType. +# Because dataclass inheritance is a little odd, the fields are +# repeated here. +@dataclass +class GraphQLFragmentType(GraphQLObjectType): + name: str + fields: List[GraphQLField] = field(default_factory=list) + on: str = "" + + def __post_init__(self) -> None: + if not self.on: + raise ValueError( + "GraphQLFragmentType must be constructed with a valid 'on'" + ) @dataclass @@ -75,7 +96,9 @@ class GraphQLInlineFragment: selections: List[GraphQLSelection] -GraphQLSelection = Union[GraphQLFieldSelection, GraphQLInlineFragment] +GraphQLSelection = Union[ + GraphQLFieldSelection, GraphQLInlineFragment, GraphQLFragmentSpread +] @dataclass diff --git a/tests/codegen/queries/fragment.graphql b/tests/codegen/queries/fragment.graphql new file mode 100644 index 0000000000..daea9e48a5 --- /dev/null +++ b/tests/codegen/queries/fragment.graphql @@ -0,0 +1,18 @@ +fragment Fields on Query { + id + integer + float + boolean + uuid + date + datetime + time + decimal + lazy { + something + } +} + +query OperationName { + ...Fields +} diff --git a/tests/codegen/queries/mutation-fragment.graphql b/tests/codegen/queries/mutation-fragment.graphql new file mode 100644 index 0000000000..b39782c70e --- /dev/null +++ b/tests/codegen/queries/mutation-fragment.graphql @@ -0,0 +1,9 @@ +fragment IdFragment on BlogPost { + id +} + +mutation addBook($input: String!) { + addBook(input: $input) { + ...IdFragment + } +} diff --git a/tests/codegen/snapshots/python/fragment.py b/tests/codegen/snapshots/python/fragment.py new file mode 100644 index 0000000000..9b1487942a --- /dev/null +++ b/tests/codegen/snapshots/python/fragment.py @@ -0,0 +1,18 @@ +from uuid import UUID +from datetime import date, datetime, time +from decimal import Decimal + +class FieldsLazy: + something: bool + +class Fields: + id: str + integer: int + float: float + boolean: bool + uuid: UUID + date: date + datetime: datetime + time: time + decimal: Decimal + lazy: FieldsLazy diff --git a/tests/codegen/snapshots/python/mutation-fragment.py b/tests/codegen/snapshots/python/mutation-fragment.py new file mode 100644 index 0000000000..17ad048a19 --- /dev/null +++ b/tests/codegen/snapshots/python/mutation-fragment.py @@ -0,0 +1,8 @@ +class IdFragment: + id: str + +class addBookResult: + add_book: IdFragment + +class addBookVariables: + input: str diff --git a/tests/codegen/snapshots/typescript/fragment.ts b/tests/codegen/snapshots/typescript/fragment.ts new file mode 100644 index 0000000000..b8d4a67e03 --- /dev/null +++ b/tests/codegen/snapshots/typescript/fragment.ts @@ -0,0 +1,16 @@ +type FieldsLazy = { + something: boolean +} + +type Fields = { + id: string + integer: number + float: number + boolean: boolean + uuid: string + date: string + datetime: string + time: string + decimal: string + lazy: FieldsLazy +} diff --git a/tests/codegen/snapshots/typescript/mutation-fragment.ts b/tests/codegen/snapshots/typescript/mutation-fragment.ts new file mode 100644 index 0000000000..261d48108a --- /dev/null +++ b/tests/codegen/snapshots/typescript/mutation-fragment.ts @@ -0,0 +1,11 @@ +type IdFragment = { + id: string +} + +type addBookResult = { + add_book: IdFragment +} + +type addBookVariables = { + input: string +} From fa59442dc5ed01080378c6d3e7ce124dda1dc1ad Mon Sep 17 00:00:00 2001 From: Botberry Date: Fri, 2 Jun 2023 06:53:53 +0000 Subject: [PATCH 023/119] =?UTF-8?q?Release=20=F0=9F=8D=93=200.180.2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 44 ++++++++++++++++++++++++++++++++++++++++++++ RELEASE.md | 39 --------------------------------------- pyproject.toml | 2 +- 3 files changed, 45 insertions(+), 40 deletions(-) delete mode 100644 RELEASE.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 685f7ee033..14bd9fc1db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,50 @@ CHANGELOG ========= +0.180.2 - 2023-06-02 +-------------------- + +In this release codegen no longer chokes on queries that use a fragment. + +There is one significant limitation at the present. When a fragment is included via the spread operator in an object, it must be the only field present. Attempts to include more fields will result in a ``ValueError``. + +However, there are some real benefits. When a fragment is included in multiple places in the query, only a single class will be made to represent that fragment: + +``` +fragment Point on Bar { + id + x + y +} + +query GetPoints { + circlePoints { + ...Point + } + squarePoints { + ...Point + } +} +``` + +Might generate the following types + +```py +class Point: + id: str + x: float + y: float + +class GetPointsResult: + circle_points: List[Point] + square_points: List[Point] +``` + +The previous behavior would generate duplicate classes for for the `GetPointsCirclePoints` and `GetPointsSquarePoints` even though they are really identical classes. + +Contributed by [Matt Gilson](https://github.com/mgilson) via [PR #2802](https://github.com/strawberry-graphql/strawberry/pull/2802/) + + 0.180.1 - 2023-06-01 -------------------- diff --git a/RELEASE.md b/RELEASE.md deleted file mode 100644 index 963dd5b8a4..0000000000 --- a/RELEASE.md +++ /dev/null @@ -1,39 +0,0 @@ -Release type: patch - -In this release codegen no longer chokes on queries that use a fragment. - -There is one significant limitation at the present. When a fragment is included via the spread operator in an object, it must be the only field present. Attempts to include more fields will result in a ``ValueError``. - -However, there are some real benefits. When a fragment is included in multiple places in the query, only a single class will be made to represent that fragment: - -``` -fragment Point on Bar { - id - x - y -} - -query GetPoints { - circlePoints { - ...Point - } - squarePoints { - ...Point - } -} -``` - -Might generate the following types - -```py -class Point: - id: str - x: float - y: float - -class GetPointsResult: - circle_points: List[Point] - square_points: List[Point] -``` - -The previous behavior would generate duplicate classes for for the `GetPointsCirclePoints` and `GetPointsSquarePoints` even though they are really identical classes. diff --git a/pyproject.toml b/pyproject.toml index 11fd827834..d7ee1f1b88 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "strawberry-graphql" packages = [ { include = "strawberry" } ] -version = "0.180.1" +version = "0.180.2" description = "A library for creating GraphQL APIs" authors = ["Patrick Arminio "] license = "MIT" From b832a620534e51fe614be673951e648ad9878123 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Fri, 2 Jun 2023 16:54:51 +0800 Subject: [PATCH 024/119] Fix GraphiQL Explorer Plugin styling and update GraphiQL to the latest 2.4.7 version (#2804) * Update GraphiQL and co to latest versions * Add @graphiql/plugin-explorer css * Add RELEASE.md * Fix pre-commit errors --- RELEASE.md | 3 +++ strawberry/codegen/query_codegen.py | 18 +++++++++++++----- strawberry/static/graphiql.html | 23 +++++++++++++++-------- 3 files changed, 31 insertions(+), 13 deletions(-) create mode 100644 RELEASE.md diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..4542a35805 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,3 @@ +Release type: patch + +This release updates the built-in GraphiQL to the current latest version 2.4.7 and improves styling for the GraphiQL Explorer Plugin. diff --git a/strawberry/codegen/query_codegen.py b/strawberry/codegen/query_codegen.py index aa380b7db2..613e7a82d6 100644 --- a/strawberry/codegen/query_codegen.py +++ b/strawberry/codegen/query_codegen.py @@ -1,6 +1,7 @@ from __future__ import annotations from dataclasses import dataclass +from functools import partial from typing import ( TYPE_CHECKING, Callable, @@ -199,15 +200,22 @@ def _populate_fragment_types(self, ast: DocumentNode) -> None: for fd in fragment_definitions: query_type = self.schema.get_type_by_name("Query") assert isinstance(query_type, TypeDefinition) + + def graph_ql_object_type_factory(name: str, on: str): + return GraphQLFragmentType(name, on=on) + + graph_ql_object_type_factory = partial( + graph_ql_object_type_factory, on=fd.type_condition.name.value + ) + self._collect_types( - # The FragmentDefinitionNode has a non-Optional `SelectionSetNode` but the Protocol - # wants an `Optional[SelectionSetNode]` so this doesn't quite conform. + # The FragmentDefinitionNode has a non-Optional `SelectionSetNode` but + # the Protocol wants an `Optional[SelectionSetNode]` so this doesn't + # quite conform. cast(HasSelectionSet, fd), parent_type=query_type, class_name=fd.name.value, - graph_ql_object_type_factory=lambda name: GraphQLFragmentType( - name, on=fd.type_condition.name.value - ), + graph_ql_object_type_factory=graph_ql_object_type_factory, ) def _convert_selection(self, selection: SelectionNode) -> GraphQLSelection: diff --git a/strawberry/static/graphiql.html b/strawberry/static/graphiql.html index 4097e1d72a..d2df16df33 100644 --- a/strawberry/static/graphiql.html +++ b/strawberry/static/graphiql.html @@ -54,15 +54,22 @@ + + @@ -70,13 +77,13 @@
Loading...