Skip to content

Commit

Permalink
Refactor mmmu_doc_to_visual function to handle validation_Psychology_…
Browse files Browse the repository at this point in the history
…21 case
  • Loading branch information
pufanyi committed Jan 16, 2025
1 parent 4728e43 commit aa10a0c
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 18 deletions.
66 changes: 48 additions & 18 deletions lmms_eval/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import os
from io import BytesIO
from re import I
from typing import List, Optional, Tuple, Union

import decord
Expand Down Expand Up @@ -40,9 +41,9 @@ def __init__(
batch_size: Optional[Union[int, str]] = 1,
use_cache=True,
use_flash_attention_2: Optional[bool] = True,
max_pixels: int = 1605632 // (2**6),
max_pixels: int = 1605632,
min_pixels: int = 3136,
max_num_frames: int = 20,
max_num_frames: int = 1,
use_custom_video_loader: Optional[bool] = True,
fps: Optional[float] = None, # Only applicable if use_custom_video_loader is True
max_image_size: Optional[int] = 1024, # Only applicable if use_custom_video_loader is True
Expand Down Expand Up @@ -221,8 +222,7 @@ def _collate(x):
res.extend(content)
pbar.update(1)
continue
visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id]
visuals = self.flatten(visuals)
visuals = [doc_to_visual[i](self.task_dict[task][split][ids]) for i, ids in enumerate(doc_id)]

gen_kwargs = all_gen_kwargs[0]

Expand All @@ -240,33 +240,53 @@ def _collate(x):
if isinstance(contexts, tuple):
contexts = list(contexts)

for i in range(len(contexts)):
if "<image 1>" in contexts[i]:
contexts[i] = contexts[i].replace("<image 1>", "<image>")
if "\\<image 1\\>" in contexts[i]:
contexts[i] = contexts[i].replace("\\<image 1\\>", "<image>")
if "<image>" in contexts[i]:
contexts[i] = contexts[i].replace("<image>", "")
# print(contexts[i])
# for i in range(len(contexts)):
# for j in range(32):
# if f"<image {j}>" in contexts[i]:
# contexts[i] = contexts[i].replace(f"<image {j}>", "<image>")
# if f"\\<image {j}\\>" in contexts[i]:
# contexts[i] = contexts[i].replace(f"\\<image {j}\\>", "<image>")
# if "<image>" in contexts[i]:
# contexts[i] = contexts[i].replace("<image>", "")
# print(contexts[i])

messages = []
processed_visuals = []
for i, context in enumerate(contexts):
context += "\nPlease think step by step."
# context += "\nPlease think step by step."

# if "<image>" in context:
# context = context.split("<image>")
# assert len(context) == 2, f"Expected 2 parts in context but got {len(context)}"
print("context", context)

if "<image>" in context:
context = context.split("<image>")
# print(json.dumps(context, indent=4))
else:
context = [context]

message = [{"role": "system", "content": "You are a helpful assistant."}]

if len(visuals) > 0:
visual = visuals[i] if i < len(visuals) else None
print("visuals", visual)
if isinstance(visual, Image.Image):
visual = [visual]
if isinstance(visual, str) and visual.endswith((".mp4", ".avi", ".mov")): # Video file
if self.use_custom_video_loader:
visual = read_video_pyav_base64(visual, num_frm=self.max_num_frames, fps=self.fps, img_format="JPEG", max_image_size=self.max_image_size)
image_contents = list(map(lambda x: f"data:image/jpeg;base64,{x}", visual))
message.append({"role": "user", "content": [{"type": "video", "video": image_contents}, {"type": "text", "text": context}]})
if len(context) == 2:
print("image_contents", image_contents)
if len(image_contents) == 1:
message.append(
{
"role": "user",
"content": [{"type": "video", "video": image_contents[:-1]}, {"type": "text", "text": context[0]}, {"type": "image", "image": image_contents[-1]}, {"type": "text", "text": context[1]}],
}
)
else:
message.append({"role": "user", "content": [{"type": "text", "text": context[0]}, {"type": "image", "image": image_contents[-1]}, {"type": "text", "text": context[1]}]})
else:
message.append({"role": "user", "content": [{"type": "video", "video": image_contents}, {"type": "text", "text": context}]})
else:
vr = decord.VideoReader(visual)
first_frame = vr[0].asnumpy()
Expand All @@ -282,14 +302,24 @@ def _collate(x):
message.append({"role": "user", "content": [{"type": "image", "image": f"data:image/jpeg;base64,{base64_string}"}, {"type": "text", "text": context}]})
elif isinstance(visual, (list, tuple)) and all(isinstance(v, Image.Image) for v in visual): # Multiple images
image_content = []
i = 0
for v in visual:
base64_image = v.convert("RGB")
buffer = BytesIO()
base64_image.save(buffer, format="JPEG")
base64_bytes = base64.b64encode(buffer.getvalue())
base64_string = base64_bytes.decode("utf-8")
image_content.append({"type": "image", "image": f"data:image/jpeg;base64,{base64_string}"})
message.append({"role": "user", "content": image_content + [{"type": "text", "text": context}]})
v.save(f"test_{i}.jpg")
i += 1
# message.append({"role": "user", "content": image_content + [{"type": "text", "text": context}]})
assert len(image_content) + 1 == len(context), f"Number of images and context do not match, {len(image_content)} images and {len(context)} context\n{json.dumps(context)}"
content = []
for i in range(len(image_content)):
content.append({"type": "text", "text": context[i]})
content.append(image_content[i])
content.append({"type": "text", "text": context[-1]})
# print("content", content)
else:
message.append({"role": "user", "content": [{"type": "text", "text": context}]})
else:
Expand Down
3 changes: 3 additions & 0 deletions lmms_eval/tasks/mmmu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def mmmu_doc_to_visual(doc):
# Remove <> and swap space as _
image_tokens = sorted(list(set([image_token.strip("<>").replace(" ", "_") for image_token in image_tokens])))
visual = [doc[image_token].convert("RGB") for image_token in image_tokens]
if doc["id"] == "validation_Psychology_21":
print(image_tokens)
print(visual)
return visual


Expand Down
Binary file added test_0.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test_1.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit aa10a0c

Please sign in to comment.