Skip to content

Commit

Permalink
Simplify test dataset creation for image retrieval, remove unused imp…
Browse files Browse the repository at this point in the history
…orts in image retrieval converter
  • Loading branch information
djwhatle committed Nov 14, 2024
1 parent 811408a commit e89c137
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from typing import Any, Dict, List, Union
from typing import Any, Dict

from genai_perf.exceptions import GenAIPerfException
from genai_perf.inputs.converters.base_converter import BaseConverter
from genai_perf.inputs.input_constants import OutputFormat
from genai_perf.inputs.inputs_config import InputsConfig
from genai_perf.inputs.retrievers.generic_dataset import DataRow, GenericDataset
from genai_perf.inputs.retrievers.generic_dataset import GenericDataset


class ImageRetrievalConverter(BaseConverter):
Expand Down
31 changes: 5 additions & 26 deletions genai-perf/tests/test_image_retrieval_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,43 +40,22 @@
class TestImageRetrievalConverter:

@staticmethod
def create_generic_dataset(rows: List[Dict[str, Any]]) -> GenericDataset:
def clean_text(row):
text = row.get("text", [])
if isinstance(text, list):
return [t for t in text if t]
elif text:
return [text]
return []

def clean_image(row):
image = row.get("image", [])
if isinstance(image, list):
return [i for i in image if i]
elif image:
return [image]
return []

def create_generic_dataset() -> GenericDataset:
return GenericDataset(
files_data={
"file1": FileData(
rows=[
DataRow(texts=clean_text(row), images=clean_image(row))
for row in rows
DataRow(images=["test_image_1", "test_image_2"]),
],
)
}
)

def test_convert_multi_modal_batched(self) -> None:
def test_convert_default(self) -> None:
"""
Test batched Image Retrieval request payload
Test Image Retrieval request payload
"""
generic_dataset = self.create_generic_dataset(
[
{"image": ["test_image_1", "test_image_2"]},
]
)
generic_dataset = self.create_generic_dataset()

config = InputsConfig(
extra_inputs={},
Expand Down
56 changes: 0 additions & 56 deletions genai-perf/tests/test_openai_chat_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,59 +295,3 @@ def test_convert_multi_modal(
}

assert result == expected_result

def test_convert_multi_modal_batched(self) -> None:
"""
Test batched multi-modal format of OpenAI Chat API for Image Retrieval
"""
generic_dataset = self.create_generic_dataset(
[
{"image": ["test_image_1", "test_image_2"]},
]
)

config = InputsConfig(
extra_inputs={},
model_name=["test_model"],
model_selection_strategy=ModelSelectionStrategy.ROUND_ROBIN,
output_format=OutputFormat.IMAGE_RETRIEVAL,
add_stream=True,
tokenizer=get_empty_tokenizer(),
)

chat_converter = OpenAIChatCompletionsConverter()
result = chat_converter.convert(generic_dataset, config)

expected_result = {
"data": [
{
"payload": [
{
"model": "test_model",
"messages": [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "test_image_1",
},
},
{
"type": "image_url",
"image_url": {
"url": "test_image_2",
},
},
],
}
],
"stream": True,
}
]
},
]
}

assert result == expected_result

0 comments on commit e89c137

Please sign in to comment.