Skip to content

Commit

Permalink
Summary: Add the date opertaor to query on datetime field. (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
harshalizode authored Jan 21, 2025
1 parent 511bc66 commit f0b9307
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 26 deletions.
4 changes: 3 additions & 1 deletion docs/en/docs/queries.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,9 @@ The same special operators are also automatically added on every column.
* **neq** - Filter instances by not equal to condition.
* **startswith** - Filter instances that start with a specific value.
* **endswith** - Filter instances that end with a specific value.
* **istartswith** - Filter instances that start with a specific value, case-insensitive.
* **istartswith** - Filter instances that start with a specific value, case-insensitive.
* **iendswith** - Filter instances that end with a specific value, case-insensitive.
* **date** - Filter instances by date.

##### Example

Expand All @@ -153,6 +154,7 @@ users = await User.objects.filter(name__startswith="foo")
users = await User.objects.filter(name__istartswith="foo")
users = await User.objects.filter(name__endswith="foo")
users = await User.objects.filter(name__iendswith="foo")
users = await User.objects.filter(updated_at__date=date.today())
```

### Using
Expand Down
3 changes: 2 additions & 1 deletion mongoz/conf/global_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class MongozSettings(Settings):
"startswith": "startswith",
"istartswith": "istartswith",
"endswith": "endswith",
"iendswith": "iendswith"
"iendswith": "iendswith",
"date": "date",
}

def get_operator(self, name: str) -> "Expression":
Expand Down
98 changes: 78 additions & 20 deletions mongoz/core/db/querysets/core/manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import datetime, timedelta
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -26,9 +27,16 @@
ORDER_EQUALITY,
VALUE_EQUALITY,
)
from mongoz.core.db.querysets.core.protocols import AwaitableQuery, MongozDocument
from mongoz.core.db.querysets.core.protocols import (
AwaitableQuery,
MongozDocument,
)
from mongoz.core.db.querysets.expressions import Expression, SortExpression
from mongoz.exceptions import DocumentNotFound, FieldDefinitionError, MultipleDocumentsReturned
from mongoz.exceptions import (
DocumentNotFound,
FieldDefinitionError,
MultipleDocumentsReturned,
)
from mongoz.protocols.queryset import QuerySetProtocol
from mongoz.utils.enums import OrderEnum

Expand Down Expand Up @@ -88,8 +96,12 @@ class registry using the database_name that provided in \
- return the self instance.
"""
manager: "Manager" = self.clone()
database = manager.model_class.meta.registry.get_database(database_name)
manager._collection = database.get_collection(manager._collection.name)._collection
database = manager.model_class.meta.registry.get_database(
database_name
)
manager._collection = database.get_collection(
manager._collection.name
)._collection
return manager

def clone(self) -> Any:
Expand All @@ -107,7 +119,9 @@ def clone(self) -> Any:

def validate_only_and_defer(self) -> None:
if self._only_fields and self._defer_fields:
raise FieldDefinitionError("You cannot use .only() and .defer() at the same time.")
raise FieldDefinitionError(
"You cannot use .only() and .defer() at the same time."
)

def get_operator(self, name: str) -> Expression:
"""
Expand All @@ -123,7 +137,9 @@ def _find_and_replace_id(self, key: str) -> str:
return cast(str, self.model_class.id.pydantic_field.alias) # type: ignore
return key

def filter_only_and_defer(self, *fields: Sequence[str], is_only: bool = False) -> "Manager":
def filter_only_and_defer(
self, *fields: Sequence[str], is_only: bool = False
) -> "Manager":
"""
Validates if should be defer or only and checks it out
"""
Expand Down Expand Up @@ -190,9 +206,15 @@ def filter_query(self, exclude: bool = False, **kwargs: Any) -> "Manager":
and value
):
asc_or_desc = lookup_operator
elif lookup_operator == OrderEnum.ASCENDING and value is False:
elif (
lookup_operator == OrderEnum.ASCENDING
and value is False
):
asc_or_desc = OrderEnum.DESCENDING
elif lookup_operator == OrderEnum.DESCENDING and value is False:
elif (
lookup_operator == OrderEnum.DESCENDING
and value is False
):
asc_or_desc = OrderEnum.ASCENDING
else:
asc_or_desc = OrderEnum.ASCENDING
Expand All @@ -207,6 +229,17 @@ def filter_query(self, exclude: bool = False, **kwargs: Any) -> "Manager":
operator = self.get_operator(lookup_operator)
expression = operator(field_name, value) # type: ignore

# For "date"
elif lookup_operator == "date":
operator = self.get_operator("gte")
from_datetime = datetime.combine(
value, datetime.min.time()
)
expression1 = operator(field_name, from_datetime) # type: ignore
clauses.append(expression1)
operator = self.get_operator("lt")
expression = operator(field_name, from_datetime + timedelta(days=1)) # type: ignore

# Add expression to the clauses
clauses.append(expression)

Expand Down Expand Up @@ -249,7 +282,9 @@ def raw(self, *values: Union[bool, Dict, Expression]) -> "Manager":
"""
manager: "Manager" = self.clone()
for value in values:
assert isinstance(value, (dict, Expression)), "Invalid argument to Raw"
assert isinstance(
value, (dict, Expression)
), "Invalid argument to Raw"
if isinstance(value, dict):
query_expressions = Expression.unpack(value)
manager._filter.extend(query_expressions)
Expand Down Expand Up @@ -290,7 +325,10 @@ def skip(self, count: int = 0) -> "Manager[T]":
return manager

def sort(
self, key: Union[Any, None] = None, direction: Union[Order, None] = None, **kwargs: Any
self,
key: Union[Any, None] = None,
direction: Union[Order, None] = None,
**kwargs: Any,
) -> "Manager[T]":
"""Sort by (key, direction) or [(key, direction)]."""
manager: "Manager" = self.clone()
Expand Down Expand Up @@ -364,7 +402,7 @@ async def _all(self) -> List[T]:
only_fields=manager._only_fields,
is_defer_fields=is_defer_fields,
defer_fields=manager._defer_fields,
from_collection=manager._collection
from_collection=manager._collection,
)
async for document in cursor
]
Expand All @@ -378,14 +416,18 @@ async def count(self, **kwargs: Any) -> int:
manager: "Manager" = self.clone()

filter_query = Expression.compile_many(manager._filter)
return cast(int, await manager._collection.count_documents(filter_query))
return cast(
int, await manager._collection.count_documents(filter_query)
)

async def create(self, **kwargs: Any) -> "Document":
"""
Creates a mongo db document.
"""
manager: "Manager" = self.clone()
instance = await manager.model_class(**kwargs).create(manager._collection)
instance = await manager.model_class(**kwargs).create(
manager._collection
)
return cast("Document", instance)

async def delete(self) -> int:
Expand Down Expand Up @@ -448,14 +490,19 @@ async def get_or_none(self, **kwargs: Any) -> Union["T", "Document", None]:
raise MultipleDocumentsReturned()
return cast(T, objects[0])

async def get_or_create(self, defaults: Union[Dict[str, Any], None] = None) -> T:
async def get_or_create(
self, defaults: Union[Dict[str, Any], None] = None
) -> T:
manager: "Manager" = self.clone()
if not defaults:
defaults = {}

data = {expression.key: expression.value for expression in manager._filter}
data = {
expression.key: expression.value for expression in manager._filter
}
defaults = {
(key if isinstance(key, str) else key._name): value for key, value in defaults.items()
(key if isinstance(key, str) else key._name): value
for key, value in defaults.items()
}

try:
Expand Down Expand Up @@ -528,12 +575,21 @@ async def update_many(self, **kwargs: Any) -> List[T]:
values = model.model_dump()

filter_query = Expression.compile_many(manager._filter)
await manager._collection.update_many(filter_query, {"$set": values})
await manager._collection.update_many(
filter_query, {"$set": values}
)

_filter = [
expression for expression in manager._filter if expression.key not in values
expression
for expression in manager._filter
if expression.key not in values
]
_filter.extend([Expression(key, "$eq", value) for key, value in values.items()])
_filter.extend(
[
Expression(key, "$eq", value)
for key, value in values.items()
]
)

manager._filter = _filter
return await manager._all()
Expand All @@ -556,7 +612,9 @@ async def bulk_update(self, **kwargs: Any) -> List[T]:
manager: "Manager" = self.clone()
return await manager.update_many(**kwargs)

async def get_document_by_id(self, id: Union[str, bson.ObjectId]) -> "Document":
async def get_document_by_id(
self, id: Union[str, bson.ObjectId]
) -> "Document":
"""
Gets a document by the id
"""
Expand Down
24 changes: 20 additions & 4 deletions tests/models/manager/test_query_builder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import date, datetime
from typing import AsyncGenerator, List, Optional

import pydantic
Expand All @@ -21,6 +22,7 @@ class Movie(Document):
year: int = mongoz.Integer()
tags: Optional[List[str]] = mongoz.Array(str, null=True)
uuid: Optional[ObjectId] = mongoz.ObjectId(null=True)
released_at: datetime = mongoz.DateTime(null=True)

class Meta:
registry = client
Expand All @@ -40,7 +42,9 @@ async def prepare_database() -> AsyncGenerator:


async def test_model_query_builder() -> None:
await Movie.objects.create(name="Downfall", year=2004)
await Movie.objects.create(
name="Downfall", year=2004, released_at=datetime.now()
)
await Movie.objects.create(name="The Two Towers", year=2002)
await Movie.objects.create(name="Casablanca", year=1942)
await Movie.objects.create(name="Gone with the wind", year=1939)
Expand Down Expand Up @@ -69,7 +73,9 @@ async def test_model_query_builder() -> None:
assert movie.name == "Downfall"
assert movie.year == 2004

movie = await Movie.objects.filter(name="Casablanca").filter(year=1942).get()
movie = (
await Movie.objects.filter(name="Casablanca").filter(year=1942).get()
)
assert movie.name == "Casablanca"
assert movie.year == 1942

Expand All @@ -81,7 +87,9 @@ async def test_model_query_builder() -> None:
assert movie.name == "Casablanca"
assert movie.year == 1942

movie = await Movie.objects.filter(year__gt=2000).filter(year__lt=2003).get()
movie = (
await Movie.objects.filter(year__gt=2000).filter(year__lt=2003).get()
)
assert movie.name == "The Two Towers"
assert movie.year == 2002

Expand Down Expand Up @@ -129,6 +137,12 @@ async def test_model_query_builder() -> None:
assert len(movies) == 1
assert movies[0].name.lower() == "gone with the Wind".lower()

movies = await Movie.objects.filter(released_at__date=date.today())
assert len(movies) == 1
assert movies[0].name == "Downfall"
assert movies[0].year == 2004


async def test_query_builder_in_list():
await Movie.objects.create(name="Downfall", year=2004)
await Movie.objects.create(name="The Two Towers", year=2002)
Expand All @@ -142,7 +156,9 @@ async def test_query_builder_in_list():
assert len(movies) == 2


@pytest.mark.parametrize("values", [{2002, 2004}, {"year": 2002}], ids=["as-set", "as-dict"])
@pytest.mark.parametrize(
"values", [{2002, 2004}, {"year": 2002}], ids=["as-set", "as-dict"]
)
async def test_query_builder_in_list_raise_assertation_error(values):
await Movie.objects.create(name="Downfall", year=2004)
await Movie.objects.create(name="The Two Towers", year=2002)
Expand Down

0 comments on commit f0b9307

Please sign in to comment.