From 8c0c09b79e613c809c5d6aded7ea14544e449de8 Mon Sep 17 00:00:00 2001 From: David Yastremsky Date: Wed, 16 Oct 2024 14:14:20 -0700 Subject: [PATCH] Address feedback --- .../openai_chat_completions_converter.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/genai-perf/genai_perf/inputs/converters/openai_chat_completions_converter.py b/genai-perf/genai_perf/inputs/converters/openai_chat_completions_converter.py index e46b18fa..70188c9e 100644 --- a/genai-perf/genai_perf/inputs/converters/openai_chat_completions_converter.py +++ b/genai-perf/genai_perf/inputs/converters/openai_chat_completions_converter.py @@ -59,14 +59,7 @@ def convert(self, generic_dataset: GenericDataset, config: InputsConfig) -> Dict def _create_payload(self, index: int, row: DataRow, config: InputsConfig) -> Dict[Any, Any]: model_name = self._select_model_name(config, index) - - content: Union[str, List[Dict[Any, Any]]] = "" - if config.output_format == OutputFormat.OPENAI_CHAT_COMPLETIONS: - content = row.texts[0] - elif config.output_format == OutputFormat.OPENAI_VISION or config.output_format == OutputFormat.IMAGE_RETRIEVAL: - content = self._add_multi_modal_content(row) - else: - raise GenAIPerfException(f"Output format {config.output_format} is not supported") + content = self._retrieve_content(row, config) payload = { "model": model_name, @@ -81,6 +74,15 @@ def _create_payload(self, index: int, row: DataRow, config: InputsConfig) -> Dic self._add_request_params(payload, config) return payload + def _retrieve_content(self, row: DataRow, config: InputsConfig) -> Union[str, List[Dict[Any, Any]]]: + content: Union[str, List[Dict[Any, Any]]] = "" + if config.output_format == OutputFormat.OPENAI_CHAT_COMPLETIONS: + content = row.texts[0] + elif config.output_format == OutputFormat.OPENAI_VISION or config.output_format == OutputFormat.IMAGE_RETRIEVAL: + content = self._add_multi_modal_content(row) + else: + raise GenAIPerfException(f"Output format {config.output_format} is not supported") + return content def _add_multi_modal_content(self, entry: DataRow) -> List[Dict[Any, Any]]: content = []