Skip to content

Commit

Permalink
Add CLI argument for server-metrics-url (#39)
Browse files Browse the repository at this point in the history
* Add CLI argument for server-metrics-url

* Remove print statement

* Fix pre-commit error

* Graceful handling of exception thrown by genai-perf if metrics url is unreachable  (#47)

* Fix exception thrown by genai-perf if metrics url is unreachable

* Fix comments

* Fix comments

* Fix comments

* Add space in error message

* Add CLI argument for server-metrics-url

* Fix pre-commit error and remove duplicates

* Remove unnecessary print

* Fix comments

* Fix pytest error

* Fix comments

* Refactor test function names

* Fix comments

* Fix comment in test_profile_handler

* Cleanup test_profile_handler.py test

* Move MockArgs outside TestProfileHandler

* Move test triton metrics url to a variable
  • Loading branch information
lkomali authored Aug 28, 2024
1 parent 42b8591 commit 16e6d33
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 10 deletions.
59 changes: 57 additions & 2 deletions genai-perf/genai_perf/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from enum import Enum, auto
from pathlib import Path
from typing import Tuple
from urllib.parse import urlparse

import genai_perf.logging as logging
import genai_perf.utils as utils
Expand Down Expand Up @@ -262,6 +263,49 @@ def _check_goodput_args(args):
return args


def _is_valid_url(parser: argparse.ArgumentParser, url: str) -> None:
"""
Validates a URL to ensure it meets the following criteria:
- The scheme must be 'http' or 'https'.
- The netloc (domain) must be present.
- The path must contain '/metrics'.
Raises:
`parser.error()` if the URL is invalid.
The URL structure is expected to follow:
<scheme>://<netloc>/<path>;<params>?<query>#<fragment>
"""
parsed_url = urlparse(url)

if (
parsed_url.scheme not in ["http", "https"]
or not parsed_url.netloc
or "/metrics" not in parsed_url.path
or parsed_url.port is None
):
parser.error(
"The URL passed for --server-metrics-url is invalid. "
"It must use 'http' or 'https', have a valid domain and port, "
"and contain '/metrics' in the path. The expected structure is: "
"<scheme>://<netloc>/<path>;<params>?<query>#<fragment>"
)


def _check_server_metrics_url(
parser: argparse.ArgumentParser, args: argparse.Namespace
) -> argparse.Namespace:
"""
Checks if the server metrics URL passed is valid
"""

# Check if the URL is valid and contains the expected path
if args.service_kind == "triton" and args.server_metrics_url:
_is_valid_url(parser, args.server_metrics_url)

return args


def _set_artifact_paths(args: argparse.Namespace) -> argparse.Namespace:
"""
Set paths for all the artifacts.
Expand Down Expand Up @@ -643,6 +687,16 @@ def _add_endpoint_args(parser):
"you must specify an api via --endpoint-type.",
)

endpoint_group.add_argument(
"--server-metrics-url",
type=str,
default=None,
required=False,
help="The full URL to access the server metrics endpoint. "
"This argument is required if the metrics are available on "
"a different machine than localhost (where GenAI-Perf is running).",
)

endpoint_group.add_argument(
"--streaming",
action="store_true",
Expand Down Expand Up @@ -824,9 +878,9 @@ def profile_handler(args, extra_args):

telemetry_data_collector = None
if args.service_kind == "triton":
# TPA-275: pass server url as a CLI option in non-default case
server_metrics_url = args.server_metrics_url or DEFAULT_TRITON_METRICS_URL
telemetry_data_collector = TritonTelemetryDataCollector(
server_metrics_url=DEFAULT_TRITON_METRICS_URL
server_metrics_url=server_metrics_url
)

Profiler.run(
Expand Down Expand Up @@ -882,6 +936,7 @@ def refine_args(
args = _check_conditional_args(parser, args)
args = _check_image_input_args(parser, args)
args = _check_load_manager_args(args)
args = _check_server_metrics_url(parser, args)
args = _set_artifact_paths(args)
args = _check_goodput_args(args)
elif args.subcommand == Subcommand.COMPARE.to_lowercase():
Expand Down
17 changes: 9 additions & 8 deletions genai-perf/genai_perf/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,34 +75,35 @@ def build_cmd(args: Namespace, extra_args: Optional[List[str]] = None) -> List[s
"formatted_model_name",
"func",
"generate_plots",
"goodput",
"image_format",
"image_height_mean",
"image_height_stddev",
"image_width_mean",
"image_width_stddev",
"input_dataset",
"input_file",
"input_format",
"model",
"model_selection_strategy",
"num_prompts",
"output_format",
"output_tokens_mean_deterministic",
"output_tokens_mean",
"output_tokens_mean_deterministic",
"output_tokens_stddev",
"prompt_source",
"random_seed",
"request_rate",
"server_metrics_url",
# The 'streaming' passed in to this script is to determine if the
# LLM response should be streaming. That is different than the
# 'streaming' that PA takes, which means something else (and is
# required for decoupled models into triton).
"streaming",
"subcommand",
"synthetic_input_tokens_mean",
"synthetic_input_tokens_stddev",
"subcommand",
"tokenizer",
"image_width_mean",
"image_width_stddev",
"image_height_mean",
"image_height_stddev",
"image_format",
"goodput",
]

utils.remove_file(args.profile_export_file)
Expand Down
54 changes: 54 additions & 0 deletions genai-perf/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,22 @@ def test_unrecognized_arg(self, monkeypatch, capsys):
],
"The --generate-plots option is not currently supported with the image_retrieval endpoint type",
),
(
[
"genai-perf",
"profile",
"--model",
"test_model",
"--service-kind",
"triton",
"--server-metrics-url",
"invalid_url",
],
"The URL passed for --server-metrics-url is invalid. "
"It must use 'http' or 'https', have a valid domain and port, "
"and contain '/metrics' in the path. The expected structure is: "
"<scheme>://<netloc>/<path>;<params>?<query>#<fragment>",
),
],
)
def test_conditional_errors(self, args, expected_output, monkeypatch, capsys):
Expand Down Expand Up @@ -916,3 +932,41 @@ def test_get_extra_inputs_as_dict(self, extra_inputs_list, expected_dict):
namespace.extra_inputs = extra_inputs_list
actual_dict = parser.get_extra_inputs_as_dict(namespace)
assert actual_dict == expected_dict

test_triton_metrics_url = "http://tritonmetrics.com:8002/metrics"

@pytest.mark.parametrize(
"args_list, expected_url",
[
# server-metrics-url is specified
(
[
"genai-perf",
"profile",
"--model",
"test_model",
"--service-kind",
"triton",
"--server-metrics-url",
test_triton_metrics_url,
],
test_triton_metrics_url,
),
# server-metrics-url is not specified
(
[
"genai-perf",
"profile",
"--model",
"test_model",
"--service-kind",
"triton",
],
None,
),
],
)
def test_server_metrics_url_arg_valid(self, args_list, expected_url, monkeypatch):
monkeypatch.setattr("sys.argv", args_list)
args, _ = parser.parse_args()
assert args.server_metrics_url == expected_url
1 change: 1 addition & 0 deletions genai-perf/tests/test_json_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def test_generate_json_custom_export(
"endpoint": null,
"endpoint_type": null,
"service_kind": "triton",
"server_metrics_url": null,
"streaming": true,
"u": null,
"input_dataset": null,
Expand Down
71 changes: 71 additions & 0 deletions genai-perf/tests/test_profile_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#!/usr/bin/env python3
#
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from unittest.mock import patch

import pytest
from genai_perf.constants import DEFAULT_TRITON_METRICS_URL
from genai_perf.parser import profile_handler
from genai_perf.telemetry_data.triton_telemetry_data_collector import (
TritonTelemetryDataCollector,
)


class MockArgs:
def __init__(self, service_kind, server_metrics_url):
self.service_kind = service_kind
self.server_metrics_url = server_metrics_url


class TestProfileHandler:
test_triton_metrics_url = "http://tritonmetrics.com:8080/metrics"

@pytest.mark.parametrize(
"server_metrics_url, expected_url",
[
(test_triton_metrics_url, test_triton_metrics_url),
(None, DEFAULT_TRITON_METRICS_URL),
],
)
@patch("genai_perf.wrapper.Profiler.run")
def test_profile_handler_creates_telemetry_collector(
self, mock_profiler_run, server_metrics_url, expected_url
):
mock_args = MockArgs(
service_kind="triton", server_metrics_url=server_metrics_url
)
profile_handler(mock_args, extra_args={})
mock_profiler_run.assert_called_once()

args, kwargs = mock_profiler_run.call_args

assert "telemetry_data_collector" in kwargs

telemetry_data_collector = kwargs["telemetry_data_collector"]
assert isinstance(telemetry_data_collector, TritonTelemetryDataCollector)
assert telemetry_data_collector.metrics_url == expected_url

0 comments on commit 16e6d33

Please sign in to comment.