diff --git a/storey/flow.py b/storey/flow.py index f161e232..22fccd57 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -840,9 +840,12 @@ def _init(self): self._lazy_init_complete = False async def _worker(self): - event = None try: while True: + # Allow event to be garbage collected + job = None # noqa + event = None + completed = None # noqa try: # If we don't handle the event before we remove it from the queue, the effective max_in_flight will # be 1 higher than requested. Hence, we peek. diff --git a/tests/test_flow.py b/tests/test_flow.py index a596219c..07263f92 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -71,6 +71,7 @@ build_flow, ) from storey.flow import ( + ConcurrentExecution, Context, ParallelExecution, ParallelExecutionRunnable, @@ -360,6 +361,50 @@ def test_async_offset_commit_before_termination_with_nosqltarget(): asyncio.run(async_offset_commit_before_termination_with_nosqltarget()) +async def async_offset_commit_before_termination_with_concurrent_execution(): + platform = Committer() + context = CommitterContext(platform) + + max_wait_before_commit = 1 + + controller = build_flow( + [ + AsyncEmitSource(context=context, explicit_ack=True, max_wait_before_commit=max_wait_before_commit), + ConcurrentExecution(event_processor=lambda x: x + 1), + Filter(lambda x: x < 3), + FlatMap(lambda x: [x, x * 10]), + Reduce(0, lambda acc, x: acc + x), + ] + ).run() + + num_shards = 10 + num_records_per_shard = 10 + + for offset in range(1, num_records_per_shard + 1): + for shard in range(num_shards): + event = Event(shard) + event.shard_id = shard + event.offset = offset + await controller.emit(event) + + del event + + await asyncio.sleep(max_wait_before_commit + 1) + + try: + offsets = copy.copy(platform.offsets) + assert offsets == {("/", i): num_records_per_shard for i in range(num_shards)} + finally: + await controller.terminate() + termination_result = await controller.await_termination() + assert termination_result == 330 + + +# ML-8799 +def test_async_offset_commit_before_termination_with_concurrent_execution(): + asyncio.run(async_offset_commit_before_termination_with_concurrent_execution()) + + def test_offset_not_committed_prematurely(): platform = Committer() context = CommitterContext(platform)