Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Use bulk_write for save operations, including referenced documents #271

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 138 additions & 53 deletions odmantic/engine.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
from typing import (
Any,
AsyncGenerator,
Expand All @@ -19,7 +20,8 @@
)

import pymongo
from pymongo import MongoClient
import pymongo.errors
from pymongo import MongoClient, UpdateOne
from pymongo.client_session import ClientSession
from pymongo.collection import Collection
from pymongo.command_cursor import CommandCursor
Expand Down Expand Up @@ -61,6 +63,15 @@
SyncSessionType = Union[ClientSession, SyncSession, SyncTransaction, None]


@dataclass()
class ModelUpdateOne:
instance: Model
update_one: UpdateOne


CollectionUpdatesType = Dict[str, List[ModelUpdateOne]]


class BaseCursor(Generic[ModelType]):
"""This object has to be built from the [odmantic.engine.AIOEngine.find][] method.

Expand Down Expand Up @@ -291,6 +302,39 @@ def _prepare_find_pipeline(
pipeline.extend(BaseEngine._cascade_find_pipeline(model))
return pipeline

def _prepare_document_updates(
self,
instance: ModelType,
*,
collection_updates: Union[CollectionUpdatesType, None] = None,
) -> CollectionUpdatesType:
"""Generate update operations"""
if collection_updates is None:
current_collection_updates = {}
else:
current_collection_updates = collection_updates
for ref_field_name in instance.__references__:
sub_instance = cast(Model, getattr(instance, ref_field_name))
self._prepare_document_updates(
sub_instance, collection_updates=current_collection_updates
)

fields_to_update = instance.__fields_modified__ | instance.__mutable_fields__
if len(fields_to_update) > 0:
doc = instance.doc(include=fields_to_update)
collection_name = type(instance).__collection__
current_collection_updates.setdefault(collection_name, []).append(
ModelUpdateOne(
instance=instance,
update_one=UpdateOne(
filter=instance.doc(include={instance.__primary_field__}),
update={"$set": doc},
upsert=True,
),
)
)
return current_collection_updates


class AIOEngine(BaseEngine):
"""The AIOEngine object is responsible for handling database operations with MongoDB
Expand Down Expand Up @@ -325,6 +369,11 @@ def __init__(
client = AsyncIOMotorClient()
super().__init__(client=client, database=database)

def _get_collection_from_name(
self, collection_name: str
) -> "AsyncIOMotorCollection":
return self.database[collection_name]

def get_collection(self, model: Type[ModelType]) -> "AsyncIOMotorCollection":
"""Get the motor collection associated to a Model.

Expand All @@ -334,7 +383,7 @@ def get_collection(self, model: Type[ModelType]) -> "AsyncIOMotorCollection":
Returns:
the AsyncIO motor collection object
"""
return self.database[model.__collection__]
return self._get_collection_from_name(model.__collection__)

@staticmethod
def _get_session(
Expand Down Expand Up @@ -511,28 +560,40 @@ async def find_one(
return None
return results[0]

async def _save_collection_updates(
self,
collection_updates: CollectionUpdatesType,
session: "AsyncIOMotorClientSession",
) -> None:
# reverse so that the last collections added, the ones for sub-documents, are
# saved first
for collection_name, updates in reversed(list(collection_updates.items())):
collection = self._get_collection_from_name(collection_name)
update_operations = [update.update_one for update in updates]
update_instances = [update.instance for update in updates]
await collection.bulk_write(
update_operations,
session=session,
)
for inst in update_instances:
object.__setattr__(inst, "__fields_modified__", set())

async def _save(
self, instance: ModelType, session: "AsyncIOMotorClientSession"
) -> ModelType:
"""Perform an atomic save operation in the specified session"""
for ref_field_name in instance.__references__:
sub_instance = cast(Model, getattr(instance, ref_field_name))
await self._save(sub_instance, session)

fields_to_update = instance.__fields_modified__ | instance.__mutable_fields__
if len(fields_to_update) > 0:
doc = instance.doc(include=fields_to_update)
collection = self.get_collection(type(instance))
try:
await collection.update_one(
instance.doc(include={instance.__primary_field__}),
{"$set": doc},
upsert=True,
session=session,
)
except pymongo.errors.DuplicateKeyError as e:
collection_updates = self._prepare_document_updates(instance)
try:
await self._save_collection_updates(
collection_updates=collection_updates, session=session
)
# Ref:
# https://github.com/mongodb/mongo/blob/master/src/mongo/base/error_codes.yml
# DuplicateKey Error
except pymongo.errors.BulkWriteError as e:
if e.details["writeErrors"][0]["code"] == 11000:
tiangolo marked this conversation as resolved.
Show resolved Hide resolved
raise DuplicateKeyError(instance, e)
object.__setattr__(instance, "__fields_modified__", set())
raise # pragma: no cover
art049 marked this conversation as resolved.
Show resolved Hide resolved
return instance

async def save(
Expand Down Expand Up @@ -611,17 +672,22 @@ async def save_all(
#noqa: DAR402 DuplicateKeyError
-->
"""
collections_updates: CollectionUpdatesType = {}
for instance in instances:
self._prepare_document_updates(
instance, collection_updates=collections_updates
)
if session:
added_instances = [
await self._save(instance, self._get_session(session))
for instance in instances
]
await self._save_collection_updates(
collection_updates=collections_updates,
session=self._get_session(session),
)
else:
async with await self.client.start_session() as local_session:
added_instances = [
await self._save(instance, local_session) for instance in instances
]
return added_instances
await self._save_collection_updates(
collection_updates=collections_updates, session=local_session
)
return list(instances)

async def delete(
self,
Expand Down Expand Up @@ -736,6 +802,9 @@ def __init__(
client = MongoClient()
super().__init__(client=client, database=database)

def _get_collection_from_name(self, collection_name: str) -> "Collection":
return self.database[collection_name]

def get_collection(self, model: Type[ModelType]) -> "Collection":
"""Get the pymongo collection associated to a Model.

Expand All @@ -745,8 +814,7 @@ def get_collection(self, model: Type[ModelType]) -> "Collection":
Returns:
the pymongo collection object
"""
collection = self.database[model.__collection__]
return collection
return self._get_collection_from_name(model.__collection__)

@staticmethod
def _get_session(
Expand Down Expand Up @@ -919,26 +987,36 @@ def find_one(
return None
return results[0]

def _save_collection_updates(
self,
collection_updates: CollectionUpdatesType,
session: ClientSession,
) -> None:
# reverse so that the last collections added, the ones for sub-documents, are
# saved first
for collection_name, updates in reversed(list(collection_updates.items())):
update_operations = [update.update_one for update in updates]
update_instances = [update.instance for update in updates]
collection = self._get_collection_from_name(collection_name)
collection.bulk_write(
update_operations,
session=session,
)
for inst in update_instances:
object.__setattr__(inst, "__fields_modified__", set())

def _save(self, instance: ModelType, session: "ClientSession") -> ModelType:
"""Perform an atomic save operation in the specified session"""
for ref_field_name in instance.__references__:
sub_instance = cast(Model, getattr(instance, ref_field_name))
self._save(sub_instance, session)
collection_updates = self._prepare_document_updates(instance)

fields_to_update = instance.__fields_modified__ | instance.__mutable_fields__
if len(fields_to_update) > 0:
doc = instance.doc(include=fields_to_update)
collection = self.get_collection(type(instance))
try:
collection.update_one(
instance.doc(include={instance.__primary_field__}),
{"$set": doc},
upsert=True,
session=session,
)
except pymongo.errors.DuplicateKeyError as e:
try:
self._save_collection_updates(
collection_updates=collection_updates, session=session
)
except pymongo.errors.BulkWriteError as e:
if e.details["writeErrors"][0]["code"] == 11000:
raise DuplicateKeyError(instance, e)
object.__setattr__(instance, "__fields_modified__", set())
raise # pragma: no cover
return instance

def save(
Expand Down Expand Up @@ -1017,17 +1095,24 @@ def save_all(
#noqa: DAR402 DuplicateKeyError
-->
"""
collections_updates: CollectionUpdatesType = {}
for instance in instances:
self._prepare_document_updates(
instance, collection_updates=collections_updates
)
if session:
added_instances = [
self._save(instance, self._get_session(session)) # type: ignore
for instance in instances
]
mongo_session = self._get_session(session)
assert mongo_session
self._save_collection_updates(
collection_updates=collections_updates,
session=mongo_session,
)
else:
with self.client.start_session() as local_session:
added_instances = [
self._save(instance, local_session) for instance in instances
]
return added_instances
self._save_collection_updates(
collection_updates=collections_updates, session=local_session
)
return list(instances)

def delete(
self,
Expand Down
5 changes: 4 additions & 1 deletion odmantic/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import TYPE_CHECKING, Any, List, Sequence, Type, TypeVar, Union

import pymongo
import pymongo.errors
from pydantic.error_wrappers import ErrorWrapper, ValidationError

if TYPE_CHECKING:
Expand Down Expand Up @@ -43,7 +44,9 @@ class DuplicateKeyError(BaseEngineException):
"""

def __init__(
self, instance: "Model", driver_error: pymongo.errors.DuplicateKeyError
self,
instance: "Model",
driver_error: pymongo.errors.BulkWriteError,
):
self.instance: "Model" = instance
self.driver_error = driver_error
Expand Down
10 changes: 8 additions & 2 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ def f():
collection = Mock()
collection.update_one = AsyncMock()
collection.aggregate = AsyncMock()
monkeypatch.setattr(aio_engine, "get_collection", lambda _: collection)
collection.bulk_write = AsyncMock()
monkeypatch.setattr(
aio_engine, "_get_collection_from_name", lambda _: collection
)
return collection

return f
Expand All @@ -108,7 +111,10 @@ def f():
collection = Mock()
collection.update_one = Mock()
collection.aggregate = Mock()
monkeypatch.setattr(sync_engine, "get_collection", lambda _: collection)
collection.bulk_write = Mock()
monkeypatch.setattr(
sync_engine, "_get_collection_from_name", lambda _: collection
)
return collection

return f
Loading
Loading