From b95912aa177883bf0b63f7d47844ff7ced77af93 Mon Sep 17 00:00:00 2001 From: Tiago Silva Date: Fri, 28 Feb 2025 13:00:03 +0000 Subject: [PATCH] Add bulk_get_or_create (#296) * Add bulk_get_or_create --- docs/queries/queries.md | 22 ++++ docs/release-notes.md | 4 + edgy/__init__.py | 2 +- edgy/core/db/querysets/base.py | 127 +++++++++++++++++++++++ tests/models/test_bulk_get_or_create.py | 131 ++++++++++++++++++++++++ 5 files changed, 285 insertions(+), 1 deletion(-) create mode 100644 tests/models/test_bulk_get_or_create.py diff --git a/docs/queries/queries.md b/docs/queries/queries.md index 3ad2b3a3..4d641786 100644 --- a/docs/queries/queries.md +++ b/docs/queries/queries.md @@ -866,6 +866,28 @@ for user in users: await User.query.bulk_update(users, fields=['is_active']) ``` +### Bulk Get or Create + +When you need to perform in bulk a `get_or_create` in your models. The normal behavior would +be like the `bulk_create` but this bring an additional `unique_fields` where we can make sure +we do not insert duplicates by filtering the unique keys of the model data being inserted. + +```python +await User.query.bulk_get_or_create([ + {"email": "foo@bar.com", "first_name": "Foo", "last_name": "Bar", "is_active": True}, + {"email": "bar@foo.com", "first_name": "Bar", "last_name": "Foo", "is_active": True}, +], unique_fields=["email"]) + +# Try to reinsert the same values +await User.query.bulk_get_or_create([ + {"email": "foo@bar.com", "first_name": "Foo", "last_name": "Bar", "is_active": True}, + {"email": "bar@foo.com", "first_name": "Bar", "last_name": "Foo", "is_active": True}, +], unique_fields=["email"]) + + +users = await User.query.all() # 2 as total +``` + ## Operators There are sometimes the need of adding some extra conditions like `AND`, or `OR` or even the `NOT` diff --git a/docs/release-notes.md b/docs/release-notes.md index ff0d9ad3..edbc184a 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -8,6 +8,10 @@ hide: ## 0.27.3 +### Added + +- [bulk_get_or_create](./queries/queries.md#bulk-get-or-create) to queryset allowing to bulk inserting or getting existing objects. + ### Fixed - BooleanField typing. Thanks to Izcarmt95. diff --git a/edgy/__init__.py b/edgy/__init__.py index 439981d5..51639326 100644 --- a/edgy/__init__.py +++ b/edgy/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations -__version__ = "0.27.2" +__version__ = "0.27.3" from typing import TYPE_CHECKING from ._monkay import Instance, create_monkay diff --git a/edgy/core/db/querysets/base.py b/edgy/core/db/querysets/base.py index 048874a6..4e477015 100644 --- a/edgy/core/db/querysets/base.py +++ b/edgy/core/db/querysets/base.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import warnings from collections.abc import AsyncIterator, Awaitable, Generator, Iterable, Sequence from functools import cached_property @@ -1588,6 +1589,132 @@ async def bulk_update(self, objs: list[EdgyModel], fields: list[str]) -> None: finally: CURRENT_INSTANCE.reset(token) + import json + + async def bulk_get_or_create( + self, + objs: list[Union[dict[str, Any], EdgyModel]], + unique_fields: Union[list[str], None] = None, + ) -> list[EdgyModel]: + """ + Bulk gets or creates records in a table. + + If records exist based on unique fields, they are retrieved. + Otherwise, new records are created. + + Args: + objs (list[Union[dict[str, Any], EdgyModel]]): A list of objects or dictionaries. + unique_fields (list[str] | None): Fields that determine uniqueness. If None, all records are treated as new. + + Returns: + list[EdgyModel]: A list of retrieved or newly created objects. + """ + queryset: QuerySet = self._clone() + new_objs: list[EdgyModel] = [] + retrieved_objs: list[EdgyModel] = [] + + if unique_fields: + existing_records = {} + for obj in objs: + if isinstance(obj, dict): + filter_kwargs = {} + dict_fields = {} + for field, value in obj.items(): + if field in unique_fields: + if isinstance(value, dict): + dict_fields[field] = value + else: + filter_kwargs[field] = value + db_records = [record async for record in queryset.filter(**filter_kwargs)] + found = False + for record in db_records: + record_dict_fields = {k: getattr(record, k) for k in dict_fields} + if dict_fields == record_dict_fields: + lookup_key = tuple( + json.dumps(getattr(record, field)) + if isinstance(getattr(record, field), dict) + else getattr(record, field) + for field in unique_fields + ) + existing_records[lookup_key] = record + retrieved_objs.append(record) + found = True + break + if found is False: + new_objs.append(queryset.model_class(**obj)) + else: + filter_kwargs = {} + dict_fields = {} + for field in unique_fields: + value = getattr(obj, field) + if isinstance(value, dict): + dict_fields[field] = value + else: + filter_kwargs[field] = value + db_records = [record async for record in queryset.filter(**filter_kwargs)] + found = False + for record in db_records: + record_dict_fields = { + k: getattr(record, k) for k, _ in dict_fields.items() + } + if dict_fields == record_dict_fields: + lookup_key = tuple( + json.dumps(getattr(record, field)) + if isinstance(getattr(record, field), dict) + else getattr(record, field) + for field in unique_fields + ) + existing_records[lookup_key] = record + retrieved_objs.append(record) + found = True + break + if found is False: + new_objs.append(obj) + + else: + new_objs.extend( + [queryset.model_class(**obj) if isinstance(obj, dict) else obj for obj in objs] + ) + existing_records = {} + + async def _prepare_obj(obj_or_dict: Union[EdgyModel, dict[str, Any]]) -> EdgyModel: + if isinstance(obj_or_dict, dict): + obj: EdgyModel = queryset.model_class(**obj_or_dict) + else: + obj = obj_or_dict + return obj + + async def _iterate(obj: EdgyModel) -> dict[str, Any]: + original = obj.extract_db_fields() + col_values: dict[str, Any] = obj.extract_column_values( + original, phase="prepare_insert", instance=self + ) + col_values.update( + await obj.execute_pre_save_hooks(col_values, original, force_insert=True) + ) + return col_values + + check_db_connection(queryset.database) + token = CURRENT_INSTANCE.set(self) + + try: + async with queryset.database as database, database.transaction(): + if new_objs: + new_obj_values = [await _iterate(obj) for obj in new_objs] + expression = queryset.table.insert().values(new_obj_values) + await database.execute_many(expression) + retrieved_objs.extend(new_objs) + + self._clear_cache() + for obj in new_objs: + await obj.execute_post_save_hooks( + self.model_class.meta.fields.keys(), force_insert=True + ) + finally: + CURRENT_INSTANCE.reset(token) + + return retrieved_objs + async def delete(self, use_models: bool = False) -> int: if ( self.model_class.__require_model_based_deletion__ diff --git a/tests/models/test_bulk_get_or_create.py b/tests/models/test_bulk_get_or_create.py new file mode 100644 index 00000000..cddcbd18 --- /dev/null +++ b/tests/models/test_bulk_get_or_create.py @@ -0,0 +1,131 @@ +import decimal +from datetime import date, datetime +from enum import Enum +from typing import Any +from uuid import UUID + +import pytest + +import edgy +from edgy.core.db import fields +from edgy.testclient import DatabaseTestClient +from tests.settings import DATABASE_URL + +pytestmark = pytest.mark.anyio + +database = DatabaseTestClient(DATABASE_URL) +models = edgy.Registry(database=edgy.Database(database, force_rollback=True)) + + +def time(): + return datetime.now().time() + + +class StatusEnum(Enum): + DRAFT = "Draft" + RELEASED = "Released" + + +class Product(edgy.StrictModel): + id: int = fields.IntegerField(primary_key=True, autoincrement=True) + uuid: UUID = fields.UUIDField(null=True) + created: datetime = fields.DateTimeField(default=datetime.now) + created_day: datetime = fields.DateField(default=date.today) + created_time: datetime = fields.TimeField(default=time) + created_date: datetime = fields.DateField(auto_now_add=True) + created_datetime: datetime = fields.DateTimeField(auto_now_add=True) + updated_datetime: datetime = fields.DateTimeField(auto_now=True) + updated_date: datetime = fields.DateField(auto_now=True) + data: dict[Any, Any] = fields.JSONField(default=dict) + description: str = fields.CharField(null=True, max_length=255) + huge_number: int = fields.BigIntegerField(default=0) + price: decimal.Decimal = fields.DecimalField(max_digits=9, decimal_places=2, null=True) + status: str = fields.ChoiceField(StatusEnum, default=StatusEnum.DRAFT) + value: float = fields.FloatField(null=True) + + class Meta: + registry = models + + +@pytest.fixture(autouse=True, scope="module") +async def create_test_database(): + async with database: + await models.create_all() + yield + if not database.drop: + await models.drop_all() + + +@pytest.fixture(autouse=True, scope="function") +async def rollback_transactions(): + async with models.database: + yield + + +async def test_bulk_bulk_get_or_create(): + await Product.query.bulk_get_or_create( + [ + {"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED}, + {"data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT}, + ] + ) + products = await Product.query.all() + assert len(products) == 2 + assert products[0].data == {"foo": 123} + assert products[0].value == 123.456 + assert products[0].status == StatusEnum.RELEASED + assert products[1].data == {"foo": 456} + assert products[1].value == 456.789 + assert products[1].status == StatusEnum.DRAFT + + +async def test_bulk_get_or_create_no_duplicates(): + await Product.query.bulk_get_or_create( + [ + {"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED}, + {"data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT}, + ] + ) + + await Product.query.bulk_get_or_create( + [ + {"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED}, + {"data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT}, + ], + unique_fields=["value", "status"], + ) + + products = await Product.query.all() + assert len(products) == 2 + assert products[0].data == {"foo": 123} + assert products[0].value == 123.456 + assert products[0].status == StatusEnum.RELEASED + assert products[1].data == {"foo": 456} + assert products[1].value == 456.789 + assert products[1].status == StatusEnum.DRAFT + + +async def test_bulk_get_or_create_no_duplicates_filter_by_dict(): + await Product.query.bulk_get_or_create( + [ + {"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED}, + {"data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT}, + ] + ) + + await Product.query.bulk_get_or_create( + [ + {"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED}, + {"data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT}, + ], + unique_fields=["data"], + ) + + products = await Product.query.all() + assert len(products) == 2 + assert products[0].data == {"foo": 123} + assert products[0].value == 123.456 + assert products[0].status == StatusEnum.RELEASED + assert products[1].data == {"foo": 456} + assert products[1].value == 456.789 + assert products[1].status == StatusEnum.DRAFT