Skip to content

Commit

Permalink
address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-hwoo committed Jul 30, 2024
1 parent 1f8c926 commit ed90c85
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def from_url(url: str, starting_index: int, length: int) -> List[Dict[str, Any]]
]
return formatted_rows

# (TMA-2018) decouple output_format from this method
@staticmethod
def from_file(file_path: Path, output_format: OutputFormat) -> List[Dict[str, str]]:
contents = DatasetRetriever._load_file_content(file_path)
Expand All @@ -69,7 +70,7 @@ def from_file(file_path: Path, output_format: OutputFormat) -> List[Dict[str, st

if output_format == OutputFormat.OPENAI_VISION:
img_filename = content.get("image", "")
encoded_img = DatasetRetriever._encode_image_to_base64(img_filename)
encoded_img = DatasetRetriever._read_image_content(img_filename)
data["image"] = encoded_img

dataset.append(data)
Expand All @@ -94,7 +95,7 @@ def _load_file_content(file_path: Path) -> List[Dict[str, str]]:
return contents

@staticmethod
def _encode_image_to_base64(filename: str) -> str:
def _read_image_content(filename: str) -> str:
try:
img = Image.open(filename)
except:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,6 @@ def create_llm_inputs(
random.seed(random_seed)

if input_type == PromptSource.DATASET:
# (TMA-1990) support VLM input from public dataset
if output_format == OutputFormat.OPENAI_VISION:
raise GenAIPerfException(
f"{OutputFormat.OPENAI_VISION.to_lowercase()} currently "
"does not support dataset as input."
)
dataset = DatasetRetriever.from_url(
cls.dataset_url_map[dataset_name], starting_index, length
)
Expand Down Expand Up @@ -169,6 +163,7 @@ def validate_args(
PromptSource.DATASET,
],
OutputFormat.RANKINGS: [PromptSource.DATASET, PromptSource.SYNTHETIC],
OutputFormat.OPENAI_VISION: [PromptSource.DATASET],
}

if input_type in unsupported_combinations.get(output_format, []):
Expand Down
2 changes: 1 addition & 1 deletion src/c++/perf_analyzer/genai-perf/tests/test_llm_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def test_get_input_file_with_multiple_prompts(self, mock_file, mock_exists):

@patch("pathlib.Path.exists", return_value=True)
@patch(
"genai_perf.llm_inputs.dataset_retriever.DatasetRetriever._encode_image_to_base64",
"genai_perf.llm_inputs.dataset_retriever.DatasetRetriever._read_image_content",
return_value="data:image/png;base64,...",
)
@patch(
Expand Down

0 comments on commit ed90c85

Please sign in to comment.