Skip to content

Commit

Permalink
expose requested_output_names in Request (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
catwell authored Oct 12, 2024
1 parent 26dc739 commit d0cc359
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 2 deletions.
4 changes: 4 additions & 0 deletions pytriton/proxy/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,8 @@ def serialize_requests(self, requests: Requests) -> bytes:
serialized_request = {"data": serialized_request, "parameters": request.parameters}
if request.span is not None:
serialized_request["span"] = get_span_dict(request.span)
if request.requested_output_names is not None:
serialized_request["requested_output_names"] = request.requested_output_names
requests_list.append(serialized_request)

requests = {"requests": requests_list}
Expand All @@ -1015,6 +1017,8 @@ def deserialize_requests(self, requests_payload: bytes) -> Requests:
span_dict = request["span"]
span = start_span_from_remote(span_dict, "proxy_inference_callable")
kwargs["span"] = span
if "requested_output_names" in request:
kwargs["requested_output_names"] = request["requested_output_names"]
request_data = {
input_name: self._tensor_store.get(tensor_id)
for input_name, tensor_id in request.get("data", {}).items()
Expand Down
7 changes: 6 additions & 1 deletion pytriton/proxy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,12 @@ def _wrap_request(self, triton_request, inputs, span=None) -> Request:
kwargs = {}
if span is not None:
kwargs["span"] = span
return Request(data=request, parameters=json.loads(triton_request.parameters()), **kwargs)
return Request(
data=request,
parameters=json.loads(triton_request.parameters()),
requested_output_names=list(triton_request.requested_output_names()),
**kwargs,
)

async def _send_requests(self, requests_id: bytes, triton_requests, spans=None) -> ConcurrentFuture:
requests = triton_requests
Expand Down
2 changes: 2 additions & 0 deletions pytriton/proxy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class Request:
"""Parameters for the request."""
span: Optional[Any] = None
"""Telemetry span for request"""
requested_output_names: Optional[List[str]] = None
"""Requested output names for the request."""

def __getitem__(self, input_name: str) -> np.ndarray:
"""Get input data."""
Expand Down
5 changes: 4 additions & 1 deletion tests/unit/test_model_proxy_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class InferenceRequest:
def __init__(self, model_name, inputs, requested_output_names, parameters=None):
self.model_name = model_name
self._inputs = inputs
self.requested_output_names = requested_output_names
self._requested_output_names = requested_output_names
self._parameters = parameters or {}

def inputs(self):
Expand All @@ -66,6 +66,9 @@ def parameters(self):
def get_response_sender(self):
return None

def requested_output_names(self):
return self._requested_output_names


def _error_infer_fn(*_, **__):
# Wrapper raises division by zero error
Expand Down

0 comments on commit d0cc359

Please sign in to comment.