Skip to content

Commit

Permalink
feat(relay): Allow to customize max_results per connection in relay
Browse files Browse the repository at this point in the history
Fix #3734
  • Loading branch information
bellini666 committed Jan 4, 2025
1 parent 6bc7332 commit a98a038
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 2 deletions.
27 changes: 27 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
Release type: minor

Add the ability to override the "max results" a relay's connection can return on
a per-field basis.

The default value for this is defined in the schema's config, and set to `100`
unless modified by the user. Now, that per-field value will take precedence over
it.

For example:

```python
@strawerry.type
class Query:
# This will still use the default value in the schema's config
fruits: ListConnection[Fruit] = relay.connection()

# This will reduce the maximum number of results to 10
limited_fruits: ListConnection[Fruit] = relay.connection(max_results=10)

# This will increase the maximum number of results to 10
higher_limited_fruits: ListConnection[Fruit] = relay.connection(max_results=10_000)
```

Note that this only affects `ListConnection` and subclasses. If you are
implementing your own connection resolver, there's an extra keyword named
`max_results: int | None` that will be passed to it.
15 changes: 15 additions & 0 deletions docs/guides/relay.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,21 @@ It can be defined in the `Query` objects in 4 ways:
- `node: List[Optional[Node]]`: The same as `List[Node]`, but the returned list
can contain `null` values if the given objects don't exist.

### Max results for connections

