Skip to content

Commit

Permalink
fix: dyn batching configs (#6204)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM authored Sep 25, 2024
1 parent 450553a commit b1139bc
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 8 deletions.
27 changes: 19 additions & 8 deletions jina/serve/runtimes/worker/request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ def _init_batchqueue_dict(self):
if getattr(self._executor, 'dynamic_batching', None) is not None:
# We need to sort the keys into endpoints and functions
# Endpoints allow specific configurations while functions allow configs to be applied to all endpoints of the function
self.logger.debug(
f'Executor Dynamic Batching configs: {self._executor.dynamic_batching}'
)
dbatch_endpoints = []
dbatch_functions = []
request_models_map = self._executor._get_endpoint_models_dict()
Expand All @@ -275,11 +278,10 @@ def _init_batchqueue_dict(self):
)
raise Exception(error_msg)

if dbatch_config.get("use_dynamic_batching", True):
if key.startswith('/'):
dbatch_endpoints.append((key, dbatch_config))
else:
dbatch_functions.append((key, dbatch_config))
if key.startswith('/'):
dbatch_endpoints.append((key, dbatch_config))
else:
dbatch_functions.append((key, dbatch_config))

# Specific endpoint configs take precedence over function configs
for endpoint, dbatch_config in dbatch_endpoints:
Expand All @@ -295,10 +297,19 @@ def _init_batchqueue_dict(self):
for endpoint in func_endpoints[func_name]:
if endpoint not in self._batchqueue_config:
self._batchqueue_config[endpoint] = dbatch_config
else:
# we need to eventually copy the `custom_metric`
if dbatch_config.get('custom_metric', None) is not None:
self._batchqueue_config[endpoint]['custom_metric'] = dbatch_config.get('custom_metric')

keys_to_remove = []
for k, batch_config in self._batchqueue_config.items():
if not batch_config.get('use_dynamic_batching', True):
keys_to_remove.append(k)

for k in keys_to_remove:
self._batchqueue_config.pop(k)

self.logger.debug(
f'Executor Dynamic Batching configs: {self._executor.dynamic_batching}'
)
self.logger.debug(
f'Endpoint Batch Queue Configs: {self._batchqueue_config}'
)
Expand Down
45 changes: 45 additions & 0 deletions tests/integration/dynamic_batching/test_dynamic_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,51 @@ def test_timeout(add_parameters, use_stream):
assert time_taken < 2 + TIMEOUT_TOLERANCE, 'Timeout ended too slowly'


@pytest.mark.parametrize(
'add_parameters',
[
{
'uses': PlaceholderExecutorWrongDecorator,
'uses_dynamic_batching': USES_DYNAMIC_BATCHING_PLACE_HOLDER_EXECUTOR,
}
],
)
@pytest.mark.parametrize('use_stream', [False, True])
@pytest.mark.parametrize('use_dynamic_batching', [False, True])
def test_timeout_no_use(add_parameters, use_stream, use_dynamic_batching):
for k, v in add_parameters["uses_dynamic_batching"].items():
v["use_dynamic_batching"] = use_dynamic_batching
f = Flow().add(**add_parameters)
with f:
start_time = time.time()
f.post('/bar', inputs=DocumentArray.empty(2), stream=use_stream)
time_taken = time.time() - start_time
if use_dynamic_batching:
assert time_taken > 2, 'Timeout ended too fast'
assert time_taken < 2 + TIMEOUT_TOLERANCE, 'Timeout ended too slowly'
else:
assert time_taken < 2

with mp.Pool(3) as p:
start_time = time.time()
list(
p.map(
call_api,
[
RequestStruct(f.port, '/bar', range(1), use_stream),
RequestStruct(f.port, '/bar', range(1), not use_stream),
RequestStruct(f.port, '/bar', range(1), use_stream),
],
)
)
time_taken = time.time() - start_time
if use_dynamic_batching:
assert time_taken > 2, 'Timeout ended too fast'
assert time_taken < 2 + TIMEOUT_TOLERANCE, 'Timeout ended too slowly'
else:
assert time_taken < 2


@pytest.mark.parametrize(
'add_parameters',
[
Expand Down

0 comments on commit b1139bc

Please sign in to comment.