diff --git a/.gitignore b/.gitignore index 2557ab1b..922b9f0b 100755 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ temp __pycache__ .ipynb_checkpoints temp +.DS_STORE # IPython profile_default/ ipython_config.py diff --git a/README.md b/README.md index 2bf7a26e..7be1c6c7 100755 --- a/README.md +++ b/README.md @@ -12,7 +12,9 @@ ## Annoucement -- [2024-06] ๐ŸŽฌ๐ŸŽฌ The `lmms-eval/v0.2` has been upgraded to support video evaluations for video models like LLaVA-NeXT Video and Gemini 1.5 Pro across tasks such as EgoSchema, PerceptionTest, VideoMME, and more. Please refer to the [blog](https://lmms-lab.github.io/posts/lmms-eval-0.2/) for more details +- [2024-07] ๐Ÿ‘จโ€๐Ÿ’ป๐Ÿ‘จโ€๐Ÿ’ป The `lmms-eval/v0.2.1` has been upgraded to support more models, including [LongVA](https://github.com/EvolvingLMMs-Lab/LongVA), [InterVL-2](https://github.com/OpenGVLab/InternVL), [VILA](https://github.com/NVlabs/VILA), and many more evaluation tasks, e.g. [Details Captions](https://github.com/EvolvingLMMs-Lab/lmms-eval/pull/136), [MLVU](https://arxiv.org/abs/2406.04264), [WildVision-Bench](https://huggingface.co/datasets/WildVision/wildvision-arena-data), [VITATECS](https://github.com/lscpku/VITATECS) and [LLaVA-Interleave-Bench](https://llava-vl.github.io/blog/2024-06-16-llava-next-interleave/). + +- [2024-06] ๐ŸŽฌ๐ŸŽฌ The `lmms-eval/v0.2.0` has been upgraded to support video evaluations for video models like LLaVA-NeXT Video and Gemini 1.5 Pro across tasks such as EgoSchema, PerceptionTest, VideoMME, and more. Please refer to the [blog](https://lmms-lab.github.io/posts/lmms-eval-0.2/) for more details - [2024-03] ๐Ÿ“๐Ÿ“ We have released the first version of `lmms-eval`, please refer to the [blog](https://lmms-lab.github.io/posts/lmms-eval-0.1/) for more details diff --git a/lmms_eval/__main__.py b/lmms_eval/__main__.py index 159ac5a6..ef0e2f1c 100755 --- a/lmms_eval/__main__.py +++ b/lmms_eval/__main__.py @@ -165,7 +165,6 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: # reset logger eval_logger.remove() eval_logger.add(sys.stdout, colorize=True, level=args.verbosity) - eval_logger.add(sys.stderr, level=args.verbosity) eval_logger.info(f"Verbosity set to {args.verbosity}") os.environ["TOKENIZERS_PARALLELISM"] = "false" diff --git a/lmms_eval/api/samplers.py b/lmms_eval/api/samplers.py index f77065e8..2cecfe22 100755 --- a/lmms_eval/api/samplers.py +++ b/lmms_eval/api/samplers.py @@ -37,7 +37,9 @@ def get_context(self, doc, num_fewshot): + ( str(self.doc_to_target(doc)[0]) if type(self.doc_to_target(doc)) is list - else self.doc_to_target(doc) if (self.config.doc_to_choice is None or type(self.doc_to_target(doc)) is str) else str(self.doc_to_choice(doc)[self.doc_to_target(doc)]) + else self.doc_to_target(doc) + if (self.config.doc_to_choice is None or type(self.doc_to_target(doc)) is str) + else str(self.doc_to_choice(doc)[self.doc_to_target(doc)]) ) for doc in selected_docs ] diff --git a/lmms_eval/evaluator.py b/lmms_eval/evaluator.py index 041eec31..aec3bcc2 100755 --- a/lmms_eval/evaluator.py +++ b/lmms_eval/evaluator.py @@ -327,12 +327,7 @@ def evaluate( # hack: remove image columns to speed avoid loading images and speed up postprocessing # reason: doc_iterator will actually load image if it's in the doc. docs = task.test_docs() if task.has_test_docs() else task.validation_docs() - if "d170" not in task_name \ - and "dc100" not in task_name \ - and "dc200" not in task_name \ - and "llava_wilder" not in task_name \ - and "livebench" not in task_name \ - and "wildvision" not in task_name: + if "d170" not in task_name and "dc100" not in task_name and "dc200" not in task_name and "llava_wilder" not in task_name and "livebench" not in task_name and "wildvision" not in task_name: remove_cols = [] features = docs.features # If it is an Image instance or a Sequence of Image instance. Remove it diff --git a/lmms_eval/models/__init__.py b/lmms_eval/models/__init__.py index 2ed909e4..0ca7692c 100755 --- a/lmms_eval/models/__init__.py +++ b/lmms_eval/models/__init__.py @@ -4,6 +4,10 @@ from loguru import logger import sys +import hf_transfer + +os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" + logger.remove() logger.add(sys.stdout, level="WARNING") @@ -25,6 +29,7 @@ "llava_sglang": "LlavaSglang", "idefics2": "Idefics2", "internvl": "InternVLChat", + "internvl2": "InternVL2", "gemini_api": "GeminiAPI", "reka": "Reka", "from_log": "FromLog", @@ -33,14 +38,16 @@ "tinyllava": "TinyLlava", "llava_hf": "LlavaHf", "longva": "LongVA", + "llava_hf": "LlavaHf", + "longva": "LongVA", + "vila": "VILA", } for model_name, model_class in AVAILABLE_MODELS.items(): try: exec(f"from .{model_name} import {model_class}") except ImportError as e: - # logger.warning(f"Failed to import {model_class} from {model_name}: {e}") - pass + logger.warning(f"Failed to import {model_class} from {model_name}: {e}") if os.environ.get("LMMS_EVAL_PLUGINS", None): # Allow specifying other packages to import models from @@ -50,8 +57,4 @@ try: exec(f"from {plugin}.models.{model_name} import {model_class}") except ImportError: - pass - -import hf_transfer - -os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" + logger.warning(f"Failed to import {model_class} from {model_name}") diff --git a/lmms_eval/models/batch_gpt4.py b/lmms_eval/models/batch_gpt4.py index 8f4c2220..7541b709 100755 --- a/lmms_eval/models/batch_gpt4.py +++ b/lmms_eval/models/batch_gpt4.py @@ -59,7 +59,7 @@ def __init__( api_key: str = API_KEY, api_url: str = API_URL, modality: str = "image", - max_frames_for_video: int = 10, + max_frames_num: int = 10, timeout: int = 120, **kwargs, ) -> None: @@ -69,7 +69,7 @@ def __init__( # Here we just use the same token as llava for convenient self.model_version = model_version self.modality = modality - self.max_frames_for_video = max_frames_for_video + self.max_frames_num = max_frames_num self.image_token = "" self.timeout = timeout @@ -128,7 +128,7 @@ def generate_until(self, requests): img = self.encode_image(visual) imgs.append(img) elif self.modality == "video": - frames = self.encode_video(visual, self.max_frames_for_video) + frames = self.encode_video(visual, self.max_frames_num) imgs.extend(frames) messages = [] diff --git a/lmms_eval/models/claude.py b/lmms_eval/models/claude.py index 4c967e88..5829fbed 100644 --- a/lmms_eval/models/claude.py +++ b/lmms_eval/models/claude.py @@ -40,6 +40,7 @@ def __init__( image_token: str = "", # Use to separate interleaved image and text system_prompt: str = "", # Whether you want some special system prompt here modality: str = "image", + max_frames_num: int = 10, continual_mode: bool = False, response_persistent_folder: str = None, **kwargs, @@ -49,20 +50,24 @@ def __init__( self.image_token = image_token self.system_prompt = system_prompt self.modality = modality + self.max_frames_num = max_frames_num self.continual_mode = continual_mode - if self.continual_mode and response_persistent_folder is None: - raise ValueError("Continual mode requires a persistent path for the response. Please provide a valid path.") - self.response_persistent_folder = response_persistent_folder - self.response_persistent_file = os.path.join(self.response_persistent_folder, f"{self.model_version}_response.json") - - if os.path.exists(self.response_persistent_file): - with open(self.response_persistent_file, "r") as f: - self.response_cache = json.load(f) - self.cache_mode = "resume" - else: - self.response_cache = {} - self.cache_mode = "start" + if self.continual_mode: + if response_persistent_folder is None: + raise ValueError("Continual mode requires a persistent path for the response. Please provide a valid path.") + + os.makedirs(response_persistent_folder, exist_ok=True) + self.response_persistent_folder = response_persistent_folder + self.response_persistent_file = os.path.join(self.response_persistent_folder, f"{self.model_version}_response.json") + + if os.path.exists(self.response_persistent_file): + with open(self.response_persistent_file, "r") as f: + self.response_cache = json.load(f) + self.cache_mode = "resume" + else: + self.response_cache = {} + self.cache_mode = "start" accelerator = Accelerator() if accelerator.num_processes > 1: @@ -81,7 +86,7 @@ def __init__( def encode_image(self, image): output_buffer = BytesIO() - image.save(output_buffer, format="PNG") + image.save(output_buffer, format="JPEG") byte_data = output_buffer.getvalue() base64_str = base64.b64encode(byte_data).decode("utf-8") return base64_str @@ -129,7 +134,7 @@ def shrink_image_to_file_size(self, img: Image, max_file_size=4838990) -> Image: def encode_video(self, video_path): vr = VideoReader(video_path, ctx=cpu(0)) total_frame_num = len(vr) - uniform_sampled_frames = np.linspace(0, total_frame_num - 1, self.max_frames_for_video, dtype=int) + uniform_sampled_frames = np.linspace(0, total_frame_num - 1, self.max_frames_num, dtype=int) frame_idx = uniform_sampled_frames.tolist() frames = vr.get_batch(frame_idx).asnumpy() @@ -137,10 +142,10 @@ def encode_video(self, video_path): for frame in frames: img = Image.fromarray(frame) output_buffer = BytesIO() - img.save(output_buffer, format="PNG") + img.save(output_buffer, format="JPEG") byte_data = output_buffer.getvalue() base64_str = base64.b64encode(byte_data).decode("utf-8") - base64_frames.append(f"data:image/jpeg;base64,{base64_str}") + base64_frames.append(f"{base64_str}") return base64_frames @@ -154,7 +159,7 @@ def generate_until(self, requests) -> List[str]: "type": "image", "source": { "type": "base64", - "media_type": "image/png", + "media_type": "image/jpeg", }, } empty_text_block = {"type": "text"} @@ -218,10 +223,12 @@ def generate_until(self, requests) -> List[str]: if "max_new_tokens" not in gen_kwargs: gen_kwargs["max_new_tokens"] = 1024 + if gen_kwargs["max_new_tokens"] > 4096: + gen_kwargs["max_new_tokens"] = 4096 if "temperature" not in gen_kwargs: gen_kwargs["temperature"] = 0 - if "top_p" not in gen_kwargs: - gen_kwargs["top_p"] = None + if "top_p" not in gen_kwargs or gen_kwargs["top_p"] is None: + gen_kwargs["top_p"] = 1 if "num_beams" not in gen_kwargs: gen_kwargs["num_beams"] = 1 @@ -238,11 +245,13 @@ def generate_until(self, requests) -> List[str]: pbar.update(1) continue + response_text = message.content[0].text res.append(message.content[0].text) pbar.update(1) ###################### CONTINUAL MODE ###################### if self.continual_mode is True: # Cache the response + response_text = message.content[0].text doc_uuid = f"{task}___{split}___{doc_id}" self.response_cache[doc_uuid] = response_text with open(self.response_persistent_file, "w") as f: diff --git a/lmms_eval/models/gemini_api.py b/lmms_eval/models/gemini_api.py index 4a43c9af..4dbc25bd 100644 --- a/lmms_eval/models/gemini_api.py +++ b/lmms_eval/models/gemini_api.py @@ -31,7 +31,7 @@ class GeminiAPI(lmms): def __init__( self, - model_version: str = "gemini-1.5-flash-latest", + model_version: str = "gemini-1.5-pro", modality: str = "image", timeout: int = 120, continual_mode: bool = False, @@ -46,6 +46,8 @@ def __init__( if self.continual_mode and response_persistent_folder is None: raise ValueError("Continual mode requires a persistent path for the response. We will cache the Gemini API response in this path and use it for future requests. Please provide a valid path.") self.response_persistent_folder = response_persistent_folder + if not os.path.exists(self.response_persistent_folder): + os.makedirs(self.response_persistent_folder) self.response_persistent_file = os.path.join(self.response_persistent_folder, f"{self.model_version}_response.json") if os.path.exists(self.response_persistent_file): diff --git a/lmms_eval/models/gpt4v.py b/lmms_eval/models/gpt4v.py index 729e73f7..7d9c5850 100755 --- a/lmms_eval/models/gpt4v.py +++ b/lmms_eval/models/gpt4v.py @@ -7,15 +7,13 @@ from tqdm import tqdm import requests as url_requests import time - +import json from lmms_eval.api.instance import Instance from lmms_eval.api.model import lmms from lmms_eval.api.registry import register_model -from lmms_eval import utils -from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs -from accelerate.state import AcceleratorState +from accelerate import Accelerator, DistributedType try: from decord import VideoReader, cpu @@ -50,8 +48,10 @@ def __init__( self, model_version: str = "gpt-4-vision-preview", modality: str = "video", - max_frames_for_video: int = 10, + max_frames_num: int = 10, timeout: int = 120, + continual_mode: bool = False, + response_persistent_folder: str = None, **kwargs, ) -> None: super().__init__() @@ -60,9 +60,25 @@ def __init__( # Here we just use the same token as llava for convenient self.model_version = model_version self.modality = modality - self.max_frames_for_video = max_frames_for_video + self.max_frames_num = max_frames_num self.image_token = "" self.timeout = timeout + self.continual_mode = continual_mode + if self.continual_mode: + if response_persistent_folder is None: + raise ValueError("Continual mode requires a persistent path for the response. Please provide a valid path.") + + os.makedirs(response_persistent_folder, exist_ok=True) + self.response_persistent_folder = response_persistent_folder + self.response_persistent_file = os.path.join(self.response_persistent_folder, f"{self.model_version}_response.json") + + if os.path.exists(self.response_persistent_file): + with open(self.response_persistent_file, "r") as f: + self.response_cache = json.load(f) + self.cache_mode = "resume" + else: + self.response_cache = {} + self.cache_mode = "start" accelerator = Accelerator() # assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue." @@ -119,9 +135,16 @@ def generate_until(self, requests) -> List[str]: pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") for contexts, gen_kwargs, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]: - # encode, pad, and truncate contexts for this batch - # visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] - visuals = [doc_to_visual(self.task_dict[task][split][0])] + if self.continual_mode is True and self.cache_mode == "resume": + doc_uuid = f"{task}___{split}___{doc_id}" + if doc_uuid in self.response_cache: + response_text = self.response_cache[doc_uuid] + if response_text: + res.append(response_text) + pbar.update(1) + continue + + visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] visuals = self.flatten(visuals) imgs = [] # multiple images or frames for video for visual in visuals: @@ -129,23 +152,26 @@ def generate_until(self, requests) -> List[str]: img = self.encode_image(visual) imgs.append(img) elif self.modality == "video": - frames = self.encode_video(visual, self.max_frames_for_video) + frames = self.encode_video(visual, self.max_frames_num) imgs.extend(frames) - payload = {"model": self.model_version, "messages": []} + payload = {"messages": []} + if API_TYPE == "openai": + payload["model"] = self.model_version + response_json = {"role": "user", "content": []} # When there is no image token in the context, append the image to the text if self.image_token not in contexts: payload["messages"].append(deepcopy(response_json)) payload["messages"][0]["content"].append({"type": "text", "text": contexts}) for img in imgs: - payload["messages"][0]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img}"}}) + payload["messages"][0]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}) else: contexts = contexts.split(self.image_token) for idx, img in enumerate(imgs): payload["messages"].append(deepcopy(response_json)) payload["messages"][idx]["content"].append({"type": "text", "text": contexts[idx]}) - payload["messages"][idx]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img}"}}) + payload["messages"][idx]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}) # If n image tokens are in the contexts # contexts will be splitted into n+1 chunks @@ -155,6 +181,8 @@ def generate_until(self, requests) -> List[str]: if "max_new_tokens" not in gen_kwargs: gen_kwargs["max_new_tokens"] = 1024 + if gen_kwargs["max_new_tokens"] > 4096: + gen_kwargs["max_new_tokens"] = 4096 if "temperature" not in gen_kwargs: gen_kwargs["temperature"] = 0 if "top_p" not in gen_kwargs: @@ -170,19 +198,30 @@ def generate_until(self, requests) -> List[str]: response = url_requests.post(API_URL, headers=headers, json=payload, timeout=self.timeout) response_data = response.json() - content = response_data["choices"][0]["message"]["content"].strip() + response_text = response_data["choices"][0]["message"]["content"].strip() break # If successful, break out of the loop except Exception as e: - eval_logger.info(f"Attempt {attempt + 1} failed with error: {str(e)}") - if attempt < 5 - 1: # If we have retries left, sleep and then continue to next attempt + try: + error_msg = response.json() + except: + error_msg = "" + + eval_logger.info(f"Attempt {attempt + 1} failed with error: {str(e)}.\nReponse: {error_msg}") + if attempt <= 5: time.sleep(NUM_SECONDS_TO_SLEEP) - else: # If this was the last attempt, log and return empty - eval_logger.error(f"All 5 attempts failed. Last error message: {str(e)}") - eval_logger.error(f"Response: {response}") - content = "" - res.append(content) + else: # If this was the last attempt, log and return empty string + eval_logger.error(f"All 5 attempts failed. Last error message: {str(e)}.\nResponse: {response.json()}") + response_text = "" + res.append(response_text) pbar.update(1) + + if self.continual_mode is True: # Cache the response + doc_uuid = f"{task}___{split}___{doc_id}" + self.response_cache[doc_uuid] = response_text + with open(self.response_persistent_file, "w") as f: + json.dump(self.response_cache, f) + pbar.close() return res diff --git a/lmms_eval/models/internvl2.py b/lmms_eval/models/internvl2.py new file mode 100644 index 00000000..5f4365d0 --- /dev/null +++ b/lmms_eval/models/internvl2.py @@ -0,0 +1,238 @@ +from typing import List, Tuple +from lmms_eval.api.instance import Instance +from decord import VideoReader, cpu +import torch +import torchvision.transforms as T +from PIL import Image +from torchvision.transforms.functional import InterpolationMode +import numpy as np +from transformers import AutoModel, AutoTokenizer +from lmms_eval.api.registry import register_model +from accelerate import Accelerator, DistributedType +from lmms_eval.api.model import lmms +from tqdm import tqdm +import logging + +eval_logger = logging.getLogger("eval_logger") + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + +DEFAULT_GEN_KWARGS = dict( + num_beams=1, + max_new_tokens=1024, + do_sample=False, +) + + +def build_transform(input_size): + MEAN, STD = IMAGENET_MEAN, IMAGENET_STD + transform = T.Compose([T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=MEAN, std=STD)]) + return transform + + +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float("inf") + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + +def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ((i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images + + +def load_image(image, input_size=448, max_num=6): + transform = build_transform(input_size=input_size) + images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) + pixel_values = [transform(image) for image in images] + pixel_values = torch.stack(pixel_values) + return pixel_values + + +def get_index(bound, fps, max_frame, first_idx=0, num_segments=32): + if bound: + start, end = bound[0], bound[1] + else: + start, end = -100000, 100000 + start_idx = max(first_idx, round(start * fps)) + end_idx = min(round(end * fps), max_frame) + seg_size = float(end_idx - start_idx) / num_segments + frame_indices = np.array([int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) for idx in range(num_segments)]) + return frame_indices + + +def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32): + vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) + max_frame = len(vr) - 1 + fps = float(vr.get_avg_fps()) + + pixel_values_list, num_patches_list = [], [] + transform = build_transform(input_size=input_size) + frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments) + for frame_index in frame_indices: + img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB") + img = dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num) + pixel_values = [transform(tile) for tile in img] + pixel_values = torch.stack(pixel_values) + num_patches_list.append(pixel_values.shape[0]) + pixel_values_list.append(pixel_values) + pixel_values = torch.cat(pixel_values_list) + return pixel_values, num_patches_list + + +from datetime import timedelta +from accelerate.state import AcceleratorState +from accelerate.utils import InitProcessGroupKwargs + + +@register_model("internvl2") +class InternVL2(lmms): + def __init__( + self, + pretrained: str = "OpenGVLab/InternVL2-2B", + modality: str = "image", + device: str = "cuda:0", + device_map: str = "cuda:0", + batch_size: str = "1", + **kwargs, + ): + super().__init__() + + self.path = pretrained + self.model = AutoModel.from_pretrained(self.path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True).eval().cuda() + self.tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True) + + batch_size = int(batch_size) + assert batch_size == 1, f"Batch size should be 1 for InternVL2, but got {batch_size}." + self.batch_size_per_gpu = batch_size + + accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) + accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) + if accelerator.num_processes > 1: + self._device = torch.device(f"cuda:{accelerator.local_process_index}") + self.device_map = f"cuda:{accelerator.local_process_index}" + elif accelerator.num_processes == 1 and device_map == "auto": + self._device = torch.device(device) + self.device_map = device_map + else: + self._device = torch.device(f"cuda:{accelerator.local_process_index}") + self.device_map = f"cuda:{accelerator.local_process_index}" + + if accelerator.num_processes > 1: + assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." + # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model + # Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works + # I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work. + if accelerator.distributed_type == DistributedType.DEEPSPEED: + kwargs = { + "train_micro_batch_size_per_gpu": self.batch_size_per_gpu, + "train_batch_size": self.batch_size_per_gpu * accelerator.num_processes, + } + AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs) + eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0") + + if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED: + self._model = accelerator.prepare(self.model) + else: + self._model = accelerator.prepare_model(self.model, evaluation_mode=True) + self.accelerator = accelerator + if self.accelerator.is_local_main_process: + eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + elif accelerator.num_processes == 1 and device_map == "auto": + eval_logger.info(f"Using {accelerator.num_processes} devices with tensor parallelism") + self._rank = 0 + self._word_size = 1 + else: + eval_logger.info(f"Using single device: {self._device}") + self.model.to(self._device) + self._rank = 0 + self._world_size = 1 + + self.device = self._device + self.modality = modality + + def flatten(self, input): + new_list = [] + for i in input: + for j in i: + new_list.append(j) + return new_list + + def generate_until(self, requests) -> List[str]: + res = [] + pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") + + for contexts, gen_kwargs, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]: + if "until" in gen_kwargs: + gen_kwargs.pop("until") + + for k, v in DEFAULT_GEN_KWARGS.items(): + if k not in gen_kwargs: + gen_kwargs[k] = v + + visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] + visuals = self.flatten(visuals) + if self.modality == "image": + visuals = [load_image(visual).to(torch.bfloat16).cuda() for visual in visuals] + pixel_values = torch.cat(visuals, dim=0) + num_patches_list = [visual.size(0) for visual in visuals] + if visuals: + image_tokens = [""] * len(visuals) + image_tokens = " ".join(image_tokens) + contexts = image_tokens + "\n" + contexts + response, history = self.model.chat(self.tokenizer, pixel_values, contexts, gen_kwargs, num_patches_list=num_patches_list, history=None, return_history=True) + + elif self.modality == "video": + assert len(visuals) == 1, f"Only one video is supported, but got {len(visuals)} videos." + video_path = visuals[0] + pixel_values, num_patches_list = load_video(video_path, num_segments=8, max_num=1) + pixel_values = pixel_values.to(torch.bfloat16).cuda() + video_prefix = "".join([f"Frame{i+1}: \n" for i in range(len(num_patches_list))]) + question = video_prefix + contexts + response, history = self.model.chat(self.tokenizer, pixel_values, question, gen_kwargs, num_patches_list=num_patches_list, history=None, return_history=True) + res.append(response) + pbar.update(1) + pbar.close() + return res + + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + assert False, "Not implemented yet." diff --git a/lmms_eval/models/llava.py b/lmms_eval/models/llava.py index 6de4c8f8..7d6420ba 100755 --- a/lmms_eval/models/llava.py +++ b/lmms_eval/models/llava.py @@ -58,6 +58,7 @@ def __init__( device_map="cuda:0", conv_template="vicuna_v1", use_cache=True, + tie_weights: bool = True, truncate_context=False, # whether to truncate the context in generation, set it False for LLaVA-1.6 customized_config=None, # ends in json **kwargs, @@ -97,7 +98,9 @@ def __init__( self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, model_name, device_map=self.device_map, **llava_model_args) self._config = self._model.config self.model.eval() - self.model.tie_weights() + if tie_weights: + self.model.tie_weights() + self.truncation = truncation self.batch_size_per_gpu = int(batch_size) self.conv_template = conv_template diff --git a/lmms_eval/models/llava_vid.py b/lmms_eval/models/llava_vid.py index 6188dd95..14cd7e61 100755 --- a/lmms_eval/models/llava_vid.py +++ b/lmms_eval/models/llava_vid.py @@ -59,6 +59,8 @@ def __init__( mm_spatial_pool_mode: str = "average", overwrite: bool = True, video_decode_backend: str = "pyav", + delay_load: bool = False, + tie_weights: bool = True, **kwargs, ) -> None: super().__init__() @@ -86,15 +88,19 @@ def __init__( self.mm_spatial_pool_out_channels = int(mm_spatial_pool_out_channels) self.mm_spatial_pool_mode = mm_spatial_pool_mode self.max_frames_num = int(max_frames_num) + self.mm_resampler_location = mm_resampler_location + self.delay_load = delay_load if self.overwrite == True: overwrite_config = {} overwrite_config["mm_resampler_type"] = self.mm_resampler_type overwrite_config["mm_spatial_pool_stride"] = self.mm_spatial_pool_stride overwrite_config["mm_spatial_pool_out_channels"] = self.mm_spatial_pool_out_channels overwrite_config["mm_spatial_pool_mode"] = self.mm_spatial_pool_mode - overwrite_config["mm_resampler_location"] = "before" - overwrite_config["patchify_video_feature"] = False - overwrite_config["attn_implementation"] = attn_implementation + overwrite_config["mm_pooling_position"] = self.mm_resampler_location + overwrite_config["mm_newline_position"] = mm_newline_position + overwrite_config["add_faster_video"] = False + overwrite_config["delay_load"] = self.delay_load + # overwrite_config["attn_implementation"] = attn_implementation cfg_pretrained = AutoConfig.from_pretrained(self.pretrained) @@ -145,7 +151,8 @@ def __init__( self._config = self._model.config self.model.eval() - self.model.tie_weights() + if tie_weights: + self.model.tie_weights() self.truncation = truncation self.batch_size_per_gpu = int(batch_size) self.conv_template = conv_template diff --git a/lmms_eval/models/longva.py b/lmms_eval/models/longva.py index 7202c49c..c5bf6861 100644 --- a/lmms_eval/models/longva.py +++ b/lmms_eval/models/longva.py @@ -50,7 +50,6 @@ @register_model("longva") class LongVA(lmms): - def __init__( self, pretrained: str = "lmms-lab/LongVA-7B", @@ -442,7 +441,7 @@ def _collate(x): # These steps are not in LLaVA's original code, but are necessary for generation to work # TODO: attention to this major generation step... if "image_aspect_ratio" in gen_kwargs.keys(): - gen_kwargs.pop("image_aspect_ratio") + gen_kwargs.pop("image_aspect_ratio") try: with torch.inference_mode(): cont = self.model.generate(input_ids, attention_mask=attention_masks, pad_token_id=pad_token_ids, images=image_tensor, use_cache=self.use_cache, **gen_kwargs) @@ -459,4 +458,4 @@ def _collate(x): res = re_ords.get_original(res) pbar.close() - return res \ No newline at end of file + return res diff --git a/lmms_eval/models/model_utils/load_video.py b/lmms_eval/models/model_utils/load_video.py index 789039e7..dbb3cf6f 100644 --- a/lmms_eval/models/model_utils/load_video.py +++ b/lmms_eval/models/model_utils/load_video.py @@ -29,7 +29,8 @@ def record_video_length_packet(container): def read_video_pyav(video_path, num_frm=8): - + container = av.open(video_path) + if "webm" not in video_path and "mkv" not in video_path: # For mp4, we try loading with stream first try: diff --git a/lmms_eval/models/phi3v.py b/lmms_eval/models/phi3v.py index c30a7081..ab1e838d 100644 --- a/lmms_eval/models/phi3v.py +++ b/lmms_eval/models/phi3v.py @@ -1,6 +1,5 @@ import torch - from accelerate import Accelerator, DistributedType from lmms_eval import utils from lmms_eval.api.instance import Instance diff --git a/lmms_eval/models/reka.py b/lmms_eval/models/reka.py index bc461cad..ee1b9c67 100644 --- a/lmms_eval/models/reka.py +++ b/lmms_eval/models/reka.py @@ -36,7 +36,7 @@ def __init__( self, model_version: str = "reka-edge", modality: str = "image", - max_frames_for_video: int = 10, + max_frames_num: int = 5, timeout: int = 120, continual_mode: bool = False, response_persistent_folder: str = None, # We will cache the Gemini API response in this path and use it for future requests @@ -45,21 +45,24 @@ def __init__( super().__init__() self.model_version = model_version self.modality = modality - self.max_frames_for_video = max_frames_for_video + self.max_frames_num = max_frames_num self.timeout = timeout self.continual_mode = continual_mode - if self.continual_mode and response_persistent_folder is None: - raise ValueError("Continual mode requires a persistent path for the response. Please provide a valid path.") - self.response_persistent_folder = response_persistent_folder - self.response_persistent_file = os.path.join(self.response_persistent_folder, f"{self.model_version}_response.json") - - if os.path.exists(self.response_persistent_file): - with open(self.response_persistent_file, "r") as f: - self.response_cache = json.load(f) - self.cache_mode = "resume" - else: - self.response_cache = {} - self.cache_mode = "start" + if self.continual_mode: + if response_persistent_folder is None: + raise ValueError("Continual mode requires a persistent path for the response. Please provide a valid path.") + + os.makedirs(response_persistent_folder, exist_ok=True) + self.response_persistent_folder = response_persistent_folder + self.response_persistent_file = os.path.join(self.response_persistent_folder, f"{self.model_version}_response.json") + + if os.path.exists(self.response_persistent_file): + with open(self.response_persistent_file, "r") as f: + self.response_cache = json.load(f) + self.cache_mode = "resume" + else: + self.response_cache = {} + self.cache_mode = "start" self.reka = RekaClient(api_key=os.getenv("REKA_API_KEY", "YOUR_API_KEY")) @@ -99,7 +102,7 @@ def encode_image(self, image): def encode_video(self, video_path): vr = VideoReader(video_path, ctx=cpu(0)) total_frame_num = len(vr) - uniform_sampled_frames = np.linspace(0, total_frame_num - 1, self.max_frames_for_video, dtype=int) + uniform_sampled_frames = np.linspace(0, total_frame_num - 1, self.max_frames_num, dtype=int) frame_idx = uniform_sampled_frames.tolist() frames = vr.get_batch(frame_idx).asnumpy() @@ -141,7 +144,7 @@ def generate_until(self, requests) -> List[str]: message_content.append({"type": "text", "text": context}) assert len(visual) == 1, "Reka only supports one video per request" media_urls = self.encode_video(visual[0]) - assert len(media_urls) == self.max_frames_for_video, f"Reka only supports {self.max_frames_for_video} frames per request" + assert len(media_urls) == self.max_frames_num, f"Reka only supports {self.max_frames_num} frames per request" for media_url in media_urls: message_content.append({"type": "image_url", "image_url": media_url}) diff --git a/lmms_eval/models/tinyllava.py b/lmms_eval/models/tinyllava.py index e07c47b8..a4335f05 100755 --- a/lmms_eval/models/tinyllava.py +++ b/lmms_eval/models/tinyllava.py @@ -2,7 +2,6 @@ torch.backends.cuda.matmul.allow_tf32 = True - import copy from tqdm import tqdm from datetime import timedelta diff --git a/lmms_eval/models/video_chatgpt/model/video_chatgpt.py b/lmms_eval/models/video_chatgpt/model/video_chatgpt.py index df6fee4f..bded27e7 100644 --- a/lmms_eval/models/video_chatgpt/model/video_chatgpt.py +++ b/lmms_eval/models/video_chatgpt/model/video_chatgpt.py @@ -76,7 +76,6 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if (input_ids.shape[1] != 1 or self.training) and video_spatio_temporal_features is not None: - video_features = self.mm_projector(video_spatio_temporal_features) dummy_video_features = torch.zeros(video_features.shape[1], 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype) dummy_video_features = self.mm_projector(dummy_video_features) diff --git a/lmms_eval/models/vila.py b/lmms_eval/models/vila.py new file mode 100755 index 00000000..2adff620 --- /dev/null +++ b/lmms_eval/models/vila.py @@ -0,0 +1,376 @@ +import argparse +import torch +import os +import json +from tqdm import tqdm +import logging +from typing import List, Optional, Union, Tuple +from PIL import Image +import math +import numpy as np +from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs +from accelerate.state import AcceleratorState +from datetime import timedelta +from decord import VideoReader, cpu + + +from torchvision.transforms import Resize + +import signal + +from lmms_eval.api.instance import Instance +from lmms_eval.api.model import lmms +from lmms_eval.api.registry import register_model + +eval_logger = logging.getLogger("lmms-eval") +# import sys;sys.path.append("llava-video") +try: + from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN + from llava.conversation import conv_templates, SeparatorStyle + from llava.model.builder import load_pretrained_model + from llava.data.dataset import LazySupervisedDataset + from llava.utils import disable_torch_init + from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria + from llava.mm_utils import process_images +except ImportError as e: + print(e) + + eval_logger.debug("VILA is not installed. Please install VILA to use this model. Error: {e}") + + +@register_model("vila") +class VILA(lmms): + """ + VILA Model + """ + + def __init__( + self, + pretrained: str = "Efficient-Large-Model/VILA1.5-40b", + max_frames_num: Optional[int] = 100, + truncation: Optional[bool] = True, + device: Optional[str] = "cuda:0", + batch_size: Optional[Union[int, str]] = 1, + attn_implementation=( + "sdpa" if torch.__version__ >= "2.1.2" else "eager" + ), # inference implementation for attention, can be "sdpa", "eager", "flash_attention_2". Seems FA2 is not effective during inference: https://discuss.huggingface.co/t/flash-attention-has-no-effect-on-inference/73453/5 + device_map="cuda:0", + conv_template="hermes-2", + use_cache=True, + truncate_context=False, # whether to truncate the context in generation, set it False for LLaVA-1.6 + video_decode_backend="decord", + **kwargs, + ) -> None: + super().__init__() + assert kwargs == {}, f"Unexpected kwargs: {kwargs}" + + accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) + accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) + if accelerator.num_processes > 1: + self._device = torch.device(f"cuda:{accelerator.local_process_index}") + self.device_map = f"cuda:{accelerator.local_process_index}" + elif accelerator.num_processes == 1 and device_map == "auto": + self._device = torch.device(device) + self.device_map = device_map + else: + self._device = torch.device(f"cuda:{accelerator.local_process_index}") + self.device_map = f"cuda:{accelerator.local_process_index}" + + self.pretrained = pretrained + self.model_name = get_model_name_from_path(pretrained) + self.max_frames_num = max_frames_num + # self._config = AutoConfig.from_pretrained(self.pretrained) + + # import pdb; pdb.set_trace() + self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, self.model_name, device_map=self.device_map, attn_implementation=attn_implementation) + + self.model.image_processor = self._image_processor + + self._config = self._model.config + + if self._tokenizer.pad_token_id is None: + if "qwen" in self._tokenizer.name_or_path.lower(): + print("Setting pad token to bos token for qwen model.") + self._tokenizer.pad_token_id = 151643 + + self.video_decode_backend = video_decode_backend + self.model.eval() + # self.model.tie_weights() + self.truncation = truncation + self.batch_size_per_gpu = int(batch_size) + self.conv_template = conv_template + self.use_cache = use_cache + self.truncate_context = truncate_context + # assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue." + if accelerator.num_processes > 1: + assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." + # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model + # Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works + # I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work. + if accelerator.distributed_type == DistributedType.DEEPSPEED: + kwargs = { + "train_micro_batch_size_per_gpu": self.batch_size_per_gpu, + "train_batch_size": self.batch_size_per_gpu * accelerator.num_processes, + } + AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs) + eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0") + if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED: + self._model = accelerator.prepare(self.model) + else: + self._model = accelerator.prepare_model(self.model, evaluation_mode=True) + self.accelerator = accelerator + if self.accelerator.is_local_main_process: + eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + elif accelerator.num_processes == 1 and device_map == "auto": + eval_logger.info(f"Using {accelerator.num_processes} devices with tensor parallelism") + self._rank = 0 + self._word_size = 1 + else: + eval_logger.info(f"Using single device: {self._device}") + self.model.to(self._device) + self._rank = 0 + self._world_size = 1 + + @property + def config(self): + # return the associated transformers.AutoConfig for the given pretrained model. + return self._config + + @property + def tokenizer(self): + return self._tokenizer + + @property + def model(self): + # returns the model, unwrapping it if using Accelerate + if hasattr(self, "accelerator"): + return self.accelerator.unwrap_model(self._model) + else: + return self._model + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self.tokenizer.eos_token_id + + @property + def max_length(self): + return self._max_length + + def pad_sequence(self, input_ids, batch_first, padding_value): + if self.tokenizer.padding_side == "left": + input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids] + input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value) + if self.tokenizer.padding_side == "left": + input_ids = torch.flip(input_ids, [1]) + return input_ids + + @property + def batch_size(self): + return self.batch_size_per_gpu + + @property + def device(self): + return self._device + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]: + """ """ + add_special_tokens = False if add_special_tokens is None else add_special_tokens + encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens) + # left-truncate the encoded context to be at most `left_truncate_len` tokens long + if left_truncate_len: + encoding = encoding[-left_truncate_len:] + return encoding + + def load_video(self, video_path, max_frames_num): + try: + vr = VideoReader(video_path, ctx=cpu(0)) + total_frame_num = len(vr) + fps = round(vr.get_avg_fps()) + frame_idx = np.linspace(0, total_frame_num - 2, max_frames_num, dtype=int) + spare_frames = vr.get_batch(frame_idx).asnumpy() + return [Image.fromarray(img) for img in spare_frames] + except Exception as e: + eval_logger.error(f"Failed to load video {video_path} with error: {e}") + + return [Image.new("RGB", (448, 448), (0, 0, 0))] * max_frames_num + + def tok_decode(self, tokens): + return self.tokenizer.decode(tokens) + + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + res = [] + pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") + + for contexts, doc_to_target, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]: + # encode, pad, and truncate contexts for this batch + if type(doc_to_target) == str: + continuation = doc_to_target + else: + continuation = doc_to_target(self.task_dict[task][split][doc_id]) + visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] + visuals = self.flatten(visuals) + videos = [] + for visual in visuals: + video = self.load_video(visual, self.max_frames_num) + video = self._image_processor.preprocess(video, return_tensors="pt")["pixel_values"].half().cuda() + videos.append(video) + + qs = contexts + if self.model.config.mm_use_im_start_end: + qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + qs + else: + qs = DEFAULT_IMAGE_TOKEN + "\n" + qs + + conv = conv_templates[self.conv_template].copy() + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + contxt_id = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device) + + conv = conv_templates[self.conv_template].copy() + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], continuation) + prompt = conv.get_prompt() + + input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda() + attention_masks = input_ids.ne(self.tokenizer.pad_token_id).long().cuda() + + labels = input_ids.clone() + # Context part no need to calculate for loss + labels[0, : contxt_id.shape[1]] = -100 + + with torch.inference_mode(): + outputs = self.model(input_ids=input_ids, labels=labels, images=videos, modalities="video") + + loss = outputs["loss"] + # loss = torch.exp(loss) + logits = outputs["logits"] + greedy_tokens = logits.argmax(dim=-1) + cont_toks = input_ids[:, contxt_id.shape[1] :] # [1, seq] + greedy_tokens = greedy_tokens[:, contxt_id.shape[1] : input_ids.shape[1]] # [1, seq] + max_equal = (greedy_tokens == cont_toks).all() + res.append((float(loss.item()), bool(max_equal))) + pbar.update(1) + pbar.close() + return res + + def flatten(self, input): + new_list = [] + for i in input: + for j in i: + new_list.append(j) + return new_list + + def generate_until(self, requests) -> List[str]: + res = [] + pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") + + for contexts, gen_kwargs, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]: + # if self.task_dict[task][split][doc_id]["duration"] != "short": + # + # res.append("A") + # pbar.update(1) + # continue + # encode, pad, and truncate contexts for this batch + visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] + visuals = self.flatten(visuals) + + num_video_frames = self.model.config.num_video_frames + videos = [] + + if self.max_frames_num == 0: + images = [Image.new("RGB", (448, 448), (0, 0, 0))] * num_video_frames + video = process_images(images, self.model.image_processor, self.model.config).half().cuda() + videos.append(video) + else: + for visual in visuals: + # images, video_loading_succeed = LazySupervisedDataset._load_video(visual, num_video_frames, self.model) + + if self.video_decode_backend == "decord": + images = self.load_video(visual, num_video_frames) + elif self.video_decode_backend == "pyav": + images = read_video_pyav(visual, num_frm=num_video_frames) + + video = process_images(images, self.model.image_processor, self.model.config).half().cuda() + videos.append(video) + + qs = f"