Skip to content

Commit

Permalink
Supporting client-side gRPC cancellation
Browse files Browse the repository at this point in the history
  • Loading branch information
tanmayv25 committed Sep 7, 2023
1 parent 8ecca20 commit de2b4c8
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 34 deletions.
2 changes: 1 addition & 1 deletion src/python/library/tritonclient/grpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from ._infer_input import InferInput
from ._infer_result import InferResult
from ._requested_output import InferRequestedOutput
from ._utils import raise_error, raise_error_grpc
from ._utils import CancelledError, raise_error, raise_error_grpc
except ModuleNotFoundError as error:
raise RuntimeError(
"The installation does not include grpc support. "
Expand Down
61 changes: 50 additions & 11 deletions src/python/library/tritonclient/grpc/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from ._infer_result import InferResult
from ._infer_stream import _InferStream, _RequestIterator
from ._utils import (
CancelledError,
_get_inference_request,
_grpc_compression_type,
get_error_grpc,
Expand Down Expand Up @@ -1391,10 +1392,13 @@ def async_infer(
callback : function
Python function that is invoked once the request is completed.
The function must reserve the last two arguments (result, error)
to hold InferResult and InferenceServerException objects
respectively which will be provided to the function when executing
the callback. The ownership of these objects will be given to the
user. The 'error' would be None for a successful inference.
to hold InferResult and InferenceServerException(or CancelledError)
objects respectively which will be provided to the function when
executing the callback. The ownership of these objects will be given
to the user. The 'error' would be None for a successful inference.
Note if the request is cancelled using the returned future object,
error provided to callback will be a CancelledError exception
object.
model_version: str
The version of the model to run inference. The default value
is an empty string which means then the server will choose
Expand Down Expand Up @@ -1451,6 +1455,26 @@ def async_infer(
Optional custom parameters to be included in the inference
request.
Returns
-------
grpc.future
A representation of a computation in another control flow.
Computations represented by a Future may be yet to be begun,
may be ongoing, or may have already completed.
This object can be used to cancel the inference request like
below:
----------
future = async_infer(...)
ret = future.cancel()
----------
See here for more details of future object:
https://grpc.github.io/grpc/python/grpc.html#grpc.Future
The callback will be invoked with
(result=None, error=CancelledError) for the requests that
were successfully cancelled.
Raises
------
InferenceServerException
Expand All @@ -1466,6 +1490,8 @@ def wrapped_callback(call_future):
result = InferResult(response)
except grpc.RpcError as rpc_error:
error = get_error_grpc(rpc_error)
except grpc.FutureCancelledError:
error = CancelledError()
callback(result=result, error=error)

metadata = self._get_metadata(headers)
Expand Down Expand Up @@ -1502,6 +1528,7 @@ def wrapped_callback(call_future):
if request_id != "":
verbose_message = verbose_message + " '{}'".format(request_id)
print(verbose_message)
return self._call_future
except grpc.RpcError as rpc_error:
raise_error_grpc(rpc_error)

Expand All @@ -1518,10 +1545,13 @@ def start_stream(
Python function that is invoked upon receiving response from
the underlying stream. The function must reserve the last two
arguments (result, error) to hold InferResult and
InferenceServerException objects respectively which will be
provided to the function when executing the callback. The
ownership of these objects will be given to the user. The
'error' would be None for a successful inference.
InferenceServerException(or CancelledError) objects respectively
which will be provided to the function when executing the callback.
The ownership of these objects will be given to the user. The 'error'
would be None for a successful inference.
Note if the stream is closed with cancel_requests set True, then
the error provided to callback will be a CancelledError object.
stream_timeout : float
Optional stream timeout (in seconds). The stream will be closed
once the specified timeout expires.
Expand Down Expand Up @@ -1561,10 +1591,19 @@ def start_stream(
except grpc.RpcError as rpc_error:
raise_error_grpc(rpc_error)

def stop_stream(self):
"""Stops a stream if one available."""
def stop_stream(self, cancel_requests=False):
"""Stops a stream if one available.
Parameters
----------
cancel_requests : bool
If set True, then client cancels all the pending requests
and closes the stream. If set False, the call blocks till
all the pending requests on the stream are processed.
"""
if self._stream is not None:
self._stream.close()
self._stream.close(cancel_requests)
self._stream = None

def async_stream_infer(
Expand Down
35 changes: 19 additions & 16 deletions src/python/library/tritonclient/grpc/_infer_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from tritonclient.utils import *

from ._infer_result import InferResult
from ._utils import get_error_grpc, raise_error
from ._utils import CancelledError, get_error_grpc, raise_error


class _InferStream:
Expand All @@ -57,18 +57,25 @@ def __init__(self, callback, verbose):
self._verbose = verbose
self._request_queue = queue.Queue()
self._handler = None
self._cancelled = False
self._active = True
self._response_iterator = None

def __del__(self):
self.close()

def close(self):
def close(self, cancel_requests=False):
"""Gracefully close underlying gRPC streams. Note that this call
blocks till response of all currently enqueued requests are not
received.
"""
if cancel_requests and self._response_iterator:
self._response_iterator.cancel()
self._cancelled = True

if self._handler is not None:
self._request_queue.put(None)
if not self._cancelled:
self._request_queue.put(None)
if self._handler.is_alive():
self._handler.join()
if self._verbose:
Expand All @@ -85,12 +92,11 @@ def _init_handler(self, response_iterator):
The iterator over the gRPC response stream.
"""
self._response_iterator = response_iterator
if self._handler is not None:
raise_error("Attempted to initialize already initialized InferStream")
# Create a new thread to handle the gRPC response stream
self._handler = threading.Thread(
target=self._process_response, args=(response_iterator,)
)
self._handler = threading.Thread(target=self._process_response)
self._handler.start()
if self._verbose:
print("stream started...")
Expand Down Expand Up @@ -129,19 +135,13 @@ def _get_request(self):
request = self._request_queue.get()
return request

def _process_response(self, responses):
def _process_response(self):
"""Worker thread function to iterate through the response stream and
executes the provided callbacks.
Parameters
----------
responses : iterator
The iterator to the response from the server for the
requests in the stream.
"""
try:
for response in responses:
for response in self._response_iterator:
if self._verbose:
print(response)
result = error = None
Expand All @@ -155,8 +155,11 @@ def _process_response(self, responses):
# can still be used. The stream won't be closed here as the thread
# executing this function is managed by stream and may cause
# circular wait
self._active = responses.is_active()
error = get_error_grpc(rpc_error)
self._active = self._response_iterator.is_active()
if rpc_error.cancelled:
error = CancelledError()
else:
error = get_error_grpc(rpc_error)
self._callback(result=None, error=error)


Expand Down
4 changes: 4 additions & 0 deletions src/python/library/tritonclient/grpc/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
from tritonclient.utils import *


class CancelledError(Exception):
"""Indicates that the issued operation was cancelled."""


def get_error_grpc(rpc_error):
"""Convert a gRPC error to an InferenceServerException.
Expand Down
51 changes: 45 additions & 6 deletions src/python/library/tritonclient/grpc/aio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,8 +586,28 @@ async def infer(
headers=None,
compression_algorithm=None,
parameters=None,
get_call_obj=False,
):
"""Refer to tritonclient.grpc.InferenceServerClient"""
"""Refer to tritonclient.grpc.InferenceServerClient
The additional parameters for this functions are
described below:
Parameters
----------
get_call_obj : bool
If set True, then this function will yield
grpc.aio.call object first bfore the
InferResult.
This object can be used to issue request
cancellation if required. This can be attained
by following:
-------
call = await client.infer(..., get_call_obj=True)
call.cancel()
-------
"""

metadata = self._get_metadata(headers)

Expand All @@ -609,18 +629,20 @@ async def infer(
)
if self._verbose:
print("infer, metadata {}\n{}".format(metadata, request))

try:
response = await self._client_stub.ModelInfer(
call = self._client_stub.ModelInfer(
request=request,
metadata=metadata,
timeout=client_timeout,
compression=_grpc_compression_type(compression_algorithm),
)
if get_call_obj:
yield call
response = await call
if self._verbose:
print(response)
result = InferResult(response)
return result
yield result
except grpc.RpcError as rpc_error:
raise_error_grpc(rpc_error)

Expand All @@ -630,6 +652,7 @@ async def stream_infer(
stream_timeout=None,
headers=None,
compression_algorithm=None,
get_call_obj=False,
):
"""Runs an asynchronous inference over gRPC bi-directional streaming
API.
Expand All @@ -650,11 +673,23 @@ async def stream_infer(
Optional grpc compression algorithm to be used on client side.
Currently supports "deflate", "gzip" and None. By default, no
compression is used.
get_call_obj : bool
If set True, then the async_generator will first generate
grpc.aio.call object and then generate rest of the results.
The call object can be used to cancel the execution of the
ongoing stream and exit. This can be done like below:
-------
async_generator = await client.infer(..., get_call_obj=True)
streaming_call = await response_iterator.__next__()
streaming_call.cancel()
-------
Returns
-------
async_generator
Yield tuple holding (InferResult, InferenceServerException) objects.
If get_call_obj is set True, then it yields the streaming_call
object before yielding the tuples.
Raises
------
Expand Down Expand Up @@ -709,13 +744,17 @@ async def _request_iterator(inputs_iterator):
)

try:
response_iterator = self._client_stub.ModelStreamInfer(
streaming_call = self._client_stub.ModelStreamInfer(
_request_iterator(inputs_iterator),
metadata=metadata,
timeout=stream_timeout,
compression=_grpc_compression_type(compression_algorithm),
)
async for response in response_iterator:

if get_call_obj:
yield streaming_call

async for response in streaming_call:
if self._verbose:
print(response)
result = error = None
Expand Down

0 comments on commit de2b4c8

Please sign in to comment.