diff --git a/docs/contributing.md b/docs/contributing.md index eb2d9cf5..19fb84c2 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -5,7 +5,7 @@ get started follow these steps: ```shell git clone https://github.com/strawberry-graphql/strawberry-django.git -cd strawberry-django +cd strawberry_django poetry install poetry run pytest ``` diff --git a/docs/guide/filters.md b/docs/guide/filters.md index 9a9654a4..bee2273c 100644 --- a/docs/guide/filters.md +++ b/docs/guide/filters.md @@ -17,6 +17,13 @@ class Fruit: ... ``` +!!! tip + + In most cases filter fields should have `Optional` annotations and default value `strawberry.UNSET` like so: + `foo: Optional[SomeType] = strawberry.UNSET` + Above `auto` annotation is wrapped in `Optional` automatically. + `UNSET` is automatically used for fields without `field` or with `strawberry_django.filter_field`. + The code above would generate following schema: ```{.graphql title=schema.graphql} @@ -26,6 +33,7 @@ input FruitFilter { AND: FruitFilter OR: FruitFilter NOT: FruitFilter + DISTINCT: Boolean } ``` @@ -36,6 +44,23 @@ input FruitFilter { `MAP_AUTO_ID_AS_GLOBAL_ID=True` in your [strawberry django settings](../settings) to make sure `auto` fields gets mapped to `GlobalID` on types and filters. +## AND, OR, NOT, DISTINCT ... + +To every filter `AND`, `OR`, `NOT` & `DISTINCT` fields are added to allow more complex filtering + +```graphql +{ + fruits( + filters: { + name: "kebab" + OR: { + name: "raspberry" + } + } + ) { ... } +} +``` + ## Lookups Lookups can be added to all fields with `lookups=True`, which will @@ -51,78 +76,25 @@ class FruitFilter: The code above would generate the following schema: ```{.graphql title=schema.graphql} -input StrFilterLookup { - exact: String - iExact: String - contains: String - iContains: String - inList: [String!] - gt: String - gte: String - lt: String - lte: String - startsWith: String - iStartsWith: String - endsWith: String - iEndsWith: String - range: [String!] +input IDBaseFilterLookup { + exact: ID isNull: Boolean - regex: String - iRegex: String - nExact: String - nIExact: String - nContains: String - nIContains: String - nInList: [String!] - nGt: String - nGte: String - nLt: String - nLte: String - nStartsWith: String - nIStartsWith: String - nEndsWith: String - nIEndsWith: String - nRange: [String!] - nIsNull: Boolean - nRegex: String - nIRegex: String + inList: [String!] } -input IDFilterLookup { - exact: String +input StrFilterLookup { + exact: ID + isNull: Boolean + inList: [String!] iExact: String contains: String iContains: String - inList: [String!] - gt: String - gte: String - lt: String - lte: String startsWith: String iStartsWith: String endsWith: String iEndsWith: String - range: [String!] - isNull: Boolean regex: String iRegex: String - nExact: String - nIExact: String - nContains: String - nIContains: String - nInList: [String!] - nGt: String - nGte: String - nLt: String - nLte: String - nStartsWith: String - nIStartsWith: String - nEndsWith: String - nIEndsWith: String - nRange: [String!] - nIsNull: Boolean - nRegex: String - nIRegex: String } input FruitFilter { @@ -131,13 +103,14 @@ input FruitFilter { AND: FruitFilter OR: FruitFilter NOT: FruitFilter + DISTINCT: Boolean } ``` Single-field lookup can be annotated with the `FilterLookup` generic type. ```{.python title=types.py} -from strawberry_django.filters import FilterLookup +from strawberry_django import FilterLookup @strawberry_django.filter(models.Fruit) class FruitFilter: @@ -147,16 +120,16 @@ class FruitFilter: ## Filtering over relationships ```{.python title=types.py} -@strawberry_django.filter(models.Fruit) -class FruitFilter: +@strawberry_django.filter(models.Color) +class ColorFilter: id: auto name: auto - color: "ColorFilter" -@strawberry_django.filter(models.Color) -class ColorFilter: +@strawberry_django.filter(models.Fruit) +class FruitFilter: id: auto name: auto + color: ColorFilter | None ``` The code above would generate following schema: @@ -180,44 +153,186 @@ input FruitFilter { } ``` -## Custom filters and overriding default filtering methods +## Custom filter methods -You can define custom filter methods and override default filter methods by defining your own resolver. +You can define custom filter method by defining your own resolver. ```{.python title=types.py} -@strawberry_django.filter(models.Fruit) +@strawberry_django.filter(models.User) class FruitFilter: - is_banana: bool | None + name: auto + last_name: auto + + @strawberry_django.filter_field + def simple(self, value: str, prefix) -> Q: + return Q(**{f"{prefix}name": value}) + + @strawberry_django.filter_field + def full_name( + self, + queryset: QuerySet, + value: str, + prefix: str + ) -> tuple[QuerySet, Q]: + queryset = queryset.alias( + _fullname=Concat( + f"{prefix}name", Value(" "), f"{prefix}last_name" + ) + ) + return queryset, Q(**{"_fullname": value}) + + @strawberry_django.filter_field + def full_name_lookups( + self, + info: Info, + queryset: QuerySet, + value: strawberry_django.FilterLookups[str], + prefix: str + ) -> tuple[QuerySet, Q]: + queryset = queryset.alias( + _fullname=Concat( + f"{prefix}name", Value(" "), f"{prefix}last_name" + ) + ) + return strawberry_django.process_filters( + filters=value, + queryset=queryset, + info=info, + prefix=f"{prefix}_fullname" + ) +``` + +!!! warning - def filter_is_banana(self, queryset): - if self.is_banana in (None, strawberry.UNSET): - return queryset + It is discouraged to use `queryset.filter()` directly. When using more + complex filtering via `NOT`, `OR` & `AND` this might lead to undesired behaviour. - if self.is_banana: - queryset = queryset.filter(name='banana') - else: - queryset = queryset.exclude(name='banana') +!!! tip + + As seen above `strawberry_django.process_filters` function is exposed and can be + reused in custom methods. Above it's used to resolve fields lookups + +!!! tip + + By default `null` value is ignored for all filters & lookups. This applies to custom + filter methods as well. Those won't even be called (you don't have to check for `None`). + This can be modified using + `strawberry_django.filter_field(filter_none=True)` + + This also means that build in `exact` & `iExact` lookups cannot be used to filter for `None` + and `isNull` have to be used explicitly. - return queryset +The code above generates the following schema: + +```{.graphql title=schema.graphql} +input FruitFilter { + name: String + lastName: String + simple: str + fullName: str + fullNameLookups: StrFilterLookups +} ``` +#### Resolver arguments + +- `prefix` - represents the current path or position + - **Required** + - Important for nested filtering + - In code bellow custom filter `name` ends up filtering `Fruit` instead of `Color` without applying `prefix` + +```{.python title="Why prefix?"} +@strawberry_django.filter(models.Fruit) +class FruitFilter: + name: auto + color: ColorFilter | None + +@strawberry_django.filter(models.Color) +class ColorFilter: + @strawberry_django.filter_field + def name(self, value: str, prefix: str): + # prefix is "fruit_set__" if unused root object is filtered instead + if value: + return Q(name=value) + return Q() +``` + +```graphql +{ + fruits( filters: {color: name: "blue"} ) { ... } +} +``` + +- `value` - represents graphql field type + - **Required**, but forbidden for default `filter` method + - _must_ be annotated + - used instead of field's return type +- `queryset` - can be used for more complex filtering + - Optional, but **Required** for default `filter` method + - usually used to `annotate` `QuerySet` + +#### Resolver return + +For custom field methods two return values are supported + +- django's `Q` object +- tuple with `QuerySet` and django's `Q` object -> `tuple[QuerySet, Q]` + +For default `filter` method only second variant is supported. + +### What about nulls? + +By default `null` values are ignored. This can be toggled as such `@strawberry_django.filter_field(filter_none=True)` + ## Overriding the default `filter` method -For overriding the default filter logic you can provide the filter method. -Note that this completely disables the default filtering, which means your custom -method is responsible for handling _all_ filter-related operations. +Works similar to field filter method, but: + +- is responsible for resolution of filtering for entire object +- _must_ be named `filter` +- argument `queryset` is **Required** +- argument `value` is **Forbidden** ```{.python title=types.py} @strawberry_django.filter(models.Fruit) class FruitFilter: - is_apple: bool | None - - def filter(self, queryset): - if self.is_apple: - return queryset.filter(name='apple') - return queryset.exclude(name='apple') + def ordered( + self, + value: int, + prefix: str, + queryset: QuerySet, + ): + queryset = queryset.alias( + _ordered_num=Count(f"{prefix}orders__id") + ) + return queryset, Q(**{f"{prefix}_ordered_num": value}) + + @strawberry_django.order_field + def filter( + self, + info: Info, + queryset: QuerySet, + prefix: str, + ) -> tuple[QuerySet, list[Q]]: + queryset = queryset.filter( + ... # Do some query modification + ) + + return strawberry_django.proces_filters( + self, + info=info, + queryset=queryset, + prefix=prefix, + skip_object_order_method=True + ) ``` +!!! tip + + As seen above `strawberry_django.proces_filters` function is exposed and can be + reused in custom methods. + For filter method `filter` `skip_object_order_method` was used to avoid endless recursion. + ## Adding filters to types All fields and CUD mutations inherit filters from the underlying type by default. @@ -245,3 +360,70 @@ Filters added into a field override the default filters of this type. class Query: fruits: list[Fruit] = strawberry_django.field(filters=FruitFilter) ``` + +## Generic Lookup reference + +There is 7 already defined Generic Lookup `strawberry.input` classes importable from `strawberry_django` + +#### `BaseFilterLookup` + +- contains `exact`, `isNull` & `inList` +- used for `ID` & `bool` fields + +#### `RangeLookup` + +- used for `range` or `BETWEEN` filtering + +#### `ComparisonFilterLookup` + +- inherits `BaseFilterLookup` +- additionaly contains `gt`, `gte`, `lt`, `lte`, & `range` +- used for Numberical fields + +#### `FilterLookup` + +- inherits `BaseFilterLookup` +- additionally contains `iExact`, `contains`, `iContains`, `startsWith`, `iStartsWith`, `endsWith`, `iEndsWith`, `regex` & `iRegex` +- used for string based fields and as default + +#### `DateFilterLookup` + +- inherits `ComparisonFilterLookup` +- additionally contains `year`,`month`,`day`,`weekDay`,`isoWeekDay`,`week`,`isoYear` & `quarter` +- used for date based fields + +#### `TimeFilterLookup` + +- inherits `ComparisonFilterLookup` +- additionally contains `hour`,`minute`,`second`,`date` & `time` +- used for time based fields + +#### `DatetimeFilterLookup` + +- inherits `DateFilterLookup` & `TimeFilterLookup` +- used for timedate based fields + +## Legacy filtering + +The previous version of filters can be enabled via [**USE_DEPRECATED_FILTERS**](settings.md#strawberry_django) + +!!! warning + + If **USE_DEPRECATED_FILTERS** is not set to `True` legacy custom filtering + methods will be _not_ be called. + +When using legacy filters it is important to use legacy +`strawberry_django.filters.FilterLookup` lookups as well. +The correct version is applied for `auto` +annotated filter field (given `lookups=True` being set). Mixing old and new lookups +might lead to error `DuplicatedTypeName: Type StrFilterLookup is defined multiple times in the schema`. + +While legacy filtering is enabled new filtering custom methods are +fully functional including default `filter` method. + +Migration process could be composed of these steps: + +- enable **USE_DEPRECATED_FILTERS** +- gradually transform custom filter field methods to new version (do not forget to use old FilterLookup if applicable) +- gradually transform default `filter` methods +- disable **USE_DEPRECATED_FILTERS** - **_This is breaking change_** diff --git a/docs/guide/ordering.md b/docs/guide/ordering.md index 24a0789a..8e7c134f 100644 --- a/docs/guide/ordering.md +++ b/docs/guide/ordering.md @@ -1,9 +1,5 @@ # Ordering -!!! note - - This API may change in the future. - ```{.python title=types.py} @strawberry_django.order(models.Color) class ColorOrder: @@ -12,15 +8,25 @@ class ColorOrder: @strawberry_django.order(models.Fruit) class FruitOrder: name: auto - color: ColorOrder + color: ColorOrder | None ``` +!!! tip + + In most cases order fields should have `Optional` annotations and default value `strawberry.UNSET`. + Above `auto` annotation is wrapped in `Optional` automatically. + `UNSET` is automatically used for fields without `field` or with `strawberry_django.order_field`. + The code above generates the following schema: ```{.graphql title=schema.graphql} enum Ordering { ASC + ASC_NULLS_FIRST + ASC_NULLS_LAST DESC + DESC_NULLS_FIRST + DESC_NULLS_LAST } input ColorOrder { @@ -33,6 +39,181 @@ input FruitOrder { } ``` +## Custom order methods + +You can define custom order method by defining your own resolver. + +```{.python title=types.py} +@strawberry_django.order(models.Fruit) +class FruitOrder: + name: auto + + @strawberry_django.order_field + def discovered_by(self, value: bool, prefix: str) -> list[str]: + if not value: + return [] + return [f"{prefix}discover_by__name", f"{prefix}name"] + + @strawberry_django.order_field + def order_number( + self, + info: Info, + queryset: QuerySet, + value: strawberry_django.Ordering, # `auto` can be used instead + prefix: str, + sequence: dict[str, strawberry_django.Ordering] | None + ) -> tuple[QuerySet, list[str]] | list[str]: + queryset = queryset.alias( + _ordered_num=Count(f"{prefix}orders__id") + ) + ordering = value.resolve(f"{prefix}_ordered_num") + return queryset, [ordering] +``` + +!!! warning + + Do not use `queryset.order_by()` directly. Due to `order_by` not being chainable + operation, changes applied this way would be overriden later. + +!!! tip + + `strawberry_django.Ordering` has convinient method `resolve` that can be used to + convert field's name to appropriate `F` object with correctly applied `asc()`, `desc()` method + with `nulls_first` and `nulls_last` arguments. + +The code above generates the following schema: + +```{.graphql title=schema.graphql} +enum Ordering { + ASC + ASC_NULLS_FIRST + ASC_NULLS_LAST + DESC + DESC_NULLS_FIRST + DESC_NULLS_LAST +} + +input FruitOrder { + name: Ordering + discoveredBy: bool + orderNumber: Ordering +} +``` + +#### Resolver arguments + +- `prefix` - represents the current path or position + - **Required** + - Important for nested ordering + - In code bellow custom order `name` ends up ordering `Fruit` instead of `Color` without applying `prefix` + +```{.python title="Why prefix?"} +@strawberry_django.order(models.Fruit) +class FruitOrder: + name: auto + color: ColorOrder | None + +@strawberry_django.order(models.Color) +class ColorOrder: + @strawberry_django.order_field + def name(self, value: bool, prefix: str): + # prefix is "fruit_set__" if unused root object is ordered instead + if value: + return ["name"] + return [] +``` + +```graphql +{ + fruits( order: {color: name: ASC} ) { ... } +} +``` + +- `value` - represents graphql field type + - **Required**, but forbidden for default `order` method + - _must_ be annotated + - used instead of field's return type + - Using `auto` is the same as `strawberry_django.Ordering`. +- `queryset` - can be used for more complex ordering + - Optional, but **Required** for default `order` method + - usually used to `annotate` `QuerySet` +- `sequence` - used to order values on the same level + - elements in graphql object are not quaranteed to keep their order as defined by user thus + this argument should be used in those cases + [GraphQL Spec](https://spec.graphql.org/October2021/#sec-Language.Arguments) + - usually for custom order field methods does not have to be used + - for advanced usage, look at `strawberry_django.process_order` function + +#### Resolver return + +For custom field methods two return values are supported + +- iterable of values acceptable by `QuerySet.order_by` -> `Collection[F | str]` +- tuple with `QuerySet` and iterable of values acceptable by `QuerySet.order_by` -> `tuple[QuerySet, Collection[F | str]]` + +For default `order` method only second variant is supported. + +### What about nulls? + +By default `null` values are ignored. This can be toggled as such `@strawberry_django.order_field(order_none=True)` + +## Overriding the default `order` method + +Works similar to field order method, but: + +- is responsible for resolution of ordering for entire object +- _must_ be named `order` +- argument `queryset` is **Required** +- argument `value` is **Forbidden** +- should probaly use `sequence` + +```{.python title=types.py} +@strawberry_django.order(models.Fruit) +class FruitOrder: + name: auto + + @strawberry_django.order_field + def ordered( + self, + info: Info, + queryset: QuerySet, + value: strawberry_django.Ordering, + prefix: str + ) -> tuple[QuerySet, list[str]] | list[str]: + queryset = queryset.alias( + _ordered_num=Count(f"{prefix}orders__id") + ) + return queryset, [value.resolve(f"{prefix}_ordered_num") ] + + @strawberry_django.order_field + def order( + self, + info: Info, + queryset: QuerySet, + prefix: str, + sequence: dict[str, strawberry_django.Ordering] | None + ) -> tuple[QuerySet, list[str]]: + queryset = queryset.filter( + ... # Do some query modification + ) + + return strawberry_django.process_order( + self, + info=info, + queryset=queryset, + sequence=sequence, + prefix=prefix, + skip_object_order_method=True + ) + +``` + +!!! tip + + As seen above `strawberry_django.process_order` function is exposed and can be + reused in custom methods. + For order method `order` `skip_object_order_method` was used to avoid endless recursion. + ## Adding orderings to types All fields and mutations inherit orderings from the underlying type by default. @@ -44,12 +225,12 @@ class Fruit: ... ``` -The `fruits` field will inherit the `filters` of the type same same way as +The `fruits` field will inherit the `order` of the type same same way as if it was passed to the field. ## Adding orderings directly into a field -Orderings added into a field override the default filters of this type. +Orderings added into a field override the default order of this type. ```{.python title=schema.py} @strawberry.type diff --git a/docs/guide/settings.md b/docs/guide/settings.md index 46db2c8b..dc0d5b8b 100644 --- a/docs/guide/settings.md +++ b/docs/guide/settings.md @@ -49,6 +49,10 @@ A dictionary with the following optional keys: instead of `strawberry.ID`. This is mostly useful if all your model types inherit from `relay.Node` and you want to work only with `GlobalID`. +- **`USE_DEPRECATED_FILTERS`** (default: `False`) + + If True, [legacy filters](filters.md#legacy-filtering) are enabled. This is usefull for migrating from previous version. + These features can be enabled by adding this code to your `settings.py` file. ```{.python title=settings.py} diff --git a/examples/django/app/types.py b/examples/django/app/types.py index 2e6d0b36..dc2c2c3b 100644 --- a/examples/django/app/types.py +++ b/examples/django/app/types.py @@ -1,9 +1,10 @@ -from typing import List +from typing import List, Optional from strawberry import auto import strawberry_django from django.contrib.auth import get_user_model +from django.db.models import Q from . import models @@ -14,14 +15,22 @@ class FruitFilter: id: auto name: auto - color: "ColorFilter" + color: Optional["ColorFilter"] + + @strawberry_django.filter_field + def special_filter(self, prefix: str, value: str): + return Q(**{f"{prefix}name": value}) @strawberry_django.filters.filter(models.Color, lookups=True) class ColorFilter: id: auto name: auto - fruits: FruitFilter + fruits: Optional[FruitFilter] + + @strawberry_django.filter_field + def filter(self, prefix, queryset): + return queryset, Q() # order @@ -30,13 +39,17 @@ class ColorFilter: @strawberry_django.ordering.order(models.Fruit) class FruitOrder: name: auto - color: "ColorOrder" + color: Optional["ColorOrder"] @strawberry_django.ordering.order(models.Color) class ColorOrder: name: auto - fruit: FruitOrder + fruits: FruitOrder + + @strawberry_django.order_field + def special_order(self, prefix: str, value: auto): + return [value.resolve(f"{prefix}fruits__name")] # types @@ -51,7 +64,7 @@ class ColorOrder: class Fruit: id: auto name: auto - color: "Color" + color: Optional["Color"] @strawberry_django.type( diff --git a/examples/django/poetry.lock b/examples/django/poetry.lock index aa41d52a..c3c40333 100644 --- a/examples/django/poetry.lock +++ b/examples/django/poetry.lock @@ -1,51 +1,66 @@ -# This file is automatically @generated by Poetry 1.4.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "asgiref" -version = "3.5.2" +version = "3.7.2" description = "ASGI specs, helper code, and adapters" -category = "main" optional = false python-versions = ">=3.7" files = [ - {file = "asgiref-3.5.2-py3-none-any.whl", hash = "sha256:1d2880b792ae8757289136f1db2b7b99100ce959b2aa57fd69dab783d05afac4"}, - {file = "asgiref-3.5.2.tar.gz", hash = "sha256:4a29362a6acebe09bf1d6640db38c1dc3d9217c68e6f9f6204d72667fc19a424"}, + {file = "asgiref-3.7.2-py3-none-any.whl", hash = "sha256:89b2ef2247e3b562a16eef663bc0e2e703ec6468e2fa8a5cd61cd449786d4f6e"}, + {file = "asgiref-3.7.2.tar.gz", hash = "sha256:9e0ce3aa93a819ba5b45120216b23878cf6e8525eb3848653452b4192b92afed"}, ] [package.dependencies] -typing-extensions = {version = "*", markers = "python_version < \"3.8\""} +typing-extensions = {version = ">=4", markers = "python_version < \"3.11\""} [package.extras] tests = ["mypy (>=0.800)", "pytest", "pytest-asyncio"] [[package]] -name = "backports.cached-property" -version = "1.0.2" -description = "cached_property() - computed once per instance, cached as attribute" -category = "main" +name = "backports-zoneinfo" +version = "0.2.1" +description = "Backport of the standard library zoneinfo module" optional = false -python-versions = ">=3.6.0" +python-versions = ">=3.6" files = [ - {file = "backports.cached-property-1.0.2.tar.gz", hash = "sha256:9306f9eed6ec55fd156ace6bc1094e2c86fae5fb2bf07b6a9c00745c656e75dd"}, - {file = "backports.cached_property-1.0.2-py3-none-any.whl", hash = "sha256:baeb28e1cd619a3c9ab8941431fe34e8490861fb998c6c4590693d50171db0cc"}, + {file = "backports.zoneinfo-0.2.1-cp36-cp36m-macosx_10_14_x86_64.whl", hash = "sha256:da6013fd84a690242c310d77ddb8441a559e9cb3d3d59ebac9aca1a57b2e18bc"}, + {file = "backports.zoneinfo-0.2.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:89a48c0d158a3cc3f654da4c2de1ceba85263fafb861b98b59040a5086259722"}, + {file = "backports.zoneinfo-0.2.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:1c5742112073a563c81f786e77514969acb58649bcdf6cdf0b4ed31a348d4546"}, + {file = "backports.zoneinfo-0.2.1-cp36-cp36m-win32.whl", hash = "sha256:e8236383a20872c0cdf5a62b554b27538db7fa1bbec52429d8d106effbaeca08"}, + {file = "backports.zoneinfo-0.2.1-cp36-cp36m-win_amd64.whl", hash = "sha256:8439c030a11780786a2002261569bdf362264f605dfa4d65090b64b05c9f79a7"}, + {file = "backports.zoneinfo-0.2.1-cp37-cp37m-macosx_10_14_x86_64.whl", hash = "sha256:f04e857b59d9d1ccc39ce2da1021d196e47234873820cbeaad210724b1ee28ac"}, + {file = "backports.zoneinfo-0.2.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:17746bd546106fa389c51dbea67c8b7c8f0d14b5526a579ca6ccf5ed72c526cf"}, + {file = "backports.zoneinfo-0.2.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:5c144945a7752ca544b4b78c8c41544cdfaf9786f25fe5ffb10e838e19a27570"}, + {file = "backports.zoneinfo-0.2.1-cp37-cp37m-win32.whl", hash = "sha256:e55b384612d93be96506932a786bbcde5a2db7a9e6a4bb4bffe8b733f5b9036b"}, + {file = "backports.zoneinfo-0.2.1-cp37-cp37m-win_amd64.whl", hash = "sha256:a76b38c52400b762e48131494ba26be363491ac4f9a04c1b7e92483d169f6582"}, + {file = "backports.zoneinfo-0.2.1-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:8961c0f32cd0336fb8e8ead11a1f8cd99ec07145ec2931122faaac1c8f7fd987"}, + {file = "backports.zoneinfo-0.2.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:e81b76cace8eda1fca50e345242ba977f9be6ae3945af8d46326d776b4cf78d1"}, + {file = "backports.zoneinfo-0.2.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:7b0a64cda4145548fed9efc10322770f929b944ce5cee6c0dfe0c87bf4c0c8c9"}, + {file = "backports.zoneinfo-0.2.1-cp38-cp38-win32.whl", hash = "sha256:1b13e654a55cd45672cb54ed12148cd33628f672548f373963b0bff67b217328"}, + {file = "backports.zoneinfo-0.2.1-cp38-cp38-win_amd64.whl", hash = "sha256:4a0f800587060bf8880f954dbef70de6c11bbe59c673c3d818921f042f9954a6"}, + {file = "backports.zoneinfo-0.2.1.tar.gz", hash = "sha256:fadbfe37f74051d024037f223b8e001611eac868b5c5b06144ef4d8b799862f2"}, ] +[package.extras] +tzdata = ["tzdata"] + [[package]] -name = "Django" -version = "3.2.16" -description = "A high-level Python Web framework that encourages rapid development and clean, pragmatic design." -category = "main" +name = "django" +version = "4.2.10" +description = "A high-level Python web framework that encourages rapid development and clean, pragmatic design." optional = false -python-versions = ">=3.6" +python-versions = ">=3.8" files = [ - {file = "Django-3.2.16-py3-none-any.whl", hash = "sha256:18ba8efa36b69cfcd4b670d0fa187c6fe7506596f0ababe580e16909bcdec121"}, - {file = "Django-3.2.16.tar.gz", hash = "sha256:3adc285124244724a394fa9b9839cc8cd116faf7d159554c43ecdaa8cdf0b94d"}, + {file = "Django-4.2.10-py3-none-any.whl", hash = "sha256:a2d4c4d4ea0b6f0895acde632071aff6400bfc331228fc978b05452a0ff3e9f1"}, + {file = "Django-4.2.10.tar.gz", hash = "sha256:b1260ed381b10a11753c73444408e19869f3241fc45c985cd55a30177c789d13"}, ] [package.dependencies] -asgiref = ">=3.3.2,<4" -pytz = "*" -sqlparse = ">=0.2.2" +asgiref = ">=3.6.0,<4" +"backports.zoneinfo" = {version = "*", markers = "python_version < \"3.9\""} +sqlparse = ">=0.3.1" +tzdata = {version = "*", markers = "sys_platform == \"win32\""} [package.extras] argon2 = ["argon2-cffi (>=19.1.0)"] @@ -55,7 +70,6 @@ bcrypt = ["bcrypt"] name = "django-debug-toolbar" version = "3.8.1" description = "A configurable set of panels that display various debug information about the current request/response." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -71,7 +85,6 @@ sqlparse = ">=0.2" name = "graphql-core" version = "3.2.3" description = "GraphQL implementation for Python, a port of GraphQL.js, the JavaScript reference implementation for GraphQL." -category = "main" optional = false python-versions = ">=3.6,<4" files = [ @@ -79,14 +92,10 @@ files = [ {file = "graphql_core-3.2.3-py3-none-any.whl", hash = "sha256:5766780452bd5ec8ba133f8bf287dc92713e3868ddd83aee4faab9fc3e303dc3"}, ] -[package.dependencies] -typing-extensions = {version = ">=4.2,<5", markers = "python_version < \"3.8\""} - [[package]] name = "python-dateutil" version = "2.8.2" description = "Extensions to the standard Python datetime module" -category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ @@ -97,23 +106,10 @@ files = [ [package.dependencies] six = ">=1.5" -[[package]] -name = "pytz" -version = "2022.4" -description = "World timezone definitions, modern and historical" -category = "main" -optional = false -python-versions = "*" -files = [ - {file = "pytz-2022.4-py2.py3-none-any.whl", hash = "sha256:2c0784747071402c6e99f0bafdb7da0fa22645f06554c7ae06bf6358897e9c91"}, - {file = "pytz-2022.4.tar.gz", hash = "sha256:48ce799d83b6f8aab2020e369b627446696619e79645419610b9facd909b3174"}, -] - [[package]] name = "six" version = "1.16.0" description = "Python 2 and 3 compatibility utilities" -category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -125,7 +121,6 @@ files = [ name = "sqlparse" version = "0.4.3" description = "A non-validating SQL parser." -category = "main" optional = false python-versions = ">=3.5" files = [ @@ -137,7 +132,6 @@ files = [ name = "strawberry-graphql" version = "0.159.0" description = "A library for creating GraphQL APIs" -category = "main" optional = false python-versions = ">=3.7,<4.0" files = [ @@ -146,7 +140,6 @@ files = [ ] [package.dependencies] -"backports.cached-property" = {version = ">=1.0.2,<2.0.0", markers = "python_version < \"3.8\""} graphql-core = ">=3.2.0,<3.3.0" python-dateutil = ">=2.7.0,<3.0.0" typing_extensions = ">=3.7.4,<5.0.0" @@ -170,7 +163,6 @@ sanic = ["sanic (>=20.12.2)"] name = "strawberry-graphql-django" version = "0.9.2" description = "Strawberry GraphQL Django extension" -category = "main" optional = false python-versions = ">=3.7,<4.0" files = [] @@ -191,7 +183,6 @@ url = "../.." name = "typing-extensions" version = "4.4.0" description = "Backported and Experimental Type Hints for Python 3.7+" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -199,7 +190,18 @@ files = [ {file = "typing_extensions-4.4.0.tar.gz", hash = "sha256:1511434bb92bf8dd198c12b1cc812e800d4181cfcb867674e0f8279cc93087aa"}, ] +[[package]] +name = "tzdata" +version = "2023.4" +description = "Provider of IANA time zone data" +optional = false +python-versions = ">=2" +files = [ + {file = "tzdata-2023.4-py2.py3-none-any.whl", hash = "sha256:aa3ace4329eeacda5b7beb7ea08ece826c28d761cda36e747cfbf97996d39bf3"}, + {file = "tzdata-2023.4.tar.gz", hash = "sha256:dd54c94f294765522c77399649b4fefd95522479a664a0cec87f41bebc6148c9"}, +] + [metadata] lock-version = "2.0" -python-versions = "^3.7" -content-hash = "95b1066c119117ce1c76b21fbd621172088621e591e7534e8fa8ef1878f34c67" +python-versions = "^3.8" +content-hash = "a6f08fff8b54f70dd16666b7f6d37308716493945af4b89c59c03e2fff49dc47" diff --git a/strawberry_django/__init__.py b/strawberry_django/__init__.py index 899d0a70..55100be4 100644 --- a/strawberry_django/__init__.py +++ b/strawberry_django/__init__.py @@ -1,5 +1,15 @@ from . import auth, filters, mutations, ordering, pagination, relay from .fields.field import connection, field, node +from .fields.filter_order import filter_field, order_field +from .fields.filter_types import ( + BaseFilterLookup, + ComparisonFilterLookup, + DateFilterLookup, + DatetimeFilterLookup, + FilterLookup, + RangeLookup, + TimeFilterLookup, +) from .fields.types import ( DjangoFileType, DjangoImageType, @@ -12,16 +22,21 @@ OneToManyInput, OneToOneInput, ) -from .filters import filter +from .filters import filter, process_filters from .mutations.mutations import input_mutation, mutation -from .ordering import order +from .ordering import Ordering, order, process_order from .resolvers import django_resolver from .type import input, interface, partial, type __all__ = [ + "BaseFilterLookup", + "ComparisonFilterLookup", + "DateFilterLookup", + "DatetimeFilterLookup", "DjangoFileType", "DjangoImageType", "DjangoModelType", + "FilterLookup", "ListInput", "ManyToManyInput", "ManyToOneInput", @@ -29,11 +44,15 @@ "NodeInputPartial", "OneToManyInput", "OneToOneInput", + "Ordering", + "RangeLookup", + "TimeFilterLookup", "auth", "connection", "django_resolver", "field", "filter", + "filter_field", "filters", "input", "input_mutation", @@ -42,9 +61,12 @@ "mutations", "node", "order", + "order_field", "ordering", "pagination", "partial", + "process_filters", + "process_order", "relay", "type", ] diff --git a/strawberry_django/exceptions.py b/strawberry_django/exceptions.py new file mode 100644 index 00000000..5ef7a7cf --- /dev/null +++ b/strawberry_django/exceptions.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from functools import cached_property +from typing import TYPE_CHECKING + +from strawberry.exceptions.exception import StrawberryException +from strawberry.exceptions.utils.source_finder import SourceFinder + +if TYPE_CHECKING: + from strawberry.exceptions.exception_source import ExceptionSource + + from strawberry_django.fields.filter_order import FilterOrderFieldResolver + + +class MissingFieldArgumentError(StrawberryException): + def __init__(self, field_name: str, resolver: FilterOrderFieldResolver): + self.function = resolver.wrapped_func + + self.message = ( + f'Missing required argument "{field_name}" ' f'in "{resolver.name}"' + ) + self.rich_message = ( + f'[bold red]Missing argument [underline]"{field_name}" for field ' + f"`[underline]{resolver.name}[/]`" + ) + self.annotation_message = "field missing argument" + + super().__init__(self.message) + + @cached_property + def exception_source(self) -> ExceptionSource | None: # pragma: no cover + source_finder = SourceFinder() + + return source_finder.find_function_from_object(self.function) # type: ignore + + +class ForbiddenFieldArgumentError(StrawberryException): + def __init__(self, resolver: FilterOrderFieldResolver, arguments: list[str]): + self.extra_arguments = arguments + self.function = resolver.wrapped_func + self.argument_name = arguments[0] + + self.message = ( + f'Found disallowed {self.extra_arguments_str} in field "{resolver.name}"' + ) + self.rich_message = ( + f"Found disallowed {self.extra_arguments_str} in " + f"`[underline]{resolver.name}[/]`" + ) + self.suggestion = "To fix this error, remove offending argument(s)" + + self.annotation_message = "forbidden field argument" + + super().__init__(self.message) + + @property + def extra_arguments_str(self) -> str: + arguments = self.extra_arguments + + if len(arguments) == 1: + return f'argument "{arguments[0]}"' + + head = ", ".join(arguments[:-1]) + return f'arguments "{head}" and "{arguments[-1]}"' + + @cached_property + def exception_source(self) -> ExceptionSource | None: # pragma: no cover + source_finder = SourceFinder() + + return source_finder.find_argument_from_object( + self.function, # type: ignore + self.argument_name, + ) diff --git a/strawberry_django/fields/filter_order.py b/strawberry_django/fields/filter_order.py new file mode 100644 index 00000000..ab003fe4 --- /dev/null +++ b/strawberry_django/fields/filter_order.py @@ -0,0 +1,380 @@ +from __future__ import annotations + +import dataclasses +import inspect +from functools import cached_property +from typing import TYPE_CHECKING, Any, Final, Literal, Optional, Sequence, overload + +from strawberry import UNSET +from strawberry.annotation import StrawberryAnnotation +from strawberry.exceptions import MissingArgumentsAnnotationsError +from strawberry.field import StrawberryField +from strawberry.types.fields.resolver import ReservedName, StrawberryResolver +from typing_extensions import Self + +from strawberry_django.exceptions import ( + ForbiddenFieldArgumentError, + MissingFieldArgumentError, +) +from strawberry_django.utils.typing import is_auto + +if TYPE_CHECKING: + from collections.abc import Callable, MutableMapping + + from strawberry.extensions.field_extension import FieldExtension + from strawberry.field import _RESOLVER_TYPE, T + from strawberry.types import Info + + +QUERYSET_PARAMSPEC = ReservedName("queryset") +PREFIX_PARAMSPEC = ReservedName("prefix") +SEQUENCE_PARAMSPEC = ReservedName("sequence") +VALUE_PARAM = ReservedName("value") + +OBJECT_FILTER_NAME: Final[str] = "filter" +OBJECT_ORDER_NAME: Final[str] = "order" +WITH_NONE_META: Final[str] = "WITH_NONE_META" + + +class FilterOrderFieldResolver(StrawberryResolver): + RESERVED_PARAMSPEC = ( + *StrawberryResolver.RESERVED_PARAMSPEC, + QUERYSET_PARAMSPEC, + PREFIX_PARAMSPEC, + SEQUENCE_PARAMSPEC, + VALUE_PARAM, + ) + + def __init__(self, *args, resolver_type: Literal["filter", "order"], **kwargs): + super().__init__(*args, **kwargs) + self._resolver_type = resolver_type + + def validate_filter_arguments(self): + is_object_filter = self.name == OBJECT_FILTER_NAME + is_object_order = self.name == OBJECT_ORDER_NAME + + if not self.reserved_parameters[PREFIX_PARAMSPEC]: + raise MissingFieldArgumentError(PREFIX_PARAMSPEC.name, self) + + if (is_object_filter or is_object_order) and not self.reserved_parameters[ + QUERYSET_PARAMSPEC + ]: + raise MissingFieldArgumentError(QUERYSET_PARAMSPEC.name, self) + + if ( + self._resolver_type != OBJECT_ORDER_NAME + and self.reserved_parameters[SEQUENCE_PARAMSPEC] + ): + raise ForbiddenFieldArgumentError(self, [SEQUENCE_PARAMSPEC.name]) + + value_param = self.reserved_parameters[VALUE_PARAM] + if value_param: + if is_object_filter or is_object_order: + raise ForbiddenFieldArgumentError(self, [VALUE_PARAM.name]) + + annotation = self.strawberry_annotations[value_param] + if annotation is None: + raise MissingArgumentsAnnotationsError(self, [VALUE_PARAM.name]) + elif not is_object_filter and not is_object_order: + raise MissingFieldArgumentError(VALUE_PARAM.name, self) + + parameters = self.signature.parameters.values() + reserved_parameters = set(self.reserved_parameters.values()) + exta_params = [p for p in parameters if p not in reserved_parameters] + if exta_params: + raise ForbiddenFieldArgumentError(self, [p.name for p in exta_params]) + + @cached_property + def type_annotation(self) -> StrawberryAnnotation | None: + param = self.reserved_parameters[VALUE_PARAM] + if param and param is not inspect.Signature.empty: + annotation = param.annotation + if is_auto(annotation) and self._resolver_type == OBJECT_ORDER_NAME: + from strawberry_django import ordering + + annotation = ordering.Ordering + + return StrawberryAnnotation(Optional[annotation]) + + return None + + def __call__( # type: ignore [reportIncompatibleMethodOverride] + self, + source: Any, + info: Info | None, + queryset=None, + sequence=None, + **kwargs: Any, + ) -> Any: + args = [] + + if self.self_parameter: + args.append(source) + + if parent_parameter := self.parent_parameter: + kwargs[parent_parameter.name] = source + + if root_parameter := self.root_parameter: + kwargs[root_parameter.name] = source + + if info_parameter := self.info_parameter: + assert info is not None + kwargs[info_parameter.name] = info + + if info_parameter := self.reserved_parameters.get(QUERYSET_PARAMSPEC): + assert queryset + kwargs[info_parameter.name] = queryset + + if info_parameter := self.reserved_parameters.get(SEQUENCE_PARAMSPEC): + assert sequence is not None + kwargs[info_parameter.name] = sequence + + return super().__call__(*args, **kwargs) + + +class FilterOrderField(StrawberryField): + base_resolver: FilterOrderFieldResolver | None # type: ignore [reportIncompatibleMethodOverride] + + def __call__(self, resolver: _RESOLVER_TYPE) -> Self | FilterOrderFieldResolver: # type: ignore [reportIncompatibleMethodOverride] + if not isinstance(resolver, StrawberryResolver): + resolver = FilterOrderFieldResolver( + resolver, resolver_type=self.metadata["_FIELD_TYPE"] + ) + elif not isinstance(resolver, FilterOrderFieldResolver): + raise TypeError( + 'Expected resolver to be instance of "FilterOrderFieldResolver", ' + f'found "{type(resolver)}"' + ) + + super().__call__(resolver) + self._arguments = [] + resolver.validate_filter_arguments() + + if resolver.name in {OBJECT_FILTER_NAME, OBJECT_ORDER_NAME}: + # For object filter we return resolver + return resolver + + self.init = self.compare = self.repr = True + return self + + +@overload +def filter_field( + *, + resolver: _RESOLVER_TYPE[T], + name: str | None = None, + is_subscription: bool = False, + description: str | None = None, + init: Literal[False] = False, + deprecation_reason: str | None = None, + default: Any = UNSET, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: MutableMapping[Any, Any] | None = None, + directives: Sequence[object] = (), + extensions: list[FieldExtension] | None = None, + filter_none: bool = False, +) -> T: ... + + +@overload +def filter_field( + *, + name: str | None = None, + is_subscription: bool = False, + description: str | None = None, + init: Literal[True] = True, + deprecation_reason: str | None = None, + default: Any = UNSET, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: MutableMapping[Any, Any] | None = None, + directives: Sequence[object] = (), + extensions: list[FieldExtension] | None = None, + filter_none: bool = False, +) -> Any: ... + + +@overload +def filter_field( + resolver: _RESOLVER_TYPE[T], + *, + name: str | None = None, + is_subscription: bool = False, + description: str | None = None, + deprecation_reason: str | None = None, + default: Any = UNSET, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: MutableMapping[Any, Any] | None = None, + directives: Sequence[object] = (), + extensions: list[FieldExtension] | None = None, + filter_none: bool = False, +) -> StrawberryField: ... + + +def filter_field( + resolver: _RESOLVER_TYPE[Any] | None = None, + *, + name: str | None = None, + is_subscription: bool = False, + description: str | None = None, + deprecation_reason: str | None = None, + default: Any = UNSET, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: MutableMapping[Any, Any] | None = None, + directives: Sequence[object] = (), + extensions: list[FieldExtension] | None = None, + filter_none: bool = False, + # This init parameter is used by pyright to determine whether this field + # is added in the constructor or not. It is not used to change + # any behavior at the moment. + init: Literal[True, False, None] = None, +) -> Any: + """Annotates a method or property as a Django filter field. + + If using with method, these parameters are required: queryset, value, prefix + Additionaly value has to be annotated with type of filter + + This is normally used inside a type declaration: + + >>> @strawberry_django.filter(SomeModel) + >>> class X: + >>> field_abc: strawberry.auto = strawberry_django.filter_field() + + >>> @strawberry.filter_field(description="ABC") + >>> def field_with_resolver(self, queryset, info, value: str, prefix): + >>> return + + it can be used both as decorator and as a normal function. + """ + metadata = metadata or {} + metadata["_FIELD_TYPE"] = OBJECT_FILTER_NAME + if filter_none: + metadata[WITH_NONE_META] = True + + field_ = FilterOrderField( + python_name=None, + graphql_name=name, + is_subscription=is_subscription, + description=description, + deprecation_reason=deprecation_reason, + default=default, + default_factory=default_factory, + metadata=metadata, + directives=directives, + extensions=extensions or [], + ) + + if resolver: + return field_(resolver) + + return field_ + + +@overload +def order_field( + *, + resolver: _RESOLVER_TYPE[T], + name: str | None = None, + is_subscription: bool = False, + description: str | None = None, + init: Literal[False] = False, + deprecation_reason: str | None = None, + default: Any = UNSET, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: MutableMapping[Any, Any] | None = None, + directives: Sequence[object] = (), + extensions: list[FieldExtension] | None = None, + order_none: bool = False, +) -> T: ... + + +@overload +def order_field( + *, + name: str | None = None, + is_subscription: bool = False, + description: str | None = None, + init: Literal[True] = True, + deprecation_reason: str | None = None, + default: Any = UNSET, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: MutableMapping[Any, Any] | None = None, + directives: Sequence[object] = (), + extensions: list[FieldExtension] | None = None, + order_none: bool = False, +) -> Any: ... + + +@overload +def order_field( + resolver: _RESOLVER_TYPE[T], + *, + name: str | None = None, + is_subscription: bool = False, + description: str | None = None, + deprecation_reason: str | None = None, + default: Any = UNSET, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: MutableMapping[Any, Any] | None = None, + directives: Sequence[object] = (), + extensions: list[FieldExtension] | None = None, + order_none: bool = False, +) -> StrawberryField: ... + + +def order_field( + resolver: _RESOLVER_TYPE[Any] | None = None, + *, + name: str | None = None, + is_subscription: bool = False, + description: str | None = None, + deprecation_reason: str | None = None, + default: Any = UNSET, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: MutableMapping[Any, Any] | None = None, + directives: Sequence[object] = (), + extensions: list[FieldExtension] | None = None, + order_none: bool = False, + # This init parameter is used by pyright to determine whether this field + # is added in the constructor or not. It is not used to change + # any behavior at the moment. + init: Literal[True, False, None] = None, +) -> Any: + """Annotates a method or property as a Django filter field. + + If using with method, these parameters are required: queryset, value, prefix + Additionaly value has to be annotated with type of filter + + This is normally used inside a type declaration: + + >>> @strawberry_django.order(SomeModel) + >>> class X: + >>> field_abc: strawberry.auto = strawberry_django.order_field() + + >>> @strawberry.order_field(description="ABC") + >>> def field_with_resolver(self, queryset, info, value: str, prefix): + >>> return + + it can be used both as decorator and as a normal function. + """ + metadata = metadata or {} + metadata["_FIELD_TYPE"] = OBJECT_ORDER_NAME + if order_none: + metadata[WITH_NONE_META] = True + + field_ = FilterOrderField( + python_name=None, + graphql_name=name, + is_subscription=is_subscription, + description=description, + deprecation_reason=deprecation_reason, + default=default, + default_factory=default_factory, + metadata=metadata, + directives=directives, + extensions=extensions or [], + ) + + if resolver: + return field_(resolver) + + return field_ diff --git a/strawberry_django/fields/filter_types.py b/strawberry_django/fields/filter_types.py new file mode 100644 index 00000000..f7d41ca4 --- /dev/null +++ b/strawberry_django/fields/filter_types.py @@ -0,0 +1,122 @@ +import datetime +import decimal +import uuid +from typing import ( + Generic, + List, + Optional, + TypeVar, +) + +import strawberry +from django.db.models import Q +from strawberry import UNSET + +from .filter_order import filter_field + +T = TypeVar("T") + +_SKIP_MSG = "Filter will be skipped on `null` value" + + +@strawberry.input +class BaseFilterLookup(Generic[T]): + exact: Optional[T] = filter_field(description=f"Exact match. {_SKIP_MSG}") + is_null: Optional[bool] = filter_field(description=f"Assignment test. {_SKIP_MSG}") + in_list: Optional[List[T]] = filter_field( + description=f"Exact match of items in a given list. {_SKIP_MSG}" + ) + + +@strawberry.input +class RangeLookup(Generic[T]): + start: Optional[T] = None + end: Optional[T] = None + + @filter_field + def filter(self, queryset, prefix: str): + return queryset, Q(**{f"{prefix}range": [self.start, self.end]}) + + +@strawberry.input +class ComparisonFilterLookup(BaseFilterLookup[T]): + gt: Optional[T] = filter_field(description=f"Greater than. {_SKIP_MSG}") + gte: Optional[T] = filter_field( + description=f"Greater than or equal to. {_SKIP_MSG}" + ) + lt: Optional[T] = filter_field(description=f"Less than. {_SKIP_MSG}") + lte: Optional[T] = filter_field(description=f"Less than or equal to. {_SKIP_MSG}") + range: Optional[RangeLookup[T]] = filter_field( + description="Inclusive range test (between)" + ) + + +@strawberry.input +class FilterLookup(BaseFilterLookup[T]): + i_exact: Optional[T] = filter_field( + description=f"Case-insensitive exact match. {_SKIP_MSG}" + ) + contains: Optional[T] = filter_field( + description=f"Case-sensitive containment test. {_SKIP_MSG}" + ) + i_contains: Optional[T] = filter_field( + description=f"Case-insensitive containment test. {_SKIP_MSG}" + ) + starts_with: Optional[T] = filter_field( + description=f"Case-sensitive starts-with. {_SKIP_MSG}" + ) + i_starts_with: Optional[T] = filter_field( + description=f"Case-insensitive starts-with. {_SKIP_MSG}" + ) + ends_with: Optional[T] = filter_field( + description=f"Case-sensitive ends-with. {_SKIP_MSG}" + ) + i_ends_with: Optional[T] = filter_field( + description=f"Case-insensitive ends-with. {_SKIP_MSG}" + ) + regex: Optional[T] = filter_field( + description=f"Case-sensitive regular expression match. {_SKIP_MSG}" + ) + i_regex: Optional[T] = filter_field( + description=f"Case-insensitive regular expression match. {_SKIP_MSG}" + ) + + +@strawberry.input +class DateFilterLookup(ComparisonFilterLookup[T]): + year: Optional[ComparisonFilterLookup[int]] = UNSET + month: Optional[ComparisonFilterLookup[int]] = UNSET + day: Optional[ComparisonFilterLookup[int]] = UNSET + week_day: Optional[ComparisonFilterLookup[int]] = UNSET + iso_week_day: Optional[ComparisonFilterLookup[int]] = UNSET + week: Optional[ComparisonFilterLookup[int]] = UNSET + iso_year: Optional[ComparisonFilterLookup[int]] = UNSET + quarter: Optional[ComparisonFilterLookup[int]] = UNSET + + +@strawberry.input +class TimeFilterLookup(ComparisonFilterLookup[T]): + hour: Optional[ComparisonFilterLookup[int]] = UNSET + minute: Optional[ComparisonFilterLookup[int]] = UNSET + second: Optional[ComparisonFilterLookup[int]] = UNSET + date: Optional[ComparisonFilterLookup[int]] = UNSET + time: Optional[ComparisonFilterLookup[int]] = UNSET + + +@strawberry.input +class DatetimeFilterLookup(DateFilterLookup[T], TimeFilterLookup[T]): + pass + + +type_filter_map = { + strawberry.ID: BaseFilterLookup, + bool: BaseFilterLookup, + datetime.date: DateFilterLookup, + datetime.datetime: DatetimeFilterLookup, + datetime.time: TimeFilterLookup, + decimal.Decimal: ComparisonFilterLookup, + float: ComparisonFilterLookup, + int: ComparisonFilterLookup, + str: FilterLookup, + uuid.UUID: FilterLookup, +} diff --git a/strawberry_django/fields/types.py b/strawberry_django/fields/types.py index 12c387d3..d77178da 100644 --- a/strawberry_django/fields/types.py +++ b/strawberry_django/fields/types.py @@ -30,6 +30,7 @@ from strawberry.utils.str_converters import capitalize_first, to_camel_case from strawberry_django import filters +from strawberry_django.fields import filter_types from strawberry_django.settings import strawberry_django_settings as django_settings try: @@ -509,12 +510,18 @@ def resolve_model_field_type( ) # TODO: could this be moved into filters.py + using_old_filters = settings["USE_DEPRECATED_FILTERS"] if ( django_type.is_filter == "lookups" and not model_field.is_relation - and field_type is not bool + and (field_type is not bool or not using_old_filters) ): - field_type = filters.FilterLookup[field_type] + if using_old_filters: + field_type = filters.FilterLookup[field_type] + else: + field_type = filter_types.type_filter_map.get( # pyright: ignore [reportInvalidTypeArguments] + field_type, filter_types.FilterLookup + )[field_type] return field_type diff --git a/strawberry_django/filters.py b/strawberry_django/filters.py index 9ec806eb..0df75139 100644 --- a/strawberry_django/filters.py +++ b/strawberry_django/filters.py @@ -1,3 +1,6 @@ +# ruff: noqa: UP007, UP006 +from __future__ import annotations + import functools import inspect from enum import Enum @@ -11,24 +14,24 @@ Optional, Sequence, Tuple, - Type, TypeVar, Union, cast, ) import strawberry -from django.db import models -from django.db.models import Q -from django.db.models.sql.query import get_field_names_from_opts # type: ignore +from django.db.models import Q, QuerySet from strawberry import UNSET, relay -from strawberry.arguments import StrawberryArgument from strawberry.field import StrawberryField, field from strawberry.type import WithStrawberryObjectDefinition, has_object_definition -from strawberry.types import Info from strawberry.unset import UnsetType from typing_extensions import Self, assert_never, dataclass_transform +from strawberry_django.fields.filter_order import ( + WITH_NONE_META, + FilterOrderField, + FilterOrderFieldResolver, +) from strawberry_django.utils.typing import ( WithStrawberryDjangoObjectDefinition, has_django_definition, @@ -36,9 +39,15 @@ from .arguments import argument from .fields.base import StrawberryDjangoFieldBase +from .settings import strawberry_django_settings if TYPE_CHECKING: - from django.db.models import QuerySet + from types import FunctionType + + from django.db.models import Model + from strawberry.arguments import StrawberryArgument + from strawberry.types import Info + T = TypeVar("T") _T = TypeVar("_T", bound=type) @@ -52,11 +61,6 @@ class DjangoModelFilterInput: pk: strawberry.ID -_n_deprecation_reason = """\ -The "n" prefix is deprecated and will be removed in the future, use `NOT` instead. -""" - - @strawberry.input class FilterLookup(Generic[T]): exact: Optional[T] = UNSET @@ -76,74 +80,6 @@ class FilterLookup(Generic[T]): is_null: Optional[bool] = UNSET regex: Optional[str] = UNSET i_regex: Optional[str] = UNSET - n_exact: Optional[T] = strawberry.field( - default=UNSET, - deprecation_reason=_n_deprecation_reason, - ) - n_i_exact: Optional[T] = strawberry.field( - default=UNSET, - deprecation_reason=_n_deprecation_reason, - ) - n_contains: Optional[T] = strawberry.field( - default=UNSET, - deprecation_reason=_n_deprecation_reason, - ) - n_i_contains: Optional[T] = strawberry.field( - default=UNSET, - deprecation_reason=_n_deprecation_reason, - ) - n_in_list: Optional[List[T]] = strawberry.field( - default=UNSET, - deprecation_reason=_n_deprecation_reason, - ) - n_gt: Optional[T] = strawberry.field( - default=UNSET, - deprecation_reason=_n_deprecation_reason, - ) - n_gte: Optional[T] = strawberry.field( - default=UNSET, - deprecation_reason=_n_deprecation_reason, - ) - n_lt: Optional[T] = strawberry.field( - default=UNSET, - deprecation_reason=_n_deprecation_reason, - ) - n_lte: Optional[T] = strawberry.field( - default=UNSET, - deprecation_reason=_n_deprecation_reason, - ) - n_starts_with: Optional[T] = strawberry.field( - default=UNSET, - deprecation_reason=_n_deprecation_reason, - ) - n_i_starts_with: Optional[T] = strawberry.field( - default=UNSET, - deprecation_reason=_n_deprecation_reason, - ) - n_ends_with: Optional[T] = strawberry.field( - default=UNSET, - deprecation_reason=_n_deprecation_reason, - ) - n_i_ends_with: Optional[T] = strawberry.field( - default=UNSET, - deprecation_reason=_n_deprecation_reason, - ) - n_range: Optional[List[T]] = strawberry.field( - default=UNSET, - deprecation_reason=_n_deprecation_reason, - ) - n_is_null: Optional[bool] = strawberry.field( - default=UNSET, - deprecation_reason=_n_deprecation_reason, - ) - n_regex: Optional[str] = strawberry.field( - default=UNSET, - deprecation_reason=_n_deprecation_reason, - ) - n_i_regex: Optional[str] = strawberry.field( - default=UNSET, - deprecation_reason=_n_deprecation_reason, - ) lookup_name_conversion_map = { @@ -159,123 +95,133 @@ class FilterLookup(Generic[T]): } -def _resolve_global_id(value: Any): +def _resolve_value(value: Any) -> Any: if isinstance(value, list): - return [_resolve_global_id(v) for v in value] + return [_resolve_value(v) for v in value] + if isinstance(value, relay.GlobalID): return value.node_id + if isinstance(value, Enum): + return value.value + return value -def build_filter_kwargs( - filters: WithStrawberryObjectDefinition, - path="", -) -> Tuple[Q, List[Callable]]: - filter_kwargs = Q() - filter_methods = [] - django_model = ( - filters.__strawberry_django_definition__.model - if has_django_definition(filters) - else None +@functools.lru_cache(maxsize=256) +def _function_allow_passing_info(filter_method: FunctionType) -> bool: + argspec = inspect.getfullargspec(filter_method) + + return "info" in getattr(argspec, "args", []) or "info" in getattr( + argspec, + "kwargs", + [], ) - # This loop relies on the filter field order: AND, OR, and NOT fields are expected to be last. Since this is not - # true in case of filter inheritance, we need to explicitely sort them. - for f in sorted( - filters.__strawberry_definition__.fields, - key=lambda f: f.name in {"AND", "OR", "NOT"}, + +def _process_deprecated_filter( + filter_method: FunctionType, info: Info | None, queryset: _QS +) -> _QS: + kwargs = {} + if _function_allow_passing_info( + # Pass the original __func__ which is always the same + getattr(filter_method, "__func__", filter_method), ): - field_name = f.name - field_value = _resolve_global_id(getattr(filters, field_name)) + kwargs["info"] = info - # Unset means we are not filtering this. None is still acceptable - if field_value is UNSET: - continue + return filter_method(queryset=queryset, **kwargs) - if isinstance(field_value, Enum): - field_value = field_value.value - elif ( - isinstance(field_value, list) - and len(field_value) > 0 - and isinstance(field_value[0], Enum) + +def process_filters( + filters: WithStrawberryObjectDefinition, + queryset: _QS, + info: Info | None, + prefix: str = "", + skip_object_filter_method: bool = False, +) -> Tuple[_QS, Q]: + using_old_filters = strawberry_django_settings()["USE_DEPRECATED_FILTERS"] + + q = Q() + + if not skip_object_filter_method and ( + filter_method := getattr(filters, "filter", None) + ): + # Dedicated function for object + if isinstance(filter_method, FilterOrderFieldResolver): + return filter_method(filters, info, queryset=queryset, prefix=prefix) + if using_old_filters: + return _process_deprecated_filter(filter_method, info, queryset), q + + # This loop relies on the filter field order that is not quaranteed for GQL input objects: + # "filter" has to be first since it overrides filtering for entire object + # DISTINCT has to be last and OR has to be after because it must be + # applied agains all other since default connector is AND + for f in sorted( + filters.__strawberry_definition__.fields, + key=lambda x: len(x.name) if x.name in {"OR", "DISTINCT"} else 0, + ): + field_value = _resolve_value(getattr(filters, f.name)) + # None is still acceptable for v1 (backwards compatibility) and filters that support it via metadata + if field_value is UNSET or ( + field_value is None + and not f.metadata.get(WITH_NONE_META, using_old_filters) ): - field_value = [el.value for el in field_value] - - negated = False - if field_name.startswith("n_"): - field_name = field_name[2:] - negated = True - - field_name = lookup_name_conversion_map.get(field_name, field_name) - filter_method = getattr( - filters, - f"filter_{'n_' if negated else ''}{field_name}", - None, - ) - if filter_method: - filter_methods.append(filter_method) continue - if django_model: - if field_name in ("AND", "OR", "NOT"): # noqa: PLR6201 - if has_object_definition(field_value): - ( - subfield_filter_kwargs, - subfield_filter_methods, - ) = build_filter_kwargs( - cast(WithStrawberryObjectDefinition, field_value), - path, - ) - if field_name == "AND": - filter_kwargs &= subfield_filter_kwargs - elif field_name == "OR": - filter_kwargs |= subfield_filter_kwargs - elif field_name == "NOT": - filter_kwargs &= ~subfield_filter_kwargs - else: - assert_never(field_name) - - filter_methods.extend(subfield_filter_methods) - continue - - if field_name not in get_field_names_from_opts( - django_model._meta, - ): - continue + field_name = lookup_name_conversion_map.get(f.name, f.name) + if field_name == "DISTINCT": + if field_value: + queryset = queryset.distinct() + elif field_name in ("AND", "OR", "NOT"): # noqa: PLR6201 + assert has_object_definition(field_value) - if has_object_definition(field_value): - subfield_filter_kwargs, subfield_filter_methods = build_filter_kwargs( + queryset, sub_q = process_filters( + cast(WithStrawberryObjectDefinition, field_value), + queryset, + info, + prefix, + ) + if field_name == "AND": + q &= sub_q + elif field_name == "OR": + q |= sub_q + elif field_name == "NOT": + q &= ~sub_q + else: + assert_never(field_name) + elif isinstance(f, FilterOrderField) and f.base_resolver: + res = f.base_resolver( + filters, info, value=field_value, queryset=queryset, prefix=prefix + ) + if isinstance(res, tuple): + queryset, sub_q = res + else: + sub_q = res + + q &= sub_q + elif using_old_filters and ( + filter_method := getattr(filters, f"filter_{field_name}", None) + ): + queryset = _process_deprecated_filter(filter_method, info, queryset) + elif has_object_definition(field_value): + queryset, sub_q = process_filters( cast(WithStrawberryObjectDefinition, field_value), - f"{path}{field_name}__", + queryset, + info, + f"{prefix}{field_name}__", ) - filter_kwargs &= subfield_filter_kwargs - filter_methods.extend(subfield_filter_methods) + q &= sub_q else: - filter_kwarg = Q(**{f"{path}{field_name}": field_value}) - if negated: - filter_kwarg = ~filter_kwarg - filter_kwargs &= filter_kwarg - - return filter_kwargs, filter_methods - - -@functools.lru_cache(maxsize=256) -def function_allow_passing_info(filter_method: FunctionType) -> bool: - argspec = inspect.getfullargspec(filter_method) + q &= Q(**{f"{prefix}{field_name}": field_value}) - return "info" in getattr(argspec, "args", []) or "info" in getattr( - argspec, - "kwargs", - [], - ) + return queryset, q def apply( - filters: Optional[object], + filters: object | None, queryset: _QS, - info: Optional[Info] = None, - pk: Optional[Any] = None, + info: Info | None = None, + pk: Any | None = None, ) -> _QS: if pk not in (None, strawberry.UNSET): # noqa: PLR6201 queryset = queryset.filter(pk=pk) @@ -283,32 +229,11 @@ def apply( if filters in (None, strawberry.UNSET) or not has_django_definition(filters): # noqa: PLR6201 return queryset - # Custom filter function in the filters object - filter_method = getattr(filters, "filter", None) - if filter_method: - kwargs = {} - if function_allow_passing_info( - # Pass the original __func__ which is always the same - getattr(filter_method, "__func__", filter_method), - ): - kwargs["info"] = info - - return filter_method(queryset=queryset, **kwargs) - - filter_kwargs, filter_methods = build_filter_kwargs( - cast(WithStrawberryObjectDefinition, filters) + queryset, q = process_filters( + cast(WithStrawberryObjectDefinition, filters), queryset, info ) - queryset = queryset.filter(filter_kwargs) - for filter_method in filter_methods: - kwargs = {} - if function_allow_passing_info( - # Pass the original __func__ which is always the same - getattr(filter_method, "__func__", filter_method), - ): - kwargs["info"] = info - - queryset = filter_method(queryset=queryset, **kwargs) - + if q: + queryset = queryset.filter(q) return queryset @@ -358,11 +283,11 @@ def arguments(self) -> List[StrawberryArgument]: return super().arguments + arguments @arguments.setter - def arguments(self, value: List[StrawberryArgument]): + def arguments(self, value: list[StrawberryArgument]): args_prop = super(StrawberryDjangoFieldFilters, self.__class__).arguments return args_prop.fset(self, value) # type: ignore - def get_filters(self) -> Optional[Type[WithStrawberryObjectDefinition]]: + def get_filters(self) -> type[WithStrawberryObjectDefinition] | None: filters = self.filters if filters is None: return None @@ -377,15 +302,6 @@ def get_filters(self) -> Optional[Type[WithStrawberryObjectDefinition]]: return filters if filters is not UNSET else None - def apply_filters( - self, - queryset: _QS, - filters: Optional[WithStrawberryDjangoObjectDefinition], - pk: Optional[Any], - info: Info, - ) -> _QS: - return apply(filters, queryset, info, pk) - def get_queryset( self, queryset: _QS, @@ -396,7 +312,7 @@ def get_queryset( **kwargs, ) -> _QS: queryset = super().get_queryset(queryset, info, **kwargs) - return self.apply_filters(queryset, filters, pk, info) + return apply(filters, queryset, info, pk) @dataclass_transform( @@ -407,11 +323,11 @@ def get_queryset( ), ) def filter( # noqa: A001 - model: Type[models.Model], + model: type[Model], *, - name: Optional[str] = None, - description: Optional[str] = None, - directives: Optional[Sequence[object]] = (), + name: str | None = None, + description: str | None = None, + directives: Sequence[object] | None = (), lookups: bool = False, ) -> Callable[[_T], _T]: from .type import input diff --git a/strawberry_django/ordering.py b/strawberry_django/ordering.py index 7ae3fe15..9d69517d 100644 --- a/strawberry_django/ordering.py +++ b/strawberry_django/ordering.py @@ -1,99 +1,170 @@ +from __future__ import annotations + import dataclasses import enum from typing import ( TYPE_CHECKING, Callable, - Dict, - List, + Collection, Optional, Sequence, - Type, TypeVar, - Union, + cast, ) import strawberry -from django.db import models +from django.db.models import F, OrderBy, QuerySet from graphql.language.ast import ObjectValueNode from strawberry import UNSET -from strawberry.arguments import StrawberryArgument from strawberry.field import StrawberryField, field from strawberry.type import WithStrawberryObjectDefinition, has_object_definition -from strawberry.types import Info from strawberry.unset import UnsetType from strawberry.utils.str_converters import to_camel_case from typing_extensions import Self, dataclass_transform from strawberry_django.fields.base import StrawberryDjangoFieldBase +from strawberry_django.fields.filter_order import ( + WITH_NONE_META, + FilterOrderField, + FilterOrderFieldResolver, +) from strawberry_django.utils.typing import is_auto from .arguments import argument if TYPE_CHECKING: - from django.db.models import QuerySet - + from django.db.models import Model + from strawberry.arguments import StrawberryArgument + from strawberry.types import Info _T = TypeVar("_T") _QS = TypeVar("_QS", bound="QuerySet") +_SFT = TypeVar("_SFT", bound=StrawberryField) ORDER_ARG = "order" @dataclasses.dataclass -class _OrderSequence: +class OrderSequence: seq: int = 0 - children: Optional[Dict[str, "_OrderSequence"]] = None + children: dict[str, OrderSequence] | None = None + + @classmethod + def get_graphql_name(cls, info: Info | None, field: StrawberryField) -> str: + if info is None: + if field.graphql_name: + return field.graphql_name + + return to_camel_case(field.python_name) + + return info.schema.config.name_converter.get_graphql_name(field) + + @classmethod + def sorted( + cls, + info: Info | None, + sequence: dict[str, OrderSequence] | None, + fields: list[_SFT], + ) -> list[_SFT]: + if info is None: + return fields + + sequence = sequence or {} + + def sort_key(f: _SFT) -> int: + if not (seq := sequence.get(cls.get_graphql_name(info, f))): + return 0 + return seq.seq + + return sorted(fields, key=sort_key) @strawberry.enum class Ordering(enum.Enum): ASC = "ASC" + ASC_NULLS_FIRST = "ASC_NULLS_FIRST" + ASC_NULLS_LAST = "ASC_NULLS_LAST" DESC = "DESC" + DESC_NULLS_FIRST = "DESC_NULLS_FIRST" + DESC_NULLS_LAST = "DESC_NULLS_LAST" + + def resolve(self, value: str) -> OrderBy: + nulls_first = True if "NULLS_FIRST" in self.name else None + nulls_last = True if "NULLS_LAST" in self.name else None + if "ASC" in self.name: + return F(value).asc(nulls_first=nulls_first, nulls_last=nulls_last) + return F(value).desc(nulls_first=nulls_first, nulls_last=nulls_last) -def generate_order_args( +def process_order( order: WithStrawberryObjectDefinition, + info: Info | None, + queryset: _QS, *, - sequence: Optional[Dict[str, _OrderSequence]] = None, + sequence: dict[str, OrderSequence] | None = None, prefix: str = "", -): + skip_object_order_method: bool = False, +) -> tuple[_QS, Collection[F | OrderBy | str]]: sequence = sequence or {} args = [] - def sort_key(f: StrawberryField) -> int: - if not (seq := sequence.get(to_camel_case(f.name))): - return 0 - return seq.seq + if not skip_object_order_method and (order_method := getattr(order, "order", None)): + assert isinstance(order_method, FilterOrderFieldResolver) + return order_method( + order, info, queryset=queryset, prefix=prefix, sequence=sequence + ) - for f in sorted(order.__strawberry_definition__.fields, key=sort_key): - ordering = getattr(order, f.name, UNSET) - if ordering is UNSET: + for f in OrderSequence.sorted( + info, sequence, order.__strawberry_definition__.fields + ): + f_value = getattr(order, f.name, UNSET) + if f_value is UNSET or (f_value is None and not f.metadata.get(WITH_NONE_META)): continue - if ordering == Ordering.ASC: - args.append(f"{prefix}{f.name}") - elif ordering == Ordering.DESC: - args.append(f"-{prefix}{f.name}") + if isinstance(f, FilterOrderField) and f.base_resolver: + res = f.base_resolver( + order, + info, + value=f_value, + queryset=queryset, + prefix=prefix, + sequence=( + (seq := sequence.get(OrderSequence.get_graphql_name(info, f))) + and seq.children + ), + ) + if isinstance(res, tuple): + queryset, subargs = res + else: + subargs = res + args.extend(subargs) + elif isinstance(f_value, Ordering): + args.append(f_value.resolve(f"{prefix}{f.name}")) else: - subargs = generate_order_args( - ordering, + queryset, subargs = process_order( + f_value, + info, + queryset, prefix=f"{prefix}{f.name}__", - sequence=(seq := sequence.get(to_camel_case(f.name))) and seq.children, + sequence=( + (seq := sequence.get(OrderSequence.get_graphql_name(info, f))) + and seq.children + ), ) args.extend(subargs) - return args + return queryset, args def apply( - order: Optional[WithStrawberryObjectDefinition], + order: object | None, queryset: _QS, - info: Optional[Info] = None, + info: Info | None = None, ) -> _QS: - if order in (None, strawberry.UNSET): # noqa: PLR6201 + if order in (None, strawberry.UNSET) or not has_object_definition(order): # noqa: PLR6201 return queryset - sequence: Dict[str, _OrderSequence] = {} + sequence: dict[str, OrderSequence] = {} if info is not None and info._raw_info.field_nodes: field_node = info._raw_info.field_nodes[0] for arg in field_node.arguments: @@ -102,24 +173,26 @@ def apply( ): continue - def parse_and_fill(field: ObjectValueNode, seq: Dict[str, _OrderSequence]): + def parse_and_fill(field: ObjectValueNode, seq: dict[str, OrderSequence]): for i, f in enumerate(field.fields): - f_sequence: Dict[str, _OrderSequence] = {} + f_sequence: dict[str, OrderSequence] = {} if isinstance(f.value, ObjectValueNode): parse_and_fill(f.value, f_sequence) - seq[f.name.value] = _OrderSequence(seq=i, children=f_sequence) + seq[f.name.value] = OrderSequence(seq=i, children=f_sequence) parse_and_fill(arg.value, sequence) - args = generate_order_args(order, sequence=sequence) + queryset, args = process_order( + cast(WithStrawberryObjectDefinition, order), info, queryset, sequence=sequence + ) if not args: return queryset return queryset.order_by(*args) class StrawberryDjangoFieldOrdering(StrawberryDjangoFieldBase): - def __init__(self, order: Union[type, UnsetType, None] = UNSET, **kwargs): + def __init__(self, order: type | UnsetType | None = UNSET, **kwargs): if order and not has_object_definition(order): raise TypeError("order needs to be a strawberry type") @@ -132,7 +205,7 @@ def __copy__(self) -> Self: return new_field @property - def arguments(self) -> List[StrawberryArgument]: + def arguments(self) -> list[StrawberryArgument]: arguments = [] if self.base_resolver is None and self.is_list: order = self.get_order() @@ -141,11 +214,11 @@ def arguments(self) -> List[StrawberryArgument]: return super().arguments + arguments @arguments.setter - def arguments(self, value: List[StrawberryArgument]): + def arguments(self, value: list[StrawberryArgument]): args_prop = super(StrawberryDjangoFieldOrdering, self.__class__).arguments return args_prop.fset(self, value) # type: ignore - def get_order(self) -> Optional[Type[WithStrawberryObjectDefinition]]: + def get_order(self) -> type[WithStrawberryObjectDefinition] | None: order = self.order if order is None: return None @@ -160,24 +233,16 @@ def get_order(self) -> Optional[Type[WithStrawberryObjectDefinition]]: return order if order is not UNSET else None - def apply_order( - self, - queryset: _QS, - order: Optional[WithStrawberryObjectDefinition] = None, - info: Optional[Info] = None, - ) -> _QS: - return apply(order, queryset, info=info) - def get_queryset( self, queryset: _QS, info: Info, *, - order: Optional[WithStrawberryObjectDefinition] = None, + order: WithStrawberryObjectDefinition | None = None, **kwargs, ) -> _QS: queryset = super().get_queryset(queryset, info, **kwargs) - return self.apply_order(queryset, order, info=info) + return apply(order, queryset, info=info) @dataclass_transform( @@ -188,19 +253,28 @@ def get_queryset( ), ) def order( - model: Type[models.Model], + model: type[Model], *, - name: Optional[str] = None, - description: Optional[str] = None, - directives: Optional[Sequence[object]] = (), + name: str | None = None, + description: str | None = None, + directives: Sequence[object] | None = (), ) -> Callable[[_T], _T]: def wrapper(cls): + try: + cls.__annotations__ # noqa: B018 + except AttributeError: + # Manual creation for python 3.8 / 3.9 + cls.__annotations__ = {} + for fname, type_ in cls.__annotations__.items(): if is_auto(type_): type_ = Ordering # noqa: PLW2901 cls.__annotations__[fname] = Optional[type_] - setattr(cls, fname, UNSET) + + field_ = cls.__dict__.get(fname) + if not isinstance(field_, StrawberryField): + setattr(cls, fname, UNSET) return strawberry.input( cls, diff --git a/strawberry_django/settings.py b/strawberry_django/settings.py index 7fe6344f..4f5dacf2 100644 --- a/strawberry_django/settings.py +++ b/strawberry_django/settings.py @@ -36,6 +36,9 @@ class StrawberryDjangoSettings(TypedDict): #: `relay.GlobalID` instead of `strawberry.ID` for types and filters. MAP_AUTO_ID_AS_GLOBAL_ID: bool + #: If True, deprecated way of using filters will be working + USE_DEPRECATED_FILTERS: bool + DEFAULT_DJANGO_SETTINGS = StrawberryDjangoSettings( FIELD_DESCRIPTION_FROM_HELP_TEXT=False, @@ -44,6 +47,7 @@ class StrawberryDjangoSettings(TypedDict): MUTATIONS_DEFAULT_ARGUMENT_NAME="data", MUTATIONS_DEFAULT_HANDLE_ERRORS=False, MAP_AUTO_ID_AS_GLOBAL_ID=False, + USE_DEPRECATED_FILTERS=False, ) diff --git a/strawberry_django/type.py b/strawberry_django/type.py index e5a43124..251e43db 100644 --- a/strawberry_django/type.py +++ b/strawberry_django/type.py @@ -127,6 +127,7 @@ def _process_type( "AND": Optional[Self], # type: ignore "OR": Optional[Self], # type: ignore "NOT": Optional[Self], # type: ignore + "DISTINCT": Optional[bool], }, ) diff --git a/tests/exceptions.py b/tests/exceptions.py new file mode 100644 index 00000000..1181e3a3 --- /dev/null +++ b/tests/exceptions.py @@ -0,0 +1,18 @@ +from strawberry_django.exceptions import ( + ForbiddenFieldArgumentError, +) +from strawberry_django.fields.filter_order import FilterOrderFieldResolver + + +def test_forbidden_field_argument_extra_one(): + resolver = FilterOrderFieldResolver(resolver_type="filter", func=lambda x: x) + + exc = ForbiddenFieldArgumentError(resolver, ["one"]) + assert exc.extra_arguments_str == 'argument "one"' + + +def test_forbidden_field_argument_extra_many(): + resolver = FilterOrderFieldResolver(resolver_type="filter", func=lambda x: x) + + exc = ForbiddenFieldArgumentError(resolver, ["extra", "forbidden", "fields"]) + assert exc.extra_arguments_str == 'arguments "extra, forbidden" and "fields"' diff --git a/tests/fields/test_attributes.py b/tests/fields/test_attributes.py index b921371c..abfb8be1 100644 --- a/tests/fields/test_attributes.py +++ b/tests/fields/test_attributes.py @@ -96,6 +96,7 @@ class Query: AND: MyTypeFilter OR: MyTypeFilter NOT: MyTypeFilter + DISTINCT: Boolean } type Query { @@ -140,6 +141,7 @@ class Query: AND: MyTypeFilter OR: MyTypeFilter NOT: MyTypeFilter + DISTINCT: Boolean }} """An object with a Globally Unique ID""" @@ -196,6 +198,7 @@ class Query: AND: MyTypeFilter OR: MyTypeFilter NOT: MyTypeFilter + DISTINCT: Boolean }} """An object with a Globally Unique ID""" diff --git a/tests/filters/test_filters.py b/tests/filters/test_filters.py index 694ee378..24ab9690 100644 --- a/tests/filters/test_filters.py +++ b/tests/filters/test_filters.py @@ -4,6 +4,7 @@ import pytest import strawberry +from django.test import override_settings from strawberry import auto from strawberry.annotation import StrawberryAnnotation from strawberry.types import ExecutionResult @@ -11,105 +12,101 @@ import strawberry_django from tests import models, utils +with override_settings(STRAWBERRY_DJANGO={"USE_DEPRECATED_FILTERS": True}): -@strawberry_django.filter(models.NameDescriptionMixin) -class NameDescriptionFilter: - name: auto - description: auto - - -@strawberry_django.filter(models.Vegetable, lookups=True) -class VegetableFilter(NameDescriptionFilter): - id: auto - world_production: auto - - -@strawberry_django.filters.filter(models.Color, lookups=True) -class ColorFilter: - id: auto - name: auto - - -@strawberry_django.filters.filter(models.Fruit, lookups=True) -class FruitFilter: - id: auto - name: auto - color: Optional[ColorFilter] - - -@strawberry.enum -class FruitEnum(Enum): - strawberry = "strawberry" - banana = "banana" - - -@strawberry_django.filters.filter(models.Fruit) -class EnumFilter: - name: Optional[FruitEnum] = strawberry.UNSET - + @strawberry_django.filter(models.NameDescriptionMixin) + class NameDescriptionFilter: + name: auto + description: auto -_T = TypeVar("_T") + @strawberry_django.filter(models.Vegetable, lookups=True) + class VegetableFilter(NameDescriptionFilter): + id: auto + world_production: auto + @strawberry_django.filters.filter(models.Color, lookups=True) + class ColorFilter: + id: auto + name: auto -@strawberry.input -class FilterInLookup(Generic[_T]): - exact: Optional[_T] = strawberry.UNSET - in_list: Optional[List[_T]] = strawberry.UNSET + @strawberry_django.filters.filter(models.Fruit, lookups=True) + class FruitFilter: + id: auto + name: auto + color: Optional[ColorFilter] + @strawberry.enum + class FruitEnum(Enum): + strawberry = "strawberry" + banana = "banana" -@strawberry_django.filters.filter(models.Fruit) -class EnumLookupFilter: - name: Optional[FilterInLookup[FruitEnum]] = strawberry.UNSET + @strawberry_django.filters.filter(models.Fruit) + class EnumFilter: + name: Optional[FruitEnum] = strawberry.UNSET + _T = TypeVar("_T") -@strawberry.input -class NonFilter: - name: FruitEnum + @strawberry.input + class FilterInLookup(Generic[_T]): + exact: Optional[_T] = strawberry.UNSET + in_list: Optional[List[_T]] = strawberry.UNSET - def filter(self, queryset): - raise NotImplementedError + @strawberry_django.filters.filter(models.Fruit) + class EnumLookupFilter: + name: Optional[FilterInLookup[FruitEnum]] = strawberry.UNSET + @strawberry.input + class NonFilter: + name: FruitEnum -@strawberry_django.filters.filter(models.Fruit) -class FieldFilter: - search: str + def filter(self, queryset): + raise NotImplementedError - def filter_search(self, queryset): - return queryset.filter(name__icontains=self.search) + @strawberry_django.filters.filter(models.Fruit) + class FieldFilter: + search: str + def filter_search(self, queryset): + return queryset.filter(name__icontains=self.search) -@strawberry_django.filters.filter(models.Fruit) -class TypeFilter: - name: auto + @strawberry_django.filters.filter(models.Fruit) + class TypeFilter: + name: auto - def filter(self, queryset): - if not self.name: - return queryset + def filter(self, queryset): + if not self.name: + return queryset - return queryset.filter(name__icontains=self.name) + return queryset.filter(name__icontains=self.name) + @strawberry_django.type(models.Vegetable, filters=VegetableFilter) + class Vegetable: + id: auto + name: auto + description: auto + world_production: auto -@strawberry_django.type(models.Vegetable, filters=VegetableFilter) -class Vegetable: - id: auto - name: auto - description: auto - world_production: auto + @strawberry_django.type(models.Fruit, filters=FruitFilter) + class Fruit: + id: auto + name: auto + @strawberry.type + class Query: + fruits: List[Fruit] = strawberry_django.field() + field_filter: List[Fruit] = strawberry_django.field(filters=FieldFilter) + type_filter: List[Fruit] = strawberry_django.field(filters=TypeFilter) + enum_filter: List[Fruit] = strawberry_django.field(filters=EnumFilter) + enum_lookup_filter: List[Fruit] = strawberry_django.field( + filters=EnumLookupFilter + ) -@strawberry_django.type(models.Fruit, filters=FruitFilter) -class Fruit: - id: auto - name: auto + _ = strawberry.Schema(query=Query) -@strawberry.type -class Query: - fruits: List[Fruit] = strawberry_django.field() - field_filter: List[Fruit] = strawberry_django.field(filters=FieldFilter) - type_filter: List[Fruit] = strawberry_django.field(filters=TypeFilter) - enum_filter: List[Fruit] = strawberry_django.field(filters=EnumFilter) - enum_lookup_filter: List[Fruit] = strawberry_django.field(filters=EnumLookupFilter) +@pytest.fixture(autouse=True) +def _autouse_old_filters(settings): + settings.STRAWBERRY_DJANGO = {"USE_DEPRECATED_FILTERS": True} @pytest.fixture() @@ -164,30 +161,18 @@ def test_in_list(query, fruits): ] -def test_deprecated_not(query, fruits): - result = query( - """{ fruits(filters: { - name: { nEndsWith: "berry" } - }) { id name } }""", - ) - assert not result.errors - assert result.data["fruits"] == [ - {"id": "3", "name": "banana"}, - ] - - def test_not(query, fruits): result = query("""{ - fruits( + fruits( filters: { - NOT: { + NOT: { name: { endsWith: "berry" } - } } - ) { + } + ) { id name - } + } }""") assert not result.errors assert result.data["fruits"] == [ @@ -283,24 +268,24 @@ def vegetables(self, filters: VegetableFilter) -> List[Vegetable]: query = utils.generate_query(Query) result = query(""" - { + { vegetables( - filters: { + filters: { worldProduction: { - gt: 100e6 + gt: 100e6 } OR: { - name: { + name: { exact: "cucumber" - } } - } + } + } ) { - id - name + id + name } - } + } """) assert isinstance(result, ExecutionResult) assert not result.errors diff --git a/tests/filters/test_filters_v2.py b/tests/filters/test_filters_v2.py new file mode 100644 index 00000000..3ac9f745 --- /dev/null +++ b/tests/filters/test_filters_v2.py @@ -0,0 +1,382 @@ +# ruff: noqa: TRY002, B904, BLE001, F811, PT012, A001, PLC2701 +from enum import Enum +from typing import Any, List, Optional, cast + +import pytest +import strawberry +from django.db.models import Case, Count, Q, QuerySet, Value, When +from strawberry import auto +from strawberry.exceptions import MissingArgumentsAnnotationsError +from strawberry.relay import GlobalID +from strawberry.type import WithStrawberryObjectDefinition, get_object_definition + +import strawberry_django +from strawberry_django.exceptions import ( + ForbiddenFieldArgumentError, + MissingFieldArgumentError, +) +from strawberry_django.fields import filter_types +from strawberry_django.fields.field import StrawberryDjangoField +from strawberry_django.fields.filter_order import ( + FilterOrderField, + FilterOrderFieldResolver, +) +from strawberry_django.filters import _resolve_value, process_filters +from tests import models, utils +from tests.types import Fruit, FruitType + + +@strawberry.enum +class Version(Enum): + ONE = "first" + TWO = "second" + THREE = "third" + + +@strawberry_django.filter(models.Color, lookups=True) +class ColorFilter: + id: auto + name: auto + + @strawberry_django.filter_field + def name_simple(self, prefix: str, value: str): + return Q(**{f"{prefix}name": value}) + + +@strawberry_django.filter(models.FruitType, lookups=True) +class FruitTypeFilter: + name: auto + fruits: Optional["FruitFilter"] + + +@strawberry_django.filter(models.Fruit, lookups=True) +class FruitFilter: + color_id: auto + name: auto + sweetness: auto + color: Optional[ColorFilter] + types: Optional[FruitTypeFilter] + + @strawberry_django.filter_field + def types_number( + self, + info, + queryset: QuerySet, + prefix, + value: filter_types.ComparisonFilterLookup[int], + ): + return process_filters( + cast(WithStrawberryObjectDefinition, value), + queryset.annotate( + count=Count(f"{prefix}types__id"), + count_nulls=Case( + When(count=0, then=Value(None)), + default="count", + ), + ), + info, + "count_nulls__", + ) + + @strawberry_django.filter_field + def double( + self, + queryset: QuerySet, + prefix, + value: bool, + ): + return queryset.union(queryset, all=True), Q() + + @strawberry_django.filter_field + def filter(self, info, queryset: QuerySet, prefix): + return process_filters( + cast(WithStrawberryObjectDefinition, self), + queryset.filter(~Q(**{f"{prefix}name": "DARK_BERRY"})), + info, + prefix, + skip_object_filter_method=True, + ) + + +@strawberry.type +class Query: + types: List[FruitType] = strawberry_django.field(filters=FruitTypeFilter) + fruits: List[Fruit] = strawberry_django.field(filters=FruitFilter) + + +@pytest.fixture() +def query(): + return utils.generate_query(Query) + + +@pytest.mark.parametrize( + ("value", "resolved"), + [ + (2, 2), + ("something", "something"), + (GlobalID("", "24"), "24"), + (Version.ONE, Version.ONE.value), + ( + [1, "str", GlobalID("", "24"), Version.THREE], + [1, "str", "24", Version.THREE.value], + ), + ], +) +def test_resolve_value(value, resolved): + assert _resolve_value(value) == resolved + + +def test_filter_field_missing_prefix(): + with pytest.raises( + MissingFieldArgumentError, match=r".*\"prefix\".*\"field_method\".*" + ): + + @strawberry_django.filter_field + def field_method(): + pass + + +def test_filter_field_missing_value(): + with pytest.raises( + MissingFieldArgumentError, match=r".*\"value\".*\"field_method\".*" + ): + + @strawberry_django.filter_field + def field_method(prefix): + pass + + +def test_filter_field_missing_value_annotation(): + with pytest.raises( + MissingArgumentsAnnotationsError, + match=r"Missing annotation.*\"value\".*\"field_method\".*", + ): + + @strawberry_django.filter_field + def field_method(prefix, value): + pass + + +def test_filter_field(): + try: + + @strawberry_django.filter_field + def field_method(self, root, info, prefix, value: str, queryset): + pass + except Exception as exc: + raise pytest.fail(f"DID RAISE {exc}") # type: ignore + + +def test_filter_field_sequence(): + with pytest.raises( + ForbiddenFieldArgumentError, + match=r".*\"sequence\".*\"field_method\".*", + ): + + @strawberry_django.filter_field + def field_method(prefix, value: auto, sequence, queryset): + pass + + +def test_filter_field_forbidden_param_annotation(): + with pytest.raises( + MissingArgumentsAnnotationsError, + match=r".*\"forbidden_param\".*\"field_method\".*", + ): + + @strawberry_django.filter_field + def field_method(prefix, value: auto, queryset, forbidden_param): + pass + + +def test_filter_field_forbidden_param(): + with pytest.raises( + ForbiddenFieldArgumentError, + match=r".*\"forbidden_param\".*\"field_method\".*", + ): + + @strawberry_django.filter_field + def field_method(prefix, value: auto, queryset, forbidden_param: str): + pass + + +def test_filter_field_missing_queryset(): + with pytest.raises( + MissingFieldArgumentError, match=r".*\"queryset\".*\"filter\".*" + ): + + @strawberry_django.filter_field + def filter(prefix): + pass + + +def test_filter_field_value_forbidden_on_object(): + with pytest.raises(ForbiddenFieldArgumentError, match=r".*\"value\".*\"filter\".*"): + + @strawberry_django.filter_field + def field_method(prefix, queryset, value: auto): + pass + + @strawberry_django.filter_field + def filter(prefix, queryset, value: auto): + pass + + +def test_filter_field_on_object(): + try: + + @strawberry_django.filter_field + def filter(self, root, info, prefix, queryset): + pass + except Exception as exc: + raise pytest.fail(f"DID RAISE {exc}") # type: ignore + + +def test_filter_field_method(): + @strawberry_django.filter(models.Fruit) + class Filter: + @strawberry_django.order_field + def custom_filter(self, root, info, prefix, value: auto, queryset): + assert self == _filter, "Unexpected self passed" + assert root == _filter, "Unexpected root passed" + assert info == _info, "Unexpected info passed" + assert prefix == "ROOT", "Unexpected prefix passed" + assert value == "SOMETHING", "Unexpected value passed" + assert queryset == _queryset, "Unexpected queryset passed" + return Q(name=1) + + _filter: Any = Filter(custom_filter="SOMETHING") # type: ignore + _info: Any = object() + _queryset: Any = object() + + q_object = process_filters(_filter, _queryset, _info, prefix="ROOT")[1] + assert q_object, "Filter was not called" + + +def test_filter_object_method(): + @strawberry_django.ordering.order(models.Fruit) + class Filter: + @strawberry_django.order_field + def field_filter(self, value: str, prefix): + raise AssertionError("Never called due to object filter override") + + @strawberry_django.order_field + def filter(self, root, info, prefix, queryset): + assert self == _filter, "Unexpected self passed" + assert root == _filter, "Unexpected root passed" + assert info == _info, "Unexpected info passed" + assert prefix == "ROOT", "Unexpected prefix passed" + assert queryset == _queryset, "Unexpected queryset passed" + return queryset, Q(name=1) + + _filter: Any = Filter() + _info: Any = object() + _queryset: Any = object() + + q_object = process_filters(_filter, _queryset, _info, prefix="ROOT")[1] + assert q_object, "Filter was not called" + + +def test_filter_type(): + @strawberry_django.filter(models.Fruit, lookups=True) + class FruitOrder: + id: auto + name: auto + sweetness: auto + + @strawberry_django.filter_field + def custom_filter(self, value: str, prefix: str): + pass + + @strawberry_django.filter_field + def custom_filter2( + self, value: filter_types.BaseFilterLookup[str], prefix: str + ): + pass + + assert [ + ( + f.name, + f.__class__, + f.type.of_type.__name__, # type: ignore + f.base_resolver.__class__ if f.base_resolver else None, # type: ignore + ) + for f in get_object_definition(FruitOrder, strict=True).fields + if f.name not in {"NOT", "AND", "OR", "DISTINCT"} + ] == [ + ("id", StrawberryDjangoField, "BaseFilterLookup", None), + ("name", StrawberryDjangoField, "FilterLookup", None), + ("sweetness", StrawberryDjangoField, "ComparisonFilterLookup", None), + ( + "custom_filter", + FilterOrderField, + "str", + FilterOrderFieldResolver, + ), + ( + "custom_filter2", + FilterOrderField, + "BaseFilterLookup", + FilterOrderFieldResolver, + ), + ] + + +def test_filter_methods(query, db, fruits): + t1 = models.FruitType.objects.create(name="Type1") + t2 = models.FruitType.objects.create(name="Type2") + + f1, f2, f3 = models.Fruit.objects.all() + _ = models.Fruit.objects.create(name="DARK_BERRY") + + f2.types.add(t1) + f3.types.add(t1, t2) + + result = query(""" + { + fruits(filters: { + typesNumber: { gt: 1 } + NOT: { color: { nameSimple: "sample" } } + OR: { + typesNumber: { isNull: true } + } + }) { id } } + """) + + assert not result.errors + assert result.data["fruits"] == [ + {"id": str(f1.id)}, + {"id": str(f3.id)}, + ] + + +def test_filter_distinct(query, db, fruits): + t1 = models.FruitType.objects.create(name="type_1") + t2 = models.FruitType.objects.create(name="type_2") + + f1 = models.Fruit.objects.all()[0] + + f1.types.add(t1, t2) + + result = query(""" + { + fruits( + filters: {types: { name: { iContains: "type" } } } + ) { id name } + } + """) + assert not result.errors + assert len(result.data["fruits"]) == 2 + + result = query(""" + { + fruits( + filters: { + DISTINCT: true, + types: { name: { iContains: "type" } } + } + ) { id name } + } + """) + assert not result.errors + assert len(result.data["fruits"]) == 1 diff --git a/tests/filters/test_types.py b/tests/filters/test_types.py index 9bd06cfe..a70cb1c2 100644 --- a/tests/filters/test_types.py +++ b/tests/filters/test_types.py @@ -1,14 +1,24 @@ from typing import Optional +import pytest import strawberry from strawberry import auto from strawberry.type import StrawberryOptional, get_object_definition import strawberry_django from strawberry_django.filters import DjangoModelFilterInput +from strawberry_django.settings import strawberry_django_settings from tests import models +@pytest.fixture(autouse=True) +def _filter_order_settings(settings): + settings.STRAWBERRY_DJANGO = { + **strawberry_django_settings(), + "USE_DEPRECATED_FILTERS": True, + } + + def test_filter(): @strawberry_django.filters.filter(models.Fruit) class Filter: @@ -26,6 +36,7 @@ class Filter: ("AND", StrawberryOptional(Filter)), ("OR", StrawberryOptional(Filter)), ("NOT", StrawberryOptional(Filter)), + ("DISTINCT", StrawberryOptional(bool)), ] @@ -49,6 +60,7 @@ class Filter: ("AND", "Filter"), ("OR", "Filter"), ("NOT", "Filter"), + ("DISTINCT", "bool"), ] @@ -73,6 +85,7 @@ class Filter(Base): ("AND", StrawberryOptional(Filter)), ("OR", StrawberryOptional(Filter)), ("NOT", StrawberryOptional(Filter)), + ("DISTINCT", StrawberryOptional(bool)), ] @@ -91,6 +104,7 @@ class Filter: ("AND", StrawberryOptional(Filter)), ("OR", StrawberryOptional(Filter)), ("NOT", StrawberryOptional(Filter)), + ("DISTINCT", StrawberryOptional(bool)), ] @@ -113,4 +127,5 @@ class Filter(Base): ("AND", StrawberryOptional(Filter)), ("OR", StrawberryOptional(Filter)), ("NOT", StrawberryOptional(Filter)), + ("DISTINCT", StrawberryOptional(bool)), ] diff --git a/tests/models.py b/tests/models.py index 4e3a24c3..3f3c174b 100644 --- a/tests/models.py +++ b/tests/models.py @@ -27,6 +27,7 @@ class Vegetable(NameDescriptionMixin): class Fruit(models.Model): + id: Optional[int] name = models.CharField(max_length=20) color_id: Optional[int] color = models.ForeignKey( @@ -75,6 +76,7 @@ class Color(models.Model): class FruitType(models.Model): + id: Optional[int] name = models.CharField(max_length=20, validators=[validate_fruit_type]) diff --git a/tests/projects/snapshots/schema.gql b/tests/projects/snapshots/schema.gql index abc62786..014d8206 100644 --- a/tests/projects/snapshots/schema.gql +++ b/tests/projects/snapshots/schema.gql @@ -105,41 +105,45 @@ union CreateQuizPayload = QuizType | OperationInfo """Date (isoformat)""" scalar Date -input DateFilterLookup { +input DateDateFilterLookup { + """Exact match. Filter will be skipped on `null` value""" exact: Date - iExact: Date - contains: Date - iContains: Date + + """Assignment test. Filter will be skipped on `null` value""" + isNull: Boolean + + """ + Exact match of items in a given list. Filter will be skipped on `null` value + """ inList: [Date!] + + """Greater than. Filter will be skipped on `null` value""" gt: Date + + """Greater than or equal to. Filter will be skipped on `null` value""" gte: Date + + """Less than. Filter will be skipped on `null` value""" lt: Date + + """Less than or equal to. Filter will be skipped on `null` value""" lte: Date - startsWith: Date - iStartsWith: Date - endsWith: Date - iEndsWith: Date - range: [Date!] - isNull: Boolean - regex: String - iRegex: String - nExact: Date @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nIExact: Date @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nContains: Date @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nIContains: Date @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nInList: [Date!] @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nGt: Date @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nGte: Date @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nLt: Date @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nLte: Date @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nStartsWith: Date @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nIStartsWith: Date @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nEndsWith: Date @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nIEndsWith: Date @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nRange: [Date!] @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nIsNull: Boolean @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nRegex: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nIRegex: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") + + """Inclusive range test (between)""" + range: DateRangeLookup + year: IntComparisonFilterLookup + month: IntComparisonFilterLookup + day: IntComparisonFilterLookup + weekDay: IntComparisonFilterLookup + isoWeekDay: IntComparisonFilterLookup + week: IntComparisonFilterLookup + isoYear: IntComparisonFilterLookup + quarter: IntComparisonFilterLookup +} + +input DateRangeLookup { + start: Date = null + end: Date = null } """Date with time (isoformat)""" @@ -190,6 +194,39 @@ The `ID` scalar type represents a unique identifier, often used to refetch an ob """ scalar GlobalID @specifiedBy(url: "https://relay.dev/graphql/objectidentification.htm") +input IntComparisonFilterLookup { + """Exact match. Filter will be skipped on `null` value""" + exact: Int + + """Assignment test. Filter will be skipped on `null` value""" + isNull: Boolean + + """ + Exact match of items in a given list. Filter will be skipped on `null` value + """ + inList: [Int!] + + """Greater than. Filter will be skipped on `null` value""" + gt: Int + + """Greater than or equal to. Filter will be skipped on `null` value""" + gte: Int + + """Less than. Filter will be skipped on `null` value""" + lt: Int + + """Less than or equal to. Filter will be skipped on `null` value""" + lte: Int + + """Inclusive range test (between)""" + range: IntRangeLookup +} + +input IntRangeLookup { + start: Int = null + end: Int = null +} + input IssueAssigneeInputPartial { id: GlobalID user: NodeInputPartial @@ -289,6 +326,7 @@ input MilestoneFilter { AND: MilestoneFilter OR: MilestoneFilter NOT: MilestoneFilter + DISTINCT: Boolean } input MilestoneInput { @@ -414,7 +452,11 @@ enum OperationMessageKind { enum Ordering { ASC + ASC_NULLS_FIRST + ASC_NULLS_LAST DESC + DESC_NULLS_FIRST + DESC_NULLS_LAST } """Information to aid in pagination.""" @@ -445,10 +487,11 @@ type ProjectConnection { input ProjectFilter { name: StrFilterLookup - dueDate: DateFilterLookup + dueDate: DateDateFilterLookup AND: ProjectFilter OR: ProjectFilter NOT: ProjectFilter + DISTINCT: Boolean } input ProjectInputPartial { @@ -744,40 +787,51 @@ type StaffTypeEdge { } input StrFilterLookup { + """Exact match. Filter will be skipped on `null` value""" exact: String + + """Assignment test. Filter will be skipped on `null` value""" + isNull: Boolean + + """ + Exact match of items in a given list. Filter will be skipped on `null` value + """ + inList: [String!] + + """Case-insensitive exact match. Filter will be skipped on `null` value""" iExact: String + + """ + Case-sensitive containment test. Filter will be skipped on `null` value + """ contains: String + + """ + Case-insensitive containment test. Filter will be skipped on `null` value + """ iContains: String - inList: [String!] - gt: String - gte: String - lt: String - lte: String + + """Case-sensitive starts-with. Filter will be skipped on `null` value""" startsWith: String + + """Case-insensitive starts-with. Filter will be skipped on `null` value""" iStartsWith: String + + """Case-sensitive ends-with. Filter will be skipped on `null` value""" endsWith: String + + """Case-insensitive ends-with. Filter will be skipped on `null` value""" iEndsWith: String - range: [String!] - isNull: Boolean + + """ + Case-sensitive regular expression match. Filter will be skipped on `null` value + """ regex: String + + """ + Case-insensitive regular expression match. Filter will be skipped on `null` value + """ iRegex: String - nExact: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nIExact: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nContains: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nIContains: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nInList: [String!] @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nGt: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nGte: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nLt: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nLte: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nStartsWith: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nIStartsWith: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nEndsWith: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nIEndsWith: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nRange: [String!] @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nIsNull: Boolean @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nRegex: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nIRegex: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") } input TagInputPartial { diff --git a/tests/relay/snapshots/schema.gql b/tests/relay/snapshots/schema.gql index 45f6951a..c5d8013a 100644 --- a/tests/relay/snapshots/schema.gql +++ b/tests/relay/snapshots/schema.gql @@ -32,6 +32,7 @@ input FruitFilter { AND: FruitFilter OR: FruitFilter NOT: FruitFilter + DISTINCT: Boolean } input FruitOrder { @@ -52,7 +53,11 @@ interface Node { enum Ordering { ASC + ASC_NULLS_FIRST + ASC_NULLS_LAST DESC + DESC_NULLS_FIRST + DESC_NULLS_LAST } """Information to aid in pagination.""" @@ -165,38 +170,49 @@ type Query { } input StrFilterLookup { + """Exact match. Filter will be skipped on `null` value""" exact: String + + """Assignment test. Filter will be skipped on `null` value""" + isNull: Boolean + + """ + Exact match of items in a given list. Filter will be skipped on `null` value + """ + inList: [String!] + + """Case-insensitive exact match. Filter will be skipped on `null` value""" iExact: String + + """ + Case-sensitive containment test. Filter will be skipped on `null` value + """ contains: String + + """ + Case-insensitive containment test. Filter will be skipped on `null` value + """ iContains: String - inList: [String!] - gt: String - gte: String - lt: String - lte: String + + """Case-sensitive starts-with. Filter will be skipped on `null` value""" startsWith: String + + """Case-insensitive starts-with. Filter will be skipped on `null` value""" iStartsWith: String + + """Case-sensitive ends-with. Filter will be skipped on `null` value""" endsWith: String + + """Case-insensitive ends-with. Filter will be skipped on `null` value""" iEndsWith: String - range: [String!] - isNull: Boolean + + """ + Case-sensitive regular expression match. Filter will be skipped on `null` value + """ regex: String + + """ + Case-insensitive regular expression match. Filter will be skipped on `null` value + """ iRegex: String - nExact: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nIExact: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nContains: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nIContains: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nInList: [String!] @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nGt: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nGte: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nLt: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nLte: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nStartsWith: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nIStartsWith: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nEndsWith: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nIEndsWith: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nRange: [String!] @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nIsNull: Boolean @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nRegex: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") - nIRegex: String @deprecated(reason: "The \"n\" prefix is deprecated and will be removed in the future, use `NOT` instead.\n") } \ No newline at end of file diff --git a/tests/test_ordering.py b/tests/test_ordering.py index bed568e2..4174a9f5 100644 --- a/tests/test_ordering.py +++ b/tests/test_ordering.py @@ -1,11 +1,30 @@ -from typing import List, Optional +# ruff: noqa: TRY002, B904, BLE001, F811, PT012 +from typing import Any, List, Optional, cast import pytest import strawberry +from django.db.models import Case, Count, Value, When from strawberry import auto from strawberry.annotation import StrawberryAnnotation +from strawberry.exceptions import MissingArgumentsAnnotationsError +from strawberry.field import StrawberryField +from strawberry.type import ( + StrawberryOptional, + WithStrawberryObjectDefinition, + get_object_definition, +) import strawberry_django +from strawberry_django.exceptions import ( + ForbiddenFieldArgumentError, + MissingFieldArgumentError, +) +from strawberry_django.fields.field import StrawberryDjangoField +from strawberry_django.fields.filter_order import ( + FilterOrderField, + FilterOrderFieldResolver, +) +from strawberry_django.ordering import Ordering, OrderSequence, process_order from tests import models, utils from tests.types import Fruit @@ -13,7 +32,10 @@ @strawberry_django.ordering.order(models.Color) class ColorOrder: pk: auto - name: auto + + @strawberry_django.order_field + def name(self, prefix, value: auto): + return [value.resolve(f"{prefix}name")] @strawberry_django.ordering.order(models.Fruit) @@ -23,6 +45,16 @@ class FruitOrder: sweetness: auto color: Optional[ColorOrder] + @strawberry_django.order_field + def types_number(self, queryset, prefix, value: auto): + return queryset.annotate( + count=Count(f"{prefix}types__id"), + count_nulls=Case( + When(count=0, then=Value(None)), + default="count", + ), + ), [value.resolve("count_nulls")] + @strawberry_django.type(models.Fruit, order=FruitOrder) class FruitWithOrder: @@ -41,8 +73,6 @@ def query(): def test_field_order_definition(): - from strawberry_django.fields.field import StrawberryDjangoField - field = StrawberryDjangoField(type_annotation=StrawberryAnnotation(FruitWithOrder)) assert field.get_order() == FruitOrder field = StrawberryDjangoField( @@ -134,3 +164,251 @@ def test_arguments_order_respected(query, db): result = query("{ fruits(order: { name: ASC, colorId: ASC }) { id } }") assert not result.errors assert result.data["fruits"] == [{"id": str(f.pk)} for f in [f3, f2, f1]] + + +def test_order_sequence(): + f1 = StrawberryField(graphql_name="sOmEnAmE", python_name="some_name") + f2 = StrawberryField(python_name="some_name") + + assert OrderSequence.get_graphql_name(None, f1) == "sOmEnAmE" + assert OrderSequence.get_graphql_name(None, f2) == "someName" + + assert OrderSequence.sorted(None, None, fields=[f1, f2]) == [f1, f2] + + sequence = {"someName": OrderSequence(0, None), "sOmEnAmE": OrderSequence(1, None)} + assert OrderSequence.sorted(None, sequence, fields=[f1, f2]) == [f1, f2] + + +def test_order_type(): + @strawberry_django.ordering.order(models.Fruit) + class FruitOrder: + color_id: auto + name: auto + sweetness: auto + + @strawberry_django.order_field + def custom_order(self, value: auto, prefix: str): + pass + + annotated_type = StrawberryOptional(Ordering._enum_definition) # type: ignore + + assert [ + ( + f.name, + f.__class__, + f.type, + f.base_resolver.__class__ if f.base_resolver else None, + ) + for f in get_object_definition(FruitOrder, strict=True).fields + ] == [ + ("color_id", StrawberryField, annotated_type, None), + ("name", StrawberryField, annotated_type, None), + ("sweetness", StrawberryField, annotated_type, None), + ( + "custom_order", + FilterOrderField, + annotated_type, + FilterOrderFieldResolver, + ), + ] + + +def test_order_field_missing_prefix(): + with pytest.raises( + MissingFieldArgumentError, match=r".*\"prefix\".*\"field_method\".*" + ): + + @strawberry_django.order_field + def field_method(): + pass + + +def test_order_field_missing_value(): + with pytest.raises( + MissingFieldArgumentError, match=r".*\"value\".*\"field_method\".*" + ): + + @strawberry_django.order_field + def field_method(prefix): + pass + + +def test_order_field_missing_value_annotation(): + with pytest.raises( + MissingArgumentsAnnotationsError, + match=r"Missing annotation.*\"value\".*\"field_method\".*", + ): + + @strawberry_django.order_field + def field_method(prefix, value): + pass + + +def test_order_field(): + try: + + @strawberry_django.order_field + def field_method(self, root, info, prefix, value: auto, sequence, queryset): + pass + except Exception as exc: + raise pytest.fail(f"DID RAISE {exc}") # type: ignore + + +def test_order_field_forbidden_param_annotation(): + with pytest.raises( + MissingArgumentsAnnotationsError, + match=r".*\"forbidden_param\".*\"field_method\".*", + ): + + @strawberry_django.order_field + def field_method(prefix, value: auto, sequence, queryset, forbidden_param): + pass + + +def test_order_field_forbidden_param(): + with pytest.raises( + ForbiddenFieldArgumentError, + match=r".*\"forbidden_param\".*\"field_method\".*", + ): + + @strawberry_django.order_field + def field_method(prefix, value: auto, sequence, queryset, forbidden_param: str): + pass + + +def test_order_field_missing_queryset(): + with pytest.raises(MissingFieldArgumentError, match=r".*\"queryset\".*\"order\".*"): + + @strawberry_django.order_field + def order(prefix): + pass + + +def test_order_field_value_forbidden_on_object(): + with pytest.raises(ForbiddenFieldArgumentError, match=r".*\"value\".*\"order\".*"): + + @strawberry_django.order_field + def field_method(prefix, queryset, value: auto): + pass + + @strawberry_django.order_field + def order(prefix, queryset, value: auto): + pass + + +def test_order_field_on_object(): + try: + + @strawberry_django.order_field + def order(self, root, info, prefix, sequence, queryset): + pass + except Exception as exc: + raise pytest.fail(f"DID RAISE {exc}") # type: ignore + + +def test_order_field_method(): + @strawberry_django.ordering.order(models.Fruit) + class Order: + @strawberry_django.order_field + def custom_order(self, root, info, prefix, value: auto, sequence, queryset): + assert self == _order, "Unexpected self passed" + assert root == _order, "Unexpected root passed" + assert info == _info, "Unexpected info passed" + assert prefix == "ROOT", "Unexpected prefix passed" + assert value == Ordering.ASC, "Unexpected value passed" + assert sequence == _sequence_inner, "Unexpected sequence passed" + assert queryset == _queryset, "Unexpected queryset passed" + raise Exception("WAS CALLED") + + _order = cast(WithStrawberryObjectDefinition, Order(custom_order=Ordering.ASC)) # type: ignore + schema = strawberry.Schema(query=Query) + _info: Any = type("FakeInfo", (), {"schema": schema}) + _queryset: Any = object() + _sequence_inner: Any = object() + _sequence = {"customOrder": OrderSequence(0, children=_sequence_inner)} + + with pytest.raises(Exception, match="WAS CALLED"): + process_order(_order, _info, _queryset, prefix="ROOT", sequence=_sequence) + + +def test_order_object_method(): + @strawberry_django.ordering.order(models.Fruit) + class Order: + @strawberry_django.order_field + def order(self, root, info, prefix, sequence, queryset): + assert self == _order, "Unexpected self passed" + assert root == _order, "Unexpected root passed" + assert info == _info, "Unexpected info passed" + assert prefix == "ROOT", "Unexpected prefix passed" + assert sequence == _sequence, "Unexpected sequence passed" + assert queryset == _queryset, "Unexpected queryset passed" + return queryset, ["name"] + + _order = cast(WithStrawberryObjectDefinition, Order()) + schema = strawberry.Schema(query=Query) + _info: Any = type("FakeInfo", (), {"schema": schema}) + _queryset: Any = object() + _sequence: Any = {"customOrder": OrderSequence(0)} + + order = process_order(_order, _info, _queryset, prefix="ROOT", sequence=_sequence)[ + 1 + ] + assert "name" in order, "order was not called" + + +def test_order_nulls(query, db, fruits): + t1 = models.FruitType.objects.create(name="Type1") + t2 = models.FruitType.objects.create(name="Type2") + + f1, f2, f3 = models.Fruit.objects.all() + + f2.types.add(t1) + f3.types.add(t1, t2) + + result = query("{ fruits(order: { typesNumber: ASC }) { id } }") + assert not result.errors + assert result.data["fruits"] == [ + {"id": str(f1.id)}, + {"id": str(f2.id)}, + {"id": str(f3.id)}, + ] + + result = query("{ fruits(order: { typesNumber: DESC }) { id } }") + assert not result.errors + assert result.data["fruits"] == [ + {"id": str(f3.id)}, + {"id": str(f2.id)}, + {"id": str(f1.id)}, + ] + + result = query("{ fruits(order: { typesNumber: ASC_NULLS_FIRST }) { id } }") + assert not result.errors + assert result.data["fruits"] == [ + {"id": str(f1.id)}, + {"id": str(f2.id)}, + {"id": str(f3.id)}, + ] + + result = query("{ fruits(order: { typesNumber: ASC_NULLS_LAST }) { id } }") + assert not result.errors + assert result.data["fruits"] == [ + {"id": str(f2.id)}, + {"id": str(f3.id)}, + {"id": str(f1.id)}, + ] + + result = query("{ fruits(order: { typesNumber: DESC_NULLS_LAST }) { id } }") + assert not result.errors + assert result.data["fruits"] == [ + {"id": str(f3.id)}, + {"id": str(f2.id)}, + {"id": str(f1.id)}, + ] + + result = query("{ fruits(order: { typesNumber: DESC_NULLS_FIRST }) { id } }") + assert not result.errors + assert result.data["fruits"] == [ + {"id": str(f1.id)}, + {"id": str(f3.id)}, + {"id": str(f2.id)}, + ] diff --git a/tests/test_settings.py b/tests/test_settings.py index ee69d7d3..ab88998b 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -28,6 +28,7 @@ def test_non_defaults(): MUTATIONS_DEFAULT_ARGUMENT_NAME="id", MUTATIONS_DEFAULT_HANDLE_ERRORS=True, MAP_AUTO_ID_AS_GLOBAL_ID=True, + USE_DEPRECATED_FILTERS=True, ), ): assert ( @@ -39,5 +40,6 @@ def test_non_defaults(): MUTATIONS_DEFAULT_ARGUMENT_NAME="id", MUTATIONS_DEFAULT_HANDLE_ERRORS=True, MAP_AUTO_ID_AS_GLOBAL_ID=True, + USE_DEPRECATED_FILTERS=True, ) )