The implementation of `relay.ListConnection` will limit the number of results to
the `relay_max_results` configuration in the
[schema's config](../types/schema-configurations.md) (which defaults to `100`).

That can also be configured on a per-field basis by passing `max_results` to the
`@connection` decorator. For example:

```python
@strawerry.type
class Query:
fruits: ListConnection[Fruit] = relay.connection(max_results=10_000)
```

### Custom connection pagination

The default `relay.Connection` class don't implement any pagination logic, and
Expand Down
11 changes: 10 additions & 1 deletion strawberry/relay/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ async def resolve(resolved: Any = resolved_nodes) -> list[Node]:
class ConnectionExtension(FieldExtension):
connection_type: type[Connection[Node]]

def __init__(self, max_results: Optional[int] = None) -> None:
self.max_results = max_results

def apply(self, field: StrawberryField) -> None:
field.arguments = [
*field.arguments,
Expand Down Expand Up @@ -288,6 +291,7 @@ def resolve(
after=after,
first=first,
last=last,
max_results=self.max_results,
)

async def resolve_async(
Expand Down Expand Up @@ -316,6 +320,7 @@ async def resolve_async(
after=after,
first=first,
last=last,
max_results=self.max_results,
)

# If nodes was an AsyncIterable/AsyncIterator, resolve_connection
Expand Down Expand Up @@ -357,6 +362,7 @@ def connection(
metadata: Optional[Mapping[Any, Any]] = None,
directives: Optional[Sequence[object]] = (),
extensions: list[FieldExtension] = (), # type: ignore
max_results: Optional[int] = None,
# This init parameter is used by pyright to determine whether this field
# is added in the constructor or not. It is not used to change
# any behaviour at the moment.
Expand Down Expand Up @@ -389,6 +395,9 @@ def connection(
metadata: The metadata of the field.
directives: The directives to apply to the field.
extensions: The extensions to apply to the field.
max_results: The maximum number of results this connection can return.
Can be set to override the default value of 100 defined in the
schema configuration.
init: Used only for type checking purposes.
Examples:
Expand Down Expand Up @@ -451,7 +460,7 @@ def get_some_nodes(self, age: int) -> Iterable[SomeType]: ...
default_factory=default_factory,
metadata=metadata,
directives=directives or (),
extensions=[*extensions, ConnectionExtension()],
extensions=[*extensions, ConnectionExtension(max_results=max_results)],
)
if resolver is not None:
f = f(resolver)
Expand Down
5 changes: 5 additions & 0 deletions strawberry/relay/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,7 @@ def resolve_connection(
after: Optional[str] = None,
first: Optional[int] = None,
last: Optional[int] = None,
max_results: Optional[int] = None,
**kwargs: Any,
) -> AwaitableOrValue[Self]:
"""Resolve a connection from nodes.
Expand All @@ -731,6 +732,7 @@ def resolve_connection(
after: Returns the items in the list that come after the specified cursor.
first: Returns the first n items from the list.
last: Returns the items in the list that come after the specified cursor.
max_results: The maximum number of results to resolve.
kwargs: Additional arguments passed to the resolver.
Returns:
Expand Down Expand Up @@ -767,6 +769,7 @@ def resolve_connection( # noqa: PLR0915
after: Optional[str] = None,
first: Optional[int] = None,
last: Optional[int] = None,
max_results: Optional[int] = None,
**kwargs: Any,
) -> AwaitableOrValue[Self]:
"""Resolve a connection from the list of nodes.
Expand All @@ -780,6 +783,7 @@ def resolve_connection( # noqa: PLR0915
after: Returns the items in the list that come after the specified cursor.
first: Returns the first n items from the list.
last: Returns the items in the list that come after the specified cursor.
max_results: The maximum number of results to resolve.
kwargs: Additional arguments passed to the resolver.
Returns:
Expand All @@ -794,6 +798,7 @@ def resolve_connection( # noqa: PLR0915
after=after,
first=first,
last=last,
max_results=max_results,
)

type_def = get_object_definition(cls)
Expand Down
7 changes: 6 additions & 1 deletion strawberry/relay/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,16 @@ def from_arguments(
after: str | None = None,
first: int | None = None,
last: int | None = None,
max_results: int | None = None,
) -> Self:
"""Get the slice metadata to use on ListConnection."""
from strawberry.relay.types import PREFIX

max_results = info.schema.config.relay_max_results
max_results = (
max_results
if max_results is not None
else info.schema.config.relay_max_results
)
start = 0
end: int | None = None

Expand Down
58 changes: 58 additions & 0 deletions tests/relay/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import strawberry
from strawberry.permission import BasePermission
from strawberry.relay import Connection, Node
from strawberry.relay.types import ListConnection
from strawberry.schema.config import StrawberryConfig


@strawberry.type
Expand All @@ -34,6 +36,8 @@ def resolve_connection(
before: Optional[str] = None,
first: Optional[int] = None,
last: Optional[int] = None,
max_results: Optional[int] = None,
**kwargs: Any,
) -> Optional[Self]:
return None

Expand Down Expand Up @@ -124,3 +128,57 @@ def users(self) -> Optional[list[User]]: # pragma: no cover
result = schema.execute_sync(query)
assert result.data == {"users": None}
assert result.errors[0].message == "Not allowed"


@pytest.mark.parametrize(
("field_max_results", "schema_max_results", "results", "expected"),
[
(5, 100, 5, 5),
(5, 2, 5, 5),
(5, 100, 10, 5),
(5, 2, 10, 5),
(5, 100, 0, 0),
(5, 2, 0, 0),
(None, 100, 5, 5),
(None, 2, 5, 2),
],
)
def test_max_results(
field_max_results: Optional[int],
schema_max_results: int,
results: int,
expected: int,
):
@strawberry.type
class User(Node):
id: strawberry.relay.NodeID[str]

@strawberry.type
class Query:
@strawberry.relay.connection(
ListConnection[User],
max_results=field_max_results,
)
def users(self) -> list[User]:
return [User(id=str(i)) for i in range(results)]

schema = strawberry.Schema(
query=Query,
config=StrawberryConfig(relay_max_results=schema_max_results),
)
query = """
query {
users {
edges {
node {
id
}
}
}
}
"""

result = schema.execute_sync(query)
assert result.data is not None
assert isinstance(result.data["users"]["edges"], list)
assert len(result.data["users"]["edges"]) == expected

0 comments on commit a98a038

Please sign in to comment.