Skip to content

Commit

Permalink
fix: Fix lazy aliased connections type resolution
Browse files Browse the repository at this point in the history
  • Loading branch information
bellini666 committed May 28, 2024
1 parent 4ba2379 commit 3a1e0d5
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 10 deletions.
23 changes: 23 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
Release type: patch

This release fixes an issue that would prevent using lazy aliased connections to
annotate a connection field.

For example, this should now work correctly:

```python
# types.py

@strawberry.type
class Fruit:
...

FruitConnection: TypeAlias = ListConnection[Fruit]


# schema.py

@strawberry.type
class Query:
fruits: Annotated["FruitConnection", strawberry.lazy("types")] = strawberry.connection()
```
13 changes: 10 additions & 3 deletions strawberry/relay/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@
SyncExtensionResolver,
)
from strawberry.field import _RESOLVER_TYPE, StrawberryField, field
from strawberry.lazy_type import LazyType
from strawberry.relay.exceptions import (
RelayWrongAnnotationError,
RelayWrongResolverAnnotationError,
)
from strawberry.type import StrawberryList, StrawberryOptional
from strawberry.types.fields.resolver import StrawberryResolver
from strawberry.utils.aio import asyncgen_to_list
from strawberry.utils.typing import eval_type
from strawberry.utils.typing import eval_type, is_generic_alias

from .types import Connection, GlobalID, Node, NodeIterableType, NodeType

Expand Down Expand Up @@ -223,7 +224,13 @@ def apply(self, field: StrawberryField) -> None:
]

f_type = field.type
if not isinstance(f_type, type) or not issubclass(f_type, Connection):

if isinstance(f_type, LazyType):
f_type = f_type.resolve_type()
field.type = f_type

type_origin = get_origin(f_type) if is_generic_alias(f_type) else f_type
if not isinstance(type_origin, type) or not issubclass(type_origin, Connection):
raise RelayWrongAnnotationError(field.name, cast(type, field.origin))

assert field.base_resolver
Expand All @@ -248,7 +255,7 @@ def apply(self, field: StrawberryField) -> None:
):
raise RelayWrongResolverAnnotationError(field.name, field.base_resolver)

self.connection_type = cast(Type[Connection[Node]], field.type)
self.connection_type = cast(Type[Connection[Node]], f_type)

def resolve(
self,
Expand Down
15 changes: 10 additions & 5 deletions strawberry/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
cast,
overload,
)
from typing_extensions import Annotated, get_args, get_origin
from typing_extensions import Annotated, TypeGuard, get_args, get_origin

ast_unparse = getattr(ast, "unparse", None)
# ast.unparse is only available on python 3.9+. For older versions we will
Expand Down Expand Up @@ -63,15 +63,20 @@ def get_generic_alias(type_: Type) -> Type:
continue

attr = getattr(typing, attr_name)
# _GenericAlias overrides all the methods that we can use to know if
# this is a subclass of it. But if it has an "_inst" attribute
# then it for sure is a _GenericAlias
if hasattr(attr, "_inst") and attr.__origin__ is type_:
if is_generic_alias(attr) and attr.__origin__ is type_:
return attr

raise AssertionError(f"No GenericAlias available for {type_}") # pragma: no cover


def is_generic_alias(type_: Any) -> TypeGuard[_GenericAlias]:
"""Returns True if the type is a generic alias."""
# _GenericAlias overrides all the methods that we can use to know if
# this is a subclass of it. But if it has an "_inst" attribute
# then it for sure is a _GenericAlias
return hasattr(type_, "_inst")


def is_list(annotation: object) -> bool:
"""Returns True if annotation is a List"""

Expand Down
10 changes: 9 additions & 1 deletion tests/relay/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
Optional,
cast,
)
from typing_extensions import Annotated, Self
from typing_extensions import Annotated, Self, TypeAlias

