From a98a0387d1183796905403ed4157698adef5099d Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Sat, 4 Jan 2025 12:50:34 +0100 Subject: [PATCH 1/2] feat(relay): Allow to customize max_results per connection in relay Fix #3734 --- RELEASE.md | 27 ++++++++++++++++ docs/guides/relay.md | 15 +++++++++ strawberry/relay/fields.py | 11 ++++++- strawberry/relay/types.py | 5 +++ strawberry/relay/utils.py | 7 +++- tests/relay/test_connection.py | 58 ++++++++++++++++++++++++++++++++++ 6 files changed, 121 insertions(+), 2 deletions(-) create mode 100644 RELEASE.md diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..35bf8e8414 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,27 @@ +Release type: minor + +Add the ability to override the "max results" a relay's connection can return on +a per-field basis. + +The default value for this is defined in the schema's config, and set to `100` +unless modified by the user. Now, that per-field value will take precedence over +it. + +For example: + +```python +@strawerry.type +class Query: + # This will still use the default value in the schema's config + fruits: ListConnection[Fruit] = relay.connection() + + # This will reduce the maximum number of results to 10 + limited_fruits: ListConnection[Fruit] = relay.connection(max_results=10) + + # This will increase the maximum number of results to 10 + higher_limited_fruits: ListConnection[Fruit] = relay.connection(max_results=10_000) +``` + +Note that this only affects `ListConnection` and subclasses. If you are +implementing your own connection resolver, there's an extra keyword named +`max_results: int | None` that will be passed to it. diff --git a/docs/guides/relay.md b/docs/guides/relay.md index 394583abcf..97cc3bc089 100644 --- a/docs/guides/relay.md +++ b/docs/guides/relay.md @@ -205,6 +205,21 @@ It can be defined in the `Query` objects in 4 ways: - `node: List[Optional[Node]]`: The same as `List[Node]`, but the returned list can contain `null` values if the given objects don't exist. +### Max results for connections + +The implementation of `relay.ListConnection` will limit the number of results to +the `relay_max_results` configuration in the +[schema's config](../types/schema-configurations.md) (which defaults to `100`). + +That can also be configured on a per-field basis by passing `max_results` to the +`@connection` decorator. For example: + +```python +@strawerry.type +class Query: + fruits: ListConnection[Fruit] = relay.connection(max_results=10_000) +``` + ### Custom connection pagination The default `relay.Connection` class don't implement any pagination logic, and diff --git a/strawberry/relay/fields.py b/strawberry/relay/fields.py index 347fd22169..54011cdfda 100644 --- a/strawberry/relay/fields.py +++ b/strawberry/relay/fields.py @@ -182,6 +182,9 @@ async def resolve(resolved: Any = resolved_nodes) -> list[Node]: class ConnectionExtension(FieldExtension): connection_type: type[Connection[Node]] + def __init__(self, max_results: Optional[int] = None) -> None: + self.max_results = max_results + def apply(self, field: StrawberryField) -> None: field.arguments = [ *field.arguments, @@ -288,6 +291,7 @@ def resolve( after=after, first=first, last=last, + max_results=self.max_results, ) async def resolve_async( @@ -316,6 +320,7 @@ async def resolve_async( after=after, first=first, last=last, + max_results=self.max_results, ) # If nodes was an AsyncIterable/AsyncIterator, resolve_connection @@ -357,6 +362,7 @@ def connection( metadata: Optional[Mapping[Any, Any]] = None, directives: Optional[Sequence[object]] = (), extensions: list[FieldExtension] = (), # type: ignore + max_results: Optional[int] = None, # This init parameter is used by pyright to determine whether this field # is added in the constructor or not. It is not used to change # any behaviour at the moment. @@ -389,6 +395,9 @@ def connection( metadata: The metadata of the field. directives: The directives to apply to the field. extensions: The extensions to apply to the field. + max_results: The maximum number of results this connection can return. + Can be set to override the default value of 100 defined in the + schema configuration. init: Used only for type checking purposes. Examples: @@ -451,7 +460,7 @@ def get_some_nodes(self, age: int) -> Iterable[SomeType]: ... default_factory=default_factory, metadata=metadata, directives=directives or (), - extensions=[*extensions, ConnectionExtension()], + extensions=[*extensions, ConnectionExtension(max_results=max_results)], ) if resolver is not None: f = f(resolver) diff --git a/strawberry/relay/types.py b/strawberry/relay/types.py index d529cf63f0..a014817777 100644 --- a/strawberry/relay/types.py +++ b/strawberry/relay/types.py @@ -717,6 +717,7 @@ def resolve_connection( after: Optional[str] = None, first: Optional[int] = None, last: Optional[int] = None, + max_results: Optional[int] = None, **kwargs: Any, ) -> AwaitableOrValue[Self]: """Resolve a connection from nodes. @@ -731,6 +732,7 @@ def resolve_connection( after: Returns the items in the list that come after the specified cursor. first: Returns the first n items from the list. last: Returns the items in the list that come after the specified cursor. + max_results: The maximum number of results to resolve. kwargs: Additional arguments passed to the resolver. Returns: @@ -767,6 +769,7 @@ def resolve_connection( # noqa: PLR0915 after: Optional[str] = None, first: Optional[int] = None, last: Optional[int] = None, + max_results: Optional[int] = None, **kwargs: Any, ) -> AwaitableOrValue[Self]: """Resolve a connection from the list of nodes. @@ -780,6 +783,7 @@ def resolve_connection( # noqa: PLR0915 after: Returns the items in the list that come after the specified cursor. first: Returns the first n items from the list. last: Returns the items in the list that come after the specified cursor. + max_results: The maximum number of results to resolve. kwargs: Additional arguments passed to the resolver. Returns: @@ -794,6 +798,7 @@ def resolve_connection( # noqa: PLR0915 after=after, first=first, last=last, + max_results=max_results, ) type_def = get_object_definition(cls) diff --git a/strawberry/relay/utils.py b/strawberry/relay/utils.py index d25eacb447..7098a9271e 100644 --- a/strawberry/relay/utils.py +++ b/strawberry/relay/utils.py @@ -131,11 +131,16 @@ def from_arguments( after: str | None = None, first: int | None = None, last: int | None = None, + max_results: int | None = None, ) -> Self: """Get the slice metadata to use on ListConnection.""" from strawberry.relay.types import PREFIX - max_results = info.schema.config.relay_max_results + max_results = ( + max_results + if max_results is not None + else info.schema.config.relay_max_results + ) start = 0 end: int | None = None diff --git a/tests/relay/test_connection.py b/tests/relay/test_connection.py index f97bb0d669..c85d148816 100644 --- a/tests/relay/test_connection.py +++ b/tests/relay/test_connection.py @@ -8,6 +8,8 @@ import strawberry from strawberry.permission import BasePermission from strawberry.relay import Connection, Node +from strawberry.relay.types import ListConnection +from strawberry.schema.config import StrawberryConfig @strawberry.type @@ -34,6 +36,8 @@ def resolve_connection( before: Optional[str] = None, first: Optional[int] = None, last: Optional[int] = None, + max_results: Optional[int] = None, + **kwargs: Any, ) -> Optional[Self]: return None @@ -124,3 +128,57 @@ def users(self) -> Optional[list[User]]: # pragma: no cover result = schema.execute_sync(query) assert result.data == {"users": None} assert result.errors[0].message == "Not allowed" + + +@pytest.mark.parametrize( + ("field_max_results", "schema_max_results", "results", "expected"), + [ + (5, 100, 5, 5), + (5, 2, 5, 5), + (5, 100, 10, 5), + (5, 2, 10, 5), + (5, 100, 0, 0), + (5, 2, 0, 0), + (None, 100, 5, 5), + (None, 2, 5, 2), + ], +) +def test_max_results( + field_max_results: Optional[int], + schema_max_results: int, + results: int, + expected: int, +): + @strawberry.type + class User(Node): + id: strawberry.relay.NodeID[str] + + @strawberry.type + class Query: + @strawberry.relay.connection( + ListConnection[User], + max_results=field_max_results, + ) + def users(self) -> list[User]: + return [User(id=str(i)) for i in range(results)] + + schema = strawberry.Schema( + query=Query, + config=StrawberryConfig(relay_max_results=schema_max_results), + ) + query = """ + query { + users { + edges { + node { + id + } + } + } + } + """ + + result = schema.execute_sync(query) + assert result.data is not None + assert isinstance(result.data["users"]["edges"], list) + assert len(result.data["users"]["edges"]) == expected From ec77ac93ecf050fd53f788d0d02cf8545b84c85e Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Sat, 4 Jan 2025 12:56:19 +0100 Subject: [PATCH 2/2] fix typo in the docs --- docs/guides/relay.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/guides/relay.md b/docs/guides/relay.md index 97cc3bc089..ba0ce18d18 100644 --- a/docs/guides/relay.md +++ b/docs/guides/relay.md @@ -222,7 +222,7 @@ class Query: ### Custom connection pagination -The default `relay.Connection` class don't implement any pagination logic, and +The default `relay.Connection` class doesn't implement any pagination logic, and should be used as a base class to implement your own pagination logic. All you need to do is implement the `resolve_connection` classmethod.