diff --git a/src/c++/perf_analyzer/command_line_parser.cc b/src/c++/perf_analyzer/command_line_parser.cc
index bd3d72d73..8003be711 100644
--- a/src/c++/perf_analyzer/command_line_parser.cc
+++ b/src/c++/perf_analyzer/command_line_parser.cc
@@ -1715,7 +1715,8 @@ CLParser::ParseCommandLine(int argc, char** argv)
// Overriding the max_threads default for request_rate search
if (!params_->max_threads_specified && params_->targeting_concurrency()) {
- params_->max_threads = 16;
+ params_->max_threads =
+ std::max(DEFAULT_MAX_THREADS, params_->concurrency_range.end);
}
if (params_->using_custom_intervals) {
diff --git a/src/c++/perf_analyzer/constants.h b/src/c++/perf_analyzer/constants.h
index 443806781..fbcd911b8 100644
--- a/src/c++/perf_analyzer/constants.h
+++ b/src/c++/perf_analyzer/constants.h
@@ -41,6 +41,7 @@ constexpr static const uint32_t STABILITY_ERROR = 2;
constexpr static const uint32_t OPTION_ERROR = 3;
constexpr static const uint32_t GENERIC_ERROR = 99;
+constexpr static const size_t DEFAULT_MAX_THREADS = 16;
const double DELAY_PCT_THRESHOLD{1.0};
diff --git a/src/c++/perf_analyzer/genai-perf/README.md b/src/c++/perf_analyzer/genai-perf/README.md
index 1d03b3dd0..53e510541 100644
--- a/src/c++/perf_analyzer/genai-perf/README.md
+++ b/src/c++/perf_analyzer/genai-perf/README.md
@@ -29,13 +29,13 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# GenAI-Perf
GenAI-Perf is a command line tool for measuring the throughput and latency of
-generative AI models as served through an inference server. For large language
-models (LLMs), GenAI-Perf provides metrics such as
+generative AI models as served through an inference server.
+For large language models (LLMs), GenAI-Perf provides metrics such as
[output token throughput](#output_token_throughput_metric),
[time to first token](#time_to_first_token_metric),
[inter token latency](#inter_token_latency_metric), and
-[request throughput](#request_throughput_metric). For a full list of metrics
-please see the [Metrics section](#metrics).
+[request throughput](#request_throughput_metric).
+For a full list of metrics please see the [Metrics section](#metrics).
Users specify a model name, an inference server URL, the type of inputs to use
(synthetic or from dataset), and the type of load to generate (number of
@@ -43,41 +43,56 @@ concurrent requests, request rate).
GenAI-Perf generates the specified load, measures the performance of the
inference server and reports the metrics in a simple table as console output.
-The tool also logs all results in a csv file that can be used to derive
+The tool also logs all results in a csv and json file that can be used to derive
additional metrics and visualizations. The inference server must already be
running when GenAI-Perf is run.
+You can use GenAI-Perf to run performance benchmarks on
+- [Large Language Models](docs/tutorial.md)
+- [Vision Language Models](docs/multi_modal.md)
+- [Embedding Models](docs/embeddings.md)
+- [Ranking Models](docs/rankings.md)
+- [Multiple LoRA Adapters](docs/lora.md)
+
> [!Note]
> GenAI-Perf is currently in early release and under rapid development. While we
> will try to remain consistent, command line options and functionality are
> subject to change as the tool matures.
-# Installation
+
-## Triton SDK Container
+
-Available starting with the 24.03 release of the
-[Triton Server SDK container](https://ngc.nvidia.com/catalog/containers/nvidia:tritonserver).
+## Installation
-Run the Triton Inference Server SDK docker container:
+The easiest way to install GenAI-Perf is through
+[Triton Server SDK container](https://ngc.nvidia.com/catalog/containers/nvidia:tritonserver).
+Install the latest release using the following command:
```bash
-export RELEASE="yy.mm" # e.g. export RELEASE="24.03"
+export RELEASE="yy.mm" # e.g. export RELEASE="24.06"
docker run -it --net=host --gpus=all nvcr.io/nvidia/tritonserver:${RELEASE}-py3-sdk
+
+# Check out genai_perf command inside the container:
+genai-perf --help
```
Alternatively, to install from source:
-## From Source
-
-GenAI-Perf depends on Perf Analyzer. Here is how to install Perf Analyzer:
+Since GenAI-Perf depends on Perf Analyzer,
+you'll need to install the Perf Analyzer binary:
### Install Perf Analyzer (Ubuntu, Python 3.8+)
-Note: you must already have CUDA 12 installed.
+**NOTE**: you must already have CUDA 12 installed
+(checkout the [CUDA installation guide](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html)).
```bash
pip install tritonclient
@@ -85,83 +100,70 @@ pip install tritonclient
apt update && apt install -y --no-install-recommends libb64-0d libcurl4
```
-Alternatively, you can install Perf Analyzer
-[from source](../docs/install.md#build-from-source).
+You can also build Perf Analyzer [from source](../docs/install.md#build-from-source) as well.
### Install GenAI-Perf from source
```bash
-export RELEASE="yy.mm" # e.g. export RELEASE="24.03"
+git clone https://github.com/triton-inference-server/client.git && cd client
-pip install "git+https://github.com/triton-inference-server/client.git@r${RELEASE}#subdirectory=src/c++/perf_analyzer/genai-perf"
+pip install -e .
```
-
-
-Run GenAI-Perf:
-
-```bash
-genai-perf --help
-```
-
-# Quick Start
-
-## Measuring Throughput and Latency of GPT2 using Triton + TensorRT-LLM
-
-### Running GPT2 on Triton Inference Server using TensorRT-LLM
-
-
-See instructions
-1. Run Triton Inference Server with TensorRT-LLM backend container:
+
-```bash
-export RELEASE="yy.mm" # e.g. export RELEASE="24.03"
+
-docker run -it --net=host --rm --gpus=all --shm-size=2g --ulimit memlock=-1 --ulimit stack=67108864 nvcr.io/nvidia/tritonserver:${RELEASE}-trtllm-python-py3
-```
+## Quick Start
-2. Install Triton CLI (~5 min):
+In this quick start, we will use GenAI-Perf to run performance benchmarking on
+the GPT-2 model running on Triton Inference Server with a TensorRT-LLM engine.
-```bash
-pip install \
- --extra-index-url https://pypi.nvidia.com \
- -U \
- psutil \
- "pynvml>=11.5.0" \
- torch==2.1.2 \
- tensorrt_llm==0.8.0 \
- "git+https://github.com/triton-inference-server/triton_cli@0.0.6"
-```
+### Serve GPT-2 TensorRT-LLM model using Triton CLI
-3. Download model:
+You can follow the [quickstart guide](https://github.com/triton-inference-server/triton_cli?tab=readme-ov-file#serving-a-trt-llm-model)
+on Triton CLI github repo to run GPT-2 model locally.
+The full instructions are copied below for convenience:
```bash
+# This container comes with all of the dependencies for building TRT-LLM engines
+# and serving the engine with Triton Inference Server.
+docker run -ti \
+ --gpus all \
+ --network=host \
+ --shm-size=1g --ulimit memlock=-1 \
+ -v /tmp:/tmp \
+ -v ${HOME}/models:/root/models \
+ -v ${HOME}/.cache/huggingface:/root/.cache/huggingface \
+ nvcr.io/nvidia/tritonserver:24.05-trtllm-python-py3
+
+# Install the Triton CLI
+pip install git+https://github.com/triton-inference-server/triton_cli.git@0.0.8
+
+# Build TRT LLM engine and generate a Triton model repository pointing at it
+triton remove -m all
triton import -m gpt2 --backend tensorrtllm
-```
-4. Run server:
-
-```bash
+# Start Triton pointing at the default model repository
triton start
```
-
-
### Running GenAI-Perf
-1. Run Triton Inference Server SDK container:
+Now we can run GenAI-Perf from Triton Inference Server SDK container:
```bash
-export RELEASE="yy.mm" # e.g. export RELEASE="24.03"
+export RELEASE="yy.mm" # e.g. export RELEASE="24.06"
docker run -it --net=host --rm --gpus=all nvcr.io/nvidia/tritonserver:${RELEASE}-py3-sdk
-```
-2. Run GenAI-Perf:
-
-```bash
+# Run GenAI-Perf in the container:
genai-perf profile \
-m gpt2 \
--service-kind triton \
@@ -184,25 +186,31 @@ genai-perf profile \
Example output:
```
- LLM Metrics
-┏━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓
-┃ Statistic ┃ avg ┃ min ┃ max ┃ p99 ┃ p90 ┃ p75 ┃
-┡━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩
-│ Time to first token (ns) │ 13,266,974 │ 11,818,732 │ 18,351,779 │ 16,513,479 │ 13,741,986 │ 13,544,376 │
-│ Inter token latency (ns) │ 2,069,766 │ 42,023 │ 15,307,799 │ 3,256,375 │ 3,020,580 │ 2,090,930 │
-│ Request latency (ns) │ 223,532,625 │ 219,123,330 │ 241,004,192 │ 238,198,306 │ 229,676,183 │ 224,715,918 │
-│ Output sequence length │ 104 │ 100 │ 129 │ 128 │ 109 │ 105 │
-│ Input sequence length │ 199 │ 199 │ 199 │ 199 │ 199 │ 199 │
-└──────────────────────────┴─────────────┴─────────────┴─────────────┴─────────────┴─────────────┴─────────────┘
-Output token throughput (per sec): 460.42
-Request throughput (per sec): 4.44
+ LLM Metrics
+┏━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━━┓
+┃ Statistic ┃ avg ┃ min ┃ max ┃ p99 ┃ p90 ┃ p75 ┃
+┡━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━┩
+│ Time to first token (ms) │ 11.70 │ 9.88 │ 17.21 │ 14.35 │ 12.01 │ 11.87 │
+│ Inter token latency (ms) │ 1.46 │ 1.08 │ 1.89 │ 1.87 │ 1.62 │ 1.52 │
+│ Request latency (ms) │ 161.24 │ 153.45 │ 200.74 │ 200.66 │ 179.43 │ 162.23 │
+│ Output sequence length │ 103.39 │ 95.00 │ 134.00 │ 120.08 │ 107.30 │ 105.00 │
+│ Input sequence length │ 200.01 │ 200.00 │ 201.00 │ 200.13 │ 200.00 │ 200.00 │
+└──────────────────────────┴────────┴────────┴────────┴────────┴────────┴────────┘
+Output token throughput (per sec): 635.61
+Request throughput (per sec): 6.15
```
See [Tutorial](docs/tutorial.md) for additional examples.
-# Visualization
+
+
+## Visualization
GenAI-Perf can also generate various plots that visualize the performance of the
current profile run. This is disabled by default but users can easily enable it
@@ -226,12 +234,12 @@ This will generate a [set of default plots](docs/compare.md#example-plots) such
- Input sequence lengths vs Output sequence lengths
-## Using `compare` Subcommand to Visualize Multiple Runs
+### Using `compare` Subcommand to Visualize Multiple Runs
The `compare` subcommand in GenAI-Perf facilitates users in comparing multiple
profile runs and visualizing the differences through plots.
-### Usage
+#### Usage
Assuming the user possesses two profile export JSON files,
namely `profile1.json` and `profile2.json`,
they can execute the `compare` subcommand using the `--files` option:
@@ -258,7 +266,7 @@ compare
└── ...
```
-### Customization
+#### Customization
Users have the flexibility to iteratively modify the generated YAML configuration
file to suit their specific requirements.
They can make alterations to the plots according to their preferences and execute
@@ -277,7 +285,13 @@ See [Compare documentation](docs/compare.md) for more details.
-# Model Inputs
+
+
+## Model Inputs
GenAI-Perf supports model input prompts from either synthetically generated
inputs, or from the HuggingFace
@@ -323,7 +337,13 @@ You can optionally set additional model inputs with the following option:
-# Metrics
+
+
+## Metrics
GenAI-Perf collects a diverse set of metrics that captures the performance of
the inference server.
@@ -340,14 +360,20 @@ the inference server.
-# Command Line Options
+
+
+## Command Line Options
##### `-h`
##### `--help`
Show the help message and exit.
-## Endpoint Options:
+### Endpoint Options:
##### `-m `
##### `--model `
@@ -392,7 +418,7 @@ An option to enable the use of the streaming API. (default: `False`)
URL of the endpoint to target for benchmarking. (default: `None`)
-## Input Options
+### Input Options
##### `-b `
##### `--batch-size `
@@ -458,7 +484,7 @@ data. (default: `550`)
The standard deviation of number of tokens in the generated prompts when
using synthetic data. (default: `0`)
-## Profiling Options
+### Profiling Options
##### `--concurrency `
@@ -483,7 +509,7 @@ stable. The measurement is considered as stable if the ratio of max / min from
the recent 3 measurements is within (stability percentage) in terms of both
infer per second and latency. (default: `999`)
-## Output Options
+### Output Options
##### `--artifact-dir`
@@ -502,7 +528,7 @@ exported to `_genai_perf.csv`. For example, if the profile
export file is `profile_export.json`, the genai-perf file will be exported to
`profile_export_genai_perf.csv`. (default: `profile_export.json`)
-## Other Options
+### Other Options
##### `--tokenizer `
@@ -518,7 +544,15 @@ An option to enable verbose mode. (default: `False`)
An option to print the version and exit.
-# Known Issues
+
+
+
+
+## Known Issues
* GenAI-Perf can be slow to finish if a high request-rate is provided
* Token counts may not be exact
diff --git a/src/c++/perf_analyzer/genai-perf/docs/multi_modal.md b/src/c++/perf_analyzer/genai-perf/docs/multi_modal.md
new file mode 100644
index 000000000..bb9f33c60
--- /dev/null
+++ b/src/c++/perf_analyzer/genai-perf/docs/multi_modal.md
@@ -0,0 +1,122 @@
+
+
+# Profile Vision-Language Models with GenAI-Perf
+
+GenAI-Perf allows you to profile Vision-Language Models (VLM) running on
+[OpenAI Chat Completions API](https://platform.openai.com/docs/guides/chat-completions)-compatible server
+by sending [multi-modal content](https://platform.openai.com/docs/guides/vision) to the server.
+Currently, you can send multi-modal contents with GenAI-Perf using the following two approaches:
+1. The synthetic data generation approach, where GenAI-Perf generates the multi-modal data for you.
+2. The Bring Your Own Data (BYOD) approach, where you provide GenAI-Perf with the data to send.
+
+Before we dive into the two approaches,
+you can start OpenAI API compatible server with a VLM model using following command:
+
+```bash
+docker run --runtime nvidia --gpus all \
+ -p 8000:8000 --ipc=host \
+ vllm/vllm-openai:latest \
+ --model llava-hf/llava-v1.6-mistral-7b-hf --dtype float16
+```
+
+
+## Approach 1: Synthetic Multi-Modal Data Generation
+
+GenAI-Perf can generate synthetic multi-modal data such as texts or images using
+the parameters provide by the user through CLI.
+
+```bash
+genai-perf profile \
+ -m llava-hf/llava-v1.6-mistral-7b-hf \
+ --service-kind openai \
+ --endpoint-type vision \
+ --image-width-mean 512 \
+ --image-width-stddev 30 \
+ --image-height-mean 512 \
+ --image-height-stddev 30 \
+ --image-format png \
+ --synthetic-input-tokens-mean 100 \
+ --synthetic-input-tokens-stddev 0 \
+ --streaming
+```
+
+> [!Note]
+> Under the hood, GenAI-Perf generates synthetic images using a few source images
+> under the `llm_inputs/source_images` directory.
+> If you would like to add/remove/edit the source images,
+> you can do so by directly editing the source images under the directory.
+> GenAI-Perf will pickup the images under the directory automatically when
+> generating the synthetic images.
+
+
+## Approach 2: Bring Your Own Data (BYOD)
+
+Instead of letting GenAI-Perf create the synthetic data,
+you can also provide GenAI-Perf with your own data using
+[`--input-file`](../README.md#--input-file-path) CLI option.
+The file needs to be in JSONL format and should contain both the prompt and
+the filepath to the image to send.
+
+For instance, an example of input file would look something as following:
+```bash
+// input.jsonl
+{"text_input": "What is in this image?", "image": "path/to/image1.png"}
+{"text_input": "What is the color of the dog?", "image": "path/to/image2.jpeg"}
+{"text_input": "Describe the scene in the picture.", "image": "path/to/image3.png"}
+...
+```
+
+After you create the file, you can run GenAI-Perf using the following command:
+
+```bash
+genai-perf profile \
+ -m llava-hf/llava-v1.6-mistral-7b-hf \
+ --service-kind openai \
+ --endpoint-type vision \
+ --input-file input.jsonl \
+ --streaming
+```
+
+Running GenAI-Perf using either approach will give you an example output that
+looks like below:
+
+```bash
+ LLM Metrics
+┏━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━┓
+┃ Statistic ┃ avg ┃ min ┃ max ┃ p99 ┃ p90 ┃ p75 ┃
+┡━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━┩
+│ Time to first token (ms) │ 321.05 │ 291.30 │ 537.07 │ 497.88 │ 318.46 │ 317.35 │
+│ Inter token latency (ms) │ 12.28 │ 11.44 │ 12.88 │ 12.87 │ 12.81 │ 12.53 │
+│ Request latency (ms) │ 1,866.23 │ 1,044.70 │ 2,832.22 │ 2,779.63 │ 2,534.64 │ 2,054.03 │
+│ Output sequence length │ 126.68 │ 59.00 │ 204.00 │ 200.58 │ 177.80 │ 147.50 │
+│ Input sequence length │ 100.00 │ 100.00 │ 100.00 │ 100.00 │ 100.00 │ 100.00 │
+└──────────────────────────┴──────────┴──────────┴──────────┴──────────┴──────────┴──────────┘
+Output token throughput (per sec): 67.40
+Request throughput (per sec): 0.53
+```
diff --git a/src/c++/perf_analyzer/genai-perf/docs/tutorial.md b/src/c++/perf_analyzer/genai-perf/docs/tutorial.md
index 1a37baf39..15cc53efe 100644
--- a/src/c++/perf_analyzer/genai-perf/docs/tutorial.md
+++ b/src/c++/perf_analyzer/genai-perf/docs/tutorial.md
@@ -71,7 +71,6 @@ export RELEASE="yy.mm" # e.g. export RELEASE="24.06"
docker run -it --net=host --gpus=all nvcr.io/nvidia/tritonserver:${RELEASE}-py3-sdk
# Run GenAI-Perf in the container:
-```bash
genai-perf profile \
-m gpt2 \
--service-kind triton \
@@ -145,7 +144,6 @@ export RELEASE="yy.mm" # e.g. export RELEASE="24.06"
docker run -it --net=host --gpus=1 nvcr.io/nvidia/tritonserver:${RELEASE}-py3-sdk
# Run GenAI-Perf in the container:
-```bash
genai-perf profile \
-m gpt2 \
--service-kind triton \
@@ -207,7 +205,6 @@ export RELEASE="yy.mm" # e.g. export RELEASE="24.06"
docker run -it --net=host --gpus=all nvcr.io/nvidia/tritonserver:${RELEASE}-py3-sdk
# Run GenAI-Perf in the container:
-```bash
genai-perf profile \
-m gpt2 \
--service-kind openai \
@@ -270,7 +267,6 @@ docker run -it --net=host --gpus=all nvcr.io/nvidia/tritonserver:${RELEASE}-py3-
# Run GenAI-Perf in the container:
-```bash
genai-perf profile \
-m gpt2 \
--service-kind openai \
diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/__init__.py b/src/c++/perf_analyzer/genai-perf/genai_perf/__init__.py
index cb5c26999..d656fe629 100644
--- a/src/c++/perf_analyzer/genai-perf/genai_perf/__init__.py
+++ b/src/c++/perf_analyzer/genai-perf/genai_perf/__init__.py
@@ -24,4 +24,4 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-__version__ = "0.0.4dev"
+__version__ = "0.0.5dev"
diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/llm_inputs.py b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/llm_inputs.py
index 39abc7ece..057c33562 100644
--- a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/llm_inputs.py
+++ b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/llm_inputs.py
@@ -20,11 +20,17 @@
from typing import Any, Dict, List, Optional, Tuple, cast
import requests
+from genai_perf import utils
from genai_perf.constants import CNN_DAILY_MAIL, DEFAULT_INPUT_DATA_JSON, OPEN_ORCA
from genai_perf.exceptions import GenAIPerfException
+from genai_perf.llm_inputs.synthetic_image_generator import (
+ ImageFormat,
+ SyntheticImageGenerator,
+)
from genai_perf.llm_inputs.synthetic_prompt_generator import SyntheticPromptGenerator
from genai_perf.tokenizer import DEFAULT_TOKENIZER, Tokenizer, get_tokenizer
from genai_perf.utils import load_json_str
+from PIL import Image
from requests import Response
@@ -43,6 +49,7 @@ class OutputFormat(Enum):
OPENAI_CHAT_COMPLETIONS = auto()
OPENAI_COMPLETIONS = auto()
OPENAI_EMBEDDINGS = auto()
+ OPENAI_VISION = auto()
RANKINGS = auto()
TENSORRTLLM = auto()
VLLM = auto()
@@ -75,6 +82,11 @@ class LlmInputs:
DEFAULT_OUTPUT_TOKENS_STDDEV = 0
DEFAULT_NUM_PROMPTS = 100
+ DEFAULT_IMAGE_WIDTH_MEAN = 100
+ DEFAULT_IMAGE_WIDTH_STDDEV = 0
+ DEFAULT_IMAGE_HEIGHT_MEAN = 100
+ DEFAULT_IMAGE_HEIGHT_STDDEV = 0
+
EMPTY_JSON_IN_VLLM_PA_FORMAT: Dict = {"data": []}
EMPTY_JSON_IN_TENSORRTLLM_PA_FORMAT: Dict = {"data": []}
EMPTY_JSON_IN_OPENAI_PA_FORMAT: Dict = {"data": []}
@@ -97,6 +109,11 @@ def create_llm_inputs(
output_tokens_deterministic: bool = False,
prompt_tokens_mean: int = DEFAULT_PROMPT_TOKENS_MEAN,
prompt_tokens_stddev: int = DEFAULT_PROMPT_TOKENS_STDDEV,
+ image_width_mean: int = DEFAULT_IMAGE_WIDTH_MEAN,
+ image_width_stddev: int = DEFAULT_IMAGE_WIDTH_STDDEV,
+ image_height_mean: int = DEFAULT_IMAGE_HEIGHT_MEAN,
+ image_height_stddev: int = DEFAULT_IMAGE_HEIGHT_STDDEV,
+ image_format: ImageFormat = ImageFormat.PNG,
random_seed: int = DEFAULT_RANDOM_SEED,
num_of_output_prompts: int = DEFAULT_NUM_PROMPTS,
add_model_name: bool = False,
@@ -139,6 +156,16 @@ def create_llm_inputs(
The standard deviation of the length of the output to generate. This is only used if output_tokens_mean is provided.
output_tokens_deterministic:
If true, the output tokens will set the minimum and maximum tokens to be equivalent.
+ image_width_mean:
+ The mean width of images when generating synthetic image data.
+ image_width_stddev:
+ The standard deviation of width of images when generating synthetic image data.
+ image_height_mean:
+ The mean height of images when generating synthetic image data.
+ image_height_stddev:
+ The standard deviation of height of images when generating synthetic image data.
+ image_format:
+ The compression format of the images.
batch_size:
The number of inputs per request (currently only used for the embeddings and rankings endpoints)
@@ -175,6 +202,11 @@ def create_llm_inputs(
prompt_tokens_mean,
prompt_tokens_stddev,
num_of_output_prompts,
+ image_width_mean,
+ image_width_stddev,
+ image_height_mean,
+ image_height_stddev,
+ image_format,
batch_size,
input_filename,
)
@@ -210,6 +242,11 @@ def get_generic_dataset_json(
prompt_tokens_mean: int,
prompt_tokens_stddev: int,
num_of_output_prompts: int,
+ image_width_mean: int,
+ image_width_stddev: int,
+ image_height_mean: int,
+ image_height_stddev: int,
+ image_format: ImageFormat,
batch_size: int,
input_filename: Optional[Path],
) -> Dict:
@@ -236,6 +273,16 @@ def get_generic_dataset_json(
The standard deviation of the length of the prompt to generate
num_of_output_prompts:
The number of synthetic output prompts to generate
+ image_width_mean:
+ The mean width of images when generating synthetic image data.
+ image_width_stddev:
+ The standard deviation of width of images when generating synthetic image data.
+ image_height_mean:
+ The mean height of images when generating synthetic image data.
+ image_height_stddev:
+ The standard deviation of height of images when generating synthetic image data.
+ image_format:
+ The compression format of the images.
batch_size:
The number of inputs per request (currently only used for the embeddings and rankings endpoints)
input_filename:
@@ -280,6 +327,12 @@ def get_generic_dataset_json(
)
else:
if input_type == PromptSource.DATASET:
+ # (TMA-1990) support VLM input from public dataset
+ if output_format == OutputFormat.OPENAI_VISION:
+ raise GenAIPerfException(
+ f"{OutputFormat.OPENAI_VISION.to_lowercase()} currently "
+ "does not support dataset as input."
+ )
dataset = cls._get_input_dataset_from_url(
dataset_name, starting_index, length
)
@@ -292,6 +345,12 @@ def get_generic_dataset_json(
prompt_tokens_mean,
prompt_tokens_stddev,
num_of_output_prompts,
+ image_width_mean,
+ image_width_stddev,
+ image_height_mean,
+ image_height_stddev,
+ image_format,
+ output_format,
)
generic_dataset_json = (
cls._convert_input_synthetic_or_file_dataset_to_generic_json(
@@ -301,6 +360,9 @@ def get_generic_dataset_json(
elif input_type == PromptSource.FILE:
input_filename = cast(Path, input_filename)
input_file_dataset = cls._get_input_dataset_from_file(input_filename)
+ input_file_dataset = cls._encode_images_in_input_dataset(
+ input_file_dataset
+ )
generic_dataset_json = (
cls._convert_input_synthetic_or_file_dataset_to_generic_json(
input_file_dataset
@@ -309,6 +371,14 @@ def get_generic_dataset_json(
else:
raise GenAIPerfException("Input source is not recognized.")
+ # When the generic_dataset_json contains multi-modal data (e.g. images),
+ # convert the format of the content to OpenAI multi-modal format:
+ # see https://platform.openai.com/docs/guides/vision
+ if output_format == OutputFormat.OPENAI_VISION:
+ generic_dataset_json = cls._convert_to_openai_multi_modal_content(
+ generic_dataset_json
+ )
+
return generic_dataset_json
@classmethod
@@ -405,17 +475,36 @@ def _get_input_dataset_from_synthetic(
prompt_tokens_mean: int,
prompt_tokens_stddev: int,
num_of_output_prompts: int,
+ image_width_mean: int,
+ image_width_stddev: int,
+ image_height_mean: int,
+ image_height_stddev: int,
+ image_format: ImageFormat,
+ output_format: OutputFormat,
) -> Dict[str, Any]:
dataset_json: Dict[str, Any] = {}
dataset_json["features"] = [{"name": "text_input"}]
dataset_json["rows"] = []
for _ in range(num_of_output_prompts):
+ row: Dict["str", Any] = {"row": {}}
synthetic_prompt = cls._create_synthetic_prompt(
tokenizer,
prompt_tokens_mean,
prompt_tokens_stddev,
)
- dataset_json["rows"].append({"row": {"text_input": synthetic_prompt}})
+ row["row"]["text_input"] = synthetic_prompt
+
+ if output_format == OutputFormat.OPENAI_VISION:
+ synthetic_image = cls._create_synthetic_image(
+ image_width_mean=image_width_mean,
+ image_width_stddev=image_width_stddev,
+ image_height_mean=image_height_mean,
+ image_height_stddev=image_height_stddev,
+ image_format=image_format,
+ )
+ row["row"]["image"] = synthetic_image
+
+ dataset_json["rows"].append(row)
return dataset_json
@@ -497,29 +586,37 @@ def _add_rows_to_generic_json(
@classmethod
def _get_input_dataset_from_file(cls, input_filename: Path) -> Dict:
"""
- Reads the input prompts from a JSONL file and converts them into the required dataset format.
+ Reads the input prompts and images from a JSONL file and converts them
+ into the required dataset format.
Parameters
----------
input_filename : Path
- The path to the input file containing the prompts in JSONL format.
+ The path to the input file containing the prompts and/or images in
+ JSONL format.
Returns
-------
Dict
- The dataset in the required format with the prompts read from the file.
+ The dataset in the required format with the prompts and/or images
+ read from the file.
"""
cls.verify_file(input_filename)
- input_file_prompts = cls._get_prompts_from_input_file(input_filename)
+ prompts, images = cls._get_prompts_from_input_file(input_filename)
dataset_json: Dict[str, Any] = {}
dataset_json["features"] = [{"name": "text_input"}]
- dataset_json["rows"] = [
- {"row": {"text_input": prompt}} for prompt in input_file_prompts
- ]
+ dataset_json["rows"] = []
+ for prompt, image in zip(prompts, images):
+ content = {"text_input": prompt}
+ content.update({"image": image} if image else {})
+ dataset_json["rows"].append({"row": content})
+
return dataset_json
@classmethod
- def _get_prompts_from_input_file(cls, input_filename: Path) -> List[str]:
+ def _get_prompts_from_input_file(
+ cls, input_filename: Path
+ ) -> Tuple[List[str], List[str]]:
"""
Reads the input prompts from a JSONL file and returns a list of prompts.
@@ -530,21 +627,63 @@ def _get_prompts_from_input_file(cls, input_filename: Path) -> List[str]:
Returns
-------
- List[str]
- A list of prompts read from the file.
+ Tuple[List[str], List[str]]
+ A list of prompts and images read from the file.
"""
prompts = []
+ images = []
with open(input_filename, mode="r", newline=None) as file:
for line in file:
if line.strip():
prompts.append(load_json_str(line).get("text_input", "").strip())
- return prompts
+ images.append(load_json_str(line).get("image", "").strip())
+ return prompts, images
@classmethod
def verify_file(cls, input_filename: Path) -> None:
if not input_filename.exists():
raise FileNotFoundError(f"The file '{input_filename}' does not exist.")
+ @classmethod
+ def _convert_to_openai_multi_modal_content(
+ cls, generic_dataset_json: Dict[str, List[Dict]]
+ ) -> Dict[str, List[Dict]]:
+ """
+ Converts to multi-modal content format of OpenAI Chat Completions API.
+ """
+ for row in generic_dataset_json["rows"]:
+ if row["image"]:
+ row["text_input"] = [
+ {
+ "type": "text",
+ "text": row["text_input"],
+ },
+ {
+ "type": "image_url",
+ "image_url": {"url": row["image"]},
+ },
+ ]
+
+ return generic_dataset_json
+
+ @classmethod
+ def _encode_images_in_input_dataset(cls, input_file_dataset: Dict) -> Dict:
+ for row in input_file_dataset["rows"]:
+ filename = row["row"].get("image")
+ if filename:
+ img = Image.open(filename)
+ if img.format.lower() not in utils.get_enum_names(ImageFormat):
+ raise GenAIPerfException(
+ f"Unsupported image format '{img.format}' of "
+ f"the image '{filename}'."
+ )
+
+ img_base64 = utils.encode_image(img, img.format)
+ payload = f"data:image/{img.format.lower()};base64,{img_base64}"
+ row["row"]["image"] = payload
+
+ return input_file_dataset
+
@classmethod
def _convert_generic_json_to_output_format(
cls,
@@ -559,7 +698,10 @@ def _convert_generic_json_to_output_format(
model_name: list = [],
model_selection_strategy: ModelSelectionStrategy = ModelSelectionStrategy.ROUND_ROBIN,
) -> Dict:
- if output_format == OutputFormat.OPENAI_CHAT_COMPLETIONS:
+ if (
+ output_format == OutputFormat.OPENAI_CHAT_COMPLETIONS
+ or output_format == OutputFormat.OPENAI_VISION
+ ):
output_json = cls._convert_generic_json_to_openai_chat_completions_format(
generic_dataset,
add_model_name,
@@ -1424,3 +1566,20 @@ def _create_synthetic_prompt(
return SyntheticPromptGenerator.create_synthetic_prompt(
tokenizer, prompt_tokens_mean, prompt_tokens_stddev
)
+
+ @classmethod
+ def _create_synthetic_image(
+ cls,
+ image_width_mean: int,
+ image_width_stddev: int,
+ image_height_mean: int,
+ image_height_stddev: int,
+ image_format: ImageFormat,
+ ) -> str:
+ return SyntheticImageGenerator.create_synthetic_image(
+ image_width_mean=image_width_mean,
+ image_width_stddev=image_width_stddev,
+ image_height_mean=image_height_mean,
+ image_height_stddev=image_height_stddev,
+ image_format=image_format,
+ )
diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/dlss.png b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/dlss.png
new file mode 100644
index 000000000..cdba23dd3
Binary files /dev/null and b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/dlss.png differ
diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/h100.jpeg b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/h100.jpeg
new file mode 100644
index 000000000..aee985fdc
Binary files /dev/null and b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/h100.jpeg differ
diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/h200.jpeg b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/h200.jpeg
new file mode 100644
index 000000000..eb0633b27
Binary files /dev/null and b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/h200.jpeg differ
diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/jensen.jpeg b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/jensen.jpeg
new file mode 100644
index 000000000..c9c831680
Binary files /dev/null and b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/jensen.jpeg differ
diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/synthetic_image_generator.py b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/synthetic_image_generator.py
new file mode 100644
index 000000000..a2df14d87
--- /dev/null
+++ b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/synthetic_image_generator.py
@@ -0,0 +1,82 @@
+# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+# * Neither the name of NVIDIA CORPORATION nor the names of its
+# contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import glob
+import random
+from enum import Enum, auto
+from pathlib import Path
+from typing import Optional
+
+from genai_perf import utils
+from PIL import Image
+
+
+class ImageFormat(Enum):
+ PNG = auto()
+ JPEG = auto()
+
+
+class SyntheticImageGenerator:
+ """A simple synthetic image generator that generates multiple synthetic
+ images from the source images.
+ """
+
+ @classmethod
+ def create_synthetic_image(
+ cls,
+ image_width_mean: int,
+ image_width_stddev: int,
+ image_height_mean: int,
+ image_height_stddev: int,
+ image_format: Optional[ImageFormat] = None,
+ ) -> str:
+ """Generate base64 encoded synthetic image using the source images."""
+ if image_format is None:
+ image_format = random.choice(list(ImageFormat))
+ width = cls._sample_random_positive_integer(
+ image_width_mean, image_width_stddev
+ )
+ height = cls._sample_random_positive_integer(
+ image_height_mean, image_height_stddev
+ )
+
+ image = cls._sample_source_image()
+ image = image.resize(size=(width, height))
+
+ img_base64 = utils.encode_image(image, image_format.name)
+ return f"data:image/{image_format.name.lower()};base64,{img_base64}"
+
+ @classmethod
+ def _sample_source_image(cls):
+ """Sample one image among the source images."""
+ filepath = Path(__file__).parent.resolve() / "source_images" / "*"
+ filenames = glob.glob(str(filepath))
+ return Image.open(random.choice(filenames))
+
+ @classmethod
+ def _sample_random_positive_integer(cls, mean: int, stddev: int) -> int:
+ n = int(abs(random.gauss(mean, stddev)))
+ return n if n != 0 else 1 # avoid zero
diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/main.py b/src/c++/perf_analyzer/genai-perf/genai_perf/main.py
index 912ee4725..9ff7b5b9a 100755
--- a/src/c++/perf_analyzer/genai-perf/genai_perf/main.py
+++ b/src/c++/perf_analyzer/genai-perf/genai_perf/main.py
@@ -76,6 +76,11 @@ def generate_inputs(args: Namespace, tokenizer: Tokenizer) -> None:
output_tokens_mean=args.output_tokens_mean,
output_tokens_stddev=args.output_tokens_stddev,
output_tokens_deterministic=args.output_tokens_mean_deterministic,
+ image_width_mean=args.image_width_mean,
+ image_width_stddev=args.image_width_stddev,
+ image_height_mean=args.image_height_mean,
+ image_height_stddev=args.image_height_stddev,
+ image_format=args.image_format,
random_seed=args.random_seed,
num_of_output_prompts=args.num_prompts,
add_model_name=add_model_name,
diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/parser.py b/src/c++/perf_analyzer/genai-perf/genai_perf/parser.py
index 901cf6ca2..776535d15 100644
--- a/src/c++/perf_analyzer/genai-perf/genai_perf/parser.py
+++ b/src/c++/perf_analyzer/genai-perf/genai_perf/parser.py
@@ -46,6 +46,7 @@
OutputFormat,
PromptSource,
)
+from genai_perf.llm_inputs.synthetic_image_generator import ImageFormat
from genai_perf.plots.plot_config_parser import PlotConfigParser
from genai_perf.plots.plot_manager import PlotManager
from genai_perf.tokenizer import DEFAULT_TOKENIZER
@@ -76,6 +77,7 @@ def to_lowercase(self):
"completions": "v1/completions",
"embeddings": "v1/embeddings",
"rankings": "v1/ranking",
+ "vision": "v1/chat/completions",
}
@@ -115,6 +117,25 @@ def _check_compare_args(
return args
+def _check_image_input_args(
+ parser: argparse.ArgumentParser, args: argparse.Namespace
+) -> argparse.Namespace:
+ """
+ Sanity check the image input args
+ """
+ if args.image_width_mean <= 0 or args.image_height_mean <= 0:
+ parser.error(
+ "Both --image-width-mean and --image-height-mean values must be positive."
+ )
+ if args.image_width_stddev < 0 or args.image_height_stddev < 0:
+ parser.error(
+ "Both --image-width-stddev and --image-height-stddev values must be non-negative."
+ )
+
+ args = _convert_str_to_enum_entry(args, "image_format", ImageFormat)
+ return args
+
+
def _check_conditional_args(
parser: argparse.ArgumentParser, args: argparse.Namespace
) -> argparse.Namespace:
@@ -138,6 +159,11 @@ def _check_conditional_args(
elif args.endpoint_type == "rankings":
args.output_format = OutputFormat.RANKINGS
+ # (TMA-1986) deduce vision format from chat completions + image CLI
+ # because there's no openai vision endpoint.
+ elif args.endpoint_type == "vision":
+ args.output_format = OutputFormat.OPENAI_VISION
+
if args.endpoint is not None:
args.endpoint = args.endpoint.lstrip(" /")
else:
@@ -411,6 +437,51 @@ def _add_input_args(parser):
)
+def _add_image_input_args(parser):
+ input_group = parser.add_argument_group("Image Input")
+
+ input_group.add_argument(
+ "--image-width-mean",
+ type=int,
+ default=LlmInputs.DEFAULT_IMAGE_WIDTH_MEAN,
+ required=False,
+ help=f"The mean width of images when generating synthetic image data.",
+ )
+
+ input_group.add_argument(
+ "--image-width-stddev",
+ type=int,
+ default=LlmInputs.DEFAULT_IMAGE_WIDTH_STDDEV,
+ required=False,
+ help=f"The standard deviation of width of images when generating synthetic image data.",
+ )
+
+ input_group.add_argument(
+ "--image-height-mean",
+ type=int,
+ default=LlmInputs.DEFAULT_IMAGE_HEIGHT_MEAN,
+ required=False,
+ help=f"The mean height of images when generating synthetic image data.",
+ )
+
+ input_group.add_argument(
+ "--image-height-stddev",
+ type=int,
+ default=LlmInputs.DEFAULT_IMAGE_HEIGHT_STDDEV,
+ required=False,
+ help=f"The standard deviation of height of images when generating synthetic image data.",
+ )
+
+ input_group.add_argument(
+ "--image-format",
+ type=str,
+ choices=utils.get_enum_names(ImageFormat),
+ required=False,
+ help=f"The compression format of the images. "
+ "If format is not selected, format of generated image is selected at random",
+ )
+
+
def _add_profile_args(parser):
profile_group = parser.add_argument_group("Profiling")
load_management_group = profile_group.add_mutually_exclusive_group(required=False)
@@ -499,7 +570,7 @@ def _add_endpoint_args(parser):
endpoint_group.add_argument(
"--endpoint-type",
type=str,
- choices=["chat", "completions", "embeddings", "rankings"],
+ choices=["chat", "completions", "embeddings", "rankings", "vision"],
required=False,
help=f"The endpoint-type to send requests to on the "
'server. This is only used with the "openai" service-kind.',
@@ -658,6 +729,7 @@ def _parse_profile_args(subparsers) -> argparse.ArgumentParser:
)
_add_endpoint_args(profile)
_add_input_args(profile)
+ _add_image_input_args(profile)
_add_profile_args(profile)
_add_output_args(profile)
_add_other_args(profile)
@@ -737,6 +809,7 @@ def refine_args(
args = _infer_prompt_source(args)
args = _check_model_args(parser, args)
args = _check_conditional_args(parser, args)
+ args = _check_image_input_args(parser, args)
args = _check_load_manager_args(args)
args = _set_artifact_paths(args)
elif args.subcommand == Subcommand.COMPARE.to_lowercase():
diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/llm_profile_data_parser.py b/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/llm_profile_data_parser.py
index 4ec1bec62..183f21fd2 100755
--- a/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/llm_profile_data_parser.py
+++ b/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/llm_profile_data_parser.py
@@ -218,6 +218,9 @@ def _get_openai_input_text(self, req_inputs: dict) -> str:
return payload["messages"][0]["content"]
elif self._response_format == ResponseFormat.OPENAI_COMPLETIONS:
return payload["prompt"]
+ elif self._response_format == ResponseFormat.OPENAI_VISION:
+ content = payload["messages"][0]["content"]
+ return " ".join(c["text"] for c in content if c["type"] == "text")
else:
raise ValueError(
"Failed to parse OpenAI request input in profile export file."
diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/profile_data_parser.py b/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/profile_data_parser.py
index d18d8f6fb..74eb48a23 100755
--- a/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/profile_data_parser.py
+++ b/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/profile_data_parser.py
@@ -39,6 +39,7 @@ class ResponseFormat(Enum):
OPENAI_CHAT_COMPLETIONS = auto()
OPENAI_COMPLETIONS = auto()
OPENAI_EMBEDDINGS = auto()
+ OPENAI_VISION = auto()
RANKINGS = auto()
TRITON = auto()
@@ -59,7 +60,15 @@ def _get_profile_metadata(self, data: dict) -> None:
if data["endpoint"] == "rerank":
self._response_format = ResponseFormat.HUGGINGFACE_RANKINGS
elif data["endpoint"] == "v1/chat/completions":
- self._response_format = ResponseFormat.OPENAI_CHAT_COMPLETIONS
+ # (TPA-66) add PA metadata to deduce the response format instead
+ # of parsing the request input payload in profile export json
+ # file.
+ request = data["experiments"][0]["requests"][0]
+ request_input = request["request_inputs"]["payload"]
+ if "image_url" in request_input:
+ self._response_format = ResponseFormat.OPENAI_VISION
+ else:
+ self._response_format = ResponseFormat.OPENAI_CHAT_COMPLETIONS
elif data["endpoint"] == "v1/completions":
self._response_format = ResponseFormat.OPENAI_COMPLETIONS
elif data["endpoint"] == "v1/embeddings":
@@ -67,13 +76,17 @@ def _get_profile_metadata(self, data: dict) -> None:
elif data["endpoint"] == "v1/ranking":
self._response_format = ResponseFormat.RANKINGS
else:
- # TPA-66: add PA metadata to handle this case
+ # (TPA-66) add PA metadata to handle this case
# When endpoint field is either empty or custom endpoint, fall
# back to parsing the response to extract the response format.
request = data["experiments"][0]["requests"][0]
+ request_input = request["request_inputs"]["payload"]
response = request["response_outputs"][0]["response"]
if "chat.completion" in response:
- self._response_format = ResponseFormat.OPENAI_CHAT_COMPLETIONS
+ if "image_url" in request_input:
+ self._response_format = ResponseFormat.OPENAI_VISION
+ else:
+ self._response_format = ResponseFormat.OPENAI_CHAT_COMPLETIONS
elif "text_completion" in response:
self._response_format = ResponseFormat.OPENAI_COMPLETIONS
elif "embedding" in response:
diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/test_end_to_end.py b/src/c++/perf_analyzer/genai-perf/genai_perf/test_end_to_end.py
deleted file mode 100644
index a44304348..000000000
--- a/src/c++/perf_analyzer/genai-perf/genai_perf/test_end_to_end.py
+++ /dev/null
@@ -1,92 +0,0 @@
-import itertools
-import os
-import subprocess
-import sys
-
-# How to run:
-# test_end_to_end.py
-# Where target is "nim_chat" or "nim_completions" or "vllm_openai" or "triton_tensorrtllm"
-#
-# For all cases but vllm_openai, it assumes that the server will be on port 9999
-#
-# This script will run a sweep of all combinations of values in the testing matrix
-# by appending those options on to the genai-perf base command
-#
-
-
-testing_matrix = [
- ["--concurrency 1", "--concurrency 32", "--request-rate 1", "--request-rate 32"],
- ["--streaming", ""],
-]
-
-base_commands = {
- "nim_chat": "genai-perf profile -s 999 -p 20000 -m llama-2-7b-chat -u http://localhost:9999 --service-kind openai --endpoint-type chat",
- "nim_completions": "genai-perf profile -s 999 -p 20000 -m llama-2-7b -u http://localhost:9999 --service-kind openai --endpoint-type completions",
- "vllm_openai": "genai-perf profile -s 999 -p 20000 -m mistralai/Mistral-7B-v0.1 --service-kind openai --endpoint-type chat",
- "triton_tensorrtllm": "genai-perf profile -s 999 -p 20000 -m llama-2-7b -u 0.0.0.0:9999 --service-kind triton --backend tensorrtllm",
- "triton_vllm": "genai-perf profile -s 999 -p 20000 -m gpt2_vllm --service-kind triton --backend vllm",
-}
-testname = ""
-
-if len(sys.argv) == 2:
- # The second element in sys.argv is the input string
- testname = sys.argv[1]
-else:
- options = " ".join(base_commands.keys())
- print(f"This script requires exactly one argument. It must be one of {options}")
- exit(1)
-
-base_command = base_commands[testname]
-
-
-def rename_files(files: list, substr: str) -> None:
- for f in files:
- name, ext = f.rsplit(".", 1)
- # Insert the substring and reassemble the filename
- new_filename = f"{testname}__{name}__{substr}.{ext}"
- try:
- os.rename(f, new_filename)
- except FileNotFoundError:
- # Just ignore the error, since if PA failed these files may not exist
- pass
-
-
-def print_summary():
- # FIXME -- print out a few basic metrics. Maybe from the csv?
- pass
-
-
-def sanity_check():
- # FIXME -- add in some sanity checking? Throughput isn't 0?
- pass
-
-
-# Loop through all combinations
-for combination in itertools.product(*testing_matrix):
- options_string = " ".join(combination)
- command_with_options = f"{base_command} {options_string}"
- command_array = command_with_options.split()
-
- file_options_string = "__".join(combination)
- file_options_string = file_options_string.replace(" ", "")
- file_options_string = file_options_string.replace("-", "")
- output_file = testname + "__" + file_options_string + ".log"
-
- with open(output_file, "w") as outfile:
- print(f"\nCMD: {command_with_options}")
- print(f" Output log is {output_file}")
- proc = subprocess.run(command_array, stdout=outfile, stderr=subprocess.STDOUT)
-
- if proc.returncode != 0:
- print(f" Command failed with return code: {proc.returncode}")
- else:
- print(f" Command executed successfully!")
- print_summary()
- sanity_check()
-
- files = [
- "profile_export.json",
- "profile_export_genai_pa.csv",
- "llm_inputs.json",
- ]
- rename_files(files, file_options_string)
diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/utils.py b/src/c++/perf_analyzer/genai-perf/genai_perf/utils.py
index 6f66230c4..4b625352a 100644
--- a/src/c++/perf_analyzer/genai-perf/genai_perf/utils.py
+++ b/src/c++/perf_analyzer/genai-perf/genai_perf/utils.py
@@ -34,10 +34,27 @@
# Skip type checking to avoid mypy error
# Issue: https://github.com/python/mypy/issues/10632
import yaml # type: ignore
+from PIL import Image
logger = logging.getLogger(__name__)
+def encode_image(img: Image, format: str):
+ """Encodes an image into base64 encoding."""
+ # Lazy import for vision related endpoints
+ import base64
+ from io import BytesIO
+
+ # JPEG does not support P or RGBA mode (commonly used for PNG) so it needs
+ # to be converted to RGB before an image can be saved as JPEG format.
+ if format == "JPEG" and img.mode != "RGB":
+ img = img.convert("RGB")
+
+ buffered = BytesIO()
+ img.save(buffered, format=format)
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
+
+
def remove_sse_prefix(msg: str) -> str:
prefix = "data: "
if msg.startswith(prefix):
diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/wrapper.py b/src/c++/perf_analyzer/genai-perf/genai_perf/wrapper.py
index dbaacc32b..76ef3e321 100644
--- a/src/c++/perf_analyzer/genai-perf/genai_perf/wrapper.py
+++ b/src/c++/perf_analyzer/genai-perf/genai_perf/wrapper.py
@@ -93,6 +93,11 @@ def build_cmd(args: Namespace, extra_args: Optional[List[str]] = None) -> List[s
"synthetic_input_tokens_stddev",
"subcommand",
"tokenizer",
+ "image_width_mean",
+ "image_width_stddev",
+ "image_height_mean",
+ "image_height_stddev",
+ "image_format",
]
utils.remove_file(args.profile_export_file)
diff --git a/src/c++/perf_analyzer/genai-perf/pyproject.toml b/src/c++/perf_analyzer/genai-perf/pyproject.toml
index 982ee24b7..f1f78a7e2 100644
--- a/src/c++/perf_analyzer/genai-perf/pyproject.toml
+++ b/src/c++/perf_analyzer/genai-perf/pyproject.toml
@@ -59,6 +59,7 @@ dependencies = [
"pytest-mock",
"pyyaml",
"responses",
+ "pillow",
]
# CLI Entrypoint
@@ -66,8 +67,8 @@ dependencies = [
genai-perf = "genai_perf.main:main"
[project.urls]
-"Homepage" = "https://github.com/triton-inference-server/"
-"Bug Tracker" = "https://github.com/triton-inference-server/server/issues"
+"Homepage" = "https://github.com/triton-inference-server/client"
+"Bug Tracker" = "https://github.com/triton-inference-server/client/issues"
# Build
[build-system]
diff --git a/src/c++/perf_analyzer/genai-perf/tests/test_cli.py b/src/c++/perf_analyzer/genai-perf/tests/test_cli.py
index eb891fd02..2ef5d52ba 100644
--- a/src/c++/perf_analyzer/genai-perf/tests/test_cli.py
+++ b/src/c++/perf_analyzer/genai-perf/tests/test_cli.py
@@ -31,16 +31,18 @@
import pytest
from genai_perf import __version__, parser
from genai_perf.llm_inputs.llm_inputs import (
+ ImageFormat,
ModelSelectionStrategy,
OutputFormat,
PromptSource,
)
+from genai_perf.llm_inputs.synthetic_image_generator import ImageFormat
from genai_perf.parser import PathType
class TestCLIArguments:
# ================================================
- # GENAI-PERF COMMAND
+ # PROFILE COMMAND
# ================================================
expected_help_output = (
"CLI to profile LLMs and Generative AI models with Perf Analyzer"
@@ -215,6 +217,23 @@ def test_help_version_arguments_output_and_exit(
["--synthetic-input-tokens-stddev", "7"],
{"synthetic_input_tokens_stddev": 7},
),
+ (
+ ["--image-width-mean", "123"],
+ {"image_width_mean": 123},
+ ),
+ (
+ ["--image-width-stddev", "123"],
+ {"image_width_stddev": 123},
+ ),
+ (
+ ["--image-height-mean", "456"],
+ {"image_height_mean": 456},
+ ),
+ (
+ ["--image-height-stddev", "456"],
+ {"image_height_stddev": 456},
+ ),
+ (["--image-format", "png"], {"image_format": ImageFormat.PNG}),
(["-v"], {"verbose": True}),
(["--verbose"], {"verbose": True}),
(["-u", "test_url"], {"u": "test_url"}),
@@ -732,6 +751,26 @@ def test_prompt_source_assertions(self, monkeypatch, mocker, capsys):
captured = capsys.readouterr()
assert expected_output in captured.err
+ @pytest.mark.parametrize(
+ "args",
+ [
+ # negative numbers
+ ["--image-width-mean", "-123"],
+ ["--image-width-stddev", "-34"],
+ ["--image-height-mean", "-123"],
+ ["--image-height-stddev", "-34"],
+ # zeros
+ ["--image-width-mean", "0"],
+ ["--image-height-mean", "0"],
+ ],
+ )
+ def test_positive_image_input_args(self, monkeypatch, args):
+ combined_args = ["genai-perf", "profile", "-m", "test_model"] + args
+ monkeypatch.setattr("sys.argv", combined_args)
+
+ with pytest.raises(SystemExit) as excinfo:
+ parser.parse_args()
+
# ================================================
# COMPARE SUBCOMMAND
# ================================================
diff --git a/src/c++/perf_analyzer/genai-perf/tests/test_json_exporter.py b/src/c++/perf_analyzer/genai-perf/tests/test_json_exporter.py
index e4a29267d..f82e59312 100644
--- a/src/c++/perf_analyzer/genai-perf/tests/test_json_exporter.py
+++ b/src/c++/perf_analyzer/genai-perf/tests/test_json_exporter.py
@@ -249,6 +249,11 @@ def test_generate_json(self, monkeypatch) -> None:
"random_seed": 0,
"synthetic_input_tokens_mean": 550,
"synthetic_input_tokens_stddev": 0,
+ "image_width_mean": 100,
+ "image_width_stddev": 0,
+ "image_height_mean": 100,
+ "image_height_stddev": 0,
+ "image_format": null,
"concurrency": 1,
"measurement_interval": 10000,
"request_rate": null,
diff --git a/src/c++/perf_analyzer/genai-perf/tests/test_llm_inputs.py b/src/c++/perf_analyzer/genai-perf/tests/test_llm_inputs.py
index c6351918e..028e72849 100644
--- a/src/c++/perf_analyzer/genai-perf/tests/test_llm_inputs.py
+++ b/src/c++/perf_analyzer/genai-perf/tests/test_llm_inputs.py
@@ -16,6 +16,7 @@
import os
import random
import statistics
+from collections import namedtuple
from pathlib import Path
from unittest.mock import mock_open, patch
@@ -30,7 +31,9 @@
OutputFormat,
PromptSource,
)
-from genai_perf.tokenizer import Tokenizer
+from genai_perf.llm_inputs.synthetic_image_generator import ImageFormat
+from genai_perf.tokenizer import DEFAULT_TOKENIZER, get_tokenizer
+from PIL import Image
mocked_openorca_data = {
"features": [
@@ -78,6 +81,7 @@ class TestLlmInputs:
("triton", "tensorrtllm", OutputFormat.TENSORRTLLM),
("openai", "v1/completions", OutputFormat.OPENAI_COMPLETIONS),
("openai", "v1/chat/completions", OutputFormat.OPENAI_CHAT_COMPLETIONS),
+ ("openai", "v1/chat/completions", OutputFormat.OPENAI_VISION),
]
@pytest.fixture
@@ -550,6 +554,94 @@ def test_llm_inputs_with_defaults(self, default_configured_url):
# else:
# assert False, f"Unsupported output format: {output_format}"
+ def test_add_image_inputs_openai_vision(self) -> None:
+ generic_json = {
+ "rows": [
+ {"text_input": "test input one", "image": "test_image1"},
+ {"text_input": "test input two", "image": "test_image2"},
+ ]
+ }
+
+ generic_json = LlmInputs._convert_to_openai_multi_modal_content(generic_json)
+
+ row1 = generic_json["rows"][0]["text_input"]
+ assert row1 == [
+ {
+ "type": "text",
+ "text": "test input one",
+ },
+ {
+ "type": "image_url",
+ "image_url": {"url": "test_image1"},
+ },
+ ]
+
+ row2 = generic_json["rows"][1]["text_input"]
+ assert row2 == [
+ {
+ "type": "text",
+ "text": "test input two",
+ },
+ {
+ "type": "image_url",
+ "image_url": {"url": "test_image2"},
+ },
+ ]
+
+ @patch(
+ "genai_perf.llm_inputs.llm_inputs.LlmInputs._create_synthetic_prompt",
+ return_value="This is test prompt",
+ )
+ @patch(
+ "genai_perf.llm_inputs.llm_inputs.LlmInputs._create_synthetic_image",
+ return_value="test_image_base64",
+ )
+ @pytest.mark.parametrize(
+ "output_format",
+ [
+ OutputFormat.OPENAI_CHAT_COMPLETIONS,
+ OutputFormat.OPENAI_COMPLETIONS,
+ OutputFormat.OPENAI_EMBEDDINGS,
+ OutputFormat.RANKINGS,
+ OutputFormat.OPENAI_VISION,
+ OutputFormat.VLLM,
+ OutputFormat.TENSORRTLLM,
+ ],
+ )
+ def test_get_input_dataset_from_synthetic(
+ self, mock_prompt, mock_image, output_format
+ ) -> None:
+ _placeholder = 123 # dummy value
+ num_prompts = 3
+
+ dataset_json = LlmInputs._get_input_dataset_from_synthetic(
+ tokenizer=get_tokenizer(DEFAULT_TOKENIZER),
+ prompt_tokens_mean=_placeholder,
+ prompt_tokens_stddev=_placeholder,
+ num_of_output_prompts=num_prompts,
+ image_width_mean=_placeholder,
+ image_width_stddev=_placeholder,
+ image_height_mean=_placeholder,
+ image_height_stddev=_placeholder,
+ image_format=ImageFormat.PNG,
+ output_format=output_format,
+ )
+
+ assert len(dataset_json["rows"]) == num_prompts
+
+ for i in range(num_prompts):
+ row = dataset_json["rows"][i]["row"]
+
+ if output_format == OutputFormat.OPENAI_VISION:
+ assert row == {
+ "text_input": "This is test prompt",
+ "image": "test_image_base64",
+ }
+ else:
+ assert row == {
+ "text_input": "This is test prompt",
+ }
+
# def test_trtllm_default_max_tokens(self, default_tokenizer: Tokenizer) -> None:
# input_name = "max_tokens"
# input_value = 256
@@ -687,6 +779,34 @@ def test_get_input_file_with_multiple_prompts(self, mock_file, mock_exists):
for i, prompt in enumerate(expected_prompts):
assert dataset["rows"][i]["row"]["text_input"] == prompt
+ @patch("pathlib.Path.exists", return_value=True)
+ @patch("PIL.Image.open", return_value=Image.new("RGB", (10, 10)))
+ @patch(
+ "builtins.open",
+ new_callable=mock_open,
+ read_data=(
+ '{"text_input": "prompt1", "image": "image1.png"}\n'
+ '{"text_input": "prompt2", "image": "image2.png"}\n'
+ '{"text_input": "prompt3", "image": "image3.png"}\n'
+ ),
+ )
+ def test_get_input_file_with_multi_modal_data(
+ self, mock_exists, mock_image, mock_file
+ ):
+ Data = namedtuple("Data", ["text_input", "image"])
+ expected_data = [
+ Data(text_input="prompt1", image="image1.png"),
+ Data(text_input="prompt2", image="image2.png"),
+ Data(text_input="prompt3", image="image3.png"),
+ ]
+ dataset = LlmInputs._get_input_dataset_from_file(Path("somefile.txt"))
+
+ assert dataset is not None
+ assert len(dataset["rows"]) == len(expected_data)
+ for i, data in enumerate(expected_data):
+ assert dataset["rows"][i]["row"]["text_input"] == data.text_input
+ assert dataset["rows"][i]["row"]["image"] == data.image
+
@pytest.mark.parametrize(
"seed, model_name_list, index,model_selection_strategy,expected_model",
[
diff --git a/src/c++/perf_analyzer/genai-perf/tests/test_llm_metrics.py b/src/c++/perf_analyzer/genai-perf/tests/test_llm_metrics.py
index 05de5b122..689e366cd 100644
--- a/src/c++/perf_analyzer/genai-perf/tests/test_llm_metrics.py
+++ b/src/c++/perf_analyzer/genai-perf/tests/test_llm_metrics.py
@@ -69,6 +69,7 @@ def test_llm_metric_system_metrics(self) -> None:
output_sequence_lengths=[3, 4],
input_sequence_lengths=[12, 34],
)
+
sys_metrics = m.system_metrics
assert len(sys_metrics) == 2
assert sys_metrics[0].name == "output_token_throughput"
diff --git a/src/c++/perf_analyzer/genai-perf/tests/test_llm_profile_data_parser.py b/src/c++/perf_analyzer/genai-perf/tests/test_llm_profile_data_parser.py
index 75976189d..d776a6a85 100644
--- a/src/c++/perf_analyzer/genai-perf/tests/test_llm_profile_data_parser.py
+++ b/src/c++/perf_analyzer/genai-perf/tests/test_llm_profile_data_parser.py
@@ -71,6 +71,9 @@ def write(self: Any, content: str) -> int:
elif filename == "openai_profile_export.json":
tmp_file = StringIO(json.dumps(self.openai_profile_data))
return tmp_file
+ elif filename == "openai_vlm_profile_export.json":
+ tmp_file = StringIO(json.dumps(self.openai_vlm_profile_data))
+ return tmp_file
elif filename == "empty_profile_export.json":
tmp_file = StringIO(json.dumps(self.empty_profile_data))
return tmp_file
@@ -322,6 +325,91 @@ def test_openai_llm_profile_data(self, mock_read_write: pytest.MonkeyPatch) -> N
with pytest.raises(KeyError):
pd.get_statistics(infer_mode="concurrency", load_level="40")
+ def test_openai_vlm_profile_data(self, mock_read_write: pytest.MonkeyPatch) -> None:
+ """Collect LLM metrics from profile export data and check values.
+
+ Metrics
+ * time to first tokens
+ - experiment 1: [5 - 1, 7 - 2] = [4, 5]
+ * inter token latencies
+ - experiment 1: [((12 - 1) - 4)/(3 - 1), ((15 - 2) - 5)/(6 - 1)]
+ : [3.5, 1.6]
+ : [4, 2] # rounded
+ * output token throughputs per request
+ - experiment 1: [3/(12 - 1), 6/(15 - 2)] = [3/11, 6/13]
+ * output token throughputs
+ - experiment 1: [(3 + 6)/(15 - 1)] = [9/14]
+ * output sequence lengths
+ - experiment 1: [3, 6]
+ * input sequence lengths
+ - experiment 1: [3, 4]
+ """
+ tokenizer = get_tokenizer(DEFAULT_TOKENIZER)
+ pd = LLMProfileDataParser(
+ filename=Path("openai_vlm_profile_export.json"),
+ tokenizer=tokenizer,
+ )
+
+ # experiment 1 statistics
+ stat_obj = pd.get_statistics(infer_mode="concurrency", load_level="10")
+ metrics = stat_obj.metrics
+ stat = stat_obj.stats_dict
+ assert isinstance(metrics, LLMMetrics)
+
+ assert metrics.time_to_first_tokens == [4, 5]
+ assert metrics.inter_token_latencies == [4, 2]
+ ottpr = [3 / ns_to_sec(11), 6 / ns_to_sec(13)]
+ assert metrics.output_token_throughputs_per_request == pytest.approx(ottpr)
+ ott = [9 / ns_to_sec(14)]
+ assert metrics.output_token_throughputs == pytest.approx(ott)
+ assert metrics.output_sequence_lengths == [3, 6]
+ assert metrics.input_sequence_lengths == [3, 4]
+
+ assert stat["time_to_first_token"]["avg"] == pytest.approx(4.5) # type: ignore
+ assert stat["inter_token_latency"]["avg"] == pytest.approx(3) # type: ignore
+ assert stat["output_token_throughput_per_request"]["avg"] == pytest.approx( # type: ignore
+ np.mean(ottpr)
+ )
+ assert stat["output_sequence_length"]["avg"] == 4.5 # type: ignore
+ assert stat["input_sequence_length"]["avg"] == 3.5 # type: ignore
+
+ assert stat["time_to_first_token"]["p50"] == pytest.approx(4.5) # type: ignore
+ assert stat["inter_token_latency"]["p50"] == pytest.approx(3) # type: ignore
+ assert stat["output_token_throughput_per_request"]["p50"] == pytest.approx( # type: ignore
+ np.percentile(ottpr, 50)
+ )
+ assert stat["output_sequence_length"]["p50"] == 4.5 # type: ignore
+ assert stat["input_sequence_length"]["p50"] == 3.5 # type: ignore
+
+ assert stat["time_to_first_token"]["min"] == pytest.approx(4) # type: ignore
+ assert stat["inter_token_latency"]["min"] == pytest.approx(2) # type: ignore
+ min_ottpr = 3 / ns_to_sec(11)
+ assert stat["output_token_throughput_per_request"]["min"] == pytest.approx(min_ottpr) # type: ignore
+ assert stat["output_sequence_length"]["min"] == 3 # type: ignore
+ assert stat["input_sequence_length"]["min"] == 3 # type: ignore
+
+ assert stat["time_to_first_token"]["max"] == pytest.approx(5) # type: ignore
+ assert stat["inter_token_latency"]["max"] == pytest.approx(4) # type: ignore
+ max_ottpr = 6 / ns_to_sec(13)
+ assert stat["output_token_throughput_per_request"]["max"] == pytest.approx(max_ottpr) # type: ignore
+ assert stat["output_sequence_length"]["max"] == 6 # type: ignore
+ assert stat["input_sequence_length"]["max"] == 4 # type: ignore
+
+ assert stat["time_to_first_token"]["std"] == np.std([4, 5]) * (1) # type: ignore
+ assert stat["inter_token_latency"]["std"] == np.std([4, 2]) * (1) # type: ignore
+ assert stat["output_token_throughput_per_request"]["std"] == pytest.approx( # type: ignore
+ np.std(ottpr)
+ )
+ assert stat["output_sequence_length"]["std"] == np.std([3, 6]) # type: ignore
+ assert stat["input_sequence_length"]["std"] == np.std([3, 4]) # type: ignore
+
+ oott = 9 / ns_to_sec(14)
+ assert stat["output_token_throughput"]["avg"] == pytest.approx(oott) # type: ignore
+
+ # check non-existing profile data
+ with pytest.raises(KeyError):
+ pd.get_statistics(infer_mode="concurrency", load_level="40")
+
def test_merged_sse_response(self, mock_read_write: pytest.MonkeyPatch) -> None:
"""Test merging the multiple sse response."""
res_timestamps = [0, 1, 2, 3]
@@ -522,6 +610,73 @@ def test_empty_response(self, mock_read_write: pytest.MonkeyPatch) -> None:
],
}
+ openai_vlm_profile_data = {
+ "service_kind": "openai",
+ "endpoint": "v1/chat/completions",
+ "experiments": [
+ {
+ "experiment": {
+ "mode": "concurrency",
+ "value": 10,
+ },
+ "requests": [
+ {
+ "timestamp": 1,
+ "request_inputs": {
+ "payload": '{"messages":[{"role":"user","content":[{"type":"text","text":"This is test"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abcdef"}}]}],"model":"llava-1.6","stream":true}',
+ },
+ # the first, and the last two responses will be ignored because they have no "content"
+ "response_timestamps": [3, 5, 8, 12, 13, 14],
+ "response_outputs": [
+ {
+ "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}\n\n'
+ },
+ {
+ "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"I"},"finish_reason":null}]}\n\n'
+ },
+ {
+ "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":" like"},"finish_reason":null}]}\n\n'
+ },
+ {
+ "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":" dogs"},"finish_reason":null}]}\n\n'
+ },
+ {
+ "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":null}]}\n\n'
+ },
+ {"response": "data: [DONE]\n\n"},
+ ],
+ },
+ {
+ "timestamp": 2,
+ "request_inputs": {
+ "payload": '{"messages":[{"role":"user","content":[{"type":"text","text":"This is test too"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abcdef"}}]}],"model":"llava-1.6","stream":true}',
+ },
+ # the first, and the last two responses will be ignored because they have no "content"
+ "response_timestamps": [4, 7, 11, 15, 18, 19],
+ "response_outputs": [
+ {
+ "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}\n\n'
+ },
+ {
+ "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"I"},"finish_reason":null}]}\n\n'
+ },
+ {
+ "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"don\'t"},"finish_reason":null}]}\n\n'
+ },
+ {
+ "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"cook food"},"finish_reason":null}]}\n\n'
+ },
+ {
+ "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":null}]}\n\n'
+ },
+ {"response": "data: [DONE]\n\n"},
+ ],
+ },
+ ],
+ },
+ ],
+ }
+
triton_profile_data = {
"service_kind": "triton",
"endpoint": "",
diff --git a/src/c++/perf_analyzer/genai-perf/tests/test_synthetic_image_generator.py b/src/c++/perf_analyzer/genai-perf/tests/test_synthetic_image_generator.py
new file mode 100644
index 000000000..5a79794bb
--- /dev/null
+++ b/src/c++/perf_analyzer/genai-perf/tests/test_synthetic_image_generator.py
@@ -0,0 +1,123 @@
+import base64
+import random
+from io import BytesIO
+
+import pytest
+from genai_perf.llm_inputs.synthetic_image_generator import (
+ ImageFormat,
+ SyntheticImageGenerator,
+)
+from PIL import Image
+
+
+def decode_image(base64_string):
+ _, data = base64_string.split(",")
+ decoded_data = base64.b64decode(data)
+ return Image.open(BytesIO(decoded_data))
+
+
+@pytest.mark.parametrize(
+ "expected_image_size",
+ [
+ (100, 100),
+ (200, 200),
+ ],
+)
+def test_different_image_size(expected_image_size):
+ expected_width, expected_height = expected_image_size
+ base64_string = SyntheticImageGenerator.create_synthetic_image(
+ image_width_mean=expected_width,
+ image_width_stddev=0,
+ image_height_mean=expected_height,
+ image_height_stddev=0,
+ image_format=ImageFormat.PNG,
+ )
+
+ image = decode_image(base64_string)
+ assert image.size == expected_image_size, "image not resized to the target size"
+
+
+def test_negative_size_is_not_selected():
+ # exception is raised, when PIL.Image.resize is called with negative values
+ _ = SyntheticImageGenerator.create_synthetic_image(
+ image_width_mean=-1,
+ image_width_stddev=10,
+ image_height_mean=-1,
+ image_height_stddev=10,
+ image_format=ImageFormat.PNG,
+ )
+
+
+@pytest.mark.parametrize(
+ "width_mean, width_stddev, height_mean, height_stddev",
+ [
+ (100, 15, 100, 15),
+ (123, 10, 456, 7),
+ ],
+)
+def test_generator_deterministic(width_mean, width_stddev, height_mean, height_stddev):
+ random.seed(123)
+ img1 = SyntheticImageGenerator.create_synthetic_image(
+ image_width_mean=width_mean,
+ image_width_stddev=width_stddev,
+ image_height_mean=height_mean,
+ image_height_stddev=height_stddev,
+ image_format=ImageFormat.PNG,
+ )
+
+ random.seed(123)
+ img2 = SyntheticImageGenerator.create_synthetic_image(
+ image_width_mean=width_mean,
+ image_width_stddev=width_stddev,
+ image_height_mean=height_mean,
+ image_height_stddev=height_stddev,
+ image_format=ImageFormat.PNG,
+ )
+
+ assert img1 == img2, "generator is nondererministic"
+
+
+@pytest.mark.parametrize("image_format", [ImageFormat.PNG, ImageFormat.JPEG])
+def test_base64_encoding_with_different_formats(image_format):
+ img_base64 = SyntheticImageGenerator.create_synthetic_image(
+ image_width_mean=100,
+ image_width_stddev=100,
+ image_height_mean=100,
+ image_height_stddev=100,
+ image_format=image_format,
+ )
+
+ # check prefix
+ expected_prefix = f"data:image/{image_format.name.lower()};base64,"
+ assert img_base64.startswith(expected_prefix), "unexpected prefix"
+
+ # check image format
+ data = img_base64[len(expected_prefix) :]
+ img_data = base64.b64decode(data)
+ img_bytes = BytesIO(img_data)
+ image = Image.open(img_bytes)
+ assert image.format == image_format.name
+
+
+def test_random_image_format():
+ random.seed(123)
+ img1 = SyntheticImageGenerator.create_synthetic_image(
+ image_width_mean=100,
+ image_width_stddev=100,
+ image_height_mean=100,
+ image_height_stddev=100,
+ image_format=None,
+ )
+
+ random.seed(456)
+ img2 = SyntheticImageGenerator.create_synthetic_image(
+ image_width_mean=100,
+ image_width_stddev=100,
+ image_height_mean=100,
+ image_height_stddev=100,
+ image_format=None,
+ )
+
+ # check prefix
+ assert img1.startswith("data:image/png")
+ assert img2.startswith("data:image/jpeg")
diff --git a/src/c++/perf_analyzer/perf_utils.h b/src/c++/perf_analyzer/perf_utils.h
index 7166936a9..6975d694b 100644
--- a/src/c++/perf_analyzer/perf_utils.h
+++ b/src/c++/perf_analyzer/perf_utils.h
@@ -56,7 +56,7 @@ constexpr uint64_t NANOS_PER_MILLIS = 1000000;
// Will use the characters specified here to construct random strings
std::string const character_set =
- "abcdefghijklmnaoqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890 .?!";
+ "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890 .?!";
// A boolean flag to mark an interrupt and commencement of early exit
extern volatile bool early_exit;
diff --git a/src/c++/perf_analyzer/test_command_line_parser.cc b/src/c++/perf_analyzer/test_command_line_parser.cc
index 765def112..2d17bbc24 100644
--- a/src/c++/perf_analyzer/test_command_line_parser.cc
+++ b/src/c++/perf_analyzer/test_command_line_parser.cc
@@ -371,10 +371,12 @@ class TestCLParser : public CLParser {
void
CheckValidRange(
std::vector& args, char* option_name, TestCLParser& parser,
- PAParamsPtr& act, bool& using_range, Range& range)
+ PAParamsPtr& act, bool& using_range, Range& range,
+ size_t* max_threads)
{
SUBCASE("start:end provided")
{
+ *max_threads = 400;
args.push_back(option_name);
args.push_back("100:400"); // start:end
@@ -392,6 +394,7 @@ CheckValidRange(
SUBCASE("start:end:step provided")
{
+ *max_threads = 400;
args.push_back(option_name);
args.push_back("100:400:10"); // start:end:step
@@ -525,7 +528,7 @@ TEST_CASE("Testing Command Line Parser")
// Most common defaults
exp->model_name = model_name; // model_name;
- exp->max_threads = 16;
+ exp->max_threads = DEFAULT_MAX_THREADS;
SUBCASE("with no parameters")
{
@@ -1111,11 +1114,16 @@ TEST_CASE("Testing Command Line Parser")
SUBCASE("Option : --concurrency-range")
{
char* option_name = "--concurrency-range";
+ uint64_t concurrency_range_start;
+ uint64_t concurrency_range_end;
SUBCASE("start provided")
{
+ concurrency_range_start = 100;
+ std::string concurrency_range_str =
+ std::to_string(concurrency_range_start);
args.push_back(option_name);
- args.push_back("100"); // start
+ args.push_back(concurrency_range_str.data()); // start
int argc = args.size();
char* argv[argc];
@@ -1125,13 +1133,13 @@ TEST_CASE("Testing Command Line Parser")
CHECK(!parser.UsageCalled());
exp->using_concurrency_range = true;
- exp->concurrency_range.start = 100;
+ exp->concurrency_range.start = concurrency_range_start;
+ exp->max_threads = DEFAULT_MAX_THREADS;
}
CheckValidRange(
args, option_name, parser, act, exp->using_concurrency_range,
- exp->concurrency_range);
-
+ exp->concurrency_range, &(exp->max_threads));
CheckInvalidRange(args, option_name, parser, act, check_params);
SUBCASE("wrong separator")
@@ -1173,6 +1181,75 @@ TEST_CASE("Testing Command Line Parser")
check_params = false;
}
+
+ concurrency_range_start = 10;
+ SUBCASE("Max threads set to default when concurrency-range.end < 16")
+ {
+ concurrency_range_end = 10;
+ std::string concurrency_range_str =
+ std::to_string(concurrency_range_start) + ":" +
+ std::to_string(concurrency_range_end);
+ args.push_back(option_name);
+ args.push_back(concurrency_range_str.data());
+
+ int argc = args.size();
+ char* argv[argc];
+ std::copy(args.begin(), args.end(), argv);
+
+ REQUIRE_NOTHROW(act = parser.Parse(argc, argv));
+ CHECK(!parser.UsageCalled());
+
+ exp->using_concurrency_range = true;
+ exp->concurrency_range.start = concurrency_range_start;
+ exp->concurrency_range.end = concurrency_range_end;
+ exp->max_threads = DEFAULT_MAX_THREADS;
+ }
+
+ SUBCASE("Max_threads set to default when concurrency-range.end = 16")
+ {
+ concurrency_range_end = 16;
+ std::string concurrency_range_str =
+ std::to_string(concurrency_range_start) + ":" +
+ std::to_string(concurrency_range_end);
+ args.push_back(option_name);
+ args.push_back(concurrency_range_str.data());
+
+ int argc = args.size();
+ char* argv[argc];
+ std::copy(args.begin(), args.end(), argv);
+
+ REQUIRE_NOTHROW(act = parser.Parse(argc, argv));
+ CHECK(!parser.UsageCalled());
+
+ exp->using_concurrency_range = true;
+ exp->concurrency_range.start = concurrency_range_start;
+ exp->concurrency_range.end = concurrency_range_end;
+ exp->max_threads = DEFAULT_MAX_THREADS;
+ }
+
+ SUBCASE(
+ "Max_threads set to concurrency-range.end when concurrency-range.end > "
+ "16")
+ {
+ concurrency_range_end = 40;
+ std::string concurrency_range_str =
+ std::to_string(concurrency_range_start) + ":" +
+ std::to_string(concurrency_range_end);
+ args.push_back(option_name);
+ args.push_back(concurrency_range_str.data());
+
+ int argc = args.size();
+ char* argv[argc];
+ std::copy(args.begin(), args.end(), argv);
+
+ REQUIRE_NOTHROW(act = parser.Parse(argc, argv));
+ CHECK(!parser.UsageCalled());
+
+ exp->using_concurrency_range = true;
+ exp->concurrency_range.start = concurrency_range_start;
+ exp->concurrency_range.end = concurrency_range_end;
+ exp->max_threads = exp->concurrency_range.end;
+ }
}
SUBCASE("Option : --periodic-concurrency-range")
@@ -1210,7 +1287,7 @@ TEST_CASE("Testing Command Line Parser")
CheckValidRange(
args, option_name, parser, act, exp->is_using_periodic_concurrency_mode,
- exp->periodic_concurrency_range);
+ exp->periodic_concurrency_range, &(exp->max_threads));
CheckInvalidRange(args, option_name, parser, act, check_params);
diff --git a/src/c++/perf_analyzer/test_dataloader.cc b/src/c++/perf_analyzer/test_dataloader.cc
index 656571cb9..c8db7df66 100644
--- a/src/c++/perf_analyzer/test_dataloader.cc
+++ b/src/c++/perf_analyzer/test_dataloader.cc
@@ -28,6 +28,7 @@
#include "doctest.h"
#include "mock_data_loader.h"
+
namespace triton { namespace perfanalyzer {
/// Helper class for testing the DataLoader
@@ -104,6 +105,199 @@ TEST_CASE("dataloader: GetTotalSteps")
CHECK_EQ(dataloader.GetTotalSteps(2), 0);
}
+TEST_CASE("dataloader: ValidateIOExistsInModel")
+{
+ MockDataLoader dataloader;
+ std::shared_ptr inputs = std::make_shared();
+ std::shared_ptr outputs = std::make_shared();
+ ModelTensor input1 = TestDataLoader::CreateTensor("INPUT1");
+ ModelTensor output1 = TestDataLoader::CreateTensor("OUTPUT1");
+ inputs->insert(std::make_pair(input1.name_, input1));
+ outputs->insert(std::make_pair(output1.name_, output1));
+
+ SUBCASE("Directory does not exist")
+ {
+ std::string data_directory = "non_existent_directory";
+ cb::Error status =
+ dataloader.ValidateIOExistsInModel(inputs, outputs, data_directory);
+ CHECK(
+ status.Message() ==
+ "Error: Directory does not exist or is not a directory: "
+ "non_existent_directory");
+ CHECK(status.Err() == pa::GENERIC_ERROR);
+ }
+
+ SUBCASE("Directory is not a directory")
+ {
+ std::string data_directory = "tmp/test.txt";
+ std::ofstream file(data_directory);
+ cb::Error status =
+ dataloader.ValidateIOExistsInModel(inputs, outputs, data_directory);
+ CHECK(
+ status.Message() ==
+ "Error: Directory does not exist or is not a directory: tmp/test.txt");
+ CHECK(status.Err() == pa::GENERIC_ERROR);
+ std::remove(data_directory.c_str());
+ }
+
+ SUBCASE("Valid directory but no corresponding files")
+ {
+ std::string data_directory = "valid_directory";
+ std::filesystem::create_directory(data_directory);
+ std::ofstream(data_directory + "/invalid_file").close();
+ cb::Error status =
+ dataloader.ValidateIOExistsInModel(inputs, outputs, data_directory);
+ std::filesystem::remove_all(data_directory);
+ CHECK(
+ status.Message() ==
+ "Provided data file 'invalid_file' does not correspond to a valid "
+ "model input or output.");
+ CHECK(status.Err() == pa::GENERIC_ERROR);
+ }
+
+ SUBCASE("Valid directory with corresponding files")
+ {
+ std::string data_directory = "valid_directory";
+ std::filesystem::create_directory(data_directory);
+ std::ofstream(data_directory + "/INPUT1").close();
+ std::ofstream(data_directory + "/OUTPUT1").close();
+ cb::Error status =
+ dataloader.ValidateIOExistsInModel(inputs, outputs, data_directory);
+ std::filesystem::remove_all(data_directory);
+ CHECK(status.Message().empty());
+ CHECK(status.IsOk());
+ }
+
+ SUBCASE("Valid directory with multiple input and output tensors")
+ {
+ ModelTensor input2 = TestDataLoader::CreateTensor("INPUT2");
+ ModelTensor output2 = TestDataLoader::CreateTensor("OUTPUT2");
+
+ inputs->insert(std::make_pair(input2.name_, input2));
+ outputs->insert(std::make_pair(output2.name_, output2));
+
+ std::string data_directory = "valid_directory_multiple";
+ std::filesystem::create_directory(data_directory);
+ std::ofstream(data_directory + "/INPUT1").close();
+ std::ofstream(data_directory + "/INPUT2").close();
+ std::ofstream(data_directory + "/OUTPUT1").close();
+ std::ofstream(data_directory + "/OUTPUT2").close();
+
+ cb::Error status =
+ dataloader.ValidateIOExistsInModel(inputs, outputs, data_directory);
+ std::filesystem::remove_all(data_directory);
+ CHECK(status.Message().empty());
+ CHECK(status.IsOk());
+ }
+}
+
+TEST_CASE("dataloader: ReadDataFromJSON")
+{
+ DataLoader dataloader;
+ std::shared_ptr inputs = std::make_shared();
+ std::shared_ptr outputs = std::make_shared();
+ ModelTensor input1 = TestDataLoader::CreateTensor("INPUT1");
+ ModelTensor output1 = TestDataLoader::CreateTensor("OUTPUT1");
+
+ inputs->insert(std::make_pair(input1.name_, input1));
+ outputs->insert(std::make_pair(output1.name_, output1));
+
+ SUBCASE("File does not exist")
+ {
+ std::string json_file = "non_existent_file.json";
+ cb::Error status = dataloader.ReadDataFromJSON(inputs, outputs, json_file);
+ CHECK(status.Message() == "failed to open file for reading provided data");
+ CHECK(status.Err() == pa::GENERIC_ERROR);
+ }
+
+ SUBCASE("Valid JSON file")
+ {
+ std::string json_file = "valid_file.json";
+ std::ofstream out(json_file);
+ out << R"({
+ "data": [
+ { "INPUT1": [1] },
+ { "INPUT1": [2] },
+ { "INPUT1": [3] }
+ ],
+ "validation_data": [
+ { "OUTPUT1": [4] },
+ { "OUTPUT1": [5] },
+ { "OUTPUT1": [6] }
+ ]})";
+ out.close();
+
+ cb::Error status = dataloader.ReadDataFromJSON(inputs, outputs, json_file);
+ std::filesystem::remove(json_file);
+ CHECK(status.Message().empty());
+ CHECK(status.IsOk());
+ }
+
+ SUBCASE("Invalid JSON file")
+ {
+ std::string json_file = "invalid_file.json";
+ std::ofstream out(json_file);
+ out << R"({invalid_json: 1,)";
+ out.close();
+
+ cb::Error status = dataloader.ReadDataFromJSON(inputs, outputs, json_file);
+ std::filesystem::remove(json_file);
+
+ CHECK(
+ status.Message() ==
+ "failed to parse the specified json file for reading provided data");
+ CHECK(status.Err() == pa::GENERIC_ERROR);
+ }
+
+ SUBCASE("Multiple input and output tensors")
+ {
+ ModelTensor input2 = TestDataLoader::CreateTensor("INPUT2");
+ ModelTensor output2 = TestDataLoader::CreateTensor("OUTPUT2");
+
+ inputs->insert(std::make_pair(input2.name_, input2));
+ outputs->insert(std::make_pair(output2.name_, output2));
+
+ std::string json_file = "valid_file_multiple_input_output.json";
+ std::ofstream out(json_file);
+ out << R"({
+ "data": [
+ {
+ "INPUT1": [1],
+ "INPUT2": [4]
+ },
+ {
+ "INPUT1": [2],
+ "INPUT2": [5]
+ },
+ {
+ "INPUT1": [3],
+ "INPUT2": [6]
+ }
+ ],
+ "validation_data": [
+ {
+ "OUTPUT1": [4],
+ "OUTPUT2": [7]
+ },
+ {
+ "OUTPUT1": [5],
+ "OUTPUT2": [8]
+ },
+ {
+ "OUTPUT1": [6],
+ "OUTPUT2": [9]
+ }
+ ]
+ })";
+ out.close();
+
+ cb::Error status = dataloader.ReadDataFromJSON(inputs, outputs, json_file);
+ std::filesystem::remove(json_file);
+ CHECK(status.Message().empty());
+ CHECK(status.IsOk());
+ }
+}
+
TEST_CASE("dataloader: GetInputData missing data")
{
MockDataLoader dataloader;
diff --git a/src/c++/perf_analyzer/test_perf_utils.cc b/src/c++/perf_analyzer/test_perf_utils.cc
index 34a08a108..74bf6afb4 100644
--- a/src/c++/perf_analyzer/test_perf_utils.cc
+++ b/src/c++/perf_analyzer/test_perf_utils.cc
@@ -144,6 +144,7 @@ TEST_CASE("perf_utils: ConvertDTypeFromTFS")
std::make_pair("DT_DOUBLE", "FP64"),
std::make_pair("DT_INT32", "INT32"),
std::make_pair("DT_INT16", "INT16"),
+ std::make_pair("DT_UINT16", "UINT16"),
std::make_pair("DT_INT8", "INT8"),
std::make_pair("DT_UINT8", "UINT8"),
std::make_pair("DT_STRING", "BYTES"),