Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for optional connections #3707

Merged
merged 11 commits into from
Dec 20, 2024
51 changes: 51 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
Release type: minor

This release adds support for making Relay connection optional, this is useful
when you want to add permission classes to the connection and not fail the whole
query if the user doesn't have permission to access the connection.
patrick91 marked this conversation as resolved.
Show resolved Hide resolved

Example:

```python
import strawberry
from strawberry import relay
from strawberry.permission import BasePermission


class IsAuthenticated(BasePermission):
message = "User is not authenticated"

# This method can also be async!
def has_permission(
self, source: typing.Any, info: strawberry.Info, **kwargs
) -> bool:
return False


@strawberry.type
class Fruit(relay.Node):
code: relay.NodeID[int]
name: str
weight: float

@classmethod
def resolve_nodes(
cls,
*,
info: strawberry.Info,
node_ids: Iterable[str],
):
return []


@strawberry.type
class Query:
node: relay.Node = relay.node()

@relay.connection(
relay.ListConnection[Fruit] | None, permission_classes=[IsAuthenticated()]
)
def fruits(self) -> Iterable[Fruit]:
# This can be a database query, a generator, an async generator, etc
return all_fruits.values()
```
2 changes: 1 addition & 1 deletion strawberry/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def create_optional(self, evaled_type: Any) -> StrawberryOptional:
)

# Note that passing a single type to `Union` is equivalent to not using `Union`
# at all. This allows us to not di any checks for how many types have been
# at all. This allows us to not do any checks for how many types have been
# passed as we can safely use `Union` for both optional types
# (e.g. `Optional[str]`) and optional unions (e.g.
# `Optional[Union[TypeA, TypeB]]`)
Expand Down
63 changes: 22 additions & 41 deletions strawberry/relay/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@
Type,
Union,
cast,
overload,
)
from typing_extensions import Annotated, get_origin
from typing_extensions import Annotated, get_args, get_origin

from strawberry.annotation import StrawberryAnnotation
from strawberry.extensions.field_extension import (
Expand All @@ -44,9 +43,9 @@
from strawberry.types.fields.resolver import StrawberryResolver
from strawberry.types.lazy_type import LazyType
from strawberry.utils.aio import asyncgen_to_list
from strawberry.utils.typing import eval_type, is_generic_alias
from strawberry.utils.typing import eval_type, is_generic_alias, is_union

from .types import Connection, GlobalID, Node, NodeIterableType, NodeType
from .types import Connection, GlobalID, Node

if TYPE_CHECKING:
from typing_extensions import Literal
Expand Down Expand Up @@ -233,7 +232,11 @@ def apply(self, field: StrawberryField) -> None:
f_type = f_type.resolve_type()
field.type = f_type

if isinstance(f_type, StrawberryOptional):
f_type = f_type.of_type

type_origin = get_origin(f_type) if is_generic_alias(f_type) else f_type

if not isinstance(type_origin, type) or not issubclass(type_origin, Connection):
raise RelayWrongAnnotationError(field.name, cast(type, field.origin))

Expand All @@ -253,13 +256,18 @@ def apply(self, field: StrawberryField) -> None:
None,
)

if is_union(resolver_type):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): The current union type handling is unsafe as it assumes the first type argument is the correct one

Consider implementing proper validation of the union type structure to ensure we're handling optional types correctly. The current approach could lead to runtime errors if the assumptions about the type structure don't hold.