import strawberry
from strawberry import relay
Expand Down Expand Up @@ -191,6 +191,9 @@ async def has_permission(
return True


FruitsListConnectionAlias: TypeAlias = relay.ListConnection[Fruit]


@strawberry.type
class Query:
node: relay.Node = relay.node()
Expand All @@ -204,6 +207,11 @@ class Query:
fruits_lazy: relay.ListConnection[
Annotated["Fruit", strawberry.lazy("tests.relay.schema")]
] = relay.connection(resolver=fruits_resolver)
fruits_alias: FruitsListConnectionAlias = relay.connection(resolver=fruits_resolver)
fruits_alias_lazy: Annotated[
"FruitsListConnectionAlias",
strawberry.lazy("tests.relay.schema"),
] = relay.connection(resolver=fruits_resolver)
fruits_async: relay.ListConnection[FruitAsync] = relay.connection(
resolver=fruits_async_resolver
)
Expand Down
26 changes: 26 additions & 0 deletions tests/relay/snapshots/schema.gql
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,32 @@ type Query {
"""Returns the items in the list that come after the specified cursor."""
last: Int = null
): FruitConnection!
fruitsAlias(
"""Returns the items in the list that come before the specified cursor."""
before: String = null

"""Returns the items in the list that come after the specified cursor."""
after: String = null

"""Returns the first n items from the list."""
first: Int = null

"""Returns the items in the list that come after the specified cursor."""
last: Int = null
): FruitConnection!
fruitsAliasLazy(
"""Returns the items in the list that come before the specified cursor."""
before: String = null

"""Returns the items in the list that come after the specified cursor."""
after: String = null

"""Returns the first n items from the list."""
first: Int = null

"""Returns the items in the list that come after the specified cursor."""
last: Int = null
): FruitConnection!
fruitsAsync(
"""Returns the items in the list that come before the specified cursor."""
before: String = null
Expand Down
2 changes: 2 additions & 0 deletions tests/relay/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,8 @@ async def test_query_nodes_optional_async():
attrs = [
"fruits",
"fruitsLazy",
"fruitsAlias",
"fruitsAliasLazy",
"fruitsConcreteResolver",
"fruitsCustomResolver",
"fruitsCustomResolverLazy",
Expand Down
28 changes: 27 additions & 1 deletion tests/types/test_lazy_types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# type: ignore
import enum
from typing import Generic, TypeVar
from typing_extensions import Annotated
from typing_extensions import Annotated, TypeAlias

import strawberry
from strawberry.annotation import StrawberryAnnotation
Expand All @@ -11,13 +11,23 @@
from strawberry.types.fields.resolver import StrawberryResolver
from strawberry.union import StrawberryUnion, union

T = TypeVar("T")


# This type is in the same file but should adequately test the logic.
@strawberry.type
class LaziestType:
something: bool


@strawberry.type
class LazyGenericType(Generic[T]):
something: T


LazyTypeAlias: TypeAlias = LazyGenericType[int]


@strawberry.enum
class LazyEnum(enum.Enum):
BREAD = "BREAD"
Expand All @@ -38,6 +48,22 @@ def test_lazy_type():
assert resolved.resolve_type() is LaziestType


def test_lazy_type_alias():
# Module path is short and relative because of the way pytest runs the file
LazierType = LazyType("LazyTypeAlias", "test_lazy_types")

annotation = StrawberryAnnotation(LazierType)
resolved = annotation.resolve()

# Currently StrawberryAnnotation(LazyType).resolve() returns the unresolved
# LazyType. We may want to find a way to directly return the referenced object
# without a second resolving step.
assert isinstance(resolved, LazyType)
resolved_type = resolved.resolve_type()
assert resolved_type.__origin__ is LazyGenericType
assert resolved_type.__args__ == (int,)


def test_lazy_type_function():
LethargicType = Annotated["LaziestType", strawberry.lazy("test_lazy_types")]

Expand Down

0 comments on commit 3a1e0d5

Please sign in to comment.