Skip to content

Commit

Permalink
Port VLM functionalities from main branch to the refactor branch (#769)
Browse files Browse the repository at this point in the history
* port VLM input generation

* add tests

* add new tests

* port output parsing and metrics

* add one more test

* Fix PR run

* address feedback
  • Loading branch information
nv-hwoo authored Jul 30, 2024
1 parent 5a55a7e commit 93b2f5a
Show file tree
Hide file tree
Showing 23 changed files with 794 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@
from typing import Any, Dict, List

import requests
from genai_perf import utils
from genai_perf.exceptions import GenAIPerfException
from genai_perf.llm_inputs.inputs_utils import ImageFormat, OutputFormat
from genai_perf.llm_inputs.synthetic_image_generator import SyntheticImageGenerator
from genai_perf.llm_inputs.synthetic_prompt_generator import SyntheticPromptGenerator
from genai_perf.tokenizer import Tokenizer
from genai_perf.utils import load_json_str
from PIL import Image


class DatasetRetriever:
Expand All @@ -56,26 +59,58 @@ def from_url(url: str, starting_index: int, length: int) -> List[Dict[str, Any]]
]
return formatted_rows

# (TMA-2018) decouple output_format from this method
@staticmethod
def from_file(file_path: Path) -> List[Dict[str, str]]:
with open(file_path, "r") as file:
data = [load_json_str(line) for line in file]
def from_file(file_path: Path, output_format: OutputFormat) -> List[Dict[str, str]]:
contents = DatasetRetriever._load_file_content(file_path)

dataset = []
for content in contents:
data = {"text_input": content.get("text_input", "")}

if output_format == OutputFormat.OPENAI_VISION:
img_filename = content.get("image", "")
encoded_img = DatasetRetriever._read_image_content(img_filename)
data["image"] = encoded_img

dataset.append(data)
return dataset

for item in data:
if not isinstance(item, dict):
@staticmethod
def _load_file_content(file_path: Path) -> List[Dict[str, str]]:
contents = []
with open(file_path, "r") as file:
for line in file:
content = utils.load_json_str(line)
if not isinstance(content, dict):
raise GenAIPerfException(
"File content is not in the expected format."
)
if "text_input" not in item:
raise GenAIPerfException(
f"Missing 'text_input' field in file item: {item}"
)
if len(item) != 1:
if "text_input" not in content:
raise GenAIPerfException(
f"Field other than 'text_input' field found in file item: {item}"
f"Missing 'text_input' field in file content: {content}"
)
contents.append(content)

return [{"text_input": item["text_input"]} for item in data]
return contents

@staticmethod
def _read_image_content(filename: str) -> str:
try:
img = Image.open(filename)
except:
raise GenAIPerfException(
f"Error occurred while opening an image file: {filename}"
)

if img.format.lower() not in utils.get_enum_names(ImageFormat):
raise GenAIPerfException(
f"Unsupported image format '{img.format}' of "
f"the image '{filename}'."
)

img_base64 = utils.encode_image(img, img.format)
return f"data:image/{img.format.lower()};base64,{img_base64}"

@staticmethod
def from_directory(directory_path: Path) -> Dict:
Expand All @@ -89,7 +124,7 @@ def from_directory(directory_path: Path) -> Dict:
# Get the file name without suffix
key = file_path.stem
with open(file_path, "r") as file:
data[key] = [load_json_str(line) for line in file]
data[key] = [utils.load_json_str(line) for line in file]

