Skip to content

Commit

Permalink
Add bulk_get_or_create (#296)
Browse files Browse the repository at this point in the history
* Add bulk_get_or_create
  • Loading branch information
tarsil authored Feb 28, 2025
1 parent 4767ea9 commit b95912a
Show file tree
Hide file tree
Showing 5 changed files with 285 additions and 1 deletion.
22 changes: 22 additions & 0 deletions docs/queries/queries.md
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,28 @@ for user in users:
await User.query.bulk_update(users, fields=['is_active'])
```

### Bulk Get or Create

When you need to perform in bulk a `get_or_create` in your models. The normal behavior would
be like the `bulk_create` but this bring an additional `unique_fields` where we can make sure
we do not insert duplicates by filtering the unique keys of the model data being inserted.

```python
await User.query.bulk_get_or_create([
{"email": "[email protected]", "first_name": "Foo", "last_name": "Bar", "is_active": True},
{"email": "[email protected]", "first_name": "Bar", "last_name": "Foo", "is_active": True},
], unique_fields=["email"])

# Try to reinsert the same values
await User.query.bulk_get_or_create([
{"email": "[email protected]", "first_name": "Foo", "last_name": "Bar", "is_active": True},
{"email": "[email protected]", "first_name": "Bar", "last_name": "Foo", "is_active": True},
], unique_fields=["email"])


users = await User.query.all() # 2 as total
```

## Operators

There are sometimes the need of adding some extra conditions like `AND`, or `OR` or even the `NOT`
Expand Down
4 changes: 4 additions & 0 deletions docs/release-notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ hide:

## 0.27.3

### Added

- [bulk_get_or_create](./queries/queries.md#bulk-get-or-create) to queryset allowing to bulk inserting or getting existing objects.

### Fixed

- BooleanField typing. Thanks to Izcarmt95.
Expand Down
2 changes: 1 addition & 1 deletion edgy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

__version__ = "0.27.2"
__version__ = "0.27.3"
from typing import TYPE_CHECKING

from ._monkay import Instance, create_monkay
Expand Down
127 changes: 127 additions & 0 deletions edgy/core/db/querysets/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
import warnings
from collections.abc import AsyncIterator, Awaitable, Generator, Iterable, Sequence
from functools import cached_property
Expand Down Expand Up @@ -1588,6 +1589,132 @@ async def bulk_update(self, objs: list[EdgyModel], fields: list[str]) -> None:
finally:
CURRENT_INSTANCE.reset(token)

import json

async def bulk_get_or_create(
self,
objs: list[Union[dict[str, Any], EdgyModel]],
unique_fields: Union[list[str], None] = None,
) -> list[EdgyModel]:
"""
Bulk gets or creates records in a table.
If records exist based on unique fields, they are retrieved.
Otherwise, new records are created.
Args:
objs (list[Union[dict[str, Any], EdgyModel]]): A list of objects or dictionaries.
unique_fields (list[str] | None): Fields that determine uniqueness. If None, all records are treated as new.
Returns:
list[EdgyModel]: A list of retrieved or newly created objects.
"""
queryset: QuerySet = self._clone()
new_objs: list[EdgyModel] = []
retrieved_objs: list[EdgyModel] = []

if unique_fields:
existing_records = {}
for obj in objs:
if isinstance(obj, dict):
filter_kwargs = {}
dict_fields = {}
for field, value in obj.items():
if field in unique_fields:
if isinstance(value, dict):
dict_fields[field] = value
else:
filter_kwargs[field] = value
db_records = [record async for record in queryset.filter(**filter_kwargs)]
found = False
for record in db_records:
record_dict_fields = {k: getattr(record, k) for k in dict_fields}
if dict_fields == record_dict_fields:
lookup_key = tuple(
json.dumps(getattr(record, field))
if isinstance(getattr(record, field), dict)
else getattr(record, field)
for field in unique_fields
)
existing_records[lookup_key] = record
retrieved_objs.append(record)
found = True
break
if found is False:
new_objs.append(queryset.model_class(**obj))
else:
filter_kwargs = {}
dict_fields = {}
for field in unique_fields:
value = getattr(obj, field)
if isinstance(value, dict):
dict_fields[field] = value
else:
filter_kwargs[field] = value
db_records = [record async for record in queryset.filter(**filter_kwargs)]
found = False
for record in db_records:
record_dict_fields = {
k: getattr(record, k) for k, _ in dict_fields.items()
}
if dict_fields == record_dict_fields:
lookup_key = tuple(
json.dumps(getattr(record, field))
if isinstance(getattr(record, field), dict)
else getattr(record, field)
for field in unique_fields
)
existing_records[lookup_key] = record
retrieved_objs.append(record)
found = True
break
if found is False:
new_objs.append(obj)

else:
new_objs.extend(
[queryset.model_class(**obj) if isinstance(obj, dict) else obj for obj in objs]
)
existing_records = {}

async def _prepare_obj(obj_or_dict: Union[EdgyModel, dict[str, Any]]) -> EdgyModel:
if isinstance(obj_or_dict, dict):
obj: EdgyModel = queryset.model_class(**obj_or_dict)
else:
obj = obj_or_dict
return obj

async def _iterate(obj: EdgyModel) -> dict[str, Any]:
original = obj.extract_db_fields()
col_values: dict[str, Any] = obj.extract_column_values(
original, phase="prepare_insert", instance=self
)
col_values.update(
await obj.execute_pre_save_hooks(col_values, original, force_insert=True)
)
return col_values

check_db_connection(queryset.database)
token = CURRENT_INSTANCE.set(self)

try:
async with queryset.database as database, database.transaction():
if new_objs:
new_obj_values = [await _iterate(obj) for obj in new_objs]
expression = queryset.table.insert().values(new_obj_values)
await database.execute_many(expression)
retrieved_objs.extend(new_objs)

self._clear_cache()
for obj in new_objs:
await obj.execute_post_save_hooks(
self.model_class.meta.fields.keys(), force_insert=True
)
finally:
CURRENT_INSTANCE.reset(token)

return retrieved_objs

async def delete(self, use_models: bool = False) -> int:
if (
self.model_class.__require_model_based_deletion__
Expand Down
131 changes: 131 additions & 0 deletions tests/models/test_bulk_get_or_create.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import decimal
from datetime import date, datetime
from enum import Enum
from typing import Any
from uuid import UUID

import pytest

import edgy
from edgy.core.db import fields
from edgy.testclient import DatabaseTestClient
from tests.settings import DATABASE_URL

pytestmark = pytest.mark.anyio

database = DatabaseTestClient(DATABASE_URL)
models = edgy.Registry(database=edgy.Database(database, force_rollback=True))


def time():
return datetime.now().time()


class StatusEnum(Enum):
DRAFT = "Draft"
RELEASED = "Released"


class Product(edgy.StrictModel):
id: int = fields.IntegerField(primary_key=True, autoincrement=True)
uuid: UUID = fields.UUIDField(null=True)
created: datetime = fields.DateTimeField(default=datetime.now)
created_day: datetime = fields.DateField(default=date.today)
created_time: datetime = fields.TimeField(default=time)
created_date: datetime = fields.DateField(auto_now_add=True)
created_datetime: datetime = fields.DateTimeField(auto_now_add=True)
updated_datetime: datetime = fields.DateTimeField(auto_now=True)
updated_date: datetime = fields.DateField(auto_now=True)
data: dict[Any, Any] = fields.JSONField(default=dict)
description: str = fields.CharField(null=True, max_length=255)
huge_number: int = fields.BigIntegerField(default=0)
price: decimal.Decimal = fields.DecimalField(max_digits=9, decimal_places=2, null=True)
status: str = fields.ChoiceField(StatusEnum, default=StatusEnum.DRAFT)
value: float = fields.FloatField(null=True)

class Meta:
registry = models


@pytest.fixture(autouse=True, scope="module")
async def create_test_database():
async with database:
await models.create_all()
yield
if not database.drop:
await models.drop_all()


@pytest.fixture(autouse=True, scope="function")
async def rollback_transactions():
async with models.database:
yield


async def test_bulk_bulk_get_or_create():
await Product.query.bulk_get_or_create(
[
{"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED},
{"data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT},
]
)
products = await Product.query.all()
assert len(products) == 2
assert products[0].data == {"foo": 123}
assert products[0].value == 123.456
assert products[0].status == StatusEnum.RELEASED
assert products[1].data == {"foo": 456}
assert products[1].value == 456.789
assert products[1].status == StatusEnum.DRAFT


async def test_bulk_get_or_create_no_duplicates():
await Product.query.bulk_get_or_create(
[
{"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED},
{"data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT},
]
)

await Product.query.bulk_get_or_create(
[
{"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED},
{"data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT},
],
unique_fields=["value", "status"],
)

products = await Product.query.all()
assert len(products) == 2
assert products[0].data == {"foo": 123}
assert products[0].value == 123.456
assert products[0].status == StatusEnum.RELEASED
assert products[1].data == {"foo": 456}
assert products[1].value == 456.789
assert products[1].status == StatusEnum.DRAFT


async def test_bulk_get_or_create_no_duplicates_filter_by_dict():
await Product.query.bulk_get_or_create(
[
{"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED},
{"data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT},
]
)

await Product.query.bulk_get_or_create(
[
{"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED},
{"data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT},
],
unique_fields=["data"],
)

products = await Product.query.all()
assert len(products) == 2
assert products[0].data == {"foo": 123}
assert products[0].value == 123.456
assert products[0].status == StatusEnum.RELEASED
assert products[1].data == {"foo": 456}
assert products[1].value == 456.789
assert products[1].status == StatusEnum.DRAFT

0 comments on commit b95912a

Please sign in to comment.