Skip to content

Commit

Permalink
Results validator extracted as class
Browse files Browse the repository at this point in the history
  • Loading branch information
pziecina-nv committed Nov 17, 2023
1 parent f828f5d commit 41133b1
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 43 deletions.
4 changes: 3 additions & 1 deletion pytriton/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pytriton.model_config.tensor import Tensor
from pytriton.model_config.triton_model_config import DeviceKind, ResponseCache, TensorSpec, TritonModelConfig
from pytriton.proxy.inference_handler import InferenceHandler, InferenceHandlerEvent
from pytriton.proxy.validators import TritonResultsValidator
from pytriton.utils.workspace import Workspace

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -192,6 +193,7 @@ def setup(self) -> None:
with self._inference_handlers_lock:
if not self._inference_handlers:
triton_model_config = self._get_triton_model_config()
validator = TritonResultsValidator(triton_model_config, self._strict)
for i, infer_function in enumerate(self.infer_functions):
self.triton_context.model_configs[infer_function] = copy.deepcopy(triton_model_config)
_inject_triton_context(self.triton_context, infer_function)
Expand All @@ -201,7 +203,7 @@ def setup(self) -> None:
shared_memory_socket=f"{self._shared_memory_socket}_{i}",
data_store_socket=self._data_store_socket.as_posix(),
zmq_context=self.zmq_context,
strict=self._strict,
validator=validator,
)
inference_handler.on_proxy_backend_event(self._on_proxy_backend_event)
inference_handler.start()
Expand Down
17 changes: 5 additions & 12 deletions pytriton/proxy/inference_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
TensorStore,
)
from pytriton.proxy.types import Request
from pytriton.proxy.validators import validate_outputs
from pytriton.proxy.validators import TritonResultsValidator

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -107,7 +107,7 @@ def __init__(
shared_memory_socket: str,
data_store_socket: str,
zmq_context: zmq.Context,
strict: bool,
validator: TritonResultsValidator,
):
"""Create a PythonBackend object.
Expand All @@ -117,13 +117,12 @@ def __init__(
shared_memory_socket: Socket path for shared memory communication
data_store_socket: Socket path for data store communication
zmq_context: zero mq context
strict: Enable strict validation for model callable outputs
validator: Result validator instance
"""
super().__init__()
self._model_config = model_config
self._model_callable = model_callable
self._model_outputs = {output.name: output for output in model_config.outputs}
self._strict = strict
self._validator = validator
self.stopped = False

