Skip to content

Commit

Permalink
Merge pull request #2 from fopina/yield
Browse files Browse the repository at this point in the history
yield update/created objects (optional)
  • Loading branch information
fopina authored Jul 23, 2020
2 parents d9e498f + 272ad5e commit 332c0d6
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 6 deletions.
37 changes: 32 additions & 5 deletions bulk_update_or_create/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,39 @@ def bulk_update_or_create(
match_field='pk',
batch_size=None,
case_insensitive_match=False,
yield_objects=False,
):
"""
:param objs:
:param match_field:
:param update_fields:
:param batch_size:
:param case_insensitive_match:
:param objs: model instances to be updated or created
:param update_fields: fields that will be updated if record already exists (passed on to bulk_update)
:param match_field: model field that will match existing records (defaults to "pk")
:param batch_size: number of records to process in each batch (defaults to len(objs))
:param case_insensitive_match: set to True if using MySQL with "ci" collations (defaults to False)
:param yield_objects: if True, method becomes a generator that will yield a tuple of lists with ([created], [updated]) objects
"""

r = self.__bulk_update_or_create(
objs,
update_fields,
match_field,
batch_size,
case_insensitive_match,
yield_objects,
)
if yield_objects:
return r
return list(r)

def __bulk_update_or_create(
self,
objs,
update_fields,
match_field='pk',
batch_size=None,
case_insensitive_match=False,
yield_objects=False,
):
if not objs:
raise ValueError('no objects to update_or_create...')
if not update_fields:
Expand Down Expand Up @@ -65,8 +88,12 @@ def _cased_key(obj): # no-op
self.bulk_update(to_update, update_fields)

# .create on the remaining (bulk_create won't work on multi-table inheritance models...)
created_objs = []
for obj in obj_map.values():
obj.save()
created_objs.append(obj)
if yield_objects:
yield created_objs, to_update


class BulkUpdateOrCreateQuerySet(BulkUpdateOrCreateMixin, models.QuerySet):
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = django-bulk-update-or-create
version = 0.1.2
version = 0.1.3
description = bulk_update_or_create for Django model managers
long_description = file: README.md
long_description_content_type = text/markdown
Expand Down
33 changes: 33 additions & 0 deletions tests/tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,39 @@ def test_update_some(self):
list(range(5)) + list(range(10, 20)),
)

def test_update_some_generator(self):
self.test_all_create()
items = [RandomData(uuid=i + 5, data=i + 10) for i in range(10)]
updated_items = RandomData.objects.bulk_update_or_create(
items, ['data'], match_field='uuid', yield_objects=True
)
# not executed yet, just generator
self.assertEqual(RandomData.objects.count(), 10)
updated_items = list(updated_items)
self.assertEqual(RandomData.objects.count(), 15)
self.assertEqual(
sorted(int(x.data) for x in RandomData.objects.all()),
list(range(5)) + list(range(10, 20)),
)
# one batch
self.assertEqual(len(updated_items), 1)
# tuple with (created, updated)
self.assertEqual(len(updated_items[0]), 2)
# 5 were created - 15 to 19
self.assertEqual(len(updated_items[0][0]), 5)
self.assertEqual(
sorted(int(x.data) for x in updated_items[0][0]), list(range(15, 20)),
)
for x in updated_items[0][0]:
self.assertIsNotNone(x.pk)
# 5 were updated - 10 to 14 (from 5 to 9)
self.assertEqual(len(updated_items[0][1]), 5)
self.assertEqual(
sorted(int(x.data) for x in updated_items[0][1]), list(range(10, 15)),
)
for x in updated_items[0][1]:
self.assertIsNotNone(x.pk)

def test_errors(self):
with self.assertRaises(ValueError) as cm:
RandomData.objects.bulk_update_or_create([], [])
Expand Down

0 comments on commit 332c0d6

Please sign in to comment.