From bd4be5e67de5f31e9336ba0fdcd545e88d70b954 Mon Sep 17 00:00:00 2001 From: Lee Dong Wook Date: Sat, 2 Nov 2024 23:07:32 +0900 Subject: [PATCH] gh-126317: Simplify pickle code by using itertools.batched() (GH-126323) --- Lib/pickle.py | 61 +++++++++++++++++++-------------------------------- 1 file changed, 22 insertions(+), 39 deletions(-) diff --git a/Lib/pickle.py b/Lib/pickle.py index ed8138beb908ee..965e1952fb8c5e 100644 --- a/Lib/pickle.py +++ b/Lib/pickle.py @@ -26,7 +26,7 @@ from types import FunctionType from copyreg import dispatch_table from copyreg import _extension_registry, _inverted_registry, _extension_cache -from itertools import islice +from itertools import batched from functools import partial import sys from sys import maxsize @@ -1033,31 +1033,26 @@ def _batch_appends(self, items, obj): write(APPEND) return - it = iter(items) start = 0 - while True: - tmp = list(islice(it, self._BATCHSIZE)) - n = len(tmp) - if n > 1: + for batch in batched(items, self._BATCHSIZE): + batch_len = len(batch) + if batch_len != 1: write(MARK) - for i, x in enumerate(tmp, start): + for i, x in enumerate(batch, start): try: save(x) except BaseException as exc: exc.add_note(f'when serializing {_T(obj)} item {i}') raise write(APPENDS) - elif n: + else: try: - save(tmp[0]) + save(batch[0]) except BaseException as exc: exc.add_note(f'when serializing {_T(obj)} item {start}') raise write(APPEND) - # else tmp is empty, and we're done - if n < self._BATCHSIZE: - return - start += n + start += batch_len def save_dict(self, obj): if self.bin: @@ -1086,13 +1081,10 @@ def _batch_setitems(self, items, obj): write(SETITEM) return - it = iter(items) - while True: - tmp = list(islice(it, self._BATCHSIZE)) - n = len(tmp) - if n > 1: + for batch in batched(items, self._BATCHSIZE): + if len(batch) != 1: write(MARK) - for k, v in tmp: + for k, v in batch: save(k) try: save(v) @@ -1100,8 +1092,8 @@ def _batch_setitems(self, items, obj): exc.add_note(f'when serializing {_T(obj)} item {k!r}') raise write(SETITEMS) - elif n: - k, v = tmp[0] + else: + k, v = batch[0] save(k) try: save(v) @@ -1109,9 +1101,6 @@ def _batch_setitems(self, items, obj): exc.add_note(f'when serializing {_T(obj)} item {k!r}') raise write(SETITEM) - # else tmp is empty, and we're done - if n < self._BATCHSIZE: - return def save_set(self, obj): save = self.save @@ -1124,21 +1113,15 @@ def save_set(self, obj): write(EMPTY_SET) self.memoize(obj) - it = iter(obj) - while True: - batch = list(islice(it, self._BATCHSIZE)) - n = len(batch) - if n > 0: - write(MARK) - try: - for item in batch: - save(item) - except BaseException as exc: - exc.add_note(f'when serializing {_T(obj)} element') - raise - write(ADDITEMS) - if n < self._BATCHSIZE: - return + for batch in batched(obj, self._BATCHSIZE): + write(MARK) + try: + for item in batch: + save(item) + except BaseException as exc: + exc.add_note(f'when serializing {_T(obj)} element') + raise + write(ADDITEMS) dispatch[set] = save_set def save_frozenset(self, obj):