self._tensor_store = TensorStore(data_store_socket)
Expand Down Expand Up @@ -167,13 +166,7 @@ def run(self) -> None:
responses_iterator = _ResponsesIterator(responses, decoupled=self._model_config.decoupled)
for responses in responses_iterator:
LOGGER.debug(f"Validating outputs for {self._model_config.model_name}.")
validate_outputs(
model_config=self._model_config,
model_outputs=self._model_outputs,
outputs=responses,
strict=self._strict,
requests_number=len(requests),
)
self._validator.validate_responses(inputs, responses)
LOGGER.debug(f"Copying outputs to shared memory for {model_name}.")
output_arrays_with_coords = [
(response_idx, output_name, tensor)
Expand Down
8 changes: 7 additions & 1 deletion pytriton/proxy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Common data structures and type used by proxy model and inference handler."""

import dataclasses
from typing import Dict, Optional, Union
from typing import Dict, List, Optional, Union

import numpy as np

Expand Down Expand Up @@ -59,6 +59,9 @@ def values(self):
return self.data.values()


Requests = List[Request]


@dataclasses.dataclass
class Response:
"""Data class for response data including numpy array outputs."""
Expand Down Expand Up @@ -96,3 +99,6 @@ def keys(self):
def values(self):
"""Iterate over output data."""
return self.data.values()


Responses = List[Response]
40 changes: 35 additions & 5 deletions pytriton/proxy/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,40 @@

import numpy as np

from pytriton.proxy.types import Requests, Responses

LOGGER = logging.getLogger(__name__)


def validate_outputs(model_config, model_outputs, outputs, strict: bool, requests_number: int):
class TritonResultsValidator:
"""Validate results returned by inference callable against PyTriton and Triton requirements."""

def __init__(self, model_config, strict: bool):
"""Validate results returned by inference callable against PyTriton and Triton requirements.
Args:
model_config: Model configuration on Triton side
strict: Enable/disable strict validation against model config
"""
self._model_config = model_config
self._model_outputs = {output.name: output for output in model_config.outputs}
self._strict = strict

def validate_responses(self, requests: Requests, responses: Responses):
"""Validate responses returned by inference callable against PyTriton and Triton requirements.
Args:
requests: Requests received from Triton
responses: Responses returned by inference callable
Raises:
ValueError if responses are incorrect
"""
requests_number = len(requests)
_validate_outputs(self._model_config, self._model_outputs, responses, self._strict, requests_number)


def _validate_outputs(model_config, model_outputs, outputs, strict: bool, requests_number: int):
"""Validate outputs of model.
Args:
Expand Down Expand Up @@ -53,12 +83,12 @@ def validate_outputs(model_config, model_outputs, outputs, strict: bool, request
)
for name, value in response.items():
LOGGER.debug(f"{name}: {value}")
validate_output_data(model_config, name, value)
_validate_output_data(model_config, name, value)
if strict:
validate_output_dtype_and_shape(model_config, model_outputs, name, value)
_validate_output_dtype_and_shape(model_config, model_outputs, name, value)


def validate_output_data(model_config, name, value):
def _validate_output_data(model_config, name, value):
"""Validate output with given name and value.
Args:
Expand Down Expand Up @@ -96,7 +126,7 @@ def validate_output_data(model_config, name, value):
)


def validate_output_dtype_and_shape(model_config, model_outputs, name, value):
def _validate_output_dtype_and_shape(model_config, model_outputs, name, value):
"""Validate output with given name and value against the model config.
Args:
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/test_model_proxy_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pytriton.model_config.generator import ModelConfigGenerator
from pytriton.model_config.triton_model_config import TensorSpec, TritonModelConfig
from pytriton.proxy.communication import TensorStore
from pytriton.proxy.validators import TritonResultsValidator
from pytriton.triton import TRITONSERVER_DIST_DIR

LOGGER = logging.getLogger("tests.test_model_error_handling")
Expand Down Expand Up @@ -155,13 +156,14 @@ def test_model_throws_exception(tmp_path, mocker, infer_fn, decoupled):

backend_model = _get_proxy_backend(mocker, model_config, shared_memory_socket, data_store_socket)

validator = TritonResultsValidator(model_config, strict=False)
inference_handler = InferenceHandler(
infer_fn,
model_config,
shared_memory_socket=shared_memory_socket,
data_store_socket=data_store_socket,
zmq_context=zmq_context,
strict=False,
validator=validator,
)
inference_handler.start()

Expand Down
7 changes: 5 additions & 2 deletions tests/unit/test_proxy_inference_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pytriton.proxy.communication import InferenceHandlerRequests, MetaRequestResponse, TensorStore
from pytriton.proxy.inference_handler import InferenceHandler, _ResponsesIterator
from pytriton.proxy.types import Request
from pytriton.proxy.validators import TritonResultsValidator
from tests.unit.utils import verify_equalness_of_dicts_with_ndarray

LOGGER = logging.getLogger("tests.unit.test_proxy_inference_handler")
Expand Down Expand Up @@ -147,16 +148,18 @@ def test_proxy_throws_exception_when_validate_outputs_raise_an_error(tmp_path, m
try:
tensor_store.start() # start tensor store side process - this way InferenceHandler will create client for it
mocker.patch(
"pytriton.proxy.inference_handler.validate_outputs", side_effect=ValueError("Validate outputs error.")
"pytriton.proxy.validators.TritonResultsValidator.validate_responses",
side_effect=ValueError("Validate outputs error."),
)
zmq_context = zmq.Context()
validator = TritonResultsValidator(triton_model_config, strict=False)
inference_handler = InferenceHandler(
infer_fn,
triton_model_config,
shared_memory_socket=f"ipc://{tmp_path}/my",
data_store_socket=data_store_socket,
zmq_context=zmq_context,
strict=False,
validator=validator,
)

mock_recv = mocker.patch.object(inference_handler.zmq_context._socket_class, "recv")
Expand Down
42 changes: 21 additions & 21 deletions tests/unit/test_proxy_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pytest

from pytriton.model_config.triton_model_config import TensorSpec, TritonModelConfig
from pytriton.proxy.validators import validate_output_data, validate_output_dtype_and_shape, validate_outputs
from pytriton.proxy.validators import _validate_output_data, _validate_output_dtype_and_shape, _validate_outputs

LOGGER = logging.getLogger("tests.unit.test_proxy_validators")
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(name)s: %(message)s")
Expand All @@ -45,7 +45,7 @@ def test_validate_outputs_throws_exception_when_outputs_is_not_a_list():
ValueError,
match=r"Outputs returned by `Foo` model callable must be list of response dicts with numpy arrays",
):
validate_outputs(
_validate_outputs(
model_config=MY_MODEL_CONFIG,
model_outputs=MY_MODEL_OUTPUTS,
outputs=outputs,
Expand All @@ -62,7 +62,7 @@ def test_validate_outputs_throws_exception_when_outputs_number_is_not_equal_to_r
match=r"Number of outputs returned by `Foo` inference callable "
r"\(1\) does not match number of requests \(2\) received from Triton\.",
):
validate_outputs(
_validate_outputs(
model_config=MY_MODEL_CONFIG,
model_outputs=MY_MODEL_OUTPUTS,
outputs=outputs,
Expand All @@ -78,7 +78,7 @@ def test_validate_outputs_throws_exception_when_outputs_is_not_a_list_of_dicts()
ValueError,
match=r"Outputs returned by `Foo` model callable must be list of response dicts with numpy arrays",
):
validate_outputs(
_validate_outputs(
model_config=MY_MODEL_CONFIG,
model_outputs=MY_MODEL_OUTPUTS,
outputs=outputs,
Expand All @@ -89,10 +89,10 @@ def test_validate_outputs_throws_exception_when_outputs_is_not_a_list_of_dicts()

def test_validate_outputs_call_validate_outputs_data_if_strict_is_false(mocker):
outputs = [{"output1": np.array([1, 2, 3]), "output2": np.array([1, 2, 3])}]
mock_validate_outputs_data = mocker.patch("pytriton.proxy.validators.validate_output_data")
mock_validate_output_dtype_and_shape = mocker.patch("pytriton.proxy.validators.validate_output_dtype_and_shape")
mock_validate_outputs_data = mocker.patch("pytriton.proxy.validators._validate_output_data")
mock_validate_output_dtype_and_shape = mocker.patch("pytriton.proxy.validators._validate_output_dtype_and_shape")

validate_outputs(
_validate_outputs(
model_config=MY_MODEL_CONFIG,
model_outputs=MY_MODEL_OUTPUTS,
outputs=outputs,
Expand All @@ -106,10 +106,10 @@ def test_validate_outputs_call_validate_outputs_data_if_strict_is_false(mocker):

def test_validate_outputs_call_validate_outputs_data_if_strict_is_true(mocker):
outputs = [{"output1": np.array([1, 2, 3]), "output2": np.array([1, 2, 3])}]
mock_validate_outputs_data = mocker.patch("pytriton.proxy.validators.validate_output_data")
mock_validate_output_dtype_and_shape = mocker.patch("pytriton.proxy.validators.validate_output_dtype_and_shape")
mock_validate_outputs_data = mocker.patch("pytriton.proxy.validators._validate_output_data")
mock_validate_output_dtype_and_shape = mocker.patch("pytriton.proxy.validators._validate_output_dtype_and_shape")

validate_outputs(
_validate_outputs(
model_config=MY_MODEL_CONFIG,
model_outputs=MY_MODEL_OUTPUTS,
outputs=outputs,
Expand All @@ -129,7 +129,7 @@ def test_validate_output_data_throws_exception_when_name_is_not_a_string():
ValueError,
match=r"Not all keys returned by `Foo` model callable are string",
):
validate_output_data(model_config=MY_MODEL_CONFIG, name=name, value=value)
_validate_output_data(model_config=MY_MODEL_CONFIG, name=name, value=value)


def test_validate_output_data_throws_exception_when_value_is_not_numpy_array():
Expand All @@ -140,7 +140,7 @@ def test_validate_output_data_throws_exception_when_value_is_not_numpy_array():
ValueError,
match=r"Not all values returned by `Foo` model callable are numpy arrays",
):
validate_output_data(model_config=MY_MODEL_CONFIG, name=name, value=value)
_validate_output_data(model_config=MY_MODEL_CONFIG, name=name, value=value)


def test_validate_output_data_throws_exception_when_value_is_not_supported_data_type():
Expand All @@ -154,7 +154,7 @@ def test_validate_output_data_throws_exception_when_value_is_not_supported_data_
"Returned `output1` for model `Foo` "
r"has `M` dtype\.kind\.",
):
validate_output_data(model_config=MY_MODEL_CONFIG, name=name, value=value)
_validate_output_data(model_config=MY_MODEL_CONFIG, name=name, value=value)


def test_validate_output_data_throws_exception_when_value_is_list_of_strings():
Expand All @@ -165,7 +165,7 @@ def test_validate_output_data_throws_exception_when_value_is_list_of_strings():
ValueError,
match=r"Use string/byte-string instead of object for passing string in NumPy array from model `Foo`\.",
):
validate_output_data(model_config=MY_MODEL_CONFIG, name=name, value=value)
_validate_output_data(model_config=MY_MODEL_CONFIG, name=name, value=value)


def test_validate_output_data_throws_exception_when_value_is_list_of_ints_defined_as_object():
Expand All @@ -178,7 +178,7 @@ def test_validate_output_data_throws_exception_when_value_is_list_of_ints_define
"Returned `output1` from `Foo` "
r"has `\<class 'int'\>` type\.",
):
validate_output_data(model_config=MY_MODEL_CONFIG, name=name, value=value)
_validate_output_data(model_config=MY_MODEL_CONFIG, name=name, value=value)


def test_validate_output_dtype_and_shape_throws_exception_when_name_not_in_model_config():
Expand All @@ -189,7 +189,7 @@ def test_validate_output_dtype_and_shape_throws_exception_when_name_not_in_model
ValueError,
match=r"Returned output `output3` is not defined in model config for model `Foo`\.",
):
validate_output_dtype_and_shape(
_validate_output_dtype_and_shape(
model_config=MY_MODEL_CONFIG, model_outputs=MY_MODEL_OUTPUTS, name=name, value=value
)

Expand All @@ -202,7 +202,7 @@ def test_validate_output_dtype_and_shape_throws_exception_when_value_has_incorre
ValueError,
match=r"Returned output `output1` for model `Foo` has invalid type\. Returned: float64 \(f\). Expected: \<class 'numpy\.int32'\>\.",
):
validate_output_dtype_and_shape(
_validate_output_dtype_and_shape(
model_config=MY_MODEL_CONFIG, model_outputs=MY_MODEL_OUTPUTS, name=name, value=value
)

Expand All @@ -215,7 +215,7 @@ def test_validate_output_dtype_and_shape_throws_exception_when_value_has_incorre
ValueError,
match=r"Returned output `output1` for model `Foo` has invalid type\. Returned: \|S5 \(S\). Expected: \<class 'numpy\.int32'\>\.",
):
validate_output_dtype_and_shape(
_validate_output_dtype_and_shape(
model_config=MY_MODEL_CONFIG, model_outputs=MY_MODEL_OUTPUTS, name=name, value=value
)

Expand All @@ -228,7 +228,7 @@ def test_validate_output_dtype_and_shape_throws_exception_when_value_has_incorre
ValueError,
match=r"Returned output `output1` for model `Foo` has invalid shapes\. Returned: \(2, 1\)\. Expected: \(3,\)\.",
):
validate_output_dtype_and_shape(
_validate_output_dtype_and_shape(
model_config=MY_MODEL_CONFIG, model_outputs=MY_MODEL_OUTPUTS, name=name, value=value
)

Expand All @@ -241,7 +241,7 @@ def test_validate_output_dtype_and_shape_throws_exception_when_value_contains_to
ValueError,
match=r"Returned output `output1` for model `Foo` has invalid shapes at one or more positions\. Returned: \(2,\)\. Expected: \(3,\)\.",
):
validate_output_dtype_and_shape(
_validate_output_dtype_and_shape(
model_config=MY_MODEL_CONFIG, model_outputs=MY_MODEL_OUTPUTS, name=name, value=value
)

Expand All @@ -254,6 +254,6 @@ def test_validate_output_dtype_and_shape_throws_exception_when_value_contains_to
ValueError,
match=r"Returned output `output2` for model `Foo` has invalid shapes at one or more positions\. Returned: \(4,\)\. Expected: \(3,\)\.",
):
validate_output_dtype_and_shape(
_validate_output_dtype_and_shape(
model_config=MY_MODEL_CONFIG, model_outputs=MY_MODEL_OUTPUTS, name=name, value=value
)

0 comments on commit 41133b1

Please sign in to comment.