From de78d890e81c0e308b14c9c23bdd7ebf7d5626dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Wed, 28 Sep 2022 17:15:20 +0200 Subject: [PATCH 01/13] =?UTF-8?q?=E2=9C=A8=20Use=20bulk=5Fwrite=20for=20sa?= =?UTF-8?q?ve=20operations,=20including=20referenced=20documents?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- odmantic/engine.py | 208 +++++++++++++++++++++++++++++++---------- odmantic/exceptions.py | 7 +- 2 files changed, 164 insertions(+), 51 deletions(-) diff --git a/odmantic/engine.py b/odmantic/engine.py index dce4e1fc..b8948a02 100644 --- a/odmantic/engine.py +++ b/odmantic/engine.py @@ -16,10 +16,12 @@ TypeVar, Union, cast, + overload, ) import pymongo -from pymongo import MongoClient +import pymongo.errors +from pymongo import MongoClient, UpdateOne from pymongo.client_session import ClientSession from pymongo.collection import Collection from pymongo.command_cursor import CommandCursor @@ -59,6 +61,8 @@ AIOSessionType = Union[AsyncIOMotorClientSession, AIOSession, AIOTransaction, None] SyncSessionType = Union[ClientSession, SyncSession, SyncTransaction, None] +AIOCollectionUpdatesType = Dict["AsyncIOMotorCollection", List[Tuple[UpdateOne, Model]]] +SyncCollectionUpdatesType = Dict[Collection, List[Tuple[UpdateOne, Model]]] class BaseCursor(Generic[ModelType]): @@ -291,6 +295,77 @@ def _prepare_find_pipeline( pipeline.extend(BaseEngine._cascade_find_pipeline(model)) return pipeline + def get_collection( + self, model: Type[ModelType] + ) -> Union["Collection", "AsyncIOMotorCollection"]: + raise NotImplementedError() + + @overload + def _prepare_document_updates( # type: ignore + self: "AIOEngine", + instance: ModelType, + *, + collection_updates: Union[AIOCollectionUpdatesType, None] = None, + ) -> AIOCollectionUpdatesType: + ... + + @overload + def _prepare_document_updates( # type: ignore + self: "SyncEngine", + instance: ModelType, + *, + collection_updates: Union[SyncCollectionUpdatesType, None] = None, + ) -> SyncCollectionUpdatesType: + ... + + @overload + def _prepare_document_updates( + self: "BaseEngine", + instance: ModelType, + *, + collection_updates: Union[ + AIOCollectionUpdatesType, SyncCollectionUpdatesType, None + ] = None, + ) -> Union[AIOCollectionUpdatesType, SyncCollectionUpdatesType]: + ... + + def _prepare_document_updates( + self, + instance: ModelType, + *, + collection_updates: Union[ + AIOCollectionUpdatesType, SyncCollectionUpdatesType, None + ] = None, + ) -> Union[AIOCollectionUpdatesType, SyncCollectionUpdatesType]: + """Perform an atomic save operation in the specified session""" + if collection_updates is None: + current_collection_updates = cast( + Union[AIOCollectionUpdatesType, SyncCollectionUpdatesType], {} + ) + else: + current_collection_updates = collection_updates + for ref_field_name in instance.__references__: + sub_instance = cast(Model, getattr(instance, ref_field_name)) + self._prepare_document_updates( + sub_instance, collection_updates=current_collection_updates + ) + + fields_to_update = instance.__fields_modified__ | instance.__mutable_fields__ + if len(fields_to_update) > 0: + doc = instance.doc(include=fields_to_update) + collection = self.get_collection(type(instance)) + current_collection_updates.setdefault(collection, []).append( + ( + UpdateOne( + filter=instance.doc(include={instance.__primary_field__}), + update={"$set": doc}, + upsert=True, + ), + instance, + ) + ) + return current_collection_updates + class AIOEngine(BaseEngine): """The AIOEngine object is responsible for handling database operations with MongoDB @@ -511,28 +586,38 @@ async def find_one( return None return results[0] + async def _save_collection_updates( + self, + collection_updates: AIOCollectionUpdatesType, + session: "AsyncIOMotorClientSession", + ) -> None: + # reverse so that the last collections added, the ones for sub-documents, are + # saved first + for collection, updates in reversed(collection_updates.items()): + update_operations = [update[0] for update in updates] + update_instances = [update[1] for update in updates] + await collection.bulk_write( + update_operations, + session=session, + ) + for inst in update_instances: + object.__setattr__(inst, "__fields_modified__", set()) + async def _save( self, instance: ModelType, session: "AsyncIOMotorClientSession" ) -> ModelType: """Perform an atomic save operation in the specified session""" - for ref_field_name in instance.__references__: - sub_instance = cast(Model, getattr(instance, ref_field_name)) - await self._save(sub_instance, session) - - fields_to_update = instance.__fields_modified__ | instance.__mutable_fields__ - if len(fields_to_update) > 0: - doc = instance.doc(include=fields_to_update) - collection = self.get_collection(type(instance)) - try: - await collection.update_one( - instance.doc(include={instance.__primary_field__}), - {"$set": doc}, - upsert=True, - session=session, - ) - except pymongo.errors.DuplicateKeyError as e: + collection_updates = self._prepare_document_updates(instance) + try: + await self._save_collection_updates( + collection_updates=collection_updates, session=session + ) + except pymongo.errors.DuplicateKeyError as e: + raise DuplicateKeyError(instance, e) + except pymongo.errors.BulkWriteError as e: + if e.details["writeErrors"][0]["code"] == 11000: raise DuplicateKeyError(instance, e) - object.__setattr__(instance, "__fields_modified__", set()) + raise return instance async def save( @@ -611,17 +696,22 @@ async def save_all( #noqa: DAR402 DuplicateKeyError --> """ + collections_updates: AIOCollectionUpdatesType = {} + for instance in instances: + self._prepare_document_updates( + instance, collection_updates=collections_updates + ) if session: - added_instances = [ - await self._save(instance, self._get_session(session)) - for instance in instances - ] + await self._save_collection_updates( + collection_updates=collections_updates, + session=self._get_session(session), + ) else: async with await self.client.start_session() as local_session: - added_instances = [ - await self._save(instance, local_session) for instance in instances - ] - return added_instances + await self._save_collection_updates( + collection_updates=collections_updates, session=local_session + ) + return list(instances) async def delete( self, @@ -919,26 +1009,37 @@ def find_one( return None return results[0] + def _save_collection_updates( + self, + collection_updates: SyncCollectionUpdatesType, + session: ClientSession, + ) -> None: + # reverse so that the last collections added, the ones for sub-documents, are + # saved first + for collection, updates in reversed(collection_updates.items()): + update_operations = [update[0] for update in updates] + update_instances = [update[1] for update in updates] + collection.bulk_write( + update_operations, + session=session, + ) + for inst in update_instances: + object.__setattr__(inst, "__fields_modified__", set()) + def _save(self, instance: ModelType, session: "ClientSession") -> ModelType: """Perform an atomic save operation in the specified session""" - for ref_field_name in instance.__references__: - sub_instance = cast(Model, getattr(instance, ref_field_name)) - self._save(sub_instance, session) + collection_updates = self._prepare_document_updates(instance) - fields_to_update = instance.__fields_modified__ | instance.__mutable_fields__ - if len(fields_to_update) > 0: - doc = instance.doc(include=fields_to_update) - collection = self.get_collection(type(instance)) - try: - collection.update_one( - instance.doc(include={instance.__primary_field__}), - {"$set": doc}, - upsert=True, - session=session, - ) - except pymongo.errors.DuplicateKeyError as e: + try: + self._save_collection_updates( + collection_updates=collection_updates, session=session + ) + except pymongo.errors.DuplicateKeyError as e: + raise DuplicateKeyError(instance, e) + except pymongo.errors.BulkWriteError as e: + if e.details["writeErrors"][0]["code"] == 11000: raise DuplicateKeyError(instance, e) - object.__setattr__(instance, "__fields_modified__", set()) + raise return instance def save( @@ -1017,17 +1118,24 @@ def save_all( #noqa: DAR402 DuplicateKeyError --> """ + collections_updates: SyncCollectionUpdatesType = {} + for instance in instances: + self._prepare_document_updates( + instance, collection_updates=collections_updates + ) if session: - added_instances = [ - self._save(instance, self._get_session(session)) # type: ignore - for instance in instances - ] + mongo_session = self._get_session(session) + assert mongo_session + self._save_collection_updates( + collection_updates=collections_updates, + session=mongo_session, + ) else: with self.client.start_session() as local_session: - added_instances = [ - self._save(instance, local_session) for instance in instances - ] - return added_instances + self._save_collection_updates( + collection_updates=collections_updates, session=local_session + ) + return list(instances) def delete( self, diff --git a/odmantic/exceptions.py b/odmantic/exceptions.py index 89fb098f..c4e144c2 100644 --- a/odmantic/exceptions.py +++ b/odmantic/exceptions.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Any, List, Sequence, Type, TypeVar, Union import pymongo +import pymongo.errors from pydantic.error_wrappers import ErrorWrapper, ValidationError if TYPE_CHECKING: @@ -43,7 +44,11 @@ class DuplicateKeyError(BaseEngineException): """ def __init__( - self, instance: "Model", driver_error: pymongo.errors.DuplicateKeyError + self, + instance: "Model", + driver_error: Union[ + pymongo.errors.DuplicateKeyError, pymongo.errors.BulkWriteError + ], ): self.instance: "Model" = instance self.driver_error = driver_error From 7f112cf3607bc451c3b0f04fb344dc4df3a61443 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Wed, 28 Sep 2022 17:16:16 +0200 Subject: [PATCH 02/13] =?UTF-8?q?=E2=9C=85=20Update=20tests=20accounting?= =?UTF-8?q?=20for=20new=20usage=20of=20bulk=5Fwrite?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/integration/conftest.py | 2 + tests/integration/test_engine.py | 72 ++++++++++++++++++-------------- 2 files changed, 43 insertions(+), 31 deletions(-) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index c8ca1a06..c96c5bae 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -96,6 +96,7 @@ def f(): collection = Mock() collection.update_one = AsyncMock() collection.aggregate = AsyncMock() + collection.bulk_write = AsyncMock() monkeypatch.setattr(aio_engine, "get_collection", lambda _: collection) return collection @@ -108,6 +109,7 @@ def f(): collection = Mock() collection.update_one = Mock() collection.aggregate = Mock() + collection.bulk_write = Mock() monkeypatch.setattr(sync_engine, "get_collection", lambda _: collection) return collection diff --git a/tests/integration/test_engine.py b/tests/integration/test_engine.py index 630047f8..5f61581a 100644 --- a/tests/integration/test_engine.py +++ b/tests/integration/test_engine.py @@ -2,7 +2,7 @@ import pytest from motor.motor_asyncio import AsyncIOMotorClient -from pymongo import MongoClient +from pymongo import MongoClient, UpdateOne from odmantic.bson import ObjectId from odmantic.engine import AIOEngine, SyncEngine @@ -742,9 +742,10 @@ async def test_only_modified_set_on_save(aio_engine: AIOEngine, aio_mock_collect instance.first_name = "John" collection = aio_mock_collection() await aio_engine.save(instance) - collection.update_one.assert_awaited_once() - (_, set_arg), _ = collection.update_one.await_args - assert set_arg == {"$set": {"first_name": "John"}} + collection.bulk_write.assert_awaited_once() + (update_operations,), session = collection.bulk_write.await_args + update_operation: UpdateOne = update_operations[0] + assert update_operation._doc == {"$set": {"first_name": "John"}} @pytest.mark.usefixtures("engine_one_person") @@ -755,9 +756,10 @@ def test_sync_only_modified_set_on_save(sync_engine: SyncEngine, sync_mock_colle instance.first_name = "John" collection = sync_mock_collection() sync_engine.save(instance) - collection.update_one.assert_called_once() - (_, set_arg), _ = collection.update_one.call_args - assert set_arg == {"$set": {"first_name": "John"}} + collection.bulk_write.assert_called_once() + (update_operations,), session = collection.bulk_write.call_args + update_operation: UpdateOne = update_operations[0] + assert update_operation._doc == {"$set": {"first_name": "John"}} async def test_only_mutable_list_set_on_save( @@ -772,9 +774,10 @@ class M(Model): collection = aio_mock_collection() await aio_engine.save(instance) - collection.update_one.assert_awaited_once() - (_, set_arg), _ = collection.update_one.await_args - set_dict = set_arg["$set"] + collection.bulk_write.assert_awaited_once() + (update_operations,), session = collection.bulk_write.await_args + update_operation: UpdateOne = update_operations[0] + set_dict = update_operation._doc["$set"] assert list(set_dict.keys()) == ["field"] @@ -790,9 +793,10 @@ class M(Model): collection = sync_mock_collection() sync_engine.save(instance) - collection.update_one.assert_called_once() - (_, set_arg), _ = collection.update_one.call_args - set_dict = set_arg["$set"] + collection.bulk_write.assert_called_once() + (update_operations,), session = collection.bulk_write.call_args + update_operation: UpdateOne = update_operations[0] + set_dict = update_operation._doc["$set"] assert list(set_dict.keys()) == ["field"] @@ -810,9 +814,10 @@ class M(Model): collection = aio_mock_collection() await aio_engine.save(instance) - collection.update_one.assert_awaited_once() - (_, set_arg), _ = collection.update_one.await_args - set_dict = set_arg["$set"] + collection.bulk_write.assert_awaited_once() + (update_operations,), session = collection.bulk_write.await_args + update_operation: UpdateOne = update_operations[0] + set_dict = update_operation._doc["$set"] assert set_dict == {"field": [{"a": "hello"}]} @@ -830,9 +835,10 @@ class M(Model): collection = sync_mock_collection() sync_engine.save(instance) - collection.update_one.assert_called_once() - (_, set_arg), _ = collection.update_one.call_args - set_dict = set_arg["$set"] + collection.bulk_write.assert_called_once() + (update_operations,), session = collection.bulk_write.call_args + update_operation: UpdateOne = update_operations[0] + set_dict = update_operation._doc["$set"] assert set_dict == {"field": [{"a": "hello"}]} @@ -850,9 +856,10 @@ class M(Model): collection = aio_mock_collection() await aio_engine.save(instance) - collection.update_one.assert_awaited_once() - (_, set_arg), _ = collection.update_one.await_args - set_dict = set_arg["$set"] + collection.bulk_write.assert_called_once() + (update_operations,), session = collection.bulk_write.call_args + update_operation: UpdateOne = update_operations[0] + set_dict = update_operation._doc["$set"] assert set_dict == {"field": {"hello": {"a": "world"}}} @@ -870,9 +877,10 @@ class M(Model): collection = sync_mock_collection() sync_engine.save(instance) - collection.update_one.assert_called_once() - (_, set_arg), _ = collection.update_one.call_args - set_dict = set_arg["$set"] + collection.bulk_write.assert_called_once() + (update_operations,), session = collection.bulk_write.call_args + update_operation: UpdateOne = update_operations[0] + set_dict = update_operation._doc["$set"] assert set_dict == {"field": {"hello": {"a": "world"}}} @@ -890,9 +898,10 @@ class M(Model): collection = aio_mock_collection() await aio_engine.save(instance) - collection.update_one.assert_awaited_once() - (_, set_arg), _ = collection.update_one.await_args - set_dict = set_arg["$set"] + collection.bulk_write.assert_awaited_once() + (update_operations,), session = collection.bulk_write.await_args + update_operation: UpdateOne = update_operations[0] + set_dict = update_operation._doc["$set"] assert set_dict == { "field": [ {"a": "world"}, @@ -914,9 +923,10 @@ class M(Model): collection = sync_mock_collection() sync_engine.save(instance) - collection.update_one.assert_called_once() - (_, set_arg), _ = collection.update_one.call_args - set_dict = set_arg["$set"] + collection.bulk_write.assert_called_once() + (update_operations,), session = collection.bulk_write.call_args + update_operation: UpdateOne = update_operations[0] + set_dict = update_operation._doc["$set"] assert set_dict == { "field": [ {"a": "world"}, From 59dbcbf69b6cf32cec5b1eaa764214a5d2b4c258 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Wed, 28 Sep 2022 17:42:56 +0200 Subject: [PATCH 03/13] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Refactor=20collectio?= =?UTF-8?q?ns-updates=20dict=20to=20use=20collection=20names,=20as=20some?= =?UTF-8?q?=20are=20not=20hashable?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- odmantic/engine.py | 74 ++++++++++++----------------------- tests/integration/conftest.py | 8 +++- 2 files changed, 30 insertions(+), 52 deletions(-) diff --git a/odmantic/engine.py b/odmantic/engine.py index b8948a02..270edf63 100644 --- a/odmantic/engine.py +++ b/odmantic/engine.py @@ -16,7 +16,6 @@ TypeVar, Union, cast, - overload, ) import pymongo @@ -61,8 +60,7 @@ AIOSessionType = Union[AsyncIOMotorClientSession, AIOSession, AIOTransaction, None] SyncSessionType = Union[ClientSession, SyncSession, SyncTransaction, None] -AIOCollectionUpdatesType = Dict["AsyncIOMotorCollection", List[Tuple[UpdateOne, Model]]] -SyncCollectionUpdatesType = Dict[Collection, List[Tuple[UpdateOne, Model]]] +CollectionUpdatesType = Dict[str, List[Tuple[UpdateOne, Model]]] class BaseCursor(Generic[ModelType]): @@ -300,48 +298,15 @@ def get_collection( ) -> Union["Collection", "AsyncIOMotorCollection"]: raise NotImplementedError() - @overload - def _prepare_document_updates( # type: ignore - self: "AIOEngine", - instance: ModelType, - *, - collection_updates: Union[AIOCollectionUpdatesType, None] = None, - ) -> AIOCollectionUpdatesType: - ... - - @overload - def _prepare_document_updates( # type: ignore - self: "SyncEngine", - instance: ModelType, - *, - collection_updates: Union[SyncCollectionUpdatesType, None] = None, - ) -> SyncCollectionUpdatesType: - ... - - @overload - def _prepare_document_updates( - self: "BaseEngine", - instance: ModelType, - *, - collection_updates: Union[ - AIOCollectionUpdatesType, SyncCollectionUpdatesType, None - ] = None, - ) -> Union[AIOCollectionUpdatesType, SyncCollectionUpdatesType]: - ... - def _prepare_document_updates( self, instance: ModelType, *, - collection_updates: Union[ - AIOCollectionUpdatesType, SyncCollectionUpdatesType, None - ] = None, - ) -> Union[AIOCollectionUpdatesType, SyncCollectionUpdatesType]: + collection_updates: Union[CollectionUpdatesType, None] = None, + ) -> CollectionUpdatesType: """Perform an atomic save operation in the specified session""" if collection_updates is None: - current_collection_updates = cast( - Union[AIOCollectionUpdatesType, SyncCollectionUpdatesType], {} - ) + current_collection_updates = {} else: current_collection_updates = collection_updates for ref_field_name in instance.__references__: @@ -353,8 +318,8 @@ def _prepare_document_updates( fields_to_update = instance.__fields_modified__ | instance.__mutable_fields__ if len(fields_to_update) > 0: doc = instance.doc(include=fields_to_update) - collection = self.get_collection(type(instance)) - current_collection_updates.setdefault(collection, []).append( + collection_name = type(instance).__collection__ + current_collection_updates.setdefault(collection_name, []).append( ( UpdateOne( filter=instance.doc(include={instance.__primary_field__}), @@ -400,6 +365,11 @@ def __init__( client = AsyncIOMotorClient() super().__init__(client=client, database=database) + def _get_collection_from_name( + self, collection_name: str + ) -> "AsyncIOMotorCollection": + return self.database[collection_name] + def get_collection(self, model: Type[ModelType]) -> "AsyncIOMotorCollection": """Get the motor collection associated to a Model. @@ -409,7 +379,7 @@ def get_collection(self, model: Type[ModelType]) -> "AsyncIOMotorCollection": Returns: the AsyncIO motor collection object """ - return self.database[model.__collection__] + return self._get_collection_from_name(model.__collection__) @staticmethod def _get_session( @@ -588,12 +558,13 @@ async def find_one( async def _save_collection_updates( self, - collection_updates: AIOCollectionUpdatesType, + collection_updates: CollectionUpdatesType, session: "AsyncIOMotorClientSession", ) -> None: # reverse so that the last collections added, the ones for sub-documents, are # saved first - for collection, updates in reversed(collection_updates.items()): + for collection_name, updates in reversed(collection_updates.items()): + collection = self._get_collection_from_name(collection_name) update_operations = [update[0] for update in updates] update_instances = [update[1] for update in updates] await collection.bulk_write( @@ -696,7 +667,7 @@ async def save_all( #noqa: DAR402 DuplicateKeyError --> """ - collections_updates: AIOCollectionUpdatesType = {} + collections_updates: CollectionUpdatesType = {} for instance in instances: self._prepare_document_updates( instance, collection_updates=collections_updates @@ -826,6 +797,9 @@ def __init__( client = MongoClient() super().__init__(client=client, database=database) + def _get_collection_from_name(self, collection_name: str) -> "Collection": + return self.database[collection_name] + def get_collection(self, model: Type[ModelType]) -> "Collection": """Get the pymongo collection associated to a Model. @@ -835,8 +809,7 @@ def get_collection(self, model: Type[ModelType]) -> "Collection": Returns: the pymongo collection object """ - collection = self.database[model.__collection__] - return collection + return self._get_collection_from_name(model.__collection__) @staticmethod def _get_session( @@ -1011,14 +984,15 @@ def find_one( def _save_collection_updates( self, - collection_updates: SyncCollectionUpdatesType, + collection_updates: CollectionUpdatesType, session: ClientSession, ) -> None: # reverse so that the last collections added, the ones for sub-documents, are # saved first - for collection, updates in reversed(collection_updates.items()): + for collection_name, updates in reversed(collection_updates.items()): update_operations = [update[0] for update in updates] update_instances = [update[1] for update in updates] + collection = self._get_collection_from_name(collection_name) collection.bulk_write( update_operations, session=session, @@ -1118,7 +1092,7 @@ def save_all( #noqa: DAR402 DuplicateKeyError --> """ - collections_updates: SyncCollectionUpdatesType = {} + collections_updates: CollectionUpdatesType = {} for instance in instances: self._prepare_document_updates( instance, collection_updates=collections_updates diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index c96c5bae..585d08e1 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -97,7 +97,9 @@ def f(): collection.update_one = AsyncMock() collection.aggregate = AsyncMock() collection.bulk_write = AsyncMock() - monkeypatch.setattr(aio_engine, "get_collection", lambda _: collection) + monkeypatch.setattr( + aio_engine, "_get_collection_from_name", lambda _: collection + ) return collection return f @@ -110,7 +112,9 @@ def f(): collection.update_one = Mock() collection.aggregate = Mock() collection.bulk_write = Mock() - monkeypatch.setattr(sync_engine, "get_collection", lambda _: collection) + monkeypatch.setattr( + sync_engine, "_get_collection_from_name", lambda _: collection + ) return collection return f From 437b382fca8144db09f7261298c5cc0eefa8be8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Wed, 28 Sep 2022 17:51:01 +0200 Subject: [PATCH 04/13] =?UTF-8?q?=F0=9F=90=9B=20Fix=20reverse=20of=20dict?= =?UTF-8?q?=20when=20creating=20collections?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- odmantic/engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/odmantic/engine.py b/odmantic/engine.py index 270edf63..c6a429e0 100644 --- a/odmantic/engine.py +++ b/odmantic/engine.py @@ -563,7 +563,7 @@ async def _save_collection_updates( ) -> None: # reverse so that the last collections added, the ones for sub-documents, are # saved first - for collection_name, updates in reversed(collection_updates.items()): + for collection_name, updates in reversed(list(collection_updates.items())): collection = self._get_collection_from_name(collection_name) update_operations = [update[0] for update in updates] update_instances = [update[1] for update in updates] @@ -989,7 +989,7 @@ def _save_collection_updates( ) -> None: # reverse so that the last collections added, the ones for sub-documents, are # saved first - for collection_name, updates in reversed(collection_updates.items()): + for collection_name, updates in reversed(list(collection_updates.items())): update_operations = [update[0] for update in updates] update_instances = [update[1] for update in updates] collection = self._get_collection_from_name(collection_name) From 7a233e056a15817ad43abe6b300518ad02ed97c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Wed, 28 Sep 2022 18:00:15 +0200 Subject: [PATCH 05/13] =?UTF-8?q?=E2=9C=85=20Remove=20unnecessary=20raise?= =?UTF-8?q?=20statements,=20not=20used=20nor=20tested?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- odmantic/engine.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/odmantic/engine.py b/odmantic/engine.py index c6a429e0..01f482de 100644 --- a/odmantic/engine.py +++ b/odmantic/engine.py @@ -293,11 +293,6 @@ def _prepare_find_pipeline( pipeline.extend(BaseEngine._cascade_find_pipeline(model)) return pipeline - def get_collection( - self, model: Type[ModelType] - ) -> Union["Collection", "AsyncIOMotorCollection"]: - raise NotImplementedError() - def _prepare_document_updates( self, instance: ModelType, @@ -583,8 +578,6 @@ async def _save( await self._save_collection_updates( collection_updates=collection_updates, session=session ) - except pymongo.errors.DuplicateKeyError as e: - raise DuplicateKeyError(instance, e) except pymongo.errors.BulkWriteError as e: if e.details["writeErrors"][0]["code"] == 11000: raise DuplicateKeyError(instance, e) @@ -1008,8 +1001,6 @@ def _save(self, instance: ModelType, session: "ClientSession") -> ModelType: self._save_collection_updates( collection_updates=collection_updates, session=session ) - except pymongo.errors.DuplicateKeyError as e: - raise DuplicateKeyError(instance, e) except pymongo.errors.BulkWriteError as e: if e.details["writeErrors"][0]["code"] == 11000: raise DuplicateKeyError(instance, e) From 14b84c38b6d91366f452242f2c913dc8f87c67bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Wed, 28 Sep 2022 18:13:54 +0200 Subject: [PATCH 06/13] =?UTF-8?q?=E2=9C=85=20Add=20pragma=20no=20cover=20t?= =?UTF-8?q?o=20raise=20for=20other=20general=20BulkWrite=20excetions=20not?= =?UTF-8?q?=20known=20yet?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- odmantic/engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/odmantic/engine.py b/odmantic/engine.py index 01f482de..f23da811 100644 --- a/odmantic/engine.py +++ b/odmantic/engine.py @@ -581,7 +581,7 @@ async def _save( except pymongo.errors.BulkWriteError as e: if e.details["writeErrors"][0]["code"] == 11000: raise DuplicateKeyError(instance, e) - raise + raise # pragma: no cover return instance async def save( @@ -1004,7 +1004,7 @@ def _save(self, instance: ModelType, session: "ClientSession") -> ModelType: except pymongo.errors.BulkWriteError as e: if e.details["writeErrors"][0]["code"] == 11000: raise DuplicateKeyError(instance, e) - raise + raise # pragma: no cover return instance def save( From 17037b86177d2acf582da18148636cef0c401cfa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Tue, 13 Dec 2022 18:27:34 +0400 Subject: [PATCH 07/13] =?UTF-8?q?=F0=9F=93=9D=20Update=20odmantic/engine.p?= =?UTF-8?q?y?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Arthur Pastel --- odmantic/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/odmantic/engine.py b/odmantic/engine.py index f23da811..e42d4d1e 100644 --- a/odmantic/engine.py +++ b/odmantic/engine.py @@ -299,7 +299,7 @@ def _prepare_document_updates( *, collection_updates: Union[CollectionUpdatesType, None] = None, ) -> CollectionUpdatesType: - """Perform an atomic save operation in the specified session""" + """Generate update operations""" if collection_updates is None: current_collection_updates = {} else: From 22009cb5ace1d0425f8126be3ecf1fe7404aa7ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Tue, 13 Dec 2022 18:38:04 +0400 Subject: [PATCH 08/13] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Update=20collection?= =?UTF-8?q?=5Fupdate=20with=20a=20dataclass=20instead=20of=20tuples=20for?= =?UTF-8?q?=20clarity?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- odmantic/engine.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/odmantic/engine.py b/odmantic/engine.py index e42d4d1e..fad516a1 100644 --- a/odmantic/engine.py +++ b/odmantic/engine.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from typing import ( Any, AsyncGenerator, @@ -60,7 +61,15 @@ AIOSessionType = Union[AsyncIOMotorClientSession, AIOSession, AIOTransaction, None] SyncSessionType = Union[ClientSession, SyncSession, SyncTransaction, None] -CollectionUpdatesType = Dict[str, List[Tuple[UpdateOne, Model]]] + + +@dataclass() +class ModelUpdateOne: + instance: Model + update_one: UpdateOne + + +CollectionUpdatesType = Dict[str, List[ModelUpdateOne]] class BaseCursor(Generic[ModelType]): @@ -315,13 +324,13 @@ def _prepare_document_updates( doc = instance.doc(include=fields_to_update) collection_name = type(instance).__collection__ current_collection_updates.setdefault(collection_name, []).append( - ( - UpdateOne( + ModelUpdateOne( + instance=instance, + update_one=UpdateOne( filter=instance.doc(include={instance.__primary_field__}), update={"$set": doc}, upsert=True, ), - instance, ) ) return current_collection_updates @@ -560,8 +569,8 @@ async def _save_collection_updates( # saved first for collection_name, updates in reversed(list(collection_updates.items())): collection = self._get_collection_from_name(collection_name) - update_operations = [update[0] for update in updates] - update_instances = [update[1] for update in updates] + update_operations = [update.update_one for update in updates] + update_instances = [update.instance for update in updates] await collection.bulk_write( update_operations, session=session, @@ -983,8 +992,8 @@ def _save_collection_updates( # reverse so that the last collections added, the ones for sub-documents, are # saved first for collection_name, updates in reversed(list(collection_updates.items())): - update_operations = [update[0] for update in updates] - update_instances = [update[1] for update in updates] + update_operations = [update.update_one for update in updates] + update_instances = [update.instance for update in updates] collection = self._get_collection_from_name(collection_name) collection.bulk_write( update_operations, From 52d6388deffba67f36078ef079679e6048a2d324 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Wed, 14 Dec 2022 18:03:20 +0400 Subject: [PATCH 09/13] =?UTF-8?q?=F0=9F=92=A1=20Add=20comment=20reference?= =?UTF-8?q?=20to=20DuplicateKey=20Error?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- odmantic/engine.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/odmantic/engine.py b/odmantic/engine.py index fad516a1..a48c9707 100644 --- a/odmantic/engine.py +++ b/odmantic/engine.py @@ -587,6 +587,9 @@ async def _save( await self._save_collection_updates( collection_updates=collection_updates, session=session ) + # Ref: + # https://github.com/mongodb/mongo/blob/master/src/mongo/base/error_codes.yml + # DuplicateKey Error except pymongo.errors.BulkWriteError as e: if e.details["writeErrors"][0]["code"] == 11000: raise DuplicateKeyError(instance, e) From 27a88e0eab460efc71d5f5ae2f9fbb23757b3c90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Wed, 14 Dec 2022 18:07:05 +0400 Subject: [PATCH 10/13] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Simplify=20type=20an?= =?UTF-8?q?notation=20in=20odmantic/exceptions.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Arthur Pastel --- odmantic/exceptions.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/odmantic/exceptions.py b/odmantic/exceptions.py index c4e144c2..7ed5834f 100644 --- a/odmantic/exceptions.py +++ b/odmantic/exceptions.py @@ -46,9 +46,7 @@ class DuplicateKeyError(BaseEngineException): def __init__( self, instance: "Model", - driver_error: Union[ - pymongo.errors.DuplicateKeyError, pymongo.errors.BulkWriteError - ], + driver_error: pymongo.errors.BulkWriteError, ): self.instance: "Model" = instance self.driver_error = driver_error From cffca45d98411824544d8ad2e3151297e87e4ff0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Wed, 14 Dec 2022 18:21:04 +0400 Subject: [PATCH 11/13] =?UTF-8?q?=F0=9F=91=B7=20Trigger=20CI?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit From c17348e12c24c6b9dffd64c355aec9bbfe460600 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Wed, 14 Dec 2022 18:36:48 +0400 Subject: [PATCH 12/13] =?UTF-8?q?=F0=9F=93=8C=20Pin=20tox,=20to=20see=20if?= =?UTF-8?q?=20that's=20what's=20breaking=20CI?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ef8ada74..9be0eef0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -52,7 +52,7 @@ jobs: .tox key: env-compatibility-${{ matrix.python-version }}-${{ hashFiles('pyproject.toml') }} - name: Install tox - run: pip install tox flit + run: pip install "tox~=3.27.1" flit - name: Run compatibility checks. run: | export VERSION_STR=$(echo ${{ matrix.python-version }} | sed -e "s/\.//g") From 867bfaff9517cd5d06cb55e8f741a7986df4fe1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Tue, 14 Nov 2023 20:38:00 +0100 Subject: [PATCH 13/13] =?UTF-8?q?=F0=9F=91=B7=20Trigger=20CI?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit