Skip to content

Commit

Permalink
➕ bulk_update, tests and raw_sql
Browse files Browse the repository at this point in the history
* The bulk_update is an adaptation of
  encode/orm#148
  • Loading branch information
tarsil committed Feb 16, 2023
1 parent 640fb03 commit 494ee68
Show file tree
Hide file tree
Showing 7 changed files with 308 additions and 22 deletions.
14 changes: 14 additions & 0 deletions saffier/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import typing
from enum import Enum
from inspect import isclass

from orjson import OPT_OMIT_MICROSECONDS # noqa
from orjson import OPT_SERIALIZE_NUMPY # noqa
from orjson import dumps
from typing_extensions import get_origin

from saffier.fields import DateField, DateTimeField
Expand All @@ -21,6 +25,16 @@ def _update_auto_now_fields(self, values: DictAny, fields: DictAny) -> DictAny:
values[k] = v.validator.get_default_value()
return values

def _resolve_value(self, value: typing.Any):
if isinstance(value, dict):
return dumps(
value,
option=OPT_SERIALIZE_NUMPY | OPT_OMIT_MICROSECONDS,
).decode("utf-8")
elif isinstance(value, Enum):
return value.name
return value


def is_class_and_subclass(value: typing.Any, _type: typing.Any) -> bool:
original = get_origin(value)
Expand Down
1 change: 0 additions & 1 deletion saffier/db/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def validate(self, value: typing.Any) -> typing.Any:
elif not isinstance(value, str):
raise self.validation_error("type")

# The null character is always invalid.
value = value.replace("\0", "")

if self.trim_whitespace:
Expand Down
82 changes: 78 additions & 4 deletions saffier/db/queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from saffier.types import DictAny

if typing.TYPE_CHECKING: # pragma: no cover
from saffier.db.connection import Database
from saffier.models import Model


Expand All @@ -25,7 +26,7 @@ class QuerySetProps:
"""

@property
def database(self):
def database(self) -> "Database":
return self.model_class._meta.registry.database

@property
Expand Down Expand Up @@ -117,6 +118,7 @@ def _build_select(self):
if self.distinct_on:
expression = self._build_select_distinct(self.distinct_on, expression=expression)

setattr(self, "_expression", expression)
return expression

def _filter_query(self, exclude: bool = False, **kwargs):
Expand Down Expand Up @@ -232,6 +234,7 @@ def _clone(self) -> "QuerySet[SaffierModel]":
queryset._order_by = copy.copy(self._order_by)
queryset._group_by = copy.copy(self._group_by)
queryset.distinct_on = copy.copy(self.distinct_on)
queryset._expression = self._expression
return queryset


Expand Down Expand Up @@ -262,14 +265,30 @@ def __init__(
self._order_by = [] if order_by is None else order_by
self._group_by = [] if group_by is None else group_by
self.distinct_on = [] if distinct_on is None else distinct_on
self._expression = None

def __get__(self, instance, owner):
return self.__class__(model_class=owner)

@property
def sql(self):
return str(self._expression)

@sql.setter
def sql(self, value):
setattr(self, "_expression", value)

async def __aiter__(self) -> typing.AsyncIterator[SaffierModel]:
for value in await self:
yield value

def _set_query_expression(self, expression: typing.Any) -> None:
"""
Sets the value of the sql property to the expression used.
"""
self.sql = expression
self.model_class.raw_query = self.sql

def _filter_or_exclude(
self,
clause: typing.Optional[sqlalchemy.sql.expression.BinaryExpression] = None,
Expand Down Expand Up @@ -389,6 +408,7 @@ async def exists(self) -> bool:
"""
expression = self._build_select()
expression = sqlalchemy.exists(expression).select()
self._set_query_expression(expression)
return await self.database.fetch_val(expression)

async def count(self) -> int:
Expand All @@ -397,6 +417,7 @@ async def count(self) -> int:
"""
expression = self._build_select().alias("subquery_for_count")
expression = sqlalchemy.func.count().select().select_from(expression)
self._set_query_expression(expression)
return await self.database.fetch_val(expression)

async def get_or_none(self, **kwargs):
Expand All @@ -405,6 +426,7 @@ async def get_or_none(self, **kwargs):
"""
queryset = self.filter(**kwargs)
expression = queryset._build_select().limit(2)
self._set_query_expression(expression)
rows = await self.database.fetch_all(expression)

if not rows:
Expand All @@ -422,7 +444,12 @@ async def all(self, **kwargs):
return await queryset.filter(**kwargs).all()

expression = queryset._build_select()
self._set_query_expression(expression)

rows = await queryset.database.fetch_all(expression)

