diff --git a/docs/user_guide/metrics.md b/docs/user_guide/metrics.md index b8fc0d8ee0..8a54565222 100644 --- a/docs/user_guide/metrics.md +++ b/docs/user_guide/metrics.md @@ -204,6 +204,43 @@ metrics are used for latencies: To disable these metrics specifically, you can set `--metrics-config counter_latencies=false` +#### Histograms + +> **Note** +> +> The following Histogram feature is experimental for the time being and may be +> subject to change based on user feedback. + +By default, the following +[Histogram](https://prometheus.io/docs/concepts/metric_types/#histogram) +metrics are used for latencies: + +|Category |Metric |Metric Name |Description |Granularity|Frequency |Model Type +|--------------|----------------|------------|---------------------------|-----------|-------------|-------------| +|Latency |Request to First Response Time |`nv_inference_first_response_histogram_ms` |Histogram of end-to-end inference request to the first response time |Per model |Per request | Decoupled | + +To enable these metrics specifically, you can set `--metrics-config histogram_latencies=true` + +Each histogram above is composed of several sub-metrics. For each histogram +metric, there is a set of `le` (less than or equal to) thresholds tracking +the counter for each bucket. Additionally, there are `_count` and `_sum` +metrics that aggregate the count and observed values for each. For example, +see the following information exposed by the "Time to First Response" histogram +metrics: +``` +# HELP nv_first_response_histogram_ms Duration from request to first response in milliseconds +# TYPE nv_first_response_histogram_ms histogram +nv_inference_first_response_histogram_ms_count{model="my_model",version="1"} 37 +nv_inference_first_response_histogram_ms_sum{model="my_model",version="1"} 10771 +nv_inference_first_response_histogram_ms{model="my_model",version="1", le="100"} 8 +nv_inference_first_response_histogram_ms{model="my_model",version="1", le="500"} 30 +nv_inference_first_response_histogram_ms{model="my_model",version="1", le="2000"} 36 +nv_inference_first_response_histogram_ms{model="my_model",version="1", le="5000"} 37 +nv_inference_first_response_histogram_ms{model="my_model",version="1", le="+Inf"} 37 +``` + +Triton initializes histograms with default buckets for each, as shown above. Customization of buckets per metric is currently unsupported. + #### Summaries > **Note** diff --git a/qa/L0_metrics/ensemble_decoupled/async_execute_decouple/1/model.py b/qa/L0_metrics/ensemble_decoupled/async_execute_decouple/1/model.py new file mode 100644 index 0000000000..6a73e12da4 --- /dev/null +++ b/qa/L0_metrics/ensemble_decoupled/async_execute_decouple/1/model.py @@ -0,0 +1,59 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import time + +import numpy as np +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + async def execute(self, requests): + request = requests[0] + wait_secs = pb_utils.get_input_tensor_by_name( + request, "WAIT_SECONDS" + ).as_numpy()[0] + response_num = pb_utils.get_input_tensor_by_name( + request, "RESPONSE_NUM" + ).as_numpy()[0] + output_tensors = [ + pb_utils.Tensor("WAIT_SECONDS", np.array([wait_secs], np.float32)), + pb_utils.Tensor("RESPONSE_NUM", np.array([1], np.uint8)), + ] + + # Wait + time.sleep(wait_secs.item()) + response_sender = request.get_response_sender() + for i in range(response_num): + response = pb_utils.InferenceResponse(output_tensors) + if i != response_num - 1: + response_sender.send(response) + else: + response_sender.send( + response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL + ) + + return None diff --git a/qa/L0_metrics/ensemble_decoupled/async_execute_decouple/config.pbtxt b/qa/L0_metrics/ensemble_decoupled/async_execute_decouple/config.pbtxt new file mode 100644 index 0000000000..cc3eda3324 --- /dev/null +++ b/qa/L0_metrics/ensemble_decoupled/async_execute_decouple/config.pbtxt @@ -0,0 +1,54 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +backend: "python" +input [ + { + name: "WAIT_SECONDS" + data_type: TYPE_FP32 + dims: [ 1 ] + }, + { + name: "RESPONSE_NUM" + data_type: TYPE_UINT8 + dims: [ 1 ] + } +] +output [ + { + name: "WAIT_SECONDS" + data_type: TYPE_FP32 + dims: [ 1 ] + }, + { + name: "RESPONSE_NUM" + data_type: TYPE_UINT8 + dims: [ 1 ] + } +] + +instance_group [{ kind: KIND_CPU }] +model_transaction_policy { decoupled: True } diff --git a/qa/L0_metrics/ensemble_decoupled/ensemble/config.pbtxt b/qa/L0_metrics/ensemble_decoupled/ensemble/config.pbtxt new file mode 100644 index 0000000000..e4f09e8086 --- /dev/null +++ b/qa/L0_metrics/ensemble_decoupled/ensemble/config.pbtxt @@ -0,0 +1,89 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "ensemble" +platform: "ensemble" +input [ + { + name: "INPUT0" + data_type: TYPE_FP32 + dims: [ 1 ] + }, + { + name: "INPUT1" + data_type: TYPE_UINT8 + dims: [ 1 ] + } +] +output [ + { + name: "OUTPUT" + data_type: TYPE_FP32 + dims: [ 1 ] + } +] +ensemble_scheduling { + step [ + { + # decoupled model + model_name: "async_execute_decouple" + model_version: 1 + input_map { + key: "WAIT_SECONDS" + value: "INPUT0" + } + input_map { + key: "RESPONSE_NUM" + value: "INPUT1" + } + output_map { + key: "WAIT_SECONDS" + value: "temp_output0" + } + output_map { + key: "RESPONSE_NUM" + value: "temp_output1" + } + }, + { + # non-decoupled model + model_name: "async_execute" + model_version: 1 + input_map { + key: "WAIT_SECONDS" + value: "temp_output0" + } + input_map { + key: "RESPONSE_NUM" + value: "temp_output1" + } + output_map { + key: "WAIT_SECONDS" + value: "OUTPUT" + } + } + ] +} diff --git a/qa/L0_metrics/histogram_metrics_test.py b/qa/L0_metrics/histogram_metrics_test.py new file mode 100755 index 0000000000..7480e2048b --- /dev/null +++ b/qa/L0_metrics/histogram_metrics_test.py @@ -0,0 +1,178 @@ +#!/usr/bin/python +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import re +import sys +import unittest +from functools import partial + +import numpy as np +import requests +import tritonclient.grpc as grpcclient +from tritonclient.utils import InferenceServerException + +sys.path.append("../common") +import test_util as tu + +MILLIS_PER_SEC = 1000 + + +def get_histogram_metric_key( + metric_family, model_name, model_version, metric_type, le="" +): + if metric_type in ["count", "sum"]: + return f'{metric_family}_{metric_type}{{model="{model_name}",version="{model_version}"}}' + else: + return None + + +class TestHistogramMetrics(tu.TestResultCollector): + def setUp(self): + self.tritonserver_ipaddr = os.environ.get("TRITONSERVER_IPADDR", "localhost") + + def get_histogram_metrics(self, metric_family: str): + r = requests.get(f"http://{self.tritonserver_ipaddr}:8002/metrics") + r.raise_for_status() + + # Regular expression to match the pattern + pattern = f"^{metric_family}.*" + histogram_dict = {} + + # Find all matches in the text + matches = re.findall(pattern, r.text, re.MULTILINE) + + for match in matches: + key, value = match.rsplit(" ") + histogram_dict[key] = int(value) + + return histogram_dict + + def async_stream_infer(self, model_name, inputs, outputs, responses_per_req): + with grpcclient.InferenceServerClient(url="localhost:8001") as triton_client: + # Define the callback function. Note the last two parameters should be + # result and error. InferenceServerClient would povide the results of an + # inference as grpcclient.InferResult in result. For successful + # inference, error will be None, otherwise it will be an object of + # tritonclientutils.InferenceServerException holding the error details + def callback(user_data, result, error): + if error: + user_data.append(error) + else: + user_data.append(result) + + # list to hold the results of inference. + user_data = [] + + # Inference call + triton_client.start_stream(callback=partial(callback, user_data)) + triton_client.async_stream_infer( + model_name=model_name, + inputs=inputs, + outputs=outputs, + ) + + self.assertEqual(len(user_data), responses_per_req) + # Validate the results + for i in range(len(user_data)): + # Check for the errors + self.assertNotIsInstance( + user_data[i], InferenceServerException, user_data[i] + ) + + def test_ensemble_decoupled(self): + wait_secs = 1 + responses_per_req = 3 + total_reqs = 3 + delta = 0.2 + + # Infer + inputs = [] + outputs = [] + inputs.append(grpcclient.InferInput("INPUT0", [1], "FP32")) + inputs.append(grpcclient.InferInput("INPUT1", [1], "UINT8")) + outputs.append(grpcclient.InferRequestedOutput("OUTPUT")) + + # Create the data for the input tensor. + input_data_0 = np.array([wait_secs], np.float32) + input_data_1 = np.array([responses_per_req], np.uint8) + + # Initialize the data + inputs[0].set_data_from_numpy(input_data_0) + inputs[1].set_data_from_numpy(input_data_1) + + # Send requests to ensemble decoupled model + for request_num in range(1, total_reqs + 1): + ensemble_model_name = "ensemble" + decoupled_model_name = "async_execute_decouple" + non_decoupled_model_name = "async_execute" + self.async_stream_infer( + ensemble_model_name, inputs, outputs, responses_per_req + ) + + # Checks metrics output + first_response_family = "nv_inference_first_response_histogram_ms" + histogram_dict = self.get_histogram_metrics(first_response_family) + + def check_existing_metrics(model_name, wait_secs_per_req, delta): + metric_count = get_histogram_metric_key( + first_response_family, model_name, "1", "count" + ) + model_sum = get_histogram_metric_key( + first_response_family, model_name, "1", "sum" + ) + # Test histogram count + self.assertIn(metric_count, histogram_dict) + self.assertEqual(histogram_dict[metric_count], request_num) + # Test histogram sum + self.assertIn(model_sum, histogram_dict) + self.assertTrue( + wait_secs_per_req * MILLIS_PER_SEC * request_num + <= histogram_dict[model_sum] + < (wait_secs_per_req + delta) * MILLIS_PER_SEC * request_num + ) + # Prometheus histogram buckets are tested in metrics_api_test.cc::HistogramAPIHelper + + # Test ensemble model metrics + check_existing_metrics(ensemble_model_name, 2 * wait_secs, 2 * delta) + + # Test decoupled model metrics + check_existing_metrics(decoupled_model_name, wait_secs, delta) + + # Test non-decoupled model metrics + non_decoupled_model_count = get_histogram_metric_key( + first_response_family, non_decoupled_model_name, "1", "count" + ) + non_decoupled_model_sum = get_histogram_metric_key( + first_response_family, non_decoupled_model_name, "1", "sum" + ) + self.assertNotIn(non_decoupled_model_count, histogram_dict) + self.assertNotIn(non_decoupled_model_sum, histogram_dict) + + +if __name__ == "__main__": + unittest.main() diff --git a/qa/L0_metrics/metrics_config_test.py b/qa/L0_metrics/metrics_config_test.py index 9153366c04..43e5a79ba1 100755 --- a/qa/L0_metrics/metrics_config_test.py +++ b/qa/L0_metrics/metrics_config_test.py @@ -44,6 +44,7 @@ "nv_inference_compute_infer_duration", "nv_inference_compute_output_duration", ] +INF_HISTOGRAM_DECOUPLED_PATTERNS = ["nv_inference_first_response_histogram_ms"] INF_SUMMARY_PATTERNS = [ "nv_inference_request_summary", "nv_inference_queue_summary", @@ -97,6 +98,18 @@ def test_cache_counters_missing(self): for metric in CACHE_COUNTER_PATTERNS: self.assertNotIn(metric, metrics) + # Histograms + def test_inf_histograms_decoupled_exist(self): + metrics = self._get_metrics() + for metric in INF_HISTOGRAM_DECOUPLED_PATTERNS: + for suffix in ["_count", "_sum", ""]: + self.assertIn(metric + suffix, metrics) + + def test_inf_histograms_decoupled_missing(self): + metrics = self._get_metrics() + for metric in INF_HISTOGRAM_DECOUPLED_PATTERNS: + self.assertNotIn(metric, metrics) + # Summaries def test_inf_summaries_exist(self): metrics = self._get_metrics() diff --git a/qa/L0_metrics/test.sh b/qa/L0_metrics/test.sh index 76e99e7c48..f6802622a3 100755 --- a/qa/L0_metrics/test.sh +++ b/qa/L0_metrics/test.sh @@ -270,11 +270,63 @@ mkdir -p "${MODELDIR}/identity_cache_on/1" mkdir -p "${MODELDIR}/identity_cache_off/1" BASE_SERVER_ARGS="--model-repository=${MODELDIR} --model-control-mode=explicit" -# Check default settings: Counters should be enabled, summaries should be disabled +# Check default settings: Counters should be enabled, histograms and summaries should be disabled SERVER_ARGS="${BASE_SERVER_ARGS} --load-model=identity_cache_off" run_and_check_server python3 ${PYTHON_TEST} MetricsConfigTest.test_inf_counters_exist 2>&1 | tee ${CLIENT_LOG} check_unit_test +python3 ${PYTHON_TEST} MetricsConfigTest.test_inf_histograms_decoupled_missing 2>&1 | tee ${CLIENT_LOG} +check_unit_test +python3 ${PYTHON_TEST} MetricsConfigTest.test_inf_summaries_missing 2>&1 | tee ${CLIENT_LOG} +check_unit_test +python3 ${PYTHON_TEST} MetricsConfigTest.test_cache_counters_missing 2>&1 | tee ${CLIENT_LOG} +check_unit_test +python3 ${PYTHON_TEST} MetricsConfigTest.test_cache_summaries_missing 2>&1 | tee ${CLIENT_LOG} +check_unit_test +kill_server + +# Check default settings: Histograms should be always disabled in non-decoupled model. +SERVER_ARGS="${BASE_SERVER_ARGS} --load-model=identity_cache_off --metrics-config histogram_latencies=true" +run_and_check_server +python3 ${PYTHON_TEST} MetricsConfigTest.test_inf_counters_exist 2>&1 | tee ${CLIENT_LOG} +check_unit_test +python3 ${PYTHON_TEST} MetricsConfigTest.test_inf_histograms_decoupled_missing 2>&1 | tee ${CLIENT_LOG} +check_unit_test +python3 ${PYTHON_TEST} MetricsConfigTest.test_inf_summaries_missing 2>&1 | tee ${CLIENT_LOG} +check_unit_test +python3 ${PYTHON_TEST} MetricsConfigTest.test_cache_counters_missing 2>&1 | tee ${CLIENT_LOG} +check_unit_test +python3 ${PYTHON_TEST} MetricsConfigTest.test_cache_summaries_missing 2>&1 | tee ${CLIENT_LOG} +check_unit_test +kill_server + +# Check default settings: Histograms should be disabled in decoupled model +decoupled_model_name="async_execute_decouple" +mkdir -p "${MODELDIR}/${decoupled_model_name}/1/" +cp ../python_models/${decoupled_model_name}/model.py ${MODELDIR}/${decoupled_model_name}/1/ +cp ../python_models/${decoupled_model_name}/config.pbtxt ${MODELDIR}/${decoupled_model_name}/ + +SERVER_ARGS="${BASE_SERVER_ARGS} --load-model=${decoupled_model_name}" +run_and_check_server +python3 ${PYTHON_TEST} MetricsConfigTest.test_inf_counters_exist 2>&1 | tee ${CLIENT_LOG} +check_unit_test +python3 ${PYTHON_TEST} MetricsConfigTest.test_inf_histograms_decoupled_missing 2>&1 | tee ${CLIENT_LOG} +check_unit_test +python3 ${PYTHON_TEST} MetricsConfigTest.test_inf_summaries_missing 2>&1 | tee ${CLIENT_LOG} +check_unit_test +python3 ${PYTHON_TEST} MetricsConfigTest.test_cache_counters_missing 2>&1 | tee ${CLIENT_LOG} +check_unit_test +python3 ${PYTHON_TEST} MetricsConfigTest.test_cache_summaries_missing 2>&1 | tee ${CLIENT_LOG} +check_unit_test +kill_server + +# Enable histograms in decoupled model +SERVER_ARGS="${BASE_SERVER_ARGS} --load-model=${decoupled_model_name} --metrics-config histogram_latencies=true" +run_and_check_server +python3 ${PYTHON_TEST} MetricsConfigTest.test_inf_counters_exist 2>&1 | tee ${CLIENT_LOG} +check_unit_test +python3 ${PYTHON_TEST} MetricsConfigTest.test_inf_histograms_decoupled_exist 2>&1 | tee ${CLIENT_LOG} +check_unit_test python3 ${PYTHON_TEST} MetricsConfigTest.test_inf_summaries_missing 2>&1 | tee ${CLIENT_LOG} check_unit_test python3 ${PYTHON_TEST} MetricsConfigTest.test_cache_counters_missing 2>&1 | tee ${CLIENT_LOG} @@ -406,6 +458,19 @@ kill_server expected_tests=6 check_unit_test "${expected_tests}" +### Test histogram data in ensemble decoupled model ### +MODELDIR="${PWD}/ensemble_decoupled" +SERVER_ARGS="--model-repository=${MODELDIR} --metrics-config histogram_latencies=true --log-verbose=1" +PYTHON_TEST="histogram_metrics_test.py" +mkdir -p "${MODELDIR}"/ensemble/1 +cp -r "${MODELDIR}"/async_execute_decouple "${MODELDIR}"/async_execute +sed -i "s/model_transaction_policy { decoupled: True }//" "${MODELDIR}"/async_execute/config.pbtxt + +run_and_check_server +python3 ${PYTHON_TEST} 2>&1 | tee ${CLIENT_LOG} +kill_server +check_unit_test + if [ $RET -eq 0 ]; then echo -e "\n***\n*** Test Passed\n***" else @@ -413,4 +478,3 @@ else fi exit $RET -