diff --git a/pytriton/proxy/data.py b/pytriton/proxy/data.py index d820173..6b28d88 100644 --- a/pytriton/proxy/data.py +++ b/pytriton/proxy/data.py @@ -993,6 +993,8 @@ def serialize_requests(self, requests: Requests) -> bytes: serialized_request = {"data": serialized_request, "parameters": request.parameters} if request.span is not None: serialized_request["span"] = get_span_dict(request.span) + if request.requested_output_names is not None: + serialized_request["requested_output_names"] = request.requested_output_names requests_list.append(serialized_request) requests = {"requests": requests_list} @@ -1015,6 +1017,8 @@ def deserialize_requests(self, requests_payload: bytes) -> Requests: span_dict = request["span"] span = start_span_from_remote(span_dict, "proxy_inference_callable") kwargs["span"] = span + if "requested_output_names" in request: + kwargs["requested_output_names"] = request["requested_output_names"] request_data = { input_name: self._tensor_store.get(tensor_id) for input_name, tensor_id in request.get("data", {}).items() diff --git a/pytriton/proxy/model.py b/pytriton/proxy/model.py index d815a38..0e5e736 100644 --- a/pytriton/proxy/model.py +++ b/pytriton/proxy/model.py @@ -136,7 +136,12 @@ def _wrap_request(self, triton_request, inputs, span=None) -> Request: kwargs = {} if span is not None: kwargs["span"] = span - return Request(data=request, parameters=json.loads(triton_request.parameters()), **kwargs) + return Request( + data=request, + parameters=json.loads(triton_request.parameters()), + requested_output_names=list(triton_request.requested_output_names()), + **kwargs, + ) async def _send_requests(self, requests_id: bytes, triton_requests, spans=None) -> ConcurrentFuture: requests = triton_requests diff --git a/pytriton/proxy/types.py b/pytriton/proxy/types.py index dd24e60..ca3a073 100644 --- a/pytriton/proxy/types.py +++ b/pytriton/proxy/types.py @@ -31,6 +31,8 @@ class Request: """Parameters for the request.""" span: Optional[Any] = None """Telemetry span for request""" + requested_output_names: Optional[List[str]] = None + """Requested output names for the request.""" def __getitem__(self, input_name: str) -> np.ndarray: """Get input data.""" diff --git a/tests/unit/test_model_proxy_communication.py b/tests/unit/test_model_proxy_communication.py index 97bacf0..3962fd6 100644 --- a/tests/unit/test_model_proxy_communication.py +++ b/tests/unit/test_model_proxy_communication.py @@ -54,7 +54,7 @@ class InferenceRequest: def __init__(self, model_name, inputs, requested_output_names, parameters=None): self.model_name = model_name self._inputs = inputs - self.requested_output_names = requested_output_names + self._requested_output_names = requested_output_names self._parameters = parameters or {} def inputs(self): @@ -66,6 +66,9 @@ def parameters(self): def get_response_sender(self): return None + def requested_output_names(self): + return self._requested_output_names + def _error_infer_fn(*_, **__): # Wrapper raises division by zero error