From de2b4c88b8bd171afaed6d5a2a84c4b5b61ed04d Mon Sep 17 00:00:00 2001 From: tanmayv25 Date: Wed, 6 Sep 2023 19:52:55 -0700 Subject: [PATCH] Supporting client-side gRPC cancellation --- .../library/tritonclient/grpc/__init__.py | 2 +- .../library/tritonclient/grpc/_client.py | 61 +++++++++++++++---- .../tritonclient/grpc/_infer_stream.py | 35 ++++++----- .../library/tritonclient/grpc/_utils.py | 4 ++ .../library/tritonclient/grpc/aio/__init__.py | 51 ++++++++++++++-- 5 files changed, 119 insertions(+), 34 deletions(-) diff --git a/src/python/library/tritonclient/grpc/__init__.py b/src/python/library/tritonclient/grpc/__init__.py index 852d5f0d6..260d8147a 100755 --- a/src/python/library/tritonclient/grpc/__init__.py +++ b/src/python/library/tritonclient/grpc/__init__.py @@ -36,7 +36,7 @@ from ._infer_input import InferInput from ._infer_result import InferResult from ._requested_output import InferRequestedOutput - from ._utils import raise_error, raise_error_grpc + from ._utils import CancelledError, raise_error, raise_error_grpc except ModuleNotFoundError as error: raise RuntimeError( "The installation does not include grpc support. " diff --git a/src/python/library/tritonclient/grpc/_client.py b/src/python/library/tritonclient/grpc/_client.py index 1c1a63799..a8751b473 100755 --- a/src/python/library/tritonclient/grpc/_client.py +++ b/src/python/library/tritonclient/grpc/_client.py @@ -39,6 +39,7 @@ from ._infer_result import InferResult from ._infer_stream import _InferStream, _RequestIterator from ._utils import ( + CancelledError, _get_inference_request, _grpc_compression_type, get_error_grpc, @@ -1391,10 +1392,13 @@ def async_infer( callback : function Python function that is invoked once the request is completed. The function must reserve the last two arguments (result, error) - to hold InferResult and InferenceServerException objects - respectively which will be provided to the function when executing - the callback. The ownership of these objects will be given to the - user. The 'error' would be None for a successful inference. + to hold InferResult and InferenceServerException(or CancelledError) + objects respectively which will be provided to the function when + executing the callback. The ownership of these objects will be given + to the user. The 'error' would be None for a successful inference. + Note if the request is cancelled using the returned future object, + error provided to callback will be a CancelledError exception + object. model_version: str The version of the model to run inference. The default value is an empty string which means then the server will choose @@ -1451,6 +1455,26 @@ def async_infer( Optional custom parameters to be included in the inference request. + Returns + ------- + grpc.future + A representation of a computation in another control flow. + Computations represented by a Future may be yet to be begun, + may be ongoing, or may have already completed. + + This object can be used to cancel the inference request like + below: + ---------- + future = async_infer(...) + ret = future.cancel() + ---------- + See here for more details of future object: + https://grpc.github.io/grpc/python/grpc.html#grpc.Future + + The callback will be invoked with + (result=None, error=CancelledError) for the requests that + were successfully cancelled. + Raises ------ InferenceServerException @@ -1466,6 +1490,8 @@ def wrapped_callback(call_future): result = InferResult(response) except grpc.RpcError as rpc_error: error = get_error_grpc(rpc_error) + except grpc.FutureCancelledError: + error = CancelledError() callback(result=result, error=error) metadata = self._get_metadata(headers) @@ -1502,6 +1528,7 @@ def wrapped_callback(call_future): if request_id != "": verbose_message = verbose_message + " '{}'".format(request_id) print(verbose_message) + return self._call_future except grpc.RpcError as rpc_error: raise_error_grpc(rpc_error) @@ -1518,10 +1545,13 @@ def start_stream( Python function that is invoked upon receiving response from the underlying stream. The function must reserve the last two arguments (result, error) to hold InferResult and - InferenceServerException objects respectively which will be - provided to the function when executing the callback. The - ownership of these objects will be given to the user. The - 'error' would be None for a successful inference. + InferenceServerException(or CancelledError) objects respectively + which will be provided to the function when executing the callback. + The ownership of these objects will be given to the user. The 'error' + would be None for a successful inference. + Note if the stream is closed with cancel_requests set True, then + the error provided to callback will be a CancelledError object. + stream_timeout : float Optional stream timeout (in seconds). The stream will be closed once the specified timeout expires. @@ -1561,10 +1591,19 @@ def start_stream( except grpc.RpcError as rpc_error: raise_error_grpc(rpc_error) - def stop_stream(self): - """Stops a stream if one available.""" + def stop_stream(self, cancel_requests=False): + """Stops a stream if one available. + + Parameters + ---------- + cancel_requests : bool + If set True, then client cancels all the pending requests + and closes the stream. If set False, the call blocks till + all the pending requests on the stream are processed. + + """ if self._stream is not None: - self._stream.close() + self._stream.close(cancel_requests) self._stream = None def async_stream_infer( diff --git a/src/python/library/tritonclient/grpc/_infer_stream.py b/src/python/library/tritonclient/grpc/_infer_stream.py index fc9924067..692ff6450 100755 --- a/src/python/library/tritonclient/grpc/_infer_stream.py +++ b/src/python/library/tritonclient/grpc/_infer_stream.py @@ -33,7 +33,7 @@ from tritonclient.utils import * from ._infer_result import InferResult -from ._utils import get_error_grpc, raise_error +from ._utils import CancelledError, get_error_grpc, raise_error class _InferStream: @@ -57,18 +57,25 @@ def __init__(self, callback, verbose): self._verbose = verbose self._request_queue = queue.Queue() self._handler = None + self._cancelled = False self._active = True + self._response_iterator = None def __del__(self): self.close() - def close(self): + def close(self, cancel_requests=False): """Gracefully close underlying gRPC streams. Note that this call blocks till response of all currently enqueued requests are not received. """ + if cancel_requests and self._response_iterator: + self._response_iterator.cancel() + self._cancelled = True + if self._handler is not None: - self._request_queue.put(None) + if not self._cancelled: + self._request_queue.put(None) if self._handler.is_alive(): self._handler.join() if self._verbose: @@ -85,12 +92,11 @@ def _init_handler(self, response_iterator): The iterator over the gRPC response stream. """ + self._response_iterator = response_iterator if self._handler is not None: raise_error("Attempted to initialize already initialized InferStream") # Create a new thread to handle the gRPC response stream - self._handler = threading.Thread( - target=self._process_response, args=(response_iterator,) - ) + self._handler = threading.Thread(target=self._process_response) self._handler.start() if self._verbose: print("stream started...") @@ -129,19 +135,13 @@ def _get_request(self): request = self._request_queue.get() return request - def _process_response(self, responses): + def _process_response(self): """Worker thread function to iterate through the response stream and executes the provided callbacks. - Parameters - ---------- - responses : iterator - The iterator to the response from the server for the - requests in the stream. - """ try: - for response in responses: + for response in self._response_iterator: if self._verbose: print(response) result = error = None @@ -155,8 +155,11 @@ def _process_response(self, responses): # can still be used. The stream won't be closed here as the thread # executing this function is managed by stream and may cause # circular wait - self._active = responses.is_active() - error = get_error_grpc(rpc_error) + self._active = self._response_iterator.is_active() + if rpc_error.cancelled: + error = CancelledError() + else: + error = get_error_grpc(rpc_error) self._callback(result=None, error=error) diff --git a/src/python/library/tritonclient/grpc/_utils.py b/src/python/library/tritonclient/grpc/_utils.py index 4496a1981..2f75323b9 100755 --- a/src/python/library/tritonclient/grpc/_utils.py +++ b/src/python/library/tritonclient/grpc/_utils.py @@ -31,6 +31,10 @@ from tritonclient.utils import * +class CancelledError(Exception): + """Indicates that the issued operation was cancelled.""" + + def get_error_grpc(rpc_error): """Convert a gRPC error to an InferenceServerException. diff --git a/src/python/library/tritonclient/grpc/aio/__init__.py b/src/python/library/tritonclient/grpc/aio/__init__.py index fc5eaccdb..60ad4c127 100755 --- a/src/python/library/tritonclient/grpc/aio/__init__.py +++ b/src/python/library/tritonclient/grpc/aio/__init__.py @@ -586,8 +586,28 @@ async def infer( headers=None, compression_algorithm=None, parameters=None, + get_call_obj=False, ): - """Refer to tritonclient.grpc.InferenceServerClient""" + """Refer to tritonclient.grpc.InferenceServerClient + The additional parameters for this functions are + described below: + + Parameters + ---------- + get_call_obj : bool + If set True, then this function will yield + grpc.aio.call object first bfore the + InferResult. + This object can be used to issue request + cancellation if required. This can be attained + by following: + ------- + call = await client.infer(..., get_call_obj=True) + call.cancel() + ------- + + + """ metadata = self._get_metadata(headers) @@ -609,18 +629,20 @@ async def infer( ) if self._verbose: print("infer, metadata {}\n{}".format(metadata, request)) - try: - response = await self._client_stub.ModelInfer( + call = self._client_stub.ModelInfer( request=request, metadata=metadata, timeout=client_timeout, compression=_grpc_compression_type(compression_algorithm), ) + if get_call_obj: + yield call + response = await call if self._verbose: print(response) result = InferResult(response) - return result + yield result except grpc.RpcError as rpc_error: raise_error_grpc(rpc_error) @@ -630,6 +652,7 @@ async def stream_infer( stream_timeout=None, headers=None, compression_algorithm=None, + get_call_obj=False, ): """Runs an asynchronous inference over gRPC bi-directional streaming API. @@ -650,11 +673,23 @@ async def stream_infer( Optional grpc compression algorithm to be used on client side. Currently supports "deflate", "gzip" and None. By default, no compression is used. + get_call_obj : bool + If set True, then the async_generator will first generate + grpc.aio.call object and then generate rest of the results. + The call object can be used to cancel the execution of the + ongoing stream and exit. This can be done like below: + ------- + async_generator = await client.infer(..., get_call_obj=True) + streaming_call = await response_iterator.__next__() + streaming_call.cancel() + ------- Returns ------- async_generator Yield tuple holding (InferResult, InferenceServerException) objects. + If get_call_obj is set True, then it yields the streaming_call + object before yielding the tuples. Raises ------ @@ -709,13 +744,17 @@ async def _request_iterator(inputs_iterator): ) try: - response_iterator = self._client_stub.ModelStreamInfer( + streaming_call = self._client_stub.ModelStreamInfer( _request_iterator(inputs_iterator), metadata=metadata, timeout=stream_timeout, compression=_grpc_compression_type(compression_algorithm), ) - async for response in response_iterator: + + if get_call_obj: + yield streaming_call + + async for response in streaming_call: if self._verbose: print(response) result = error = None