From 272ad5efdb74ef725555f1b284b0e87cb37f632b Mon Sep 17 00:00:00 2001 From: Pina Date: Thu, 23 Jul 2020 01:00:43 +0100 Subject: [PATCH] yield update/created objects (optional) --- bulk_update_or_create/query.py | 37 +++++++++++++++++++++++++++++----- setup.cfg | 2 +- tests/tests/tests.py | 33 ++++++++++++++++++++++++++++++ 3 files changed, 66 insertions(+), 6 deletions(-) diff --git a/bulk_update_or_create/query.py b/bulk_update_or_create/query.py index 3b280f7..7cb1ac1 100644 --- a/bulk_update_or_create/query.py +++ b/bulk_update_or_create/query.py @@ -9,16 +9,39 @@ def bulk_update_or_create( match_field='pk', batch_size=None, case_insensitive_match=False, + yield_objects=False, ): """ - :param objs: - :param match_field: - :param update_fields: - :param batch_size: - :param case_insensitive_match: + :param objs: model instances to be updated or created + :param update_fields: fields that will be updated if record already exists (passed on to bulk_update) + :param match_field: model field that will match existing records (defaults to "pk") + :param batch_size: number of records to process in each batch (defaults to len(objs)) + :param case_insensitive_match: set to True if using MySQL with "ci" collations (defaults to False) + :param yield_objects: if True, method becomes a generator that will yield a tuple of lists with ([created], [updated]) objects """ + r = self.__bulk_update_or_create( + objs, + update_fields, + match_field, + batch_size, + case_insensitive_match, + yield_objects, + ) + if yield_objects: + return r + return list(r) + + def __bulk_update_or_create( + self, + objs, + update_fields, + match_field='pk', + batch_size=None, + case_insensitive_match=False, + yield_objects=False, + ): if not objs: raise ValueError('no objects to update_or_create...') if not update_fields: @@ -65,8 +88,12 @@ def _cased_key(obj): # no-op self.bulk_update(to_update, update_fields) # .create on the remaining (bulk_create won't work on multi-table inheritance models...) + created_objs = [] for obj in obj_map.values(): obj.save() + created_objs.append(obj) + if yield_objects: + yield created_objs, to_update class BulkUpdateOrCreateQuerySet(BulkUpdateOrCreateMixin, models.QuerySet): diff --git a/setup.cfg b/setup.cfg index b208fad..d2371bc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = django-bulk-update-or-create -version = 0.1.2 +version = 0.1.3 description = bulk_update_or_create for Django model managers long_description = file: README.md long_description_content_type = text/markdown diff --git a/tests/tests/tests.py b/tests/tests/tests.py index 35dcf58..b4693fa 100644 --- a/tests/tests/tests.py +++ b/tests/tests/tests.py @@ -24,6 +24,39 @@ def test_update_some(self): list(range(5)) + list(range(10, 20)), ) + def test_update_some_generator(self): + self.test_all_create() + items = [RandomData(uuid=i + 5, data=i + 10) for i in range(10)] + updated_items = RandomData.objects.bulk_update_or_create( + items, ['data'], match_field='uuid', yield_objects=True + ) + # not executed yet, just generator + self.assertEqual(RandomData.objects.count(), 10) + updated_items = list(updated_items) + self.assertEqual(RandomData.objects.count(), 15) + self.assertEqual( + sorted(int(x.data) for x in RandomData.objects.all()), + list(range(5)) + list(range(10, 20)), + ) + # one batch + self.assertEqual(len(updated_items), 1) + # tuple with (created, updated) + self.assertEqual(len(updated_items[0]), 2) + # 5 were created - 15 to 19 + self.assertEqual(len(updated_items[0][0]), 5) + self.assertEqual( + sorted(int(x.data) for x in updated_items[0][0]), list(range(15, 20)), + ) + for x in updated_items[0][0]: + self.assertIsNotNone(x.pk) + # 5 were updated - 10 to 14 (from 5 to 9) + self.assertEqual(len(updated_items[0][1]), 5) + self.assertEqual( + sorted(int(x.data) for x in updated_items[0][1]), list(range(10, 15)), + ) + for x in updated_items[0][1]: + self.assertIsNotNone(x.pk) + def test_errors(self): with self.assertRaises(ValueError) as cm: RandomData.objects.bulk_update_or_create([], [])