# Create rows with keys based on file names without suffix
num_entries = len(next(iter(data.values())))
Expand All @@ -105,11 +140,29 @@ def from_synthetic(
prompt_tokens_mean: int,
prompt_tokens_stddev: int,
num_of_output_prompts: int,
image_width_mean: int,
image_width_stddev: int,
image_height_mean: int,
image_height_stddev: int,
image_format: ImageFormat,
output_format: OutputFormat,
) -> List[Dict[str, str]]:
synthetic_prompts = []
synthetic_dataset = []
for _ in range(num_of_output_prompts):
synthetic_prompt = SyntheticPromptGenerator.create_synthetic_prompt(
prompt = SyntheticPromptGenerator.create_synthetic_prompt(
tokenizer, prompt_tokens_mean, prompt_tokens_stddev
)
synthetic_prompts.append({"text_input": synthetic_prompt})
return synthetic_prompts
data = {"text_input": prompt}

if output_format == OutputFormat.OPENAI_VISION:
image = SyntheticImageGenerator.create_synthetic_image(
image_width_mean=image_width_mean,
image_width_stddev=image_width_stddev,
image_height_mean=image_height_mean,
image_height_stddev=image_height_stddev,
image_format=image_format,
)
data["image"] = image

synthetic_dataset.append(data)
return synthetic_dataset
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class OutputFormat(Enum):
OPENAI_CHAT_COMPLETIONS = auto()
OPENAI_COMPLETIONS = auto()
OPENAI_EMBEDDINGS = auto()
OPENAI_VISION = auto()
RANKINGS = auto()
TENSORRTLLM = auto()
VLLM = auto()
Expand All @@ -53,6 +54,11 @@ def to_lowercase(self):
return self.name.lower()


class ImageFormat(Enum):
PNG = auto()
JPEG = auto()


DEFAULT_STARTING_INDEX = 0
DEFAULT_LENGTH = 100
DEFAULT_TENSORRTLLM_MAX_TOKENS = 256
Expand All @@ -63,3 +69,9 @@ def to_lowercase(self):
DEFAULT_OUTPUT_TOKENS_MEAN = -1
DEFAULT_OUTPUT_TOKENS_STDDEV = 0
DEFAULT_NUM_PROMPTS = 100

# Images
DEFAULT_IMAGE_WIDTH_MEAN = 100
DEFAULT_IMAGE_WIDTH_STDDEV = 0
DEFAULT_IMAGE_HEIGHT_MEAN = 100
DEFAULT_IMAGE_HEIGHT_STDDEV = 0
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def to_generic(dataset: List[Dict[str, Any]]) -> Dict:
for item in dataset:
row_data = {
"text_input": item.get("text_input", ""),
"image": item.get("image", ""),
"system_prompt": item.get("system_prompt", ""),
"response": item.get("response", ""),
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
from genai_perf.exceptions import GenAIPerfException
from genai_perf.llm_inputs.dataset_retriever import DatasetRetriever
from genai_perf.llm_inputs.inputs_utils import (
DEFAULT_IMAGE_HEIGHT_MEAN,
DEFAULT_IMAGE_HEIGHT_STDDEV,
DEFAULT_IMAGE_WIDTH_MEAN,
DEFAULT_IMAGE_WIDTH_STDDEV,
DEFAULT_LENGTH,
DEFAULT_NUM_PROMPTS,
DEFAULT_OUTPUT_TOKENS_MEAN,
Expand All @@ -40,6 +44,7 @@
DEFAULT_PROMPT_TOKENS_STDDEV,
DEFAULT_RANDOM_SEED,
DEFAULT_STARTING_INDEX,
ImageFormat,
ModelSelectionStrategy,
OutputFormat,
PromptSource,
Expand Down Expand Up @@ -76,6 +81,11 @@ def create_llm_inputs(
output_tokens_deterministic: bool = False,
prompt_tokens_mean: int = DEFAULT_PROMPT_TOKENS_MEAN,
prompt_tokens_stddev: int = DEFAULT_PROMPT_TOKENS_STDDEV,
image_width_mean: int = DEFAULT_IMAGE_WIDTH_MEAN,
image_width_stddev: int = DEFAULT_IMAGE_WIDTH_STDDEV,
image_height_mean: int = DEFAULT_IMAGE_HEIGHT_MEAN,
image_height_stddev: int = DEFAULT_IMAGE_HEIGHT_STDDEV,
image_format: ImageFormat = ImageFormat.PNG,
random_seed: int = DEFAULT_RANDOM_SEED,
num_of_output_prompts: int = DEFAULT_NUM_PROMPTS,
add_model_name: bool = False,
Expand All @@ -101,14 +111,20 @@ def create_llm_inputs(
prompt_tokens_mean,
prompt_tokens_stddev,
num_of_output_prompts,
image_width_mean,
image_width_stddev,
image_height_mean,
image_height_stddev,
image_format,
output_format,
)
elif input_type == PromptSource.FILE:
input_filename = cast(Path, input_filename)
# TODO: Follow-up ticket to add support for rankings
# if output_format == OutputFormat.RANKINGS:
# dataset = DatasetRetriever.from_directory(input_filename)
# else:
dataset = DatasetRetriever.from_file(input_filename)
dataset = DatasetRetriever.from_file(input_filename, output_format)
else:
raise GenAIPerfException("Input source is not recognized.")

Expand Down Expand Up @@ -147,6 +163,7 @@ def validate_args(
PromptSource.DATASET,
],
OutputFormat.RANKINGS: [PromptSource.DATASET, PromptSource.SYNTHETIC],
OutputFormat.OPENAI_VISION: [PromptSource.DATASET],
}

if input_type in unsupported_combinations.get(output_format, []):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import random
from typing import Dict, List
from typing import Any, Dict, List, Union

from genai_perf.exceptions import GenAIPerfException
from genai_perf.llm_inputs.inputs_utils import (
Expand All @@ -46,6 +46,7 @@ def create(output_format: OutputFormat):
converters = {
OutputFormat.OPENAI_CHAT_COMPLETIONS: OpenAIChatCompletionsConverter,
OutputFormat.OPENAI_COMPLETIONS: OpenAICompletionsConverter,
OutputFormat.OPENAI_VISION: OpenAIChatCompletionsConverter,
OutputFormat.OPENAI_EMBEDDINGS: OpenAIEmbeddingsConverter,
OutputFormat.RANKINGS: RankingsConverter,
OutputFormat.VLLM: VLLMConverter,
Expand Down Expand Up @@ -105,8 +106,8 @@ def convert(

for index, row in enumerate(generic_dataset["rows"]):
model = self._select_model_name(model_name, index, model_selection_strategy)
text_content = row["row"]["text_input"]
messages = [{"role": "user", "content": text_content}]
content = self._generate_content(data=row["row"])
messages = [{"role": "user", "content": content}]
payload: Dict = {"messages": messages}

if add_model_name:
Expand All @@ -123,6 +124,28 @@ def convert(

return pa_json

def _generate_content(
self, data: Dict[str, str]
) -> Union[str, List[Dict[str, Any]]]:
"""
Generate either text only or multi-modal content for OpenAI Chat Completions API.
"""
content: str | List[Dict[str, Any]] = data["text_input"]

# convert into multi-modal content format when image exists
if data["image"]:
content = [
{
"type": "text",
"text": data["text_input"],
},
{
"type": "image_url",
"image_url": {"url": data["image"]},
},
]
return content


class OpenAICompletionsConverter(BaseConverter):
def convert(
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import glob
import random
from pathlib import Path
from typing import Optional

from genai_perf import utils
from genai_perf.llm_inputs.inputs_utils import ImageFormat
from PIL import Image


class SyntheticImageGenerator:
"""A simple synthetic image generator that generates multiple synthetic
images from the source images.
"""

@classmethod
def create_synthetic_image(
cls,
image_width_mean: int,
image_width_stddev: int,
image_height_mean: int,
image_height_stddev: int,
image_format: Optional[ImageFormat] = None,
) -> str:
"""Generate base64 encoded synthetic image using the source images."""
if image_format is None:
image_format = random.choice(list(ImageFormat))
width = cls._sample_random_positive_integer(
image_width_mean, image_width_stddev
)
height = cls._sample_random_positive_integer(
image_height_mean, image_height_stddev
)

image = cls._sample_source_image()
image = image.resize(size=(width, height))

img_base64 = utils.encode_image(image, image_format.name)
return f"data:image/{image_format.name.lower()};base64,{img_base64}"

@classmethod
def _sample_source_image(cls):
"""Sample one image among the source images."""
filepath = Path(__file__).parent.resolve() / "source_images" / "*"
filenames = glob.glob(str(filepath))
return Image.open(random.choice(filenames))

@classmethod
def _sample_random_positive_integer(cls, mean: int, stddev: int) -> int:
n = int(abs(random.gauss(mean, stddev)))
return n if n != 0 else 1 # avoid zero
5 changes: 5 additions & 0 deletions src/c++/perf_analyzer/genai-perf/genai_perf/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ def generate_inputs(args: Namespace, tokenizer: Tokenizer) -> None:
output_tokens_mean=args.output_tokens_mean,
output_tokens_stddev=args.output_tokens_stddev,
output_tokens_deterministic=args.output_tokens_mean_deterministic,
image_width_mean=args.image_width_mean,
image_width_stddev=args.image_width_stddev,
image_height_mean=args.image_height_mean,
image_height_stddev=args.image_height_stddev,
image_format=args.image_format,
random_seed=args.random_seed,
num_of_output_prompts=args.num_prompts,
add_model_name=add_model_name,
Expand Down
Loading

0 comments on commit 93b2f5a

Please sign in to comment.