diff --git a/docs/en/docs/queries.md b/docs/en/docs/queries.md index 072f619..fea8ae1 100644 --- a/docs/en/docs/queries.md +++ b/docs/en/docs/queries.md @@ -131,8 +131,9 @@ The same special operators are also automatically added on every column. * **neq** - Filter instances by not equal to condition. * **startswith** - Filter instances that start with a specific value. * **endswith** - Filter instances that end with a specific value. -* **istartswith** - Filter instances that start with a specific value, case-insensitive. +* **istartswith** - Filter instances that start with a specific value, case-insensitive. * **iendswith** - Filter instances that end with a specific value, case-insensitive. +* **date** - Filter instances by date. ##### Example @@ -153,6 +154,7 @@ users = await User.objects.filter(name__startswith="foo") users = await User.objects.filter(name__istartswith="foo") users = await User.objects.filter(name__endswith="foo") users = await User.objects.filter(name__iendswith="foo") +users = await User.objects.filter(updated_at__date=date.today()) ``` ### Using diff --git a/mongoz/conf/global_settings.py b/mongoz/conf/global_settings.py index d12a759..b9460ad 100644 --- a/mongoz/conf/global_settings.py +++ b/mongoz/conf/global_settings.py @@ -36,7 +36,8 @@ class MongozSettings(Settings): "startswith": "startswith", "istartswith": "istartswith", "endswith": "endswith", - "iendswith": "iendswith" + "iendswith": "iendswith", + "date": "date", } def get_operator(self, name: str) -> "Expression": diff --git a/mongoz/core/db/querysets/core/manager.py b/mongoz/core/db/querysets/core/manager.py index 8bb51e9..9db54d4 100644 --- a/mongoz/core/db/querysets/core/manager.py +++ b/mongoz/core/db/querysets/core/manager.py @@ -1,3 +1,4 @@ +from datetime import datetime, timedelta from typing import ( TYPE_CHECKING, Any, @@ -26,9 +27,16 @@ ORDER_EQUALITY, VALUE_EQUALITY, ) -from mongoz.core.db.querysets.core.protocols import AwaitableQuery, MongozDocument +from mongoz.core.db.querysets.core.protocols import ( + AwaitableQuery, + MongozDocument, +) from mongoz.core.db.querysets.expressions import Expression, SortExpression -from mongoz.exceptions import DocumentNotFound, FieldDefinitionError, MultipleDocumentsReturned +from mongoz.exceptions import ( + DocumentNotFound, + FieldDefinitionError, + MultipleDocumentsReturned, +) from mongoz.protocols.queryset import QuerySetProtocol from mongoz.utils.enums import OrderEnum @@ -88,8 +96,12 @@ class registry using the database_name that provided in \ - return the self instance. """ manager: "Manager" = self.clone() - database = manager.model_class.meta.registry.get_database(database_name) - manager._collection = database.get_collection(manager._collection.name)._collection + database = manager.model_class.meta.registry.get_database( + database_name + ) + manager._collection = database.get_collection( + manager._collection.name + )._collection return manager def clone(self) -> Any: @@ -107,7 +119,9 @@ def clone(self) -> Any: def validate_only_and_defer(self) -> None: if self._only_fields and self._defer_fields: - raise FieldDefinitionError("You cannot use .only() and .defer() at the same time.") + raise FieldDefinitionError( + "You cannot use .only() and .defer() at the same time." + ) def get_operator(self, name: str) -> Expression: """ @@ -123,7 +137,9 @@ def _find_and_replace_id(self, key: str) -> str: return cast(str, self.model_class.id.pydantic_field.alias) # type: ignore return key - def filter_only_and_defer(self, *fields: Sequence[str], is_only: bool = False) -> "Manager": + def filter_only_and_defer( + self, *fields: Sequence[str], is_only: bool = False + ) -> "Manager": """ Validates if should be defer or only and checks it out """ @@ -190,9 +206,15 @@ def filter_query(self, exclude: bool = False, **kwargs: Any) -> "Manager": and value ): asc_or_desc = lookup_operator - elif lookup_operator == OrderEnum.ASCENDING and value is False: + elif ( + lookup_operator == OrderEnum.ASCENDING + and value is False + ): asc_or_desc = OrderEnum.DESCENDING - elif lookup_operator == OrderEnum.DESCENDING and value is False: + elif ( + lookup_operator == OrderEnum.DESCENDING + and value is False + ): asc_or_desc = OrderEnum.ASCENDING else: asc_or_desc = OrderEnum.ASCENDING @@ -207,6 +229,17 @@ def filter_query(self, exclude: bool = False, **kwargs: Any) -> "Manager": operator = self.get_operator(lookup_operator) expression = operator(field_name, value) # type: ignore + # For "date" + elif lookup_operator == "date": + operator = self.get_operator("gte") + from_datetime = datetime.combine( + value, datetime.min.time() + ) + expression1 = operator(field_name, from_datetime) # type: ignore + clauses.append(expression1) + operator = self.get_operator("lt") + expression = operator(field_name, from_datetime + timedelta(days=1)) # type: ignore + # Add expression to the clauses clauses.append(expression) @@ -249,7 +282,9 @@ def raw(self, *values: Union[bool, Dict, Expression]) -> "Manager": """ manager: "Manager" = self.clone() for value in values: - assert isinstance(value, (dict, Expression)), "Invalid argument to Raw" + assert isinstance( + value, (dict, Expression) + ), "Invalid argument to Raw" if isinstance(value, dict): query_expressions = Expression.unpack(value) manager._filter.extend(query_expressions) @@ -290,7 +325,10 @@ def skip(self, count: int = 0) -> "Manager[T]": return manager def sort( - self, key: Union[Any, None] = None, direction: Union[Order, None] = None, **kwargs: Any + self, + key: Union[Any, None] = None, + direction: Union[Order, None] = None, + **kwargs: Any, ) -> "Manager[T]": """Sort by (key, direction) or [(key, direction)].""" manager: "Manager" = self.clone() @@ -364,7 +402,7 @@ async def _all(self) -> List[T]: only_fields=manager._only_fields, is_defer_fields=is_defer_fields, defer_fields=manager._defer_fields, - from_collection=manager._collection + from_collection=manager._collection, ) async for document in cursor ] @@ -378,14 +416,18 @@ async def count(self, **kwargs: Any) -> int: manager: "Manager" = self.clone() filter_query = Expression.compile_many(manager._filter) - return cast(int, await manager._collection.count_documents(filter_query)) + return cast( + int, await manager._collection.count_documents(filter_query) + ) async def create(self, **kwargs: Any) -> "Document": """ Creates a mongo db document. """ manager: "Manager" = self.clone() - instance = await manager.model_class(**kwargs).create(manager._collection) + instance = await manager.model_class(**kwargs).create( + manager._collection + ) return cast("Document", instance) async def delete(self) -> int: @@ -448,14 +490,19 @@ async def get_or_none(self, **kwargs: Any) -> Union["T", "Document", None]: raise MultipleDocumentsReturned() return cast(T, objects[0]) - async def get_or_create(self, defaults: Union[Dict[str, Any], None] = None) -> T: + async def get_or_create( + self, defaults: Union[Dict[str, Any], None] = None + ) -> T: manager: "Manager" = self.clone() if not defaults: defaults = {} - data = {expression.key: expression.value for expression in manager._filter} + data = { + expression.key: expression.value for expression in manager._filter + } defaults = { - (key if isinstance(key, str) else key._name): value for key, value in defaults.items() + (key if isinstance(key, str) else key._name): value + for key, value in defaults.items() } try: @@ -528,12 +575,21 @@ async def update_many(self, **kwargs: Any) -> List[T]: values = model.model_dump() filter_query = Expression.compile_many(manager._filter) - await manager._collection.update_many(filter_query, {"$set": values}) + await manager._collection.update_many( + filter_query, {"$set": values} + ) _filter = [ - expression for expression in manager._filter if expression.key not in values + expression + for expression in manager._filter + if expression.key not in values ] - _filter.extend([Expression(key, "$eq", value) for key, value in values.items()]) + _filter.extend( + [ + Expression(key, "$eq", value) + for key, value in values.items() + ] + ) manager._filter = _filter return await manager._all() @@ -556,7 +612,9 @@ async def bulk_update(self, **kwargs: Any) -> List[T]: manager: "Manager" = self.clone() return await manager.update_many(**kwargs) - async def get_document_by_id(self, id: Union[str, bson.ObjectId]) -> "Document": + async def get_document_by_id( + self, id: Union[str, bson.ObjectId] + ) -> "Document": """ Gets a document by the id """ diff --git a/tests/models/manager/test_query_builder.py b/tests/models/manager/test_query_builder.py index abe7406..8ebf972 100644 --- a/tests/models/manager/test_query_builder.py +++ b/tests/models/manager/test_query_builder.py @@ -1,3 +1,4 @@ +from datetime import date, datetime from typing import AsyncGenerator, List, Optional import pydantic @@ -21,6 +22,7 @@ class Movie(Document): year: int = mongoz.Integer() tags: Optional[List[str]] = mongoz.Array(str, null=True) uuid: Optional[ObjectId] = mongoz.ObjectId(null=True) + released_at: datetime = mongoz.DateTime(null=True) class Meta: registry = client @@ -40,7 +42,9 @@ async def prepare_database() -> AsyncGenerator: async def test_model_query_builder() -> None: - await Movie.objects.create(name="Downfall", year=2004) + await Movie.objects.create( + name="Downfall", year=2004, released_at=datetime.now() + ) await Movie.objects.create(name="The Two Towers", year=2002) await Movie.objects.create(name="Casablanca", year=1942) await Movie.objects.create(name="Gone with the wind", year=1939) @@ -69,7 +73,9 @@ async def test_model_query_builder() -> None: assert movie.name == "Downfall" assert movie.year == 2004 - movie = await Movie.objects.filter(name="Casablanca").filter(year=1942).get() + movie = ( + await Movie.objects.filter(name="Casablanca").filter(year=1942).get() + ) assert movie.name == "Casablanca" assert movie.year == 1942 @@ -81,7 +87,9 @@ async def test_model_query_builder() -> None: assert movie.name == "Casablanca" assert movie.year == 1942 - movie = await Movie.objects.filter(year__gt=2000).filter(year__lt=2003).get() + movie = ( + await Movie.objects.filter(year__gt=2000).filter(year__lt=2003).get() + ) assert movie.name == "The Two Towers" assert movie.year == 2002 @@ -129,6 +137,12 @@ async def test_model_query_builder() -> None: assert len(movies) == 1 assert movies[0].name.lower() == "gone with the Wind".lower() + movies = await Movie.objects.filter(released_at__date=date.today()) + assert len(movies) == 1 + assert movies[0].name == "Downfall" + assert movies[0].year == 2004 + + async def test_query_builder_in_list(): await Movie.objects.create(name="Downfall", year=2004) await Movie.objects.create(name="The Two Towers", year=2002) @@ -142,7 +156,9 @@ async def test_query_builder_in_list(): assert len(movies) == 2 -@pytest.mark.parametrize("values", [{2002, 2004}, {"year": 2002}], ids=["as-set", "as-dict"]) +@pytest.mark.parametrize( + "values", [{2002, 2004}, {"year": 2002}], ids=["as-set", "as-dict"] +) async def test_query_builder_in_list_raise_assertation_error(values): await Movie.objects.create(name="Downfall", year=2004) await Movie.objects.create(name="The Two Towers", year=2002)