From 41133b10a6a2d18a3f17fa2a641487711ca0a1dd Mon Sep 17 00:00:00 2001 From: Pawel Ziecina Date: Fri, 17 Nov 2023 02:37:37 -0800 Subject: [PATCH] Results validator extracted as class --- pytriton/models/model.py | 4 +- pytriton/proxy/inference_handler.py | 17 +++----- pytriton/proxy/types.py | 8 +++- pytriton/proxy/validators.py | 40 ++++++++++++++++--- tests/unit/test_model_proxy_communication.py | 4 +- tests/unit/test_proxy_inference_handler.py | 7 +++- tests/unit/test_proxy_validators.py | 42 ++++++++++---------- 7 files changed, 79 insertions(+), 43 deletions(-) diff --git a/pytriton/models/model.py b/pytriton/models/model.py index 365200d..0d40241 100644 --- a/pytriton/models/model.py +++ b/pytriton/models/model.py @@ -33,6 +33,7 @@ from pytriton.model_config.tensor import Tensor from pytriton.model_config.triton_model_config import DeviceKind, ResponseCache, TensorSpec, TritonModelConfig from pytriton.proxy.inference_handler import InferenceHandler, InferenceHandlerEvent +from pytriton.proxy.validators import TritonResultsValidator from pytriton.utils.workspace import Workspace LOGGER = logging.getLogger(__name__) @@ -192,6 +193,7 @@ def setup(self) -> None: with self._inference_handlers_lock: if not self._inference_handlers: triton_model_config = self._get_triton_model_config() + validator = TritonResultsValidator(triton_model_config, self._strict) for i, infer_function in enumerate(self.infer_functions): self.triton_context.model_configs[infer_function] = copy.deepcopy(triton_model_config) _inject_triton_context(self.triton_context, infer_function) @@ -201,7 +203,7 @@ def setup(self) -> None: shared_memory_socket=f"{self._shared_memory_socket}_{i}", data_store_socket=self._data_store_socket.as_posix(), zmq_context=self.zmq_context, - strict=self._strict, + validator=validator, ) inference_handler.on_proxy_backend_event(self._on_proxy_backend_event) inference_handler.start() diff --git a/pytriton/proxy/inference_handler.py b/pytriton/proxy/inference_handler.py index db60239..2af5b7a 100644 --- a/pytriton/proxy/inference_handler.py +++ b/pytriton/proxy/inference_handler.py @@ -47,7 +47,7 @@ TensorStore, ) from pytriton.proxy.types import Request -from pytriton.proxy.validators import validate_outputs +from pytriton.proxy.validators import TritonResultsValidator LOGGER = logging.getLogger(__name__) @@ -107,7 +107,7 @@ def __init__( shared_memory_socket: str, data_store_socket: str, zmq_context: zmq.Context, - strict: bool, + validator: TritonResultsValidator, ): """Create a PythonBackend object. @@ -117,13 +117,12 @@ def __init__( shared_memory_socket: Socket path for shared memory communication data_store_socket: Socket path for data store communication zmq_context: zero mq context - strict: Enable strict validation for model callable outputs + validator: Result validator instance """ super().__init__() self._model_config = model_config self._model_callable = model_callable - self._model_outputs = {output.name: output for output in model_config.outputs} - self._strict = strict + self._validator = validator self.stopped = False self._tensor_store = TensorStore(data_store_socket) @@ -167,13 +166,7 @@ def run(self) -> None: responses_iterator = _ResponsesIterator(responses, decoupled=self._model_config.decoupled) for responses in responses_iterator: LOGGER.debug(f"Validating outputs for {self._model_config.model_name}.") - validate_outputs( - model_config=self._model_config, - model_outputs=self._model_outputs, - outputs=responses, - strict=self._strict, - requests_number=len(requests), - ) + self._validator.validate_responses(inputs, responses) LOGGER.debug(f"Copying outputs to shared memory for {model_name}.") output_arrays_with_coords = [ (response_idx, output_name, tensor) diff --git a/pytriton/proxy/types.py b/pytriton/proxy/types.py index 0ec0cf2..f68449f 100644 --- a/pytriton/proxy/types.py +++ b/pytriton/proxy/types.py @@ -14,7 +14,7 @@ """Common data structures and type used by proxy model and inference handler.""" import dataclasses -from typing import Dict, Optional, Union +from typing import Dict, List, Optional, Union import numpy as np @@ -59,6 +59,9 @@ def values(self): return self.data.values() +Requests = List[Request] + + @dataclasses.dataclass class Response: """Data class for response data including numpy array outputs.""" @@ -96,3 +99,6 @@ def keys(self): def values(self): """Iterate over output data.""" return self.data.values() + + +Responses = List[Response] diff --git a/pytriton/proxy/validators.py b/pytriton/proxy/validators.py index 18666d2..de63789 100644 --- a/pytriton/proxy/validators.py +++ b/pytriton/proxy/validators.py @@ -16,10 +16,40 @@ import numpy as np +from pytriton.proxy.types import Requests, Responses + LOGGER = logging.getLogger(__name__) -def validate_outputs(model_config, model_outputs, outputs, strict: bool, requests_number: int): +class TritonResultsValidator: + """Validate results returned by inference callable against PyTriton and Triton requirements.""" + + def __init__(self, model_config, strict: bool): + """Validate results returned by inference callable against PyTriton and Triton requirements. + + Args: + model_config: Model configuration on Triton side + strict: Enable/disable strict validation against model config + """ + self._model_config = model_config + self._model_outputs = {output.name: output for output in model_config.outputs} + self._strict = strict + + def validate_responses(self, requests: Requests, responses: Responses): + """Validate responses returned by inference callable against PyTriton and Triton requirements. + + Args: + requests: Requests received from Triton + responses: Responses returned by inference callable + + Raises: + ValueError if responses are incorrect + """ + requests_number = len(requests) + _validate_outputs(self._model_config, self._model_outputs, responses, self._strict, requests_number) + + +def _validate_outputs(model_config, model_outputs, outputs, strict: bool, requests_number: int): """Validate outputs of model. Args: @@ -53,12 +83,12 @@ def validate_outputs(model_config, model_outputs, outputs, strict: bool, request ) for name, value in response.items(): LOGGER.debug(f"{name}: {value}") - validate_output_data(model_config, name, value) + _validate_output_data(model_config, name, value) if strict: - validate_output_dtype_and_shape(model_config, model_outputs, name, value) + _validate_output_dtype_and_shape(model_config, model_outputs, name, value) -def validate_output_data(model_config, name, value): +def _validate_output_data(model_config, name, value): """Validate output with given name and value. Args: @@ -96,7 +126,7 @@ def validate_output_data(model_config, name, value): ) -def validate_output_dtype_and_shape(model_config, model_outputs, name, value): +def _validate_output_dtype_and_shape(model_config, model_outputs, name, value): """Validate output with given name and value against the model config. Args: diff --git a/tests/unit/test_model_proxy_communication.py b/tests/unit/test_model_proxy_communication.py index 31186ce..c99137c 100644 --- a/tests/unit/test_model_proxy_communication.py +++ b/tests/unit/test_model_proxy_communication.py @@ -29,6 +29,7 @@ from pytriton.model_config.generator import ModelConfigGenerator from pytriton.model_config.triton_model_config import TensorSpec, TritonModelConfig from pytriton.proxy.communication import TensorStore +from pytriton.proxy.validators import TritonResultsValidator from pytriton.triton import TRITONSERVER_DIST_DIR LOGGER = logging.getLogger("tests.test_model_error_handling") @@ -155,13 +156,14 @@ def test_model_throws_exception(tmp_path, mocker, infer_fn, decoupled): backend_model = _get_proxy_backend(mocker, model_config, shared_memory_socket, data_store_socket) + validator = TritonResultsValidator(model_config, strict=False) inference_handler = InferenceHandler( infer_fn, model_config, shared_memory_socket=shared_memory_socket, data_store_socket=data_store_socket, zmq_context=zmq_context, - strict=False, + validator=validator, ) inference_handler.start() diff --git a/tests/unit/test_proxy_inference_handler.py b/tests/unit/test_proxy_inference_handler.py index 97966c8..fee81f8 100644 --- a/tests/unit/test_proxy_inference_handler.py +++ b/tests/unit/test_proxy_inference_handler.py @@ -26,6 +26,7 @@ from pytriton.proxy.communication import InferenceHandlerRequests, MetaRequestResponse, TensorStore from pytriton.proxy.inference_handler import InferenceHandler, _ResponsesIterator from pytriton.proxy.types import Request +from pytriton.proxy.validators import TritonResultsValidator from tests.unit.utils import verify_equalness_of_dicts_with_ndarray LOGGER = logging.getLogger("tests.unit.test_proxy_inference_handler") @@ -147,16 +148,18 @@ def test_proxy_throws_exception_when_validate_outputs_raise_an_error(tmp_path, m try: tensor_store.start() # start tensor store side process - this way InferenceHandler will create client for it mocker.patch( - "pytriton.proxy.inference_handler.validate_outputs", side_effect=ValueError("Validate outputs error.") + "pytriton.proxy.validators.TritonResultsValidator.validate_responses", + side_effect=ValueError("Validate outputs error."), ) zmq_context = zmq.Context() + validator = TritonResultsValidator(triton_model_config, strict=False) inference_handler = InferenceHandler( infer_fn, triton_model_config, shared_memory_socket=f"ipc://{tmp_path}/my", data_store_socket=data_store_socket, zmq_context=zmq_context, - strict=False, + validator=validator, ) mock_recv = mocker.patch.object(inference_handler.zmq_context._socket_class, "recv") diff --git a/tests/unit/test_proxy_validators.py b/tests/unit/test_proxy_validators.py index 9b7526a..6e4328d 100644 --- a/tests/unit/test_proxy_validators.py +++ b/tests/unit/test_proxy_validators.py @@ -18,7 +18,7 @@ import pytest from pytriton.model_config.triton_model_config import TensorSpec, TritonModelConfig -from pytriton.proxy.validators import validate_output_data, validate_output_dtype_and_shape, validate_outputs +from pytriton.proxy.validators import _validate_output_data, _validate_output_dtype_and_shape, _validate_outputs LOGGER = logging.getLogger("tests.unit.test_proxy_validators") logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(name)s: %(message)s") @@ -45,7 +45,7 @@ def test_validate_outputs_throws_exception_when_outputs_is_not_a_list(): ValueError, match=r"Outputs returned by `Foo` model callable must be list of response dicts with numpy arrays", ): - validate_outputs( + _validate_outputs( model_config=MY_MODEL_CONFIG, model_outputs=MY_MODEL_OUTPUTS, outputs=outputs, @@ -62,7 +62,7 @@ def test_validate_outputs_throws_exception_when_outputs_number_is_not_equal_to_r match=r"Number of outputs returned by `Foo` inference callable " r"\(1\) does not match number of requests \(2\) received from Triton\.", ): - validate_outputs( + _validate_outputs( model_config=MY_MODEL_CONFIG, model_outputs=MY_MODEL_OUTPUTS, outputs=outputs, @@ -78,7 +78,7 @@ def test_validate_outputs_throws_exception_when_outputs_is_not_a_list_of_dicts() ValueError, match=r"Outputs returned by `Foo` model callable must be list of response dicts with numpy arrays", ): - validate_outputs( + _validate_outputs( model_config=MY_MODEL_CONFIG, model_outputs=MY_MODEL_OUTPUTS, outputs=outputs, @@ -89,10 +89,10 @@ def test_validate_outputs_throws_exception_when_outputs_is_not_a_list_of_dicts() def test_validate_outputs_call_validate_outputs_data_if_strict_is_false(mocker): outputs = [{"output1": np.array([1, 2, 3]), "output2": np.array([1, 2, 3])}] - mock_validate_outputs_data = mocker.patch("pytriton.proxy.validators.validate_output_data") - mock_validate_output_dtype_and_shape = mocker.patch("pytriton.proxy.validators.validate_output_dtype_and_shape") + mock_validate_outputs_data = mocker.patch("pytriton.proxy.validators._validate_output_data") + mock_validate_output_dtype_and_shape = mocker.patch("pytriton.proxy.validators._validate_output_dtype_and_shape") - validate_outputs( + _validate_outputs( model_config=MY_MODEL_CONFIG, model_outputs=MY_MODEL_OUTPUTS, outputs=outputs, @@ -106,10 +106,10 @@ def test_validate_outputs_call_validate_outputs_data_if_strict_is_false(mocker): def test_validate_outputs_call_validate_outputs_data_if_strict_is_true(mocker): outputs = [{"output1": np.array([1, 2, 3]), "output2": np.array([1, 2, 3])}] - mock_validate_outputs_data = mocker.patch("pytriton.proxy.validators.validate_output_data") - mock_validate_output_dtype_and_shape = mocker.patch("pytriton.proxy.validators.validate_output_dtype_and_shape") + mock_validate_outputs_data = mocker.patch("pytriton.proxy.validators._validate_output_data") + mock_validate_output_dtype_and_shape = mocker.patch("pytriton.proxy.validators._validate_output_dtype_and_shape") - validate_outputs( + _validate_outputs( model_config=MY_MODEL_CONFIG, model_outputs=MY_MODEL_OUTPUTS, outputs=outputs, @@ -129,7 +129,7 @@ def test_validate_output_data_throws_exception_when_name_is_not_a_string(): ValueError, match=r"Not all keys returned by `Foo` model callable are string", ): - validate_output_data(model_config=MY_MODEL_CONFIG, name=name, value=value) + _validate_output_data(model_config=MY_MODEL_CONFIG, name=name, value=value) def test_validate_output_data_throws_exception_when_value_is_not_numpy_array(): @@ -140,7 +140,7 @@ def test_validate_output_data_throws_exception_when_value_is_not_numpy_array(): ValueError, match=r"Not all values returned by `Foo` model callable are numpy arrays", ): - validate_output_data(model_config=MY_MODEL_CONFIG, name=name, value=value) + _validate_output_data(model_config=MY_MODEL_CONFIG, name=name, value=value) def test_validate_output_data_throws_exception_when_value_is_not_supported_data_type(): @@ -154,7 +154,7 @@ def test_validate_output_data_throws_exception_when_value_is_not_supported_data_ "Returned `output1` for model `Foo` " r"has `M` dtype\.kind\.", ): - validate_output_data(model_config=MY_MODEL_CONFIG, name=name, value=value) + _validate_output_data(model_config=MY_MODEL_CONFIG, name=name, value=value) def test_validate_output_data_throws_exception_when_value_is_list_of_strings(): @@ -165,7 +165,7 @@ def test_validate_output_data_throws_exception_when_value_is_list_of_strings(): ValueError, match=r"Use string/byte-string instead of object for passing string in NumPy array from model `Foo`\.", ): - validate_output_data(model_config=MY_MODEL_CONFIG, name=name, value=value) + _validate_output_data(model_config=MY_MODEL_CONFIG, name=name, value=value) def test_validate_output_data_throws_exception_when_value_is_list_of_ints_defined_as_object(): @@ -178,7 +178,7 @@ def test_validate_output_data_throws_exception_when_value_is_list_of_ints_define "Returned `output1` from `Foo` " r"has `\` type\.", ): - validate_output_data(model_config=MY_MODEL_CONFIG, name=name, value=value) + _validate_output_data(model_config=MY_MODEL_CONFIG, name=name, value=value) def test_validate_output_dtype_and_shape_throws_exception_when_name_not_in_model_config(): @@ -189,7 +189,7 @@ def test_validate_output_dtype_and_shape_throws_exception_when_name_not_in_model ValueError, match=r"Returned output `output3` is not defined in model config for model `Foo`\.", ): - validate_output_dtype_and_shape( + _validate_output_dtype_and_shape( model_config=MY_MODEL_CONFIG, model_outputs=MY_MODEL_OUTPUTS, name=name, value=value ) @@ -202,7 +202,7 @@ def test_validate_output_dtype_and_shape_throws_exception_when_value_has_incorre ValueError, match=r"Returned output `output1` for model `Foo` has invalid type\. Returned: float64 \(f\). Expected: \\.", ): - validate_output_dtype_and_shape( + _validate_output_dtype_and_shape( model_config=MY_MODEL_CONFIG, model_outputs=MY_MODEL_OUTPUTS, name=name, value=value ) @@ -215,7 +215,7 @@ def test_validate_output_dtype_and_shape_throws_exception_when_value_has_incorre ValueError, match=r"Returned output `output1` for model `Foo` has invalid type\. Returned: \|S5 \(S\). Expected: \\.", ): - validate_output_dtype_and_shape( + _validate_output_dtype_and_shape( model_config=MY_MODEL_CONFIG, model_outputs=MY_MODEL_OUTPUTS, name=name, value=value ) @@ -228,7 +228,7 @@ def test_validate_output_dtype_and_shape_throws_exception_when_value_has_incorre ValueError, match=r"Returned output `output1` for model `Foo` has invalid shapes\. Returned: \(2, 1\)\. Expected: \(3,\)\.", ): - validate_output_dtype_and_shape( + _validate_output_dtype_and_shape( model_config=MY_MODEL_CONFIG, model_outputs=MY_MODEL_OUTPUTS, name=name, value=value ) @@ -241,7 +241,7 @@ def test_validate_output_dtype_and_shape_throws_exception_when_value_contains_to ValueError, match=r"Returned output `output1` for model `Foo` has invalid shapes at one or more positions\. Returned: \(2,\)\. Expected: \(3,\)\.", ): - validate_output_dtype_and_shape( + _validate_output_dtype_and_shape( model_config=MY_MODEL_CONFIG, model_outputs=MY_MODEL_OUTPUTS, name=name, value=value ) @@ -254,6 +254,6 @@ def test_validate_output_dtype_and_shape_throws_exception_when_value_contains_to ValueError, match=r"Returned output `output2` for model `Foo` has invalid shapes at one or more positions\. Returned: \(4,\)\. Expected: \(3,\)\.", ): - validate_output_dtype_and_shape( + _validate_output_dtype_and_shape( model_config=MY_MODEL_CONFIG, model_outputs=MY_MODEL_OUTPUTS, name=name, value=value )