# TODO: actually check if is optional and get correct type
resolver_type = get_args(resolver_type)[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: we can probably use is_optional to check and get_optional_annotation in this line from strawberry.utils.typing


origin = get_origin(resolver_type)

if origin is None or not issubclass(
origin, (Iterator, Iterable, AsyncIterator, AsyncIterable)
):
raise RelayWrongResolverAnnotationError(field.name, field.base_resolver)

self.connection_type = cast(Type[Connection[Node]], field.type)
self.connection_type = cast(Type[Connection[Node]], f_type)

def resolve(
self,
Expand Down Expand Up @@ -327,44 +335,17 @@ def node(*args: Any, **kwargs: Any) -> StrawberryField:
return field(*args, **kwargs)


@overload
def connection(
graphql_type: Optional[Type[Connection[NodeType]]] = None,
*,
resolver: Optional[_RESOLVER_TYPE[NodeIterableType[Any]]] = None,
name: Optional[str] = None,
is_subscription: bool = False,
description: Optional[str] = None,
init: Literal[True] = True,
permission_classes: Optional[List[Type[BasePermission]]] = None,
deprecation_reason: Optional[str] = None,
default: Any = dataclasses.MISSING,
default_factory: Union[Callable[..., object], object] = dataclasses.MISSING,
metadata: Optional[Mapping[Any, Any]] = None,
directives: Optional[Sequence[object]] = (),
extensions: List[FieldExtension] = (), # type: ignore
) -> Any: ...


@overload
def connection(
graphql_type: Optional[Type[Connection[NodeType]]] = None,
*,
name: Optional[str] = None,
is_subscription: bool = False,
description: Optional[str] = None,
permission_classes: Optional[List[Type[BasePermission]]] = None,
deprecation_reason: Optional[str] = None,
default: Any = dataclasses.MISSING,
default_factory: Union[Callable[..., object], object] = dataclasses.MISSING,
metadata: Optional[Mapping[Any, Any]] = None,
directives: Optional[Sequence[object]] = (),
extensions: List[FieldExtension] = (), # type: ignore
) -> StrawberryField: ...
# we used to have `Type[Connection[NodeType]]` here, but that when we added
# support for making the Connection type optional, we had to change it to
# `Any` because otherwise it wouldn't be type check since `Optional[Connection[Something]]`
# is not a `Type`, but a special form, see https://discuss.python.org/t/is-annotated-compatible-with-type-t/43898/46
# for more information, and also https://peps.python.org/pep-0747/, which is currently
# in draft status (and no type checker supports it yet)
ConnectionGraphQLType = Any


def connection(
graphql_type: Optional[Type[Connection[NodeType]]] = None,
graphql_type: Optional[ConnectionGraphQLType] = None,
*,
resolver: Optional[_RESOLVER_TYPE[Any]] = None,
name: Optional[str] = None,
Expand All @@ -379,7 +360,7 @@ def connection(
extensions: List[FieldExtension] = (), # type: ignore
# This init parameter is used by pyright to determine whether this field
# is added in the constructor or not. It is not used to change
# any behavior at the moment.
# any behaviour at the moment.
init: Literal[True, False, None] = None,
) -> Any:
"""Annotate a property or a method to create a relay connection field.
Expand Down
2 changes: 1 addition & 1 deletion strawberry/schema/subscribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ async def subscribe(
middleware_manager,
execution_context_class,
)
# GrapQL-core might return an initial error result instead of an async iterator.
# GraphQL-core might return an initial error result instead of an async iterator.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick (typo): Fix typo in comment: 'GrapQL' should be 'GraphQL'

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hrm, isn't this suggestion correct? 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the suggestion is correct. The original comment had a typo, "GrapQL," which should indeed be corrected to "GraphQL." This change ensures clarity and accuracy in the code comments. Let me know if you have any other questions or need further clarification!

# This happens when "there was an immediate error" i.e resolver is not an async iterator.
# To overcome this while maintaining the extension contexts we do this trick.
first = await asyncgen.__anext__()
Expand Down
125 changes: 125 additions & 0 deletions tests/relay/test_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import sys
from typing import Any, Iterable, List, Optional
from typing_extensions import Self

import pytest

import strawberry
from strawberry.permission import BasePermission
from strawberry.relay import Connection, Node


@strawberry.type
class User(Node):
id: strawberry.relay.NodeID
name: str = "John"

@classmethod
def resolve_nodes(
cls, *, info: strawberry.Info, node_ids: List[Any], required: bool
) -> List[Self]:
return [cls() for _ in node_ids]


@strawberry.type
class UserConnection(Connection[User]):
@classmethod
def resolve_connection(
cls,
nodes: Iterable[User],
*,
info: Any,
after: Optional[str] = None,
before: Optional[str] = None,
first: Optional[int] = None,
last: Optional[int] = None,
) -> Optional[Self]:
return None


class TestPermission(BasePermission):
message = "Not allowed"

def has_permission(self, source, info, **kwargs: Any):
return False


def test_nullable_connection_with_optional():
@strawberry.type
class Query:
@strawberry.relay.connection(Optional[UserConnection])
def users(self) -> Optional[List[User]]:
return None

schema = strawberry.Schema(query=Query)
query = """
query {
users {
edges {
node {
name
}
}
}
}
"""

result = schema.execute_sync(query)
assert result.data == {"users": None}
assert not result.errors


@pytest.mark.skipif(
sys.version_info < (3, 10),
reason="pipe syntax for union is only available on python 3.10+",
)
def test_nullable_connection_with_pipe():
@strawberry.type
class Query:
@strawberry.relay.connection(UserConnection | None)
def users(self) -> List[User] | None:
return None

schema = strawberry.Schema(query=Query)
query = """
query {
users {
edges {
node {
name
}
}
}
}
"""

result = schema.execute_sync(query)
assert result.data == {"users": None}
assert not result.errors


def test_nullable_connection_with_permission():
@strawberry.type
class Query:
@strawberry.relay.connection(
Optional[UserConnection], permission_classes=[TestPermission]
)
def users(self) -> Optional[List[User]]:
return None

Check warning on line 108 in tests/relay/test_connection.py

View check run for this annotation

Codecov / codecov/patch

tests/relay/test_connection.py#L108

Added line #L108 was not covered by tests
patrick91 marked this conversation as resolved.
Show resolved Hide resolved

schema = strawberry.Schema(query=Query)
patrick91 marked this conversation as resolved.
Show resolved Hide resolved
query = """
query {
users {
edges {
node {
name
}
}
}
}
"""

result = schema.execute_sync(query)
assert result.data == {"users": None}
assert result.errors[0].message == "Not allowed"
Loading