From eaad323320ba8951bd17d9154a1d0adcf3e5c9d0 Mon Sep 17 00:00:00 2001 From: Joan Fontanals Date: Wed, 25 Sep 2024 16:57:35 +0200 Subject: [PATCH] test: add extra test for dyn batching (#6205) --- .../dynamic_batching/test_dynamic_batching.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/integration/dynamic_batching/test_dynamic_batching.py b/tests/integration/dynamic_batching/test_dynamic_batching.py index 0e42785d1b8be..018d50e381626 100644 --- a/tests/integration/dynamic_batching/test_dynamic_batching.py +++ b/tests/integration/dynamic_batching/test_dynamic_batching.py @@ -289,6 +289,45 @@ def test_timeout_no_use(add_parameters, use_stream, use_dynamic_batching): assert time_taken < 2 +@pytest.mark.asyncio +@pytest.mark.parametrize('use_custom_metric', [False, True]) +@pytest.mark.parametrize('use_dynamic_batching', [False, True]) +async def test_timeout_no_use_custom(use_dynamic_batching, use_custom_metric): + class TextUseCustomDynBatch(Executor): + @requests(on='/foo') + @dynamic_batching(custom_metric=lambda d: len(d.text)) + def fun(self, docs, **kwargs): + if use_custom_metric: + self.logger.debug(f'Received {len(docs)} in "/foo" call with with custom metric and sum of text lengths? {sum([len(d.text) for d in docs])}') + else: + self.logger.debug( + f'Received {len(docs)} in "/foo" call with sum of text lengths? {sum([len(d.text) for d in docs])}') + time.sleep(1) + for doc in docs: + doc.text += FOO_SUCCESS_MSG + + d = Deployment(uses=TextUseCustomDynBatch, uses_dynamic_batching={'/foo': {'timeout': 2000, "preferred_batch_size": 10, 'use_dynamic_batching': use_dynamic_batching, 'use_custom_metric': use_custom_metric}}) + with d: + start_time = time.time() + inputs = DocumentArray([Document(text='ab') for _ in range(8)]) + client = Client(port=d.port, asyncio=True, protocol=d.protocol) + async for _ in client.post('/foo', inputs=inputs, request_size=1): + pass + time_taken = time.time() - start_time + if not use_dynamic_batching: + # in this case it should simply call once for each + assert time_taken > 8, 'Timeout ended too fast' + assert time_taken < 8 + TIMEOUT_TOLERANCE, 'Timeout ended too slowly' + elif not use_custom_metric: + # in this case it should accumulate all in 2 seconds, and spend only 1 second inside call + assert time_taken > 3, 'Timeout ended too fast' + assert time_taken < 3 + TIMEOUT_TOLERANCE, 'Timeout ended too slowly' + elif use_custom_metric: + # in this case it should accumulate all before 2 seconds, and divide the call in 2 calls + assert time_taken > 2, 'Timeout ended too fast' + assert time_taken < 2 + TIMEOUT_TOLERANCE, 'Timeout ended too slowly' + + @pytest.mark.parametrize( 'add_parameters', [