From f071bc5b3defc8e3fe7f9b6b2f0731ea80c2fab9 Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 30 Aug 2023 19:57:25 +0900 Subject: [PATCH] Support `parameters` on config.pbtxt --- model_repository/sample/1/model.py | 5 ++- model_repository/sample/config.pbtxt | 8 +++++ .../sample_autobatching/config.pbtxt | 8 +++++ model_repository/sample_multiple/config.pbtxt | 8 +++++ tests/test_model_call.py | 30 ++++++++++------ tritony/tools.py | 34 +++++++++++++++---- tritony/version.py | 2 +- 7 files changed, 75 insertions(+), 20 deletions(-) diff --git a/model_repository/sample/1/model.py b/model_repository/sample/1/model.py index ff8803a..1618bb2 100644 --- a/model_repository/sample/1/model.py +++ b/model_repository/sample/1/model.py @@ -13,10 +13,13 @@ def initialize(self, args): pb_utils.triton_string_to_numpy(output_config["data_type"]) for output_config in output_configs ] + parameters = self.model_config["parameters"] + def execute(self, requests): responses = [None for _ in requests] for idx, request in enumerate(requests): - in_tensor = [item.as_numpy() for item in request.inputs()] + current_add_value = int(json.loads(request.parameters()).get("add", 0)) + in_tensor = [item.as_numpy() + current_add_value for item in request.inputs()] out_tensor = [ pb_utils.Tensor(output_name, x.astype(output_dtype)) for x, output_name, output_dtype in zip(in_tensor, self.output_name_list, self.output_dtype_list) diff --git a/model_repository/sample/config.pbtxt b/model_repository/sample/config.pbtxt index cb6537d..60b403d 100644 --- a/model_repository/sample/config.pbtxt +++ b/model_repository/sample/config.pbtxt @@ -1,6 +1,14 @@ name: "sample" backend: "python" max_batch_size: 0 + +parameters [ + { + key: "add", + value: { string_value: "0" } + } +] + input [ { name: "model_in" diff --git a/model_repository/sample_autobatching/config.pbtxt b/model_repository/sample_autobatching/config.pbtxt index 14393b2..0d67899 100644 --- a/model_repository/sample_autobatching/config.pbtxt +++ b/model_repository/sample_autobatching/config.pbtxt @@ -1,6 +1,14 @@ name: "sample_autobatching" backend: "python" max_batch_size: 2 + +parameters [ + { + key: "add", + value: { string_value: "0" } + } +] + input [ { name: "model_in" diff --git a/model_repository/sample_multiple/config.pbtxt b/model_repository/sample_multiple/config.pbtxt index 7a880cf..8a6b357 100644 --- a/model_repository/sample_multiple/config.pbtxt +++ b/model_repository/sample_multiple/config.pbtxt @@ -1,6 +1,14 @@ name: "sample_multiple" backend: "python" max_batch_size: 2 + +parameters [ + { + key: "add", + value: { string_value: "0" } + } +] + input [ { name: "model_in0" diff --git a/tests/test_model_call.py b/tests/test_model_call.py index 314b9d1..99b8f30 100644 --- a/tests/test_model_call.py +++ b/tests/test_model_call.py @@ -16,28 +16,36 @@ def protocol_and_port(request): return request.param -def test_swithcing(protocol_and_port): - protocol, port = protocol_and_port - print(f"Testing {protocol}") +def get_client(protocol, port): + print(f"Testing {protocol}", flush=True) + return InferenceClient.create_with(MODEL_NAME, f"{TRITON_HOST}:{port}", protocol=protocol) + - client = InferenceClient.create_with(MODEL_NAME, f"{TRITON_HOST}:{port}", protocol=protocol) +def test_swithcing(protocol_and_port): + client = get_client(*protocol_and_port) sample = np.random.rand(1, 100).astype(np.float32) result = client(sample) - print(f"Result: {np.isclose(result, sample).all()}") + assert {np.isclose(result, sample).all()} sample_batched = np.random.rand(100, 100).astype(np.float32) client(sample_batched, model_name="sample_autobatching") - print(f"Result: {np.isclose(result, sample).all()}") + assert {np.isclose(result, sample).all()} def test_with_input_name(protocol_and_port): - protocol, port = protocol_and_port - print(f"Testing {protocol}") - - client = InferenceClient.create_with(MODEL_NAME, f"{TRITON_HOST}:{port}", protocol=protocol) + client = get_client(*protocol_and_port) sample = np.random.rand(100, 100).astype(np.float32) result = client({client.default_model_spec.model_input[0].name: sample}) + assert {np.isclose(result, sample).all()} + + +def test_with_parameters(protocol_and_port): + client = get_client(*protocol_and_port) + + sample = np.random.rand(1, 100).astype(np.float32) + ADD_VALUE = 1 + result = client({client.default_model_spec.model_input[0].name: sample}, parameters={"add": f"{ADD_VALUE}"}) - print(f"Result: {np.isclose(result, sample).all()}") + assert {np.isclose(result[0], sample[0] + ADD_VALUE).all()} diff --git a/tritony/tools.py b/tritony/tools.py index 135a358..1581c58 100644 --- a/tritony/tools.py +++ b/tritony/tools.py @@ -72,6 +72,7 @@ async def send_request_async( done_event, triton_client: Union[grpcclient.InferenceServerClient, httpclient.InferenceServerClient], model_spec: TritonModelSpec, + parameters: dict | None = None, ): ret = [] while True: @@ -86,7 +87,7 @@ async def send_request_async( try: a_pred = await request_async( inference_client.flag.protocol, - inference_client.build_triton_input(batch_data, model_spec), + inference_client.build_triton_input(batch_data, model_spec, parameters=parameters), triton_client, timeout=inference_client.client_timeout, compression=inference_client.flag.compression_algorithm, @@ -232,6 +233,7 @@ def _get_request_id(self): def __call__( self, sequences_or_dict: Union[List[Any], Dict[str, List[Any]]], + parameters: dict | None = None, model_name: str | None = None, model_version: str | None = None, ): @@ -254,9 +256,14 @@ def __call__( or (model_input.optional is True and model_input.name in sequences_or_dict) # check optional ] - return self._call_async(sequences_list, model_spec=model_spec) + return self._call_async(sequences_list, model_spec=model_spec, parameters=parameters) - def build_triton_input(self, _input_list: List[np.array], model_spec: TritonModelSpec): + def build_triton_input( + self, + _input_list: List[np.array], + model_spec: TritonModelSpec, + parameters: dict | None = None, + ): if self.flag.protocol is TritonProtocol.grpc: client = grpcclient else: @@ -278,19 +285,30 @@ def build_triton_input(self, _input_list: List[np.array], model_spec: TritonMode request_id=str(request_id), model_version=model_spec.model_version, outputs=infer_requested_output, + parameters=parameters, ) return request_input - def _call_async(self, data: List[np.ndarray], model_spec: TritonModelSpec) -> Optional[np.ndarray]: - async_result = asyncio.run(self._call_async_item(data=data, model_spec=model_spec)) + def _call_async( + self, + data: List[np.ndarray], + model_spec: TritonModelSpec, + parameters: dict | None = None, + ) -> Optional[np.ndarray]: + async_result = asyncio.run(self._call_async_item(data=data, model_spec=model_spec, parameters=parameters)) if isinstance(async_result, Exception): raise async_result return async_result - async def _call_async_item(self, data: List[np.ndarray], model_spec: TritonModelSpec): + async def _call_async_item( + self, + data: List[np.ndarray], + model_spec: TritonModelSpec, + parameters: dict | None = None, + ): current_grpc_async_tasks = [] try: @@ -301,7 +319,9 @@ async def _call_async_item(self, data: List[np.ndarray], model_spec: TritonModel current_grpc_async_tasks.append(generator) predict_tasks = [ - asyncio.create_task(send_request_async(self, data_queue, done_event, self.triton_client, model_spec)) + asyncio.create_task( + send_request_async(self, data_queue, done_event, self.triton_client, model_spec, parameters) + ) for idx in range(ASYNC_TASKS) ] current_grpc_async_tasks.extend(predict_tasks) diff --git a/tritony/version.py b/tritony/version.py index b2f0155..72eb129 100644 --- a/tritony/version.py +++ b/tritony/version.py @@ -1 +1 @@ -__version__ = "0.0.11" +__version__ = "0.0.12rc0"