-
-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add bulk_get_or_create
- Loading branch information
Showing
5 changed files
with
285 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -866,6 +866,28 @@ for user in users: | |
await User.query.bulk_update(users, fields=['is_active']) | ||
``` | ||
|
||
### Bulk Get or Create | ||
|
||
When you need to perform in bulk a `get_or_create` in your models. The normal behavior would | ||
be like the `bulk_create` but this bring an additional `unique_fields` where we can make sure | ||
we do not insert duplicates by filtering the unique keys of the model data being inserted. | ||
|
||
```python | ||
await User.query.bulk_get_or_create([ | ||
{"email": "[email protected]", "first_name": "Foo", "last_name": "Bar", "is_active": True}, | ||
{"email": "[email protected]", "first_name": "Bar", "last_name": "Foo", "is_active": True}, | ||
], unique_fields=["email"]) | ||
|
||
# Try to reinsert the same values | ||
await User.query.bulk_get_or_create([ | ||
{"email": "[email protected]", "first_name": "Foo", "last_name": "Bar", "is_active": True}, | ||
{"email": "[email protected]", "first_name": "Bar", "last_name": "Foo", "is_active": True}, | ||
], unique_fields=["email"]) | ||
|
||
|
||
users = await User.query.all() # 2 as total | ||
``` | ||
|
||
## Operators | ||
|
||
There are sometimes the need of adding some extra conditions like `AND`, or `OR` or even the `NOT` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
import decimal | ||
from datetime import date, datetime | ||
from enum import Enum | ||
from typing import Any | ||
from uuid import UUID | ||
|
||
import pytest | ||
|
||
import edgy | ||
from edgy.core.db import fields | ||
from edgy.testclient import DatabaseTestClient | ||
from tests.settings import DATABASE_URL | ||
|
||
pytestmark = pytest.mark.anyio | ||
|
||
database = DatabaseTestClient(DATABASE_URL) | ||
models = edgy.Registry(database=edgy.Database(database, force_rollback=True)) | ||
|
||
|
||
def time(): | ||
return datetime.now().time() | ||
|
||
|
||
class StatusEnum(Enum): | ||
DRAFT = "Draft" | ||
RELEASED = "Released" | ||
|
||
|
||
class Product(edgy.StrictModel): | ||
id: int = fields.IntegerField(primary_key=True, autoincrement=True) | ||
uuid: UUID = fields.UUIDField(null=True) | ||
created: datetime = fields.DateTimeField(default=datetime.now) | ||
created_day: datetime = fields.DateField(default=date.today) | ||
created_time: datetime = fields.TimeField(default=time) | ||
created_date: datetime = fields.DateField(auto_now_add=True) | ||
created_datetime: datetime = fields.DateTimeField(auto_now_add=True) | ||
updated_datetime: datetime = fields.DateTimeField(auto_now=True) | ||
updated_date: datetime = fields.DateField(auto_now=True) | ||
data: dict[Any, Any] = fields.JSONField(default=dict) | ||
description: str = fields.CharField(null=True, max_length=255) | ||
huge_number: int = fields.BigIntegerField(default=0) | ||
price: decimal.Decimal = fields.DecimalField(max_digits=9, decimal_places=2, null=True) | ||
status: str = fields.ChoiceField(StatusEnum, default=StatusEnum.DRAFT) | ||
value: float = fields.FloatField(null=True) | ||
|
||
class Meta: | ||
registry = models | ||
|
||
|
||
@pytest.fixture(autouse=True, scope="module") | ||
async def create_test_database(): | ||
async with database: | ||
await models.create_all() | ||
yield | ||
if not database.drop: | ||
await models.drop_all() | ||
|
||
|
||
@pytest.fixture(autouse=True, scope="function") | ||
async def rollback_transactions(): | ||
async with models.database: | ||
yield | ||
|
||
|
||
async def test_bulk_bulk_get_or_create(): | ||
await Product.query.bulk_get_or_create( | ||
[ | ||
{"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED}, | ||
{"data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT}, | ||
] | ||
) | ||
products = await Product.query.all() | ||
assert len(products) == 2 | ||
assert products[0].data == {"foo": 123} | ||
assert products[0].value == 123.456 | ||
assert products[0].status == StatusEnum.RELEASED | ||
assert products[1].data == {"foo": 456} | ||
assert products[1].value == 456.789 | ||
assert products[1].status == StatusEnum.DRAFT | ||
|
||
|
||
async def test_bulk_get_or_create_no_duplicates(): | ||
await Product.query.bulk_get_or_create( | ||
[ | ||
{"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED}, | ||
{"data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT}, | ||
] | ||
) | ||
|
||
await Product.query.bulk_get_or_create( | ||
[ | ||
{"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED}, | ||
{"data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT}, | ||
], | ||
unique_fields=["value", "status"], | ||
) | ||
|
||
products = await Product.query.all() | ||
assert len(products) == 2 | ||
assert products[0].data == {"foo": 123} | ||
assert products[0].value == 123.456 | ||
assert products[0].status == StatusEnum.RELEASED | ||
assert products[1].data == {"foo": 456} | ||
assert products[1].value == 456.789 | ||
assert products[1].status == StatusEnum.DRAFT | ||
|
||
|
||
async def test_bulk_get_or_create_no_duplicates_filter_by_dict(): | ||
await Product.query.bulk_get_or_create( | ||
[ | ||
{"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED}, | ||
{"data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT}, | ||
] | ||
) | ||
|
||
await Product.query.bulk_get_or_create( | ||
[ | ||
{"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED}, | ||
{"data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT}, | ||
], | ||
unique_fields=["data"], | ||
) | ||
|
||
products = await Product.query.all() | ||
assert len(products) == 2 | ||
assert products[0].data == {"foo": 123} | ||
assert products[0].value == 123.456 | ||
assert products[0].status == StatusEnum.RELEASED | ||
assert products[1].data == {"foo": 456} | ||
assert products[1].value == 456.789 | ||
assert products[1].status == StatusEnum.DRAFT |