-
Notifications
You must be signed in to change notification settings - Fork 99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add support for bulk_update #148
base: master
Are you sure you want to change the base?
Changes from 5 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
import enum | ||
import json | ||
import typing | ||
|
||
import databases | ||
|
@@ -20,6 +22,8 @@ | |
"lte": "__le__", | ||
} | ||
|
||
MODEL = typing.TypeVar("MODEL", bound="Model") | ||
|
||
|
||
def _update_auto_now_fields(values, fields): | ||
for key, value in fields.items(): | ||
|
@@ -28,6 +32,15 @@ def _update_auto_now_fields(values, fields): | |
return values | ||
|
||
|
||
def _convert_value(value): | ||
if isinstance(value, dict): | ||
return json.dumps(value) | ||
elif isinstance(value, enum.Enum): | ||
return value.name | ||
else: | ||
return value | ||
|
||
|
||
class ModelRegistry: | ||
def __init__(self, database: databases.Database) -> None: | ||
self.database = database | ||
|
@@ -454,6 +467,41 @@ async def update(self, **kwargs) -> None: | |
|
||
await self.database.execute(expr) | ||
|
||
async def bulk_update( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry I should've noticed this earlier, apologies for that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I agree with you it needs to be more readable There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @aminalaee Any updates ? |
||
self, objs: typing.List[MODEL], fields: typing.List[str] | ||
) -> None: | ||
fields = { | ||
key: field.validator | ||
for key, field in self.model_cls.fields.items() | ||
if key in fields | ||
} | ||
validator = typesystem.Schema(fields=fields) | ||
new_objs = [ | ||
_update_auto_now_fields(validator.validate(value), self.model_cls.fields) | ||
for value in [ | ||
{ | ||
key: _convert_value(value) | ||
for key, value in obj.__dict__.items() | ||
if key in fields | ||
} | ||
for obj in objs | ||
] | ||
] | ||
expr = ( | ||
self.table.update() | ||
.where(self.table.c.id == sqlalchemy.bindparam(self.pkname)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to get primary key column dynamically. The |
||
.values( | ||
{ | ||
field: sqlalchemy.bindparam(field) | ||
for obj in new_objs | ||
for field in obj.keys() | ||
} | ||
) | ||
) | ||
pk_list = [{self.pkname: obj.pk} for obj in objs] | ||
joined_list = [{**pk, **value} for pk, value in zip(pk_list, new_objs)] | ||
await self.database.execute_many(str(expr), joined_list) | ||
|
||
async def get_or_create( | ||
self, defaults: typing.Dict[str, typing.Any], **kwargs | ||
) -> typing.Tuple[typing.Any, bool]: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -278,3 +278,22 @@ async def test_nullable_foreign_key(): | |
|
||
assert member.email == "[email protected]" | ||
assert member.team.pk is None | ||
|
||
|
||
async def test_bulk_update_with_relation(): | ||
album = await Album.objects.create(name="foo") | ||
album2 = await Album.objects.create(name="bar") | ||
|
||
await Track.objects.bulk_create( | ||
[ | ||
{"name": "foo", "album": album, "position": 1, "title": "foo"}, | ||
{"name": "bar", "album": album, "position": 2, "title": "bar"}, | ||
] | ||
) | ||
tracks = await Track.objects.all() | ||
for track in tracks: | ||
track.album = album2 | ||
await Track.objects.bulk_update(tracks, fields=["album"]) | ||
tracks = await Track.objects.all() | ||
assert tracks[0].album.pk == album2.pk | ||
assert tracks[1].album.pk == album2.pk |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What value does this bring? I mean we could call bulk_update with
Model
itself. right?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean? We use it as a type annotation for obj
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I understand. I mean you've done this:
What would be the difference if we did:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case Model will be undefined because it has been defined after bulk_update
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe "Model" ?