# Attach the raw query to the object
queryset.model_class.raw_query = self.sql
return [
queryset.model_class._from_row(row, select_related=self._select_related)
for row in rows
Expand All @@ -437,6 +464,7 @@ async def get(self, **kwargs):

expression = self._build_select().limit(2)
rows = await self.database.fetch_all(expression)
self._set_query_expression(expression)

if not rows:
raise DoesNotFound()
Expand Down Expand Up @@ -475,6 +503,7 @@ async def create(self, **kwargs):
kwargs = self._validate_kwargs(**kwargs)
instance = self.model_class(**kwargs)
expression = self.table.insert().values(**kwargs)
self._set_query_expression(expression)

if self.pkname not in kwargs:
instance.pk = await self.database.execute(expression)
Expand All @@ -490,13 +519,57 @@ async def bulk_create(self, objs: typing.List[typing.Dict]) -> None:
new_objs = [self._validate_kwargs(**obj) for obj in objs]

expression = self.table.insert().values(new_objs)
self._set_query_expression(expression)
await self.database.execute(expression)

async def bulk_update(self, objs: typing.List[SaffierModel], fields: typing.List[str]) -> None:
"""
Bulk updates records in a table.
A similar solution was suggested here: https://github.com/encode/orm/pull/148
It is thought to be a clean approach to a simple problem so it was added here and
refactored to be compatible with Saffier.
"""
new_fields = {}
for key, field in self.model_class.fields.items():
if key in fields:
new_fields[key] = field.validator

validator = Schema(fields=new_fields)

new_objs = []
for obj in objs:
new_obj = {}
for key, value in obj.__dict__.items():
if key in fields:
new_obj[key] = self._resolve_value(value)
new_objs.append(new_obj)

new_objs = [
self._update_auto_now_fields(validator.validate(obj), self.model_class.fields)
for obj in new_objs
]

pk = getattr(self.table.c, self.pkname)
expression = self.table.update().where(pk == sqlalchemy.bindparam(self.pkname))
kwargs = {field: sqlalchemy.bindparam(field) for obj in new_objs for field in obj.keys()}
pks = [{self.pkname: getattr(obj, self.pkname)} for obj in objs]

query_list = []
for pk, value in zip(pks, new_objs):
query_list.append({**pk, **value})

expression = expression.values(kwargs)
self._set_query_expression(expression)
await self.database.execute_many(str(expression), query_list)

async def delete(self) -> None:
expression = self.table.delete()
for filter_clause in self.filter_clauses:
expression = expression.where(filter_clause)

self._set_query_expression(expression)
await self.database.execute(expression)

async def update(self, **kwargs) -> None:
Expand All @@ -509,12 +582,13 @@ async def update(self, **kwargs) -> None:

validator = Schema(fields=fields)
kwargs = self._update_auto_now_fields(validator.validate(kwargs), self.model_class.fields)
expr = self.table.update().values(**kwargs)
expression = self.table.update().values(**kwargs)

for filter_clause in self.filter_clauses:
expr = expr.where(filter_clause)
expression = expression.where(filter_clause)

await self.database.execute(expr)
self._set_query_expression(expression)
await self.database.execute(expression)

async def get_or_create(
self, defaults: typing.Dict[str, typing.Any], **kwargs
Expand Down
9 changes: 9 additions & 0 deletions saffier/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Model(ModelMeta, ModelUtil):
query = Manager()
_meta = MetaInfo(None)
_db_model: bool = False
_raw_query: str = None

def __init__(self, **kwargs: DictAny) -> None:
if "pk" in kwargs:
Expand Down Expand Up @@ -55,6 +56,14 @@ def pk(self):
def pk(self, value):
setattr(self, self.pkname, value)

@property
def raw_query(self):
return getattr(self, self._raw_query)

@raw_query.setter
def raw_query(self, value):
setattr(self, self.raw_query, value)

def __repr__(self):
return f"<{self.__class__.__name__}: {self}>"

Expand Down
75 changes: 75 additions & 0 deletions tests/models/test_bulk_create.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import datetime
from enum import Enum

import pytest
from tests.settings import DATABASE_URL

import saffier
from saffier import fields
from saffier.db.connection import Database

pytestmark = pytest.mark.anyio

database = Database(DATABASE_URL)
models = saffier.Registry(database=database)


def time():
return datetime.datetime.now().time()


class StatusEnum(Enum):
DRAFT = "Draft"
RELEASED = "Released"


class Product(saffier.Model):
id = fields.IntegerField(primary_key=True)
uuid = fields.UUIDField(null=True)
created = fields.DateTimeField(default=datetime.datetime.now)
created_day = fields.DateField(default=datetime.date.today)
created_time = fields.TimeField(default=time)
created_date = fields.DateField(auto_now_add=True)
created_datetime = fields.DateTimeField(auto_now_add=True)
updated_datetime = fields.DateTimeField(auto_now=True)
updated_date = fields.DateField(auto_now=True)
data = fields.JSONField(default={})
description = fields.CharField(blank=True, max_length=255)
huge_number = fields.BigIntegerField(default=0)
price = fields.DecimalField(max_digits=5, decimal_places=2, null=True)
status = fields.ChoiceField(StatusEnum, default=StatusEnum.DRAFT)
value = fields.FloatField(null=True)

class Meta:
registry = models


@pytest.fixture(autouse=True, scope="module")
async def create_test_database():
await models.create_all()
yield
await models.drop_all()


@pytest.fixture(autouse=True)
async def rollback_transactions():
with database.force_rollback():
async with database:
yield


async def test_bulk_create():
await Product.query.bulk_create(
[
{"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED},
{"data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT},
]
)
products = await Product.query.all()
assert len(products) == 2
assert products[0].data == {"foo": 123}
assert products[0].value == 123.456
assert products[0].status == StatusEnum.RELEASED
assert products[1].data == {"foo": 456}
assert products[1].value == 456.789
assert products[1].status == StatusEnum.DRAFT
Loading

0 comments on commit 494ee68

Please sign in to comment.