diff --git a/lmms_eval/models/internvl.py b/lmms_eval/models/internvl.py index 6238d7fcf..d5b668ce0 100644 --- a/lmms_eval/models/internvl.py +++ b/lmms_eval/models/internvl.py @@ -449,7 +449,10 @@ def _collate(x): split = split[0] batched_visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] # [B, N] flattened_visuals = self.flatten(batched_visuals) - pixel_values = self.load_image(flattened_visuals, self.image_size).cuda().to(torch.bfloat16) + try: + pixel_values = self.load_image(flattened_visuals, self.image_size).cuda().to(torch.bfloat16) + except IndexError: + pixel_values = None gen_kwargs = all_gen_kwargs[0] if "max_new_tokens" not in gen_kwargs: diff --git a/lmms_eval/models/internvl2.py b/lmms_eval/models/internvl2.py index 509df5838..478fd49cd 100644 --- a/lmms_eval/models/internvl2.py +++ b/lmms_eval/models/internvl2.py @@ -213,13 +213,16 @@ def generate_until(self, requests) -> List[str]: visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] visuals = self.flatten(visuals) if self.modality == "image": - visuals = [load_image(visual).to(torch.bfloat16).cuda() for visual in visuals] - pixel_values = torch.cat(visuals, dim=0) - num_patches_list = [visual.size(0) for visual in visuals] if visuals: + visuals = [load_image(visual).to(torch.bfloat16).cuda() for visual in visuals] + pixel_values = torch.cat(visuals, dim=0) + num_patches_list = [visual.size(0) for visual in visuals] image_tokens = [""] * len(visuals) image_tokens = " ".join(image_tokens) contexts = image_tokens + "\n" + contexts + else: + pixel_values = None + num_patch_list = None response, history = self.model.chat(self.tokenizer, pixel_values, contexts, gen_kwargs, num_patches_list=num_patches_list, history=None, return_history=True) elif self.modality == "video": diff --git a/lmms_eval/tasks/scienceqa/utils.py b/lmms_eval/tasks/scienceqa/utils.py index eed6b26ab..128ebf562 100755 --- a/lmms_eval/tasks/scienceqa/utils.py +++ b/lmms_eval/tasks/scienceqa/utils.py @@ -33,8 +33,8 @@ def sqa_doc_to_target(doc): def sqa_process_results(doc, results): # I know this is weird, but it's how llava parse it. - target = sqa_doc_to_target(doc) - pred = results[0] + target = sqa_doc_to_target(doc).strip().lower() + pred = results[0].strip().lower() if pred == target: return {"exact_match": 1.0} # pattern: ^[A-Z]\. .*