From 42d37e639c696ef9314e54667c2a861c79e54077 Mon Sep 17 00:00:00 2001 From: Peter Volf Date: Sat, 18 Jan 2025 23:09:33 +0100 Subject: [PATCH 1/3] feat: add aggregation utility, tests, and some nitpicking here and there, refs #19 --- motorhead/__init__.py | 11 ++- motorhead/aggregation.py | 105 ++++++++++++++++++++++++ motorhead/operator.py | 6 +- motorhead/query.py | 12 +-- motorhead/service.py | 23 +++--- motorhead/typing.py | 4 + motorhead/validator.py | 8 +- pyproject.toml | 1 + tests/test_aggregation.py | 166 ++++++++++++++++++++++++++++++++++++++ 9 files changed, 313 insertions(+), 23 deletions(-) create mode 100644 motorhead/aggregation.py create mode 100644 tests/test_aggregation.py diff --git a/motorhead/__init__.py b/motorhead/__init__.py index 3cb7af2..73ceab2 100644 --- a/motorhead/__init__.py +++ b/motorhead/__init__.py @@ -1,4 +1,7 @@ from . import operator as operator +from .aggregation import Aggregation as Aggregation +from .aggregation import AggregationStage as AggregationStage +from .aggregation import make_aggregation_stage as make_aggregation_stage from .bound_method_wrapper import BoundMethodWrapper as BoundMethodWrapper from .delete_rule import DeleteConfig as DeleteConfig from .delete_rule import DeleteError as DeleteError @@ -13,13 +16,9 @@ from .query import Q as Q from .query import Query as Query from .query import Queryable as Queryable -from .service import DeleteResult as DeleteResult -from .service import InsertManyResult as InsertManyResult -from .service import InsertOneResult as InsertOneResult from .service import Service as Service from .service import ServiceConfig as ServiceConfig from .service import ServiceException as ServiceException -from .service import UpdateResult as UpdateResult from .typing import AgnosticClient as AgnosticClient from .typing import AgnosticClientSession as AgnosticClientSession from .typing import AgnosticCollection as AgnosticCollection @@ -35,15 +34,19 @@ from .typing import CollectionOptions as CollectionOptions from .typing import DatabaseProvider as DatabaseProvider from .typing import DeleteOptions as DeleteOptions +from .typing import DeleteResult as DeleteResult from .typing import FindOptions as FindOptions from .typing import IndexData as IndexData from .typing import InsertManyOptions as InsertManyOptions +from .typing import InsertManyResult as InsertManyResult from .typing import InsertOneOptions as InsertOneOptions +from .typing import InsertOneResult as InsertOneResult from .typing import MongoProjection as MongoProjection from .typing import MongoQuery as MongoQuery from .typing import UpdateManyOptions as UpdateManyOptions from .typing import UpdateObject as UpdateObject from .typing import UpdateOneOptions as UpdateOneOptions +from .typing import UpdateResult as UpdateResult from .validator import ValidationError as ValidationError from .validator import Validator as Validator from .validator import validator as validator diff --git a/motorhead/aggregation.py b/motorhead/aggregation.py new file mode 100644 index 0000000..4594e3e --- /dev/null +++ b/motorhead/aggregation.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal + +if TYPE_CHECKING: + from collections.abc import Iterable + from typing import TypeAlias + + from typing_extensions import Self + + from .typing import Clause + +AggregationStage = Literal[ + "$addFields", + "$bucket", + "$bucketAuto", + "$changeStream", + "$changeStreamSplitLargeEvent", + "$collStats", + "$count", + "$currentOp", + "$densify", + "$documents", + "$facet", + "$fill", + "$geoNear", + "$graphLookup", + "$group", + "$indexStats", + "$limit", + "$listLocalSessions", + "$listSampledQueries", + "$listSearchIndexes", + "$listSessions", + "$lookup", + "$match", + "$merge", + "$out", + "$planCacheStats", + "$project", + "$querySettings", + "$redact", + "$replaceRoot", + "$replaceWith", + "$sample", + "$search", + "$searchMeta", + "$set", + "$setWindowFields", + "$shardedDataDistribution", + "$skip", + "$sort", + "$sortByCount", + "$unionWith", + "$unset", + "$unwind", + "$vectorSearch", +] +"""Aggregation pipeline stage.""" + + +AggregationData: TypeAlias = Any + + +def make_aggregation_stage( + stage: AggregationStage, value: AggregationData | Clause +) -> dict[str, AggregationData]: + """ + Creates an aggregation pipeline stage. + + Arguments: + stage: The stage operator. + value: The stage operator's content. + + Returns: + The aggregation pipeline stage. + """ + return {stage: value.to_mongo() if hasattr(value, "to_mongo") else value} + + +class Aggregation(list[dict[str, AggregationData]]): + """Aggregation pipeline.""" + + def __init__(self, stages: Iterable[AggregationData] = ()) -> None: + """ + Initialization. + + Arguments: + stages: The aggregation pipeline stages. + """ + super().__init__(stages) + + def stage(self, stage: AggregationStage, value: AggregationData | Clause) -> Self: + """ + Adds the stage to the aggregation pipeline. + + Arguments: + stage: The stage operator. + value: The stage operator's content. + + Returns: + The aggregation pipeline. + """ + self.append(make_aggregation_stage(stage, value)) + return self diff --git a/motorhead/operator.py b/motorhead/operator.py index 6952a39..39b8441 100644 --- a/motorhead/operator.py +++ b/motorhead/operator.py @@ -1,9 +1,11 @@ from __future__ import annotations -from collections.abc import Generator -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING if TYPE_CHECKING: + from collections.abc import Generator + from typing import Any + from .query import Field from .typing import Clause diff --git a/motorhead/query.py b/motorhead/query.py index 635f2e3..262a464 100644 --- a/motorhead/query.py +++ b/motorhead/query.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TYPE_CHECKING, TypeVar from pydantic import BaseModel @@ -23,12 +23,14 @@ Size, Type, ) -from .typing import Clause - -_T = TypeVar("_T", bound=BaseModel) if TYPE_CHECKING: - from .typing import MongoQuery + from typing import Any + + from .typing import Clause, MongoQuery + + +_T = TypeVar("_T", bound=BaseModel) class Field: diff --git a/motorhead/service.py b/motorhead/service.py index a8e800a..ea5fff4 100644 --- a/motorhead/service.py +++ b/motorhead/service.py @@ -1,17 +1,19 @@ from __future__ import annotations -from collections.abc import AsyncGenerator, Callable, Coroutine, Generator, Iterable, Mapping, Sequence from contextlib import AbstractAsyncContextManager, asynccontextmanager, nullcontext -from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypedDict, TypeVar, get_args +from typing import TYPE_CHECKING, Generic, TypedDict, TypeVar, get_args from bson import ObjectId from pydantic import BaseModel -from pymongo.results import DeleteResult, InsertManyResult, InsertOneResult, UpdateResult +from .delete_rule import DeleteRule from .operator import ensure_dict -from .typing import ClauseOrMongoQuery +from .validator import Validator if TYPE_CHECKING: + from collections.abc import AsyncGenerator, Callable, Coroutine, Generator, Iterable, Mapping, Sequence + from typing import Any, ClassVar + from .typing import ( AgnosticClient, AgnosticClientSession, @@ -20,29 +22,30 @@ AgnosticCursor, AgnosticDatabase, AgnosticLatentCommandCursor, + ClauseOrMongoQuery, Collation, CollectionOptions, DeleteOptions, + DeleteResult, FindOptions, IndexData, InsertManyOptions, + InsertManyResult, InsertOneOptions, + InsertOneResult, MongoProjection, UpdateManyOptions, UpdateObject, UpdateOneOptions, + UpdateResult, ) -from .delete_rule import DeleteRule -from .validator import Validator __all__ = ( "BaseService", - "DeleteResult", - "InsertManyResult", - "InsertOneResult", "Service", - "UpdateResult", + "ServiceConfig", + "ServiceException", ) TInsert = TypeVar("TInsert", bound=BaseModel) diff --git a/motorhead/typing.py b/motorhead/typing.py index d01b41a..293bb96 100644 --- a/motorhead/typing.py +++ b/motorhead/typing.py @@ -12,6 +12,10 @@ from motor.core import AgnosticDatabase as _AgnosticDatabase from motor.core import AgnosticLatentCommandCursor as _AgnosticLatentCommandCursor from pymongo.collation import Collation as PMCollation +from pymongo.results import DeleteResult as DeleteResult +from pymongo.results import InsertManyResult as InsertManyResult +from pymongo.results import InsertOneResult as InsertOneResult +from pymongo.results import UpdateResult as UpdateResult if TYPE_CHECKING: from bson.codec_options import CodecOptions diff --git a/motorhead/validator.py b/motorhead/validator.py index e859dc6..fcd7816 100644 --- a/motorhead/validator.py +++ b/motorhead/validator.py @@ -1,9 +1,13 @@ -from collections.abc import Callable, Coroutine -from typing import Literal, TypeVar +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal, TypeVar from .bound_method_wrapper import BoundMethodWrapper from .typing import ClauseOrMongoQuery +if TYPE_CHECKING: + from collections.abc import Callable, Coroutine + __all__ = ( "ValidationError", "Validator", diff --git a/pyproject.toml b/pyproject.toml index 3e54b66..01da03d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ pytest = "^8.3.3" pytest-asyncio = "^0.24.0" pytest-docker = "^3.1.1" pytest-random-order = "^1.1.1" +typing-extensions = "^4.12.2" [tool.mypy] strict = true diff --git a/tests/test_aggregation.py b/tests/test_aggregation.py new file mode 100644 index 0000000..5369aff --- /dev/null +++ b/tests/test_aggregation.py @@ -0,0 +1,166 @@ +from dataclasses import dataclass +from typing import Any, get_args + +import pytest + +from motorhead import ( + Aggregation, + AggregationStage, + AgnosticDatabase, + BaseDocument, + Document, + Service, + make_aggregation_stage, +) + +aggregation_stages = get_args(AggregationStage) + + +@dataclass(frozen=True) +class DummyClause: + data: dict[str, Any] + + def to_mongo(self) -> dict[str, Any]: + return self.data + + +class Person(Document): + name: str + lucky_number: int = -1 + group_id: int = 1 + + @classmethod + def group_id_from_int(cls, value: int) -> int: + value = value % 10 + if value % 5 == 0: + return 5 + if value % 3 == 0: + return 3 + return 1 + + +class PersonData(BaseDocument): + name: str + lucky_number: int = -1 + group_id: int = 1 + + +class PersonService(Service[PersonData, PersonData]): + collection_name = "test_aggregation_person" + + +@pytest.fixture(scope="session") +def person_service(*, database: AgnosticDatabase) -> PersonService: + return PersonService(database) + + +@pytest.mark.parametrize( + ("stage", "value"), + [(stage, {"first": 1, "second": 2}) for stage in aggregation_stages], +) +def test_make_aggregation_stage(stage: AggregationStage, value: dict[str, Any]) -> None: + result = make_aggregation_stage(stage, value) + assert isinstance(result, dict) + assert len(result) == 1 + assert result[stage] is value + + result = make_aggregation_stage(stage, DummyClause(value)) + assert isinstance(result, dict) + assert len(result) == 1 + assert result[stage] is value + + +class TestAggregation: + @pytest.mark.parametrize( + ("stage", "value"), + [(stage, {"first": 1, "second": 2}) for stage in aggregation_stages], + ) + def test_stage(self, stage: AggregationStage, value: dict[str, Any]) -> None: + aggr = Aggregation() + result = aggr.stage(stage, value) + assert isinstance(result, list) + assert result is aggr + assert len(result) == 1 + assert isinstance(result[0], dict) + assert result[0][stage] is value + + aggr = Aggregation() + result = aggr.stage(stage, DummyClause(value)) + assert isinstance(result, list) + assert result is aggr + assert len(result) == 1 + assert isinstance(result[0], dict) + assert result[0][stage] is value + + def test_init(self) -> None: + dummy_stage_data = {"first": 1, "second": 2} + + aggr = Aggregation(make_aggregation_stage(stage, dummy_stage_data) for stage in aggregation_stages) + assert len(aggr) == len(aggregation_stages) + + for stage_id, stage_data in zip(aggregation_stages, aggr, strict=True): + assert stage_data[stage_id] is dummy_stage_data + + def test_stage_chaining(self) -> None: + dummy_stage_data = {"first": 1, "second": 2} + + original_aggr = aggr = Aggregation() + + for stage in aggregation_stages: + aggr = aggr.stage(stage, dummy_stage_data) + + for stage_id, stage_data in zip(aggregation_stages, aggr, strict=True): + assert stage_data[stage_id] is dummy_stage_data + + assert original_aggr == aggr + assert len(aggr) == len(aggregation_stages) + + for stage_id, stage_data in zip(aggregation_stages, aggr, strict=True): + assert stage_data[stage_id] is dummy_stage_data + + @pytest.mark.asyncio(loop_scope="session") + async def test_with_service(self, database: AgnosticDatabase, person_service: PersonService) -> None: + try: + insert_result = await person_service.insert_many( + PersonData( + name=f"Person {i}", + lucky_number=1000 + i, + group_id=Person.group_id_from_int(i), + ) + for i in range(100) + ) + assert len(insert_result.inserted_ids) == 100 + + # -- Count all documents + result = [doc async for doc in person_service.aggregate(Aggregation().stage("$count", "total"))] + assert len(result) == 1 + assert isinstance(result[0], dict) + assert len(result[0]) == 1 + assert result[0]["total"] == 100 + + # -- Filter and count documents + result = [ + doc + async for doc in person_service.aggregate( + Aggregation().stage("$match", {"lucky_number": {"$lt": 1024}}).stage("$count", "total") + ) + ] + assert len(result) == 1 + assert isinstance(result[0], dict) + assert len(result[0]) == 1 + assert result[0]["total"] == 24 + + # -- Sort and count documents in groups + result = [ + doc + async for doc in person_service.aggregate( + Aggregation() + .stage("$sort", {"lucky_number": 1}) + .stage("$group", {"_id": "$group_id", "size": {"$count": {}}}) + .stage("$sort", {"size": -1}) + ) + ] + assert result == [{"_id": 1, "size": 50}, {"_id": 3, "size": 30}, {"_id": 5, "size": 20}] + + finally: + await database.drop_collection(person_service.collection_name) From 45fa005d1c98621500b7d2e51529ae5346275903 Mon Sep 17 00:00:00 2001 From: Peter Volf Date: Sat, 18 Jan 2025 23:13:30 +0100 Subject: [PATCH 2/3] docs: add docs for the aggregation module, refs #19 --- docs/api/aggregation.md | 4 ++++ mkdocs.yml | 1 + 2 files changed, 5 insertions(+) create mode 100644 docs/api/aggregation.md diff --git a/docs/api/aggregation.md b/docs/api/aggregation.md new file mode 100644 index 0000000..ea5aedf --- /dev/null +++ b/docs/api/aggregation.md @@ -0,0 +1,4 @@ +# ::: motorhead.aggregation + + options: + show_root_heading: true diff --git a/mkdocs.yml b/mkdocs.yml index c33cab6..b445e10 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -49,6 +49,7 @@ nav: - api/service.md - api/query.md - api/operator.md + - api/aggregation.md - Model: - api/model/document.md - api/model/objectid.md From 6886788535d6cb4465a97ee5f03c3afb95ba48ef Mon Sep 17 00:00:00 2001 From: Peter Volf Date: Sat, 18 Jan 2025 23:14:03 +0100 Subject: [PATCH 3/3] chore: bump patch version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 01da03d..749a09d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ tracker = "https://github.com/volfpeter/motorhead/issues" [tool.poetry] name = "motorhead" -version = "0.2501.1" +version = "0.2501.2" description = "Async MongoDB with vanilla Pydantic v2+ - made easy." authors = ["Peter Volf "] readme = "README.md"