diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..b099974c24 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,51 @@ +Release type: minor + +This release adds support for making Relay connection optional, this is useful +when you want to add permission classes to the connection and not fail the whole +query if the user doesn't have permission to access the connection. + +Example: + +```python +import strawberry +from strawberry import relay +from strawberry.permission import BasePermission + + +class IsAuthenticated(BasePermission): + message = "User is not authenticated" + + # This method can also be async! + def has_permission( + self, source: typing.Any, info: strawberry.Info, **kwargs + ) -> bool: + return False + + +@strawberry.type +class Fruit(relay.Node): + code: relay.NodeID[int] + name: str + weight: float + + @classmethod + def resolve_nodes( + cls, + *, + info: strawberry.Info, + node_ids: Iterable[str], + ): + return [] + + +@strawberry.type +class Query: + node: relay.Node = relay.node() + + @relay.connection( + relay.ListConnection[Fruit] | None, permission_classes=[IsAuthenticated()] + ) + def fruits(self) -> Iterable[Fruit]: + # This can be a database query, a generator, an async generator, etc + return all_fruits.values() +``` diff --git a/strawberry/annotation.py b/strawberry/annotation.py index dff708a1a1..8934f5bad8 100644 --- a/strawberry/annotation.py +++ b/strawberry/annotation.py @@ -208,7 +208,7 @@ def create_optional(self, evaled_type: Any) -> StrawberryOptional: ) # Note that passing a single type to `Union` is equivalent to not using `Union` - # at all. This allows us to not di any checks for how many types have been + # at all. This allows us to not do any checks for how many types have been # passed as we can safely use `Union` for both optional types # (e.g. `Optional[str]`) and optional unions (e.g. # `Optional[Union[TypeA, TypeB]]`) diff --git a/strawberry/relay/fields.py b/strawberry/relay/fields.py index 5af00700f8..32673cfc10 100644 --- a/strawberry/relay/fields.py +++ b/strawberry/relay/fields.py @@ -24,9 +24,8 @@ Type, Union, cast, - overload, ) -from typing_extensions import Annotated, get_origin +from typing_extensions import Annotated, get_args, get_origin from strawberry.annotation import StrawberryAnnotation from strawberry.extensions.field_extension import ( @@ -44,9 +43,9 @@ from strawberry.types.fields.resolver import StrawberryResolver from strawberry.types.lazy_type import LazyType from strawberry.utils.aio import asyncgen_to_list -from strawberry.utils.typing import eval_type, is_generic_alias +from strawberry.utils.typing import eval_type, is_generic_alias, is_optional, is_union -from .types import Connection, GlobalID, Node, NodeIterableType, NodeType +from .types import Connection, GlobalID, Node if TYPE_CHECKING: from typing_extensions import Literal @@ -233,7 +232,11 @@ def apply(self, field: StrawberryField) -> None: f_type = f_type.resolve_type() field.type = f_type + if isinstance(f_type, StrawberryOptional): + f_type = f_type.of_type + type_origin = get_origin(f_type) if is_generic_alias(f_type) else f_type + if not isinstance(type_origin, type) or not issubclass(type_origin, Connection): raise RelayWrongAnnotationError(field.name, cast(type, field.origin)) @@ -253,13 +256,19 @@ def apply(self, field: StrawberryField) -> None: None, ) + if is_union(resolver_type): + assert is_optional(resolver_type) + + resolver_type = get_args(resolver_type)[0] + origin = get_origin(resolver_type) + if origin is None or not issubclass( origin, (Iterator, Iterable, AsyncIterator, AsyncIterable) ): raise RelayWrongResolverAnnotationError(field.name, field.base_resolver) - self.connection_type = cast(Type[Connection[Node]], field.type) + self.connection_type = cast(Type[Connection[Node]], f_type) def resolve( self, @@ -327,44 +336,17 @@ def node(*args: Any, **kwargs: Any) -> StrawberryField: return field(*args, **kwargs) -@overload -def connection( - graphql_type: Optional[Type[Connection[NodeType]]] = None, - *, - resolver: Optional[_RESOLVER_TYPE[NodeIterableType[Any]]] = None, - name: Optional[str] = None, - is_subscription: bool = False, - description: Optional[str] = None, - init: Literal[True] = True, - permission_classes: Optional[List[Type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, - default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), - extensions: List[FieldExtension] = (), # type: ignore -) -> Any: ... - - -@overload -def connection( - graphql_type: Optional[Type[Connection[NodeType]]] = None, - *, - name: Optional[str] = None, - is_subscription: bool = False, - description: Optional[str] = None, - permission_classes: Optional[List[Type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, - default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), - extensions: List[FieldExtension] = (), # type: ignore -) -> StrawberryField: ... +# we used to have `Type[Connection[NodeType]]` here, but that when we added +# support for making the Connection type optional, we had to change it to +# `Any` because otherwise it wouldn't be type check since `Optional[Connection[Something]]` +# is not a `Type`, but a special form, see https://discuss.python.org/t/is-annotated-compatible-with-type-t/43898/46 +# for more information, and also https://peps.python.org/pep-0747/, which is currently +# in draft status (and no type checker supports it yet) +ConnectionGraphQLType = Any def connection( - graphql_type: Optional[Type[Connection[NodeType]]] = None, + graphql_type: Optional[ConnectionGraphQLType] = None, *, resolver: Optional[_RESOLVER_TYPE[Any]] = None, name: Optional[str] = None, @@ -379,7 +361,7 @@ def connection( extensions: List[FieldExtension] = (), # type: ignore # This init parameter is used by pyright to determine whether this field # is added in the constructor or not. It is not used to change - # any behavior at the moment. + # any behaviour at the moment. init: Literal[True, False, None] = None, ) -> Any: """Annotate a property or a method to create a relay connection field. diff --git a/strawberry/schema/subscribe.py b/strawberry/schema/subscribe.py index be8b783cb2..22052e36cf 100644 --- a/strawberry/schema/subscribe.py +++ b/strawberry/schema/subscribe.py @@ -138,7 +138,7 @@ async def subscribe( middleware_manager, execution_context_class, ) - # GrapQL-core might return an initial error result instead of an async iterator. + # GraphQL-core might return an initial error result instead of an async iterator. # This happens when "there was an immediate error" i.e resolver is not an async iterator. # To overcome this while maintaining the extension contexts we do this trick. first = await asyncgen.__anext__() diff --git a/tests/relay/test_connection.py b/tests/relay/test_connection.py new file mode 100644 index 0000000000..f84d398895 --- /dev/null +++ b/tests/relay/test_connection.py @@ -0,0 +1,125 @@ +import sys +from typing import Any, Iterable, List, Optional +from typing_extensions import Self + +import pytest + +import strawberry +from strawberry.permission import BasePermission +from strawberry.relay import Connection, Node + + +@strawberry.type +class User(Node): + id: strawberry.relay.NodeID + name: str = "John" + + @classmethod + def resolve_nodes( + cls, *, info: strawberry.Info, node_ids: List[Any], required: bool + ) -> List[Self]: + return [cls() for _ in node_ids] + + +@strawberry.type +class UserConnection(Connection[User]): + @classmethod + def resolve_connection( + cls, + nodes: Iterable[User], + *, + info: Any, + after: Optional[str] = None, + before: Optional[str] = None, + first: Optional[int] = None, + last: Optional[int] = None, + ) -> Optional[Self]: + return None + + +class TestPermission(BasePermission): + message = "Not allowed" + + def has_permission(self, source, info, **kwargs: Any): + return False + + +def test_nullable_connection_with_optional(): + @strawberry.type + class Query: + @strawberry.relay.connection(Optional[UserConnection]) + def users(self) -> Optional[List[User]]: + return None + + schema = strawberry.Schema(query=Query) + query = """ + query { + users { + edges { + node { + name + } + } + } + } + """ + + result = schema.execute_sync(query) + assert result.data == {"users": None} + assert not result.errors + + +@pytest.mark.skipif( + sys.version_info < (3, 10), + reason="pipe syntax for union is only available on python 3.10+", +) +def test_nullable_connection_with_pipe(): + @strawberry.type + class Query: + @strawberry.relay.connection(UserConnection | None) + def users(self) -> List[User] | None: + return None + + schema = strawberry.Schema(query=Query) + query = """ + query { + users { + edges { + node { + name + } + } + } + } + """ + + result = schema.execute_sync(query) + assert result.data == {"users": None} + assert not result.errors + + +def test_nullable_connection_with_permission(): + @strawberry.type + class Query: + @strawberry.relay.connection( + Optional[UserConnection], permission_classes=[TestPermission] + ) + def users(self) -> Optional[List[User]]: # pragma: no cover + pytest.fail("Should not have been called...") + + schema = strawberry.Schema(query=Query) + query = """ + query { + users { + edges { + node { + name + } + } + } + } + """ + + result = schema.execute_sync(query) + assert result.data == {"users": None} + assert result.errors[0].message == "Not allowed"