From 2862e0f0ac41efcd6aca25f1fe82ca70c4399ad5 Mon Sep 17 00:00:00 2001 From: huangsong Date: Mon, 17 Jan 2022 12:00:26 +0800 Subject: [PATCH 01/22] add onupdate option in field --- ormar/fields/base.py | 32 ++++++++- ormar/models/mixins/save_mixin.py | 16 +++++ ormar/models/model.py | 26 ++++--- .../test_populate_default_values.py | 1 - .../test_populate_onupdate_values.py | 71 +++++++++++++++++++ 5 files changed, 135 insertions(+), 11 deletions(-) create mode 100644 tests/test_model_methods/test_populate_onupdate_values.py diff --git a/ormar/fields/base.py b/ormar/fields/base.py index f7eaff5c5..5b7188e78 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -1,5 +1,6 @@ import warnings -from typing import Any, Dict, List, Optional, TYPE_CHECKING, Type, Union +from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Type, \ + Union import sqlalchemy from pydantic import Json, typing @@ -47,6 +48,8 @@ def __init__(self, **kwargs: Any) -> None: self.index: bool = kwargs.pop("index", False) self.unique: bool = kwargs.pop("unique", False) self.pydantic_only: bool = kwargs.pop("pydantic_only", False) + self.onupdate: Union[Callable[..., Any], Any, None] = kwargs.pop( + "onupdate", None) if self.pydantic_only: warnings.warn( "Parameter `pydantic_only` is deprecated and will " @@ -217,6 +220,33 @@ def has_default(self, use_server: bool = True) -> bool: self.server_default is not None and use_server ) + def has_onupdate(self) -> bool: + """ + Checks if the field has onupdate value set. + :return: result of the check if onupdate value is set + rtype: bool + """ + if self.__pydantic_type__ is None: + return self.onupdate is not None + + if self.onupdate is not None and not callable(self.onupdate): + if isinstance(self.onupdate, self.__pydantic_type__): + return True + elif self.onupdate is not None and callable(self.onupdate): + return True + return False + + def get_onupdate(self) -> Union[None, Any]: + """ + Get onupdate value if set + + :return: result of the onupdate + rtype: Any + """ + if callable(self.onupdate): + return self.onupdate() + return self.onupdate + def is_auto_primary_key(self) -> bool: """ Checks if field is first a primary key and if it, diff --git a/ormar/models/mixins/save_mixin.py b/ormar/models/mixins/save_mixin.py index 89ddbfed9..6ce709bf2 100644 --- a/ormar/models/mixins/save_mixin.py +++ b/ormar/models/mixins/save_mixin.py @@ -77,6 +77,7 @@ def prepare_model_to_update(cls, new_kwargs: dict) -> dict: new_kwargs = cls.substitute_models_with_pks(new_kwargs) new_kwargs = cls.reconvert_str_to_bytes(new_kwargs) new_kwargs = cls.dump_all_json_fields_to_str(new_kwargs) + new_kwargs = cls.populate_onupdate_value(new_kwargs) new_kwargs = cls.translate_columns_to_aliases(new_kwargs) return new_kwargs @@ -238,6 +239,21 @@ def populate_default_values(cls, new_kwargs: Dict) -> Dict: new_kwargs.pop(field_name, None) return new_kwargs + @classmethod + def populate_onupdate_value(cls, new_kwargs: Dict) -> Dict: + """ + Populate value which from onupdate options in field + + :param new_kwargs: dictionary of model that is about to be saved + :type new_kwargs: Dict + :return: dictionary of model that is about to be saved + :rtype: Dict + """ + for field_name, field in cls.Meta.model_fields.items(): + if field.has_onupdate() and not field.pydantic_only: + new_kwargs[field_name] = field.get_onupdate() + return new_kwargs + @classmethod def validate_choices(cls, new_kwargs: Dict) -> Dict: """ diff --git a/ormar/models/model.py b/ormar/models/model.py index d17cafded..b7aa11663 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, TypeVar, Union +from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, TypeVar, \ + Union import ormar.queryset # noqa I100 from ormar.exceptions import ModelPersistenceError, NoMatch @@ -67,7 +68,8 @@ async def save(self: T) -> T: await self.signals.pre_save.send(sender=self.__class__, instance=self) self_fields = self._extract_model_db_fields() - if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement: + if not self.pk and self.Meta.model_fields[ + self.Meta.pkname].autoincrement: self_fields.pop(self.Meta.pkname, None) self_fields = self.populate_default_values(self_fields) self.update_from_dict( @@ -90,8 +92,8 @@ async def save(self: T) -> T: # refresh server side defaults if any( field.server_default is not None - for name, field in self.Meta.model_fields.items() - if name not in self_fields + for name, field in self.Meta.model_fields.items() + if name not in self_fields ): await self.load() @@ -220,6 +222,7 @@ async def update(self: T, _columns: List[str] = None, **kwargs: Any) -> T: :rtype: Model """ if kwargs: + kwargs = self.populate_onupdate_value(kwargs) self.update_from_dict(kwargs) if not self.pk: @@ -233,14 +236,16 @@ async def update(self: T, _columns: List[str] = None, **kwargs: Any) -> T: self_fields = self._extract_model_db_fields() self_fields.pop(self.get_column_name_from_alias(self.Meta.pkname)) if _columns: - self_fields = {k: v for k, v in self_fields.items() if k in _columns} + self_fields = {k: v for k, v in self_fields.items() if + k in _columns} self_fields = self.translate_columns_to_aliases(self_fields) expr = self.Meta.table.update().values(**self_fields) expr = expr.where(self.pk_column == getattr(self, self.Meta.pkname)) await self.Meta.database.execute(expr) self.set_save_status(True) - await self.signals.post_update.send(sender=self.__class__, instance=self) + await self.signals.post_update.send(sender=self.__class__, + instance=self) return self async def delete(self) -> int: @@ -258,12 +263,14 @@ async def delete(self) -> int: :return: number of deleted rows (for some backends) :rtype: int """ - await self.signals.pre_delete.send(sender=self.__class__, instance=self) + await self.signals.pre_delete.send(sender=self.__class__, + instance=self) expr = self.Meta.table.delete() expr = expr.where(self.pk_column == (getattr(self, self.Meta.pkname))) result = await self.Meta.database.execute(expr) self.set_save_status(False) - await self.signals.post_delete.send(sender=self.__class__, instance=self) + await self.signals.post_delete.send(sender=self.__class__, + instance=self) return result async def load(self: T) -> T: @@ -280,7 +287,8 @@ async def load(self: T) -> T: expr = self.Meta.table.select().where(self.pk_column == self.pk) row = await self.Meta.database.fetch_one(expr) if not row: # pragma nocover - raise NoMatch("Instance was deleted from database and cannot be refreshed") + raise NoMatch( + "Instance was deleted from database and cannot be refreshed") kwargs = dict(row) kwargs = self.translate_aliases_to_columns(kwargs) self.update_from_dict(kwargs) diff --git a/tests/test_model_methods/test_populate_default_values.py b/tests/test_model_methods/test_populate_default_values.py index b3cd2b65d..e883023bd 100644 --- a/tests/test_model_methods/test_populate_default_values.py +++ b/tests/test_model_methods/test_populate_default_values.py @@ -1,5 +1,4 @@ import databases -import pytest import sqlalchemy from sqlalchemy import text diff --git a/tests/test_model_methods/test_populate_onupdate_values.py b/tests/test_model_methods/test_populate_onupdate_values.py new file mode 100644 index 000000000..616a2cb0d --- /dev/null +++ b/tests/test_model_methods/test_populate_onupdate_values.py @@ -0,0 +1,71 @@ +import databases +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class BaseMeta(ormar.ModelMeta): + database = database + metadata = metadata + + +class Task(ormar.Model): + class Meta(BaseMeta): + tablename = "tasks" + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String( + max_length=255, onupdate=lambda: "hello", + ) + age: int = ormar.Integer() + points: int = ormar.Integer( + default=0, minimum=0, onupdate=1 + ) + year = ormar.Integer(onupdate=2, default=1) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_populate_onupdate_values(): + task = Task(name="123", age=1, points=1) + await task.save() + + assert task.year == 1 + await task.update(age=2) + + t = await Task.objects.filter(age=2).first() + assert t.name == "hello" + assert t.points == 1 + assert t.year == 2 + + +@pytest.mark.asyncio +async def test_bulk_update_populate_onupdate_values(): + task1 = await Task(name="123", age=1, points=2).save() + task2 = await Task(name="123", age=2, points=3).save() + task3 = await Task(name="123", age=3, points=4).save() + + tasks = [task1, task2, task3] + + for task in tasks: + task.age += 1 + + await Task.objects.bulk_update(tasks) + + for task in await Task.objects.all(): + assert task.points == 1 + assert task.name == "hello" + assert task.year == 2 From 083780b622a2c601a2f8b4a38ea8b396fb0d6b98 Mon Sep 17 00:00:00 2001 From: huangsong Date: Mon, 17 Jan 2022 12:38:51 +0800 Subject: [PATCH 02/22] with database --- .../test_populate_onupdate_values.py | 40 ++++++++++--------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/tests/test_model_methods/test_populate_onupdate_values.py b/tests/test_model_methods/test_populate_onupdate_values.py index 616a2cb0d..6678ba0b8 100644 --- a/tests/test_model_methods/test_populate_onupdate_values.py +++ b/tests/test_model_methods/test_populate_onupdate_values.py @@ -40,32 +40,34 @@ def create_test_database(): @pytest.mark.asyncio async def test_populate_onupdate_values(): - task = Task(name="123", age=1, points=1) - await task.save() + async with database: + task = Task(name="123", age=1, points=1) + await task.save() - assert task.year == 1 - await task.update(age=2) + assert task.year == 1 + await task.update(age=2) - t = await Task.objects.filter(age=2).first() - assert t.name == "hello" - assert t.points == 1 - assert t.year == 2 + t = await Task.objects.filter(age=2).first() + assert t.name == "hello" + assert t.points == 1 + assert t.year == 2 @pytest.mark.asyncio async def test_bulk_update_populate_onupdate_values(): - task1 = await Task(name="123", age=1, points=2).save() - task2 = await Task(name="123", age=2, points=3).save() - task3 = await Task(name="123", age=3, points=4).save() + async with database: + task1 = await Task(name="123", age=1, points=2).save() + task2 = await Task(name="123", age=2, points=3).save() + task3 = await Task(name="123", age=3, points=4).save() - tasks = [task1, task2, task3] + tasks = [task1, task2, task3] - for task in tasks: - task.age += 1 + for task in tasks: + task.age += 1 - await Task.objects.bulk_update(tasks) + await Task.objects.bulk_update(tasks) - for task in await Task.objects.all(): - assert task.points == 1 - assert task.name == "hello" - assert task.year == 2 + for task in await Task.objects.all(): + assert task.points == 1 + assert task.name == "hello" + assert task.year == 2 From 88cdf15fc88f8dd4f9b5d511c78ebd91f1612ae7 Mon Sep 17 00:00:00 2001 From: huangsong Date: Mon, 17 Jan 2022 12:46:35 +0800 Subject: [PATCH 03/22] fix lint --- ormar/models/model.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/ormar/models/model.py b/ormar/models/model.py index b7aa11663..2f0520f28 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -1,5 +1,7 @@ -from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, TypeVar, \ - Union +from typing import ( + Any, Dict, List, Optional, Set, + TYPE_CHECKING, TypeVar, Union +) import ormar.queryset # noqa I100 from ormar.exceptions import ModelPersistenceError, NoMatch @@ -68,8 +70,7 @@ async def save(self: T) -> T: await self.signals.pre_save.send(sender=self.__class__, instance=self) self_fields = self._extract_model_db_fields() - if not self.pk and self.Meta.model_fields[ - self.Meta.pkname].autoincrement: + if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement: self_fields.pop(self.Meta.pkname, None) self_fields = self.populate_default_values(self_fields) self.update_from_dict( @@ -92,8 +93,8 @@ async def save(self: T) -> T: # refresh server side defaults if any( field.server_default is not None - for name, field in self.Meta.model_fields.items() - if name not in self_fields + for name, field in self.Meta.model_fields.items() + if name not in self_fields ): await self.load() @@ -236,8 +237,7 @@ async def update(self: T, _columns: List[str] = None, **kwargs: Any) -> T: self_fields = self._extract_model_db_fields() self_fields.pop(self.get_column_name_from_alias(self.Meta.pkname)) if _columns: - self_fields = {k: v for k, v in self_fields.items() if - k in _columns} + self_fields = {k: v for k, v in self_fields.items() if k in _columns} self_fields = self.translate_columns_to_aliases(self_fields) expr = self.Meta.table.update().values(**self_fields) expr = expr.where(self.pk_column == getattr(self, self.Meta.pkname)) @@ -263,14 +263,12 @@ async def delete(self) -> int: :return: number of deleted rows (for some backends) :rtype: int """ - await self.signals.pre_delete.send(sender=self.__class__, - instance=self) + await self.signals.pre_delete.send(sender=self.__class__, instance=self) expr = self.Meta.table.delete() expr = expr.where(self.pk_column == (getattr(self, self.Meta.pkname))) result = await self.Meta.database.execute(expr) self.set_save_status(False) - await self.signals.post_delete.send(sender=self.__class__, - instance=self) + await self.signals.post_delete.send(sender=self.__class__, instance=self) return result async def load(self: T) -> T: @@ -287,8 +285,7 @@ async def load(self: T) -> T: expr = self.Meta.table.select().where(self.pk_column == self.pk) row = await self.Meta.database.fetch_one(expr) if not row: # pragma nocover - raise NoMatch( - "Instance was deleted from database and cannot be refreshed") + raise NoMatch("Instance was deleted from database and cannot be refreshed") kwargs = dict(row) kwargs = self.translate_aliases_to_columns(kwargs) self.update_from_dict(kwargs) From 00b48244169b3e19a8a66aac21916a12fef2983d Mon Sep 17 00:00:00 2001 From: huangsong Date: Mon, 17 Jan 2022 17:44:24 +0800 Subject: [PATCH 04/22] fix some lint and check the array in --- ormar/queryset/queryset.py | 6 ++++-- tests/test_queries/test_queryset_level_methods.py | 3 +++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index c57654718..11235388b 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -1064,9 +1064,11 @@ async def bulk_create(self, objects: List["T"]) -> None: :param objects: list of ormar models already initialized and ready to save. :type objects: List[Model] """ + if not objects: + raise ModelListEmptyError("Bulk create objects are empty!") + ready_objects = [obj.prepare_model_to_save(obj.dict()) for obj in objects] - - # don't use execute_many, as in databases it's executed in a loop + # don't use execute_many, as in databases it's executed in a loop # instead of using execute_many from drivers expr = self.table.insert().values(ready_objects) await self.database.execute(expr) diff --git a/tests/test_queries/test_queryset_level_methods.py b/tests/test_queries/test_queryset_level_methods.py index 449c7dfd6..f53cc28f2 100644 --- a/tests/test_queries/test_queryset_level_methods.py +++ b/tests/test_queries/test_queryset_level_methods.py @@ -208,6 +208,9 @@ async def test_bulk_create(): completed = await ToDo.objects.filter(completed=True).all() assert len(completed) == 2 + with pytest.raises(ModelListEmptyError): + await ToDo.objects.bulk_create([]) + @pytest.mark.asyncio async def test_bulk_create_with_relation(): From 4357b6ca871af3ec8a408681c7e3ecf4ebb5ce4a Mon Sep 17 00:00:00 2001 From: huangsong Date: Mon, 17 Jan 2022 17:46:41 +0800 Subject: [PATCH 05/22] fix unnecessary lint --- ormar/models/model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ormar/models/model.py b/ormar/models/model.py index 2f0520f28..4ac5be1b0 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -244,8 +244,7 @@ async def update(self: T, _columns: List[str] = None, **kwargs: Any) -> T: await self.Meta.database.execute(expr) self.set_save_status(True) - await self.signals.post_update.send(sender=self.__class__, - instance=self) + await self.signals.post_update.send(sender=self.__class__, instance=self) return self async def delete(self) -> int: From 93edcbf615d432f92592c4f6fbbaa96854f5d046 Mon Sep 17 00:00:00 2001 From: huangsong Date: Mon, 17 Jan 2022 17:55:57 +0800 Subject: [PATCH 06/22] lazy doc --- README.md | 1 + docs/fields/common-parameters.md | 29 +++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/README.md b/README.md index 8183bbbbc..72417d34e 100644 --- a/README.md +++ b/README.md @@ -654,6 +654,7 @@ The following keyword arguments are supported on all field types. * `choices: typing.Sequence` * `name: str` * `pydantic_only: bool` +* `onupdate: Any/callable` All fields are required unless one of the following is set: diff --git a/docs/fields/common-parameters.md b/docs/fields/common-parameters.md index 79ce92016..78c1eed8b 100644 --- a/docs/fields/common-parameters.md +++ b/docs/fields/common-parameters.md @@ -216,6 +216,35 @@ class OverwriteTest(ormar.Model): `choices`: `Sequence` = `[]` +## onupdate + +when the object update or bulk_update, if you don't update the field which has the onupdate option, +its value will be changed from `onupdate definition` + +```python + +class ToDo(ormar.Model): + class Meta: + tablename = "todo" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=32) + my_timestamp: int = ormar.Integer(onupdate=time.time) + is_dirty: bool = ormar.Boolean(default=False, onupdate=True) + + +todo = await ToDo.objects.get(id=1) +await todo.update(name="test") + +assert todo.is_dirty +assert todo.my_timestamp == now + +``` + + + A set of choices allowed to be used for given field. Used for data validation on pydantic side. From 5cff483ab1b1ff3c1a213bf88b32c5ded6473b0d Mon Sep 17 00:00:00 2001 From: huangsong Date: Mon, 17 Jan 2022 19:01:30 +0800 Subject: [PATCH 07/22] fix _columns args --- ormar/models/mixins/save_mixin.py | 8 +++++ ormar/models/model.py | 6 +++- ormar/queryset/queryset.py | 13 ++++++-- .../test_populate_onupdate_values.py | 30 ++++++++++++++++--- 4 files changed, 49 insertions(+), 8 deletions(-) diff --git a/ormar/models/mixins/save_mixin.py b/ormar/models/mixins/save_mixin.py index 6ce709bf2..fb8d3e247 100644 --- a/ormar/models/mixins/save_mixin.py +++ b/ormar/models/mixins/save_mixin.py @@ -418,3 +418,11 @@ def _get_field_values(self, name: str) -> List: if not isinstance(values, list): values = [values] return values + + @classmethod + def get_fields_has_onupdate(cls) -> List[str]: + return [ + field_name + for field_name, field in cls.Meta.model_fields.items() + if field.has_onupdate() and not field.pydantic_only + ] diff --git a/ormar/models/model.py b/ormar/models/model.py index 4ac5be1b0..bd3e7c5dd 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -237,7 +237,11 @@ async def update(self: T, _columns: List[str] = None, **kwargs: Any) -> T: self_fields = self._extract_model_db_fields() self_fields.pop(self.get_column_name_from_alias(self.Meta.pkname)) if _columns: - self_fields = {k: v for k, v in self_fields.items() if k in _columns} + onupdate_fields = self.get_fields_has_onupdate() + self_fields = { + k: v for k, v in self_fields.items() + if k in _columns or k in onupdate_fields + } self_fields = self.translate_columns_to_aliases(self_fields) expr = self.Meta.table.update().values(**self_fields) expr = expr.where(self.pk_column == getattr(self, self.Meta.pkname)) diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 11235388b..9b05277fd 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -1112,6 +1112,11 @@ async def bulk_update( # noqa: CCR001 columns.append(pk_name) columns = [self.model.get_column_alias(k) for k in columns] + onupdate_fields = [ + self.model.get_column_alias(k) + for k in cast(Type["Model"], self.model_cls).get_fields_has_onupdate() + ] + updated_columns = list(set(columns + onupdate_fields)) for obj in objects: new_kwargs = obj.dict() @@ -1122,9 +1127,11 @@ async def bulk_update( # noqa: CCR001 ) new_kwargs = obj.prepare_model_to_update(new_kwargs) ready_objects.append( - {"new_" + k: v for k, v in new_kwargs.items() if k in columns} + { + "new_" + k: v for k, v in new_kwargs.items() + if k in updated_columns + } ) - pk_column = self.model_meta.table.c.get(self.model.get_column_alias(pk_name)) pk_column_name = self.model.get_column_alias(pk_name) table_columns = [c.name for c in self.model_meta.table.c] @@ -1134,7 +1141,7 @@ async def bulk_update( # noqa: CCR001 expr = expr.values( **{ k: bindparam("new_" + k) - for k in columns + for k in updated_columns if k != pk_column_name and k in table_columns } ) diff --git a/tests/test_model_methods/test_populate_onupdate_values.py b/tests/test_model_methods/test_populate_onupdate_values.py index 6678ba0b8..232b40454 100644 --- a/tests/test_model_methods/test_populate_onupdate_values.py +++ b/tests/test_model_methods/test_populate_onupdate_values.py @@ -42,32 +42,54 @@ def create_test_database(): async def test_populate_onupdate_values(): async with database: task = Task(name="123", age=1, points=1) + task2 = Task(name="123", age=1, points=1) await task.save() + await task2.save() assert task.year == 1 + assert task2.year == 1 + await task.update(age=2) + await task2.update(_columns=["age"], age=3) t = await Task.objects.filter(age=2).first() assert t.name == "hello" assert t.points == 1 assert t.year == 2 + t = await Task.objects.filter(age=3).first() + assert t.name == "hello" + assert t.points == 1 + assert t.year == 2 + @pytest.mark.asyncio async def test_bulk_update_populate_onupdate_values(): async with database: task1 = await Task(name="123", age=1, points=2).save() task2 = await Task(name="123", age=2, points=3).save() - task3 = await Task(name="123", age=3, points=4).save() + task3 = await Task(name="345", age=5, points=4).save() + task4 = await Task(name="345", age=6, points=5).save() + + tasks = [task1, task2] - tasks = [task1, task2, task3] + tasks_ = [task3, task4] - for task in tasks: + for task in tasks_ + tasks: task.age += 1 await Task.objects.bulk_update(tasks) + assert 2 == await Task.objects.filter(Task.age <= 3).count() + + for task in await Task.objects.filter(Task.age <= 3).all(): + assert task.points == 1 + assert task.name == "hello" + assert task.year == 2 + + await Task.objects.bulk_update(tasks_, columns=["age"]) - for task in await Task.objects.all(): + assert 2 == await Task.objects.filter(Task.age > 3).count() + for task in await Task.objects.filter(Task.age > 3).all(): assert task.points == 1 assert task.name == "hello" assert task.year == 2 From 4db16eb8089f5b32546d9adca6391c6e527c7e9c Mon Sep 17 00:00:00 2001 From: huangsong Date: Tue, 18 Jan 2022 10:33:22 +0800 Subject: [PATCH 08/22] fix --- ormar/fields/base.py | 14 ++------------ ormar/models/mixins/save_mixin.py | 5 +++-- ormar/models/model.py | 2 +- ormar/queryset/queryset.py | 2 +- 4 files changed, 7 insertions(+), 16 deletions(-) diff --git a/ormar/fields/base.py b/ormar/fields/base.py index 5b7188e78..0c8458961 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -226,15 +226,7 @@ def has_onupdate(self) -> bool: :return: result of the check if onupdate value is set rtype: bool """ - if self.__pydantic_type__ is None: - return self.onupdate is not None - - if self.onupdate is not None and not callable(self.onupdate): - if isinstance(self.onupdate, self.__pydantic_type__): - return True - elif self.onupdate is not None and callable(self.onupdate): - return True - return False + return self.onupdate is not None def get_onupdate(self) -> Union[None, Any]: """ @@ -243,9 +235,7 @@ def get_onupdate(self) -> Union[None, Any]: :return: result of the onupdate rtype: Any """ - if callable(self.onupdate): - return self.onupdate() - return self.onupdate + return self.onupdate() if callable(self.onupdate) else self.onupdate def is_auto_primary_key(self) -> bool: """ diff --git a/ormar/models/mixins/save_mixin.py b/ormar/models/mixins/save_mixin.py index fb8d3e247..96765bb67 100644 --- a/ormar/models/mixins/save_mixin.py +++ b/ormar/models/mixins/save_mixin.py @@ -249,7 +249,8 @@ def populate_onupdate_value(cls, new_kwargs: Dict) -> Dict: :return: dictionary of model that is about to be saved :rtype: Dict """ - for field_name, field in cls.Meta.model_fields.items(): + for field_name in cls.get_fields_with_onupdate(): + field = cls.Meta.model_fields[field_name] if field.has_onupdate() and not field.pydantic_only: new_kwargs[field_name] = field.get_onupdate() return new_kwargs @@ -420,7 +421,7 @@ def _get_field_values(self, name: str) -> List: return values @classmethod - def get_fields_has_onupdate(cls) -> List[str]: + def get_fields_with_onupdate(cls) -> List[str]: return [ field_name for field_name, field in cls.Meta.model_fields.items() diff --git a/ormar/models/model.py b/ormar/models/model.py index bd3e7c5dd..ca37ccda4 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -237,7 +237,7 @@ async def update(self: T, _columns: List[str] = None, **kwargs: Any) -> T: self_fields = self._extract_model_db_fields() self_fields.pop(self.get_column_name_from_alias(self.Meta.pkname)) if _columns: - onupdate_fields = self.get_fields_has_onupdate() + onupdate_fields = self.get_fields_with_onupdate() self_fields = { k: v for k, v in self_fields.items() if k in _columns or k in onupdate_fields diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 9b05277fd..1d78949aa 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -1114,7 +1114,7 @@ async def bulk_update( # noqa: CCR001 columns = [self.model.get_column_alias(k) for k in columns] onupdate_fields = [ self.model.get_column_alias(k) - for k in cast(Type["Model"], self.model_cls).get_fields_has_onupdate() + for k in cast(Type["Model"], self.model_cls).get_fields_with_onupdate() ] updated_columns = list(set(columns + onupdate_fields)) From 544b6aaf4450cf88065c6da23452414967b238de Mon Sep 17 00:00:00 2001 From: huangsong Date: Tue, 18 Jan 2022 17:26:27 +0800 Subject: [PATCH 09/22] fix --- ormar/models/mixins/save_mixin.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/ormar/models/mixins/save_mixin.py b/ormar/models/mixins/save_mixin.py index 96765bb67..a2e954695 100644 --- a/ormar/models/mixins/save_mixin.py +++ b/ormar/models/mixins/save_mixin.py @@ -38,6 +38,7 @@ class SavePrepareMixin(RelationMixin, AliasMixin): _skip_ellipsis: Callable _json_fields: Set[str] _bytes_fields: Set[str] + _onupdate_fields: Set[str] __fields__: Dict[str, pydantic.fields.ModelField] @classmethod @@ -251,8 +252,7 @@ def populate_onupdate_value(cls, new_kwargs: Dict) -> Dict: """ for field_name in cls.get_fields_with_onupdate(): field = cls.Meta.model_fields[field_name] - if field.has_onupdate() and not field.pydantic_only: - new_kwargs[field_name] = field.get_onupdate() + new_kwargs[field_name] = field.get_onupdate() return new_kwargs @classmethod @@ -421,9 +421,12 @@ def _get_field_values(self, name: str) -> List: return values @classmethod - def get_fields_with_onupdate(cls) -> List[str]: - return [ - field_name - for field_name, field in cls.Meta.model_fields.items() - if field.has_onupdate() and not field.pydantic_only - ] + def get_fields_with_onupdate(cls) -> Set[str]: + if not cls._onupdate_fields: + cls._onupdate_fields = { + field_name + for field_name, field in cls.Meta.model_fields.items() + if field.has_onupdate() and not field.pydantic_only + } + return cls._onupdate_fields + From 2c0188d3f80d73d8c9777ef9a77916564988e862 Mon Sep 17 00:00:00 2001 From: huangsong Date: Tue, 18 Jan 2022 17:33:38 +0800 Subject: [PATCH 10/22] fix lint --- ormar/models/metaclass.py | 1 + ormar/models/newbasemodel.py | 1 + 2 files changed, 2 insertions(+) diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index cc1dede39..094369aa7 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -108,6 +108,7 @@ def add_cached_properties(new_model: Type["Model"]) -> None: new_model._pydantic_fields = {name for name in new_model.__fields__} new_model._json_fields = set() new_model._bytes_fields = set() + new_model._onupdate_fields = set() def add_property_fields(new_model: Type["Model"], attrs: Dict) -> None: # noqa: CCR001 diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index d05498e12..c59a88d1b 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -94,6 +94,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass _quick_access_fields: Set _json_fields: Set _bytes_fields: Set + _onupdate_fields: Set Meta: ModelMeta # noinspection PyMissingConstructor From 4721e9c03d48204b6211e302cb5b3e25f2153ab6 Mon Sep 17 00:00:00 2001 From: huangsong Date: Tue, 18 Jan 2022 18:59:57 +0800 Subject: [PATCH 11/22] fix naming --- ormar/fields/base.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ormar/fields/base.py b/ormar/fields/base.py index 0c8458961..4a4414a57 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -48,8 +48,8 @@ def __init__(self, **kwargs: Any) -> None: self.index: bool = kwargs.pop("index", False) self.unique: bool = kwargs.pop("unique", False) self.pydantic_only: bool = kwargs.pop("pydantic_only", False) - self.onupdate: Union[Callable[..., Any], Any, None] = kwargs.pop( - "onupdate", None) + self.on_update: Union[Callable[..., Any], Any, None] = kwargs.pop( + "on_update", None) if self.pydantic_only: warnings.warn( "Parameter `pydantic_only` is deprecated and will " @@ -226,7 +226,7 @@ def has_onupdate(self) -> bool: :return: result of the check if onupdate value is set rtype: bool """ - return self.onupdate is not None + return self.on_update is not None def get_onupdate(self) -> Union[None, Any]: """ @@ -235,7 +235,7 @@ def get_onupdate(self) -> Union[None, Any]: :return: result of the onupdate rtype: Any """ - return self.onupdate() if callable(self.onupdate) else self.onupdate + return self.on_update() if callable(self.on_update) else self.on_update def is_auto_primary_key(self) -> bool: """ From 30d09672c29811f3b60650e69d6ba08fe277fe7e Mon Sep 17 00:00:00 2001 From: huangsong Date: Tue, 18 Jan 2022 19:00:34 +0800 Subject: [PATCH 12/22] fix ut --- tests/test_model_methods/test_populate_onupdate_values.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_model_methods/test_populate_onupdate_values.py b/tests/test_model_methods/test_populate_onupdate_values.py index 232b40454..ed74eb0e2 100644 --- a/tests/test_model_methods/test_populate_onupdate_values.py +++ b/tests/test_model_methods/test_populate_onupdate_values.py @@ -20,13 +20,13 @@ class Meta(BaseMeta): id: int = ormar.Integer(primary_key=True) name: str = ormar.String( - max_length=255, onupdate=lambda: "hello", + max_length=255, on_update=lambda: "hello", ) age: int = ormar.Integer() points: int = ormar.Integer( - default=0, minimum=0, onupdate=1 + default=0, minimum=0, on_update=1 ) - year = ormar.Integer(onupdate=2, default=1) + year = ormar.Integer(on_update=2, default=1) @pytest.fixture(autouse=True, scope="module") From 60342b2e25b25ed0d65df06f1a04402f19038f67 Mon Sep 17 00:00:00 2001 From: huangsong Date: Wed, 19 Jan 2022 10:20:07 +0800 Subject: [PATCH 13/22] remove on bulk-update --- ormar/models/mixins/save_mixin.py | 5 +- ormar/queryset/queryset.py | 16 +++-- .../test_populate_onupdate_values.py | 66 ++++++++++--------- 3 files changed, 47 insertions(+), 40 deletions(-) diff --git a/ormar/models/mixins/save_mixin.py b/ormar/models/mixins/save_mixin.py index a2e954695..71ec9f246 100644 --- a/ormar/models/mixins/save_mixin.py +++ b/ormar/models/mixins/save_mixin.py @@ -78,7 +78,7 @@ def prepare_model_to_update(cls, new_kwargs: dict) -> dict: new_kwargs = cls.substitute_models_with_pks(new_kwargs) new_kwargs = cls.reconvert_str_to_bytes(new_kwargs) new_kwargs = cls.dump_all_json_fields_to_str(new_kwargs) - new_kwargs = cls.populate_onupdate_value(new_kwargs) + # new_kwargs = cls.populate_onupdate_value(new_kwargs) new_kwargs = cls.translate_columns_to_aliases(new_kwargs) return new_kwargs @@ -252,7 +252,8 @@ def populate_onupdate_value(cls, new_kwargs: Dict) -> Dict: """ for field_name in cls.get_fields_with_onupdate(): field = cls.Meta.model_fields[field_name] - new_kwargs[field_name] = field.get_onupdate() + if field_name not in new_kwargs: + new_kwargs[field_name] = field.get_onupdate() return new_kwargs @classmethod diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 1d78949aa..6dd56153e 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -1112,13 +1112,15 @@ async def bulk_update( # noqa: CCR001 columns.append(pk_name) columns = [self.model.get_column_alias(k) for k in columns] - onupdate_fields = [ - self.model.get_column_alias(k) - for k in cast(Type["Model"], self.model_cls).get_fields_with_onupdate() - ] - updated_columns = list(set(columns + onupdate_fields)) + # on_update_fields = [ + # self.model.get_column_alias(k) + # for k in cast(Type["Model"], self.model_cls).get_fields_with_onupdate() + # ] + # updated_columns = list(set(columns + on_update_fields)) for obj in objects: + # when the obj.__setattr__, should be dirty for column + # only load the kv from dirty fields new_kwargs = obj.dict() if new_kwargs.get(pk_name) is None: raise ModelPersistenceError( @@ -1129,7 +1131,7 @@ async def bulk_update( # noqa: CCR001 ready_objects.append( { "new_" + k: v for k, v in new_kwargs.items() - if k in updated_columns + if k in columns } ) pk_column = self.model_meta.table.c.get(self.model.get_column_alias(pk_name)) @@ -1141,7 +1143,7 @@ async def bulk_update( # noqa: CCR001 expr = expr.values( **{ k: bindparam("new_" + k) - for k in updated_columns + for k in columns if k != pk_column_name and k in table_columns } ) diff --git a/tests/test_model_methods/test_populate_onupdate_values.py b/tests/test_model_methods/test_populate_onupdate_values.py index ed74eb0e2..dce8f6051 100644 --- a/tests/test_model_methods/test_populate_onupdate_values.py +++ b/tests/test_model_methods/test_populate_onupdate_values.py @@ -62,34 +62,38 @@ async def test_populate_onupdate_values(): assert t.points == 1 assert t.year == 2 - -@pytest.mark.asyncio -async def test_bulk_update_populate_onupdate_values(): - async with database: - task1 = await Task(name="123", age=1, points=2).save() - task2 = await Task(name="123", age=2, points=3).save() - task3 = await Task(name="345", age=5, points=4).save() - task4 = await Task(name="345", age=6, points=5).save() - - tasks = [task1, task2] - - tasks_ = [task3, task4] - - for task in tasks_ + tasks: - task.age += 1 - - await Task.objects.bulk_update(tasks) - assert 2 == await Task.objects.filter(Task.age <= 3).count() - - for task in await Task.objects.filter(Task.age <= 3).all(): - assert task.points == 1 - assert task.name == "hello" - assert task.year == 2 - - await Task.objects.bulk_update(tasks_, columns=["age"]) - - assert 2 == await Task.objects.filter(Task.age > 3).count() - for task in await Task.objects.filter(Task.age > 3).all(): - assert task.points == 1 - assert task.name == "hello" - assert task.year == 2 + await task.update(points=3) + t = await Task.objects.get_or_none(id=task.id) + assert t.points == 3 + + +# @pytest.mark.asyncio +# async def test_bulk_update_populate_onupdate_values(): +# async with database: +# task1 = await Task(name="123", age=1, points=2).save() +# task2 = await Task(name="123", age=2, points=3).save() +# task3 = await Task(name="345", age=5, points=4).save() +# task4 = await Task(name="345", age=6, points=5).save() +# +# tasks = [task1, task2] +# +# tasks_ = [task3, task4] +# +# for task in tasks_ + tasks: +# task.age += 1 +# +# await Task.objects.bulk_update(tasks) +# assert 2 == await Task.objects.filter(Task.age <= 3).count() +# +# for task in await Task.objects.filter(Task.age <= 3).all(): +# assert task.points == 1 +# assert task.name == "hello" +# assert task.year == 2 +# +# await Task.objects.bulk_update(tasks_, columns=["age"]) +# +# assert 2 == await Task.objects.filter(Task.age > 3).count() +# for task in await Task.objects.filter(Task.age > 3).all(): +# assert task.points == 1 +# assert task.name == "hello" +# assert task.year == 2 From d30919c4308a327cc70f9bab26b5afeb302e3890 Mon Sep 17 00:00:00 2001 From: huangsong Date: Tue, 8 Feb 2022 10:26:25 +0800 Subject: [PATCH 14/22] fix --- ormar/queryset/queryset.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index e79d7cbd9..d221c5ce4 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -1112,12 +1112,12 @@ async def bulk_update( # noqa: CCR001 if pk_name not in columns: columns.append(pk_name) - columns = [self.model.get_column_alias(k) for k in columns] - # on_update_fields = [ - # self.model.get_column_alias(k) - # for k in cast(Type["Model"], self.model_cls).get_fields_with_onupdate() - # ] - # updated_columns = list(set(columns + on_update_fields)) + columns = {self.model.get_column_alias(k) for k in columns} + on_update_fields = { + self.model.get_column_alias(k) + for k in cast(Type["Model"], self.model_cls).get_fields_with_onupdate() + } + updated_columns = columns | on_update_fields for obj in objects: # when the obj.__setattr__, should be dirty for column @@ -1132,19 +1132,20 @@ async def bulk_update( # noqa: CCR001 ready_objects.append( { "new_" + k: v for k, v in new_kwargs.items() - if k in columns + if k in updated_columns } ) pk_column = self.model_meta.table.c.get(self.model.get_column_alias(pk_name)) pk_column_name = self.model.get_column_alias(pk_name) table_columns = [c.name for c in self.model_meta.table.c] + # Make the expr expr = self.table.update().where( pk_column == bindparam("new_" + pk_column_name) ) expr = expr.values( **{ k: bindparam("new_" + k) - for k in columns + for k in updated_columns if k != pk_column_name and k in table_columns } ) From 362ba26503cdb4f9ba89b4e506b899ce15c69611 Mon Sep 17 00:00:00 2001 From: huangsong Date: Tue, 8 Feb 2022 10:38:36 +0800 Subject: [PATCH 15/22] fix lint --- ormar/queryset/queryset.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index d221c5ce4..01e2a72c6 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -1112,12 +1112,13 @@ async def bulk_update( # noqa: CCR001 if pk_name not in columns: columns.append(pk_name) - columns = {self.model.get_column_alias(k) for k in columns} + new_columns: Set[str] = {self.model.get_column_alias(k) for k in columns} + on_update_fields = { self.model.get_column_alias(k) for k in cast(Type["Model"], self.model_cls).get_fields_with_onupdate() } - updated_columns = columns | on_update_fields + updated_columns = new_columns | on_update_fields for obj in objects: # when the obj.__setattr__, should be dirty for column From 5ee21967bfa5e5798d05b0653150f984c9ad41ff Mon Sep 17 00:00:00 2001 From: huangsong Date: Tue, 8 Feb 2022 11:14:44 +0800 Subject: [PATCH 16/22] fix --- ormar/queryset/queryset.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 01e2a72c6..67fa73d7f 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -1112,13 +1112,13 @@ async def bulk_update( # noqa: CCR001 if pk_name not in columns: columns.append(pk_name) - new_columns: Set[str] = {self.model.get_column_alias(k) for k in columns} + columns = [self.model.get_column_alias(k) for k in columns] - on_update_fields = { - self.model.get_column_alias(k) - for k in cast(Type["Model"], self.model_cls).get_fields_with_onupdate() - } - updated_columns = new_columns | on_update_fields + # on_update_fields = { + # self.model.get_column_alias(k) + # for k in cast(Type["Model"], self.model_cls).get_fields_with_onupdate() + # } + # updated_columns = new_columns | on_update_fields for obj in objects: # when the obj.__setattr__, should be dirty for column @@ -1133,7 +1133,7 @@ async def bulk_update( # noqa: CCR001 ready_objects.append( { "new_" + k: v for k, v in new_kwargs.items() - if k in updated_columns + if k in columns } ) pk_column = self.model_meta.table.c.get(self.model.get_column_alias(pk_name)) @@ -1146,7 +1146,7 @@ async def bulk_update( # noqa: CCR001 expr = expr.values( **{ k: bindparam("new_" + k) - for k in updated_columns + for k in columns if k != pk_column_name and k in table_columns } ) From 739fa5f3df6f9572d4273f64ec79b0eaae820b5e Mon Sep 17 00:00:00 2001 From: vvanglro Date: Tue, 27 Feb 2024 16:26:36 +0800 Subject: [PATCH 17/22] feat: add on_update to the field --- ormar/fields/base.py | 6 +- ormar/models/mixins/save_mixin.py | 8 +- ormar/models/model.py | 15 +- ormar/models/newbasemodel.py | 4 + ormar/queryset/join.py | 17 +- ormar/queryset/queryset.py | 2 + .../test_field_quoting.py | 33 ++-- .../test_populate_onupdate_values.py | 151 +++++++++++------- 8 files changed, 151 insertions(+), 85 deletions(-) diff --git a/ormar/fields/base.py b/ormar/fields/base.py index 4a4414a57..c47349461 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -1,6 +1,5 @@ import warnings -from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Type, \ - Union +from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Type, Union import sqlalchemy from pydantic import Json, typing @@ -49,7 +48,8 @@ def __init__(self, **kwargs: Any) -> None: self.unique: bool = kwargs.pop("unique", False) self.pydantic_only: bool = kwargs.pop("pydantic_only", False) self.on_update: Union[Callable[..., Any], Any, None] = kwargs.pop( - "on_update", None) + "on_update", None + ) if self.pydantic_only: warnings.warn( "Parameter `pydantic_only` is deprecated and will " diff --git a/ormar/models/mixins/save_mixin.py b/ormar/models/mixins/save_mixin.py index 3cf40db70..78cb564cf 100644 --- a/ormar/models/mixins/save_mixin.py +++ b/ormar/models/mixins/save_mixin.py @@ -22,6 +22,7 @@ from ormar.models.mixins.relation_mixin import RelationMixin if TYPE_CHECKING: # pragma: no cover + from ormar.models import T from ormar import ForeignKeyField, Model @@ -74,7 +75,6 @@ def prepare_model_to_update(cls, new_kwargs: dict) -> dict: new_kwargs = cls.substitute_models_with_pks(new_kwargs) new_kwargs = cls.reconvert_str_to_bytes(new_kwargs) new_kwargs = cls.dump_all_json_fields_to_str(new_kwargs) - # new_kwargs = cls.populate_onupdate_value(new_kwargs) new_kwargs = cls.translate_columns_to_aliases(new_kwargs) new_kwargs = cls.translate_enum_columns(new_kwargs) return new_kwargs @@ -245,7 +245,7 @@ def populate_default_values(cls, new_kwargs: Dict) -> Dict: return new_kwargs @classmethod - def populate_onupdate_value(cls, new_kwargs: Dict) -> Dict: + def populate_onupdate_value(cls, new_kwargs: Dict, obj: "T" = None) -> Dict: """ Populate value which from onupdate options in field @@ -258,6 +258,9 @@ def populate_onupdate_value(cls, new_kwargs: Dict) -> Dict: field = cls.Meta.model_fields[field_name] if field_name not in new_kwargs: new_kwargs[field_name] = field.get_onupdate() + if obj: + if field_name not in obj.__setattr_fields__: + new_kwargs[field_name] = field.get_onupdate() return new_kwargs @classmethod @@ -434,4 +437,3 @@ def get_fields_with_onupdate(cls) -> Set[str]: if field.has_onupdate() and not field.pydantic_only } return cls._onupdate_fields - diff --git a/ormar/models/model.py b/ormar/models/model.py index c205408f4..05c3aba0b 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -1,7 +1,4 @@ -from typing import ( - Any, Dict, List, Optional, Set, - TYPE_CHECKING, TypeVar, Union -) +from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, TypeVar, Union import ormar.queryset # noqa I100 from ormar.exceptions import ModelPersistenceError, NoMatch @@ -248,9 +245,17 @@ async def update(self: T, _columns: List[str] = None, **kwargs: Any) -> T: if _columns: onupdate_fields = self.get_fields_with_onupdate() self_fields = { - k: v for k, v in self_fields.items() + k: v + for k, v in self_fields.items() if k in _columns or k in onupdate_fields } + if not kwargs and not _columns: + for field_name in self.get_fields_with_onupdate(): + if field_name not in self.__setattr_fields__: + field = self.Meta.model_fields[field_name] + onupdate_field_value = {field_name: field.get_onupdate()} + self_fields.update(onupdate_field_value) + self.update_from_dict(onupdate_field_value) self_fields = self.translate_columns_to_aliases(self_fields) expr = self.Meta.table.update().values(**self_fields) expr = expr.where(self.pk_column == getattr(self, self.Meta.pkname)) diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 7d74188ee..8bc43ebb2 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -74,6 +74,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass "_pk_column", "__pk_only__", "__cached_hash__", + "__setattr_fields__", ) if TYPE_CHECKING: # pragma no cover @@ -87,6 +88,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass __database__: databases.Database __relation_map__: Optional[List[str]] __cached_hash__: Optional[int] + __setattr_fields__: Set _orm_relationship_manager: AliasManager _orm: RelationsManager _orm_id: int @@ -169,6 +171,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # type: ignore if hasattr(self, "_init_private_attributes"): # introduced in pydantic 1.7 self._init_private_attributes() + object.__setattr__(self, "__setattr_fields__", set()) def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001 """ @@ -196,6 +199,7 @@ def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001 if prev_hash != new_hash: self._update_relation_cache(prev_hash, new_hash) + self.__setattr_fields__.add(name) def __getattr__(self, item: str) -> Any: """ diff --git a/ormar/queryset/join.py b/ormar/queryset/join.py index 29d061429..01354fe50 100644 --- a/ormar/queryset/join.py +++ b/ormar/queryset/join.py @@ -92,7 +92,14 @@ def to_table(self) -> sqlalchemy.Table: """ return self.next_model.Meta.table - def _on_clause(self, previous_alias: str, from_table_name:str, from_column_name: str, to_table_name: str, to_column_name: str) -> text: + def _on_clause( + self, + previous_alias: str, + from_table_name: str, + from_column_name: str, + to_table_name: str, + to_column_name: str, + ) -> text: """ Receives aliases and names of both ends of the join and combines them into one text clause used in joins. @@ -112,11 +119,15 @@ def _on_clause(self, previous_alias: str, from_table_name:str, from_column_name: """ dialect = self.main_model.Meta.database._backend._dialect quoter = dialect.identifier_preparer.quote - left_part = f"{quoter(f'{self.next_alias}_{to_table_name}')}.{quoter(to_column_name)}" + left_part = ( + f"{quoter(f'{self.next_alias}_{to_table_name}')}.{quoter(to_column_name)}" + ) if not previous_alias: right_part = f"{quoter(from_table_name)}.{quoter(from_column_name)}" else: - right_part = f"{quoter(f'{previous_alias}_{from_table_name}')}.{from_column_name}" + right_part = ( + f"{quoter(f'{previous_alias}_{from_table_name}')}.{from_column_name}" + ) return text(f"{left_part}={right_part}") diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index c958729cd..44df41f30 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -796,6 +796,7 @@ async def update(self, each: bool = False, **kwargs: Any) -> int: self.model.extract_related_names() ) updates = {k: v for k, v in kwargs.items() if k in self_fields} + updates = self.model.populate_onupdate_value(updates) updates = self.model.validate_choices(updates) updates = self.model.translate_columns_to_aliases(updates) @@ -1199,6 +1200,7 @@ async def bulk_update( # noqa: CCR001 "You cannot update unsaved objects. " f"{self.model.__name__} has to have {pk_name} filled." ) + new_kwargs = obj.populate_onupdate_value(new_kwargs, obj) new_kwargs = obj.prepare_model_to_update(new_kwargs) ready_objects.append( {"new_" + k: v for k, v in new_kwargs.items() if k in columns} diff --git a/tests/test_model_definition/test_field_quoting.py b/tests/test_model_definition/test_field_quoting.py index 4cf656816..ae1105580 100644 --- a/tests/test_model_definition/test_field_quoting.py +++ b/tests/test_model_definition/test_field_quoting.py @@ -41,8 +41,12 @@ class Meta: id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) gpa: float = ormar.Float() - schoolclass: Optional[SchoolClass] = ormar.ForeignKey(SchoolClass, related_name="students") - category: Optional[Category] = ormar.ForeignKey(Category, nullable=True, related_name="students") + schoolclass: Optional[SchoolClass] = ormar.ForeignKey( + SchoolClass, related_name="students" + ) + category: Optional[Category] = ormar.ForeignKey( + Category, nullable=True, related_name="students" + ) @pytest.fixture(autouse=True, scope="module") @@ -59,9 +63,15 @@ async def create_data(): class2 = await SchoolClass.objects.create(name="Logic") category = await Category.objects.create(name="Foreign") category2 = await Category.objects.create(name="Domestic") - await Student.objects.create(name="Jane", category=category, schoolclass=class1, gpa=3.2) - await Student.objects.create(name="Judy", category=category2, schoolclass=class1, gpa=2.6) - await Student.objects.create(name="Jack", category=category2, schoolclass=class2, gpa=3.8) + await Student.objects.create( + name="Jane", category=category, schoolclass=class1, gpa=3.2 + ) + await Student.objects.create( + name="Judy", category=category2, schoolclass=class1, gpa=2.6 + ) + await Student.objects.create( + name="Jack", category=category2, schoolclass=class2, gpa=3.8 + ) @pytest.mark.asyncio @@ -70,10 +80,14 @@ async def test_quotes_left_join(): async with database.transaction(force_rollback=True): await create_data() students = await Student.objects.filter( - (Student.schoolclass.name == "Math") | (Student.category.name == "Foreign") + (Student.schoolclass.name == "Math") + | (Student.category.name == "Foreign") ).all() for student in students: - assert student.schoolclass.name == "Math" or student.category.name == "Foreign" + assert ( + student.schoolclass.name == "Math" + or student.category.name == "Foreign" + ) @pytest.mark.asyncio @@ -92,8 +106,9 @@ async def test_quotes_deep_join(): async with database: async with database.transaction(force_rollback=True): await create_data() - schoolclasses = await SchoolClass.objects.filter(students__category__name="Domestic").all() + schoolclasses = await SchoolClass.objects.filter( + students__category__name="Domestic" + ).all() for schoolclass in schoolclasses: for student in schoolclass.students: assert student.category.name == "Domestic" - diff --git a/tests/test_model_methods/test_populate_onupdate_values.py b/tests/test_model_methods/test_populate_onupdate_values.py index dce8f6051..0830c4fc7 100644 --- a/tests/test_model_methods/test_populate_onupdate_values.py +++ b/tests/test_model_methods/test_populate_onupdate_values.py @@ -1,6 +1,9 @@ +from datetime import datetime + import databases import pytest import sqlalchemy +from sqlalchemy import func import ormar from tests.settings import DATABASE_URL @@ -20,13 +23,14 @@ class Meta(BaseMeta): id: int = ormar.Integer(primary_key=True) name: str = ormar.String( - max_length=255, on_update=lambda: "hello", - ) - age: int = ormar.Integer() - points: int = ormar.Integer( - default=0, minimum=0, on_update=1 + max_length=255, + on_update=lambda: "hello", ) + points: int = ormar.Integer(default=0, minimum=0, on_update=1) year = ormar.Integer(on_update=2, default=1) + updated_at: datetime = ormar.DateTime( + default=datetime.now, server_default=func.now(), on_update=datetime.now + ) @pytest.fixture(autouse=True, scope="module") @@ -39,61 +43,84 @@ def create_test_database(): @pytest.mark.asyncio -async def test_populate_onupdate_values(): +async def test_onupdate_use_setattr_to_update(): + async with database: + t1 = await Task.objects.create(name="123") + assert t1.name == "123" + assert t1.points == 0 + assert t1.year == 1 + + t2 = await Task.objects.get(name="123") + t2.name = "hello" + t2.year = 2024 + await t2.update() + assert t2.name == "hello" + assert t2.points == 1 + assert t2.year == 2024 + assert t2.updated_at > t1.updated_at + + +@pytest.mark.asyncio +async def test_onupdate_use_update_func_kwargs(): + async with database: + t1 = await Task.objects.create(name="123") + assert t1.name == "123" + assert t1.points == 0 + assert t1.year == 1 + + t2 = await Task.objects.get(name="123") + await t2.update(name="hello") + assert t2.name == "hello" + assert t2.points == 1 + assert t2.year == 2 + assert t2.updated_at > t1.updated_at + + +@pytest.mark.asyncio +async def test_onupdate_use_update_func_columns(): + async with database: + t1 = await Task.objects.create(name="123") + assert t1.name == "123" + assert t1.points == 0 + assert t1.year == 1 + + t2 = await Task.objects.get(name="123") + await t2.update(_columns=["year"], year=2024) + assert t2.name == "hello" + assert t2.points == 1 + assert t2.year == 2024 + assert t2.updated_at > t1.updated_at + + +@pytest.mark.asyncio +async def test_onupdate_queryset_update(): + async with database: + t1 = await Task.objects.create(name="123") + assert t1.name == "123" + assert t1.points == 0 + assert t1.year == 1 + + await Task.objects.filter(name="123").update(name="hello") + t2 = await Task.objects.get(name="hello") + assert t2.name == "hello" + assert t2.points == 1 + assert t2.year == 2 + assert t2.updated_at > t1.updated_at + + +@pytest.mark.asyncio +async def test_onupdate_bulk_update(): async with database: - task = Task(name="123", age=1, points=1) - task2 = Task(name="123", age=1, points=1) - await task.save() - await task2.save() - - assert task.year == 1 - assert task2.year == 1 - - await task.update(age=2) - await task2.update(_columns=["age"], age=3) - - t = await Task.objects.filter(age=2).first() - assert t.name == "hello" - assert t.points == 1 - assert t.year == 2 - - t = await Task.objects.filter(age=3).first() - assert t.name == "hello" - assert t.points == 1 - assert t.year == 2 - - await task.update(points=3) - t = await Task.objects.get_or_none(id=task.id) - assert t.points == 3 - - -# @pytest.mark.asyncio -# async def test_bulk_update_populate_onupdate_values(): -# async with database: -# task1 = await Task(name="123", age=1, points=2).save() -# task2 = await Task(name="123", age=2, points=3).save() -# task3 = await Task(name="345", age=5, points=4).save() -# task4 = await Task(name="345", age=6, points=5).save() -# -# tasks = [task1, task2] -# -# tasks_ = [task3, task4] -# -# for task in tasks_ + tasks: -# task.age += 1 -# -# await Task.objects.bulk_update(tasks) -# assert 2 == await Task.objects.filter(Task.age <= 3).count() -# -# for task in await Task.objects.filter(Task.age <= 3).all(): -# assert task.points == 1 -# assert task.name == "hello" -# assert task.year == 2 -# -# await Task.objects.bulk_update(tasks_, columns=["age"]) -# -# assert 2 == await Task.objects.filter(Task.age > 3).count() -# for task in await Task.objects.filter(Task.age > 3).all(): -# assert task.points == 1 -# assert task.name == "hello" -# assert task.year == 2 + t1 = await Task.objects.create(name="123") + assert t1.name == "123" + assert t1.points == 0 + assert t1.year == 1 + + t2 = await Task.objects.get(name="123") + t2.name = "bulk_update" + await Task.objects.bulk_update([t2]) + t3 = await Task.objects.get(name="bulk_update") + assert t3.name == "bulk_update" + assert t3.points == 1 + assert t3.year == 2 + assert t3.updated_at > t2.updated_at From b79b552e41e864230092c4e41613d48ae61d3f6c Mon Sep 17 00:00:00 2001 From: vvanglro Date: Tue, 27 Feb 2024 17:03:13 +0800 Subject: [PATCH 18/22] fix: bulk_update has columns --- ormar/queryset/queryset.py | 7 ++++++- tests/test_model_methods/test_populate_onupdate_values.py | 7 +++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 44df41f30..5bd9a7c33 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -1191,7 +1191,12 @@ async def bulk_update( # noqa: CCR001 if pk_name not in columns: columns.append(pk_name) - columns = [self.model.get_column_alias(k) for k in columns] + columns = {self.model.get_column_alias(k) for k in columns} + on_update_fields = { + self.model.get_column_alias(k) + for k in cast(Type["Model"], self.model_cls).get_fields_with_onupdate() + } + columns |= on_update_fields for obj in objects: new_kwargs = obj.dict() diff --git a/tests/test_model_methods/test_populate_onupdate_values.py b/tests/test_model_methods/test_populate_onupdate_values.py index 0830c4fc7..60a32616d 100644 --- a/tests/test_model_methods/test_populate_onupdate_values.py +++ b/tests/test_model_methods/test_populate_onupdate_values.py @@ -124,3 +124,10 @@ async def test_onupdate_bulk_update(): assert t3.points == 1 assert t3.year == 2 assert t3.updated_at > t2.updated_at + + t4 = await Task.objects.get(name="bulk_update") + t4.year = 2024 + await Task.objects.bulk_update([t4], columns=["year"]) + t5 = await Task.objects.get(name="hello") + assert t5.year == 2024 + assert t5.points == 1 From e810644d1e8f1a47ebb4d065ecde0660f1ff32c4 Mon Sep 17 00:00:00 2001 From: vvanglro Date: Tue, 27 Feb 2024 17:03:34 +0800 Subject: [PATCH 19/22] docs: update --- README.md | 2 +- docs/fields/common-parameters.md | 25 +++++++++++++++---------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 1c71dc9be..52babd336 100644 --- a/README.md +++ b/README.md @@ -655,7 +655,7 @@ The following keyword arguments are supported on all field types. * `choices: typing.Sequence` * `name: str` * `pydantic_only: bool` -* `onupdate: Any/callable` +* `on_update: Any/callable` All fields are required unless one of the following is set: diff --git a/docs/fields/common-parameters.md b/docs/fields/common-parameters.md index 78c1eed8b..e2613d351 100644 --- a/docs/fields/common-parameters.md +++ b/docs/fields/common-parameters.md @@ -216,10 +216,10 @@ class OverwriteTest(ormar.Model): `choices`: `Sequence` = `[]` -## onupdate +## on_update -when the object update or bulk_update, if you don't update the field which has the onupdate option, -its value will be changed from `onupdate definition` +when the object update or bulk_update, if you don't update the field which has the on_update option, +its value will be changed from `on_update definition` ```python @@ -230,16 +230,21 @@ class ToDo(ormar.Model): database = database id: int = ormar.Integer(primary_key=True) - name: str = ormar.String(max_length=32) - my_timestamp: int = ormar.Integer(onupdate=time.time) - is_dirty: bool = ormar.Boolean(default=False, onupdate=True) - - + name: str = ormar.String( + max_length=255, + on_update=lambda: "hello", + ) + is_dirty: bool = ormar.Boolean(default=False, on_update=True) + updated_at: datetime = ormar.DateTime( + default=datetime.now, server_default=func.now(), on_update=datetime.now + ) + +await ToDo.objects.create(name="test") todo = await ToDo.objects.get(id=1) -await todo.update(name="test") +await todo.update() assert todo.is_dirty -assert todo.my_timestamp == now +assert todo.name == "hello" ``` From 20e90edc0c70fcbaa9d9b5057515cb3c42d73b81 Mon Sep 17 00:00:00 2001 From: vvanglro Date: Tue, 27 Feb 2024 17:07:28 +0800 Subject: [PATCH 20/22] docs: update --- ormar/models/mixins/save_mixin.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ormar/models/mixins/save_mixin.py b/ormar/models/mixins/save_mixin.py index 78cb564cf..a126acdf6 100644 --- a/ormar/models/mixins/save_mixin.py +++ b/ormar/models/mixins/save_mixin.py @@ -251,6 +251,8 @@ def populate_onupdate_value(cls, new_kwargs: Dict, obj: "T" = None) -> Dict: :param new_kwargs: dictionary of model that is about to be saved :type new_kwargs: Dict + :param obj: ormar models + :type obj: Model :return: dictionary of model that is about to be saved :rtype: Dict """ From 06bc1b207551b331ce115977649a7e9ce0fe37a3 Mon Sep 17 00:00:00 2001 From: vvanglro Date: Tue, 27 Feb 2024 17:55:08 +0800 Subject: [PATCH 21/22] fix: test case --- ormar/models/newbasemodel.py | 2 ++ tests/test_model_methods/test_populate_onupdate_values.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 8bc43ebb2..bdb940bb1 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -212,6 +212,8 @@ def __getattr__(self, item: str) -> Any: :return: Any :rtype: Any """ + if item == "__setattr_fields__": + return set() return super().__getattribute__(item) def __getstate__(self) -> Dict[Any, Any]: diff --git a/tests/test_model_methods/test_populate_onupdate_values.py b/tests/test_model_methods/test_populate_onupdate_values.py index 60a32616d..674edfd01 100644 --- a/tests/test_model_methods/test_populate_onupdate_values.py +++ b/tests/test_model_methods/test_populate_onupdate_values.py @@ -1,3 +1,4 @@ +import asyncio from datetime import datetime import databases @@ -118,6 +119,7 @@ async def test_onupdate_bulk_update(): t2 = await Task.objects.get(name="123") t2.name = "bulk_update" + await asyncio.sleep(0.1) await Task.objects.bulk_update([t2]) t3 = await Task.objects.get(name="bulk_update") assert t3.name == "bulk_update" @@ -127,7 +129,9 @@ async def test_onupdate_bulk_update(): t4 = await Task.objects.get(name="bulk_update") t4.year = 2024 + await asyncio.sleep(0.1) await Task.objects.bulk_update([t4], columns=["year"]) t5 = await Task.objects.get(name="hello") assert t5.year == 2024 assert t5.points == 1 + assert t5.updated_at > t4.updated_at From fdc41a4ddb38a80f0d8c3f982ac54e3986dbe831 Mon Sep 17 00:00:00 2001 From: vvanglro Date: Wed, 28 Feb 2024 11:33:22 +0800 Subject: [PATCH 22/22] fix: move __setattr_fields__ to _initialize_internal_attributes --- ormar/models/newbasemodel.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index bdb940bb1..5d0b83965 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -171,7 +171,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # type: ignore if hasattr(self, "_init_private_attributes"): # introduced in pydantic 1.7 self._init_private_attributes() - object.__setattr__(self, "__setattr_fields__", set()) def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001 """ @@ -212,8 +211,6 @@ def __getattr__(self, item: str) -> Any: :return: Any :rtype: Any """ - if item == "__setattr_fields__": - return set() return super().__getattribute__(item) def __getstate__(self) -> Dict[Any, Any]: @@ -381,6 +378,7 @@ def _initialize_internal_attributes(self) -> None: :rtype: None """ # object.__setattr__(self, "_orm_id", uuid.uuid4().hex) + object.__setattr__(self, "__setattr_fields__", set()) object.__setattr__(self, "_orm_saved", False) object.__setattr__(self, "_pk_column", None) object.__setattr__(