Skip to content

Commit

Permalink
Merge pull request #138 from EvolvingLMMs-Lab/internal_main_dev
Browse files Browse the repository at this point in the history
[Sync Features] add vila, add wildvision, add vibe-eval, add interleave bench
  • Loading branch information
Luodian authored Jul 13, 2024
2 parents a60e4e0 + e31cd78 commit c65118d
Show file tree
Hide file tree
Showing 70 changed files with 3,308 additions and 486 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ temp
__pycache__
.ipynb_checkpoints
temp
.DS_STORE
# IPython
profile_default/
ipython_config.py
Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

## Annoucement

- [2024-06] 🎬🎬 The `lmms-eval/v0.2` has been upgraded to support video evaluations for video models like LLaVA-NeXT Video and Gemini 1.5 Pro across tasks such as EgoSchema, PerceptionTest, VideoMME, and more. Please refer to the [blog](https://lmms-lab.github.io/posts/lmms-eval-0.2/) for more details
- [2024-07] 👨‍💻👨‍💻 The `lmms-eval/v0.2.1` has been upgraded to support more models, including [LongVA](https://github.com/EvolvingLMMs-Lab/LongVA), [InterVL-2](https://github.com/OpenGVLab/InternVL), [VILA](https://github.com/NVlabs/VILA), and many more evaluation tasks, e.g. [Details Captions](https://github.com/EvolvingLMMs-Lab/lmms-eval/pull/136), [MLVU](https://arxiv.org/abs/2406.04264), [WildVision-Bench](https://huggingface.co/datasets/WildVision/wildvision-arena-data), [VITATECS](https://github.com/lscpku/VITATECS) and [LLaVA-Interleave-Bench](https://llava-vl.github.io/blog/2024-06-16-llava-next-interleave/).

- [2024-06] 🎬🎬 The `lmms-eval/v0.2.0` has been upgraded to support video evaluations for video models like LLaVA-NeXT Video and Gemini 1.5 Pro across tasks such as EgoSchema, PerceptionTest, VideoMME, and more. Please refer to the [blog](https://lmms-lab.github.io/posts/lmms-eval-0.2/) for more details

- [2024-03] 📝📝 We have released the first version of `lmms-eval`, please refer to the [blog](https://lmms-lab.github.io/posts/lmms-eval-0.1/) for more details

Expand Down
1 change: 0 additions & 1 deletion lmms_eval/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
# reset logger
eval_logger.remove()
eval_logger.add(sys.stdout, colorize=True, level=args.verbosity)
eval_logger.add(sys.stderr, level=args.verbosity)
eval_logger.info(f"Verbosity set to {args.verbosity}")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

Expand Down
4 changes: 3 additions & 1 deletion lmms_eval/api/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def get_context(self, doc, num_fewshot):
+ (
str(self.doc_to_target(doc)[0])
if type(self.doc_to_target(doc)) is list
else self.doc_to_target(doc) if (self.config.doc_to_choice is None or type(self.doc_to_target(doc)) is str) else str(self.doc_to_choice(doc)[self.doc_to_target(doc)])
else self.doc_to_target(doc)
if (self.config.doc_to_choice is None or type(self.doc_to_target(doc)) is str)
else str(self.doc_to_choice(doc)[self.doc_to_target(doc)])
)
for doc in selected_docs
]
Expand Down
7 changes: 1 addition & 6 deletions lmms_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,12 +327,7 @@ def evaluate(
# hack: remove image columns to speed avoid loading images and speed up postprocessing
# reason: doc_iterator will actually load image if it's in the doc.
docs = task.test_docs() if task.has_test_docs() else task.validation_docs()
if "d170" not in task_name \
and "dc100" not in task_name \
and "dc200" not in task_name \
and "llava_wilder" not in task_name \
and "livebench" not in task_name \
and "wildvision" not in task_name:
if "d170" not in task_name and "dc100" not in task_name and "dc200" not in task_name and "llava_wilder" not in task_name and "livebench" not in task_name and "wildvision" not in task_name:
remove_cols = []
features = docs.features
# If it is an Image instance or a Sequence of Image instance. Remove it
Expand Down
17 changes: 10 additions & 7 deletions lmms_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from loguru import logger
import sys

import hf_transfer

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

logger.remove()
logger.add(sys.stdout, level="WARNING")

Expand All @@ -25,6 +29,7 @@
"llava_sglang": "LlavaSglang",
"idefics2": "Idefics2",
"internvl": "InternVLChat",
"internvl2": "InternVL2",
"gemini_api": "GeminiAPI",
"reka": "Reka",
"from_log": "FromLog",
Expand All @@ -33,14 +38,16 @@
"tinyllava": "TinyLlava",
"llava_hf": "LlavaHf",
"longva": "LongVA",
"llava_hf": "LlavaHf",
"longva": "LongVA",
"vila": "VILA",
}

for model_name, model_class in AVAILABLE_MODELS.items():
try:
exec(f"from .{model_name} import {model_class}")
except ImportError as e:
# logger.warning(f"Failed to import {model_class} from {model_name}: {e}")
pass
logger.warning(f"Failed to import {model_class} from {model_name}: {e}")

if os.environ.get("LMMS_EVAL_PLUGINS", None):
# Allow specifying other packages to import models from
Expand All @@ -50,8 +57,4 @@
try:
exec(f"from {plugin}.models.{model_name} import {model_class}")
except ImportError:
pass

import hf_transfer

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
logger.warning(f"Failed to import {model_class} from {model_name}")
6 changes: 3 additions & 3 deletions lmms_eval/models/batch_gpt4.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
api_key: str = API_KEY,
api_url: str = API_URL,
modality: str = "image",
max_frames_for_video: int = 10,
max_frames_num: int = 10,
timeout: int = 120,
**kwargs,
) -> None:
Expand All @@ -69,7 +69,7 @@ def __init__(
# Here we just use the same token as llava for convenient
self.model_version = model_version
self.modality = modality
self.max_frames_for_video = max_frames_for_video
self.max_frames_num = max_frames_num
self.image_token = "<image>"
self.timeout = timeout

Expand Down Expand Up @@ -128,7 +128,7 @@ def generate_until(self, requests):
img = self.encode_image(visual)
imgs.append(img)
elif self.modality == "video":
frames = self.encode_video(visual, self.max_frames_for_video)
frames = self.encode_video(visual, self.max_frames_num)
imgs.extend(frames)

messages = []
Expand Down
47 changes: 28 additions & 19 deletions lmms_eval/models/claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
image_token: str = "<image>", # Use to separate interleaved image and text
system_prompt: str = "", # Whether you want some special system prompt here
modality: str = "image",
max_frames_num: int = 10,
continual_mode: bool = False,
response_persistent_folder: str = None,
**kwargs,
Expand All @@ -49,20 +50,24 @@ def __init__(
self.image_token = image_token
self.system_prompt = system_prompt
self.modality = modality
self.max_frames_num = max_frames_num

self.continual_mode = continual_mode
if self.continual_mode and response_persistent_folder is None:
raise ValueError("Continual mode requires a persistent path for the response. Please provide a valid path.")
self.response_persistent_folder = response_persistent_folder
self.response_persistent_file = os.path.join(self.response_persistent_folder, f"{self.model_version}_response.json")

if os.path.exists(self.response_persistent_file):
with open(self.response_persistent_file, "r") as f:
self.response_cache = json.load(f)
self.cache_mode = "resume"
else:
self.response_cache = {}
self.cache_mode = "start"
if self.continual_mode:
if response_persistent_folder is None:
raise ValueError("Continual mode requires a persistent path for the response. Please provide a valid path.")

os.makedirs(response_persistent_folder, exist_ok=True)
self.response_persistent_folder = response_persistent_folder
self.response_persistent_file = os.path.join(self.response_persistent_folder, f"{self.model_version}_response.json")

if os.path.exists(self.response_persistent_file):
with open(self.response_persistent_file, "r") as f:
self.response_cache = json.load(f)
self.cache_mode = "resume"
else:
self.response_cache = {}
self.cache_mode = "start"

accelerator = Accelerator()
if accelerator.num_processes > 1:
Expand All @@ -81,7 +86,7 @@ def __init__(

def encode_image(self, image):
output_buffer = BytesIO()
image.save(output_buffer, format="PNG")
image.save(output_buffer, format="JPEG")
byte_data = output_buffer.getvalue()
base64_str = base64.b64encode(byte_data).decode("utf-8")
return base64_str
Expand Down Expand Up @@ -129,18 +134,18 @@ def shrink_image_to_file_size(self, img: Image, max_file_size=4838990) -> Image:
def encode_video(self, video_path):
vr = VideoReader(video_path, ctx=cpu(0))
total_frame_num = len(vr)
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, self.max_frames_for_video, dtype=int)
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, self.max_frames_num, dtype=int)
frame_idx = uniform_sampled_frames.tolist()
frames = vr.get_batch(frame_idx).asnumpy()

base64_frames = []
for frame in frames:
img = Image.fromarray(frame)
output_buffer = BytesIO()
img.save(output_buffer, format="PNG")
img.save(output_buffer, format="JPEG")
byte_data = output_buffer.getvalue()
base64_str = base64.b64encode(byte_data).decode("utf-8")
base64_frames.append(f"data:image/jpeg;base64,{base64_str}")
base64_frames.append(f"{base64_str}")

return base64_frames

Expand All @@ -154,7 +159,7 @@ def generate_until(self, requests) -> List[str]:
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"media_type": "image/jpeg",
},
}
empty_text_block = {"type": "text"}
Expand Down Expand Up @@ -218,10 +223,12 @@ def generate_until(self, requests) -> List[str]:

if "max_new_tokens" not in gen_kwargs:
gen_kwargs["max_new_tokens"] = 1024
if gen_kwargs["max_new_tokens"] > 4096:
gen_kwargs["max_new_tokens"] = 4096
if "temperature" not in gen_kwargs:
gen_kwargs["temperature"] = 0
if "top_p" not in gen_kwargs:
gen_kwargs["top_p"] = None
if "top_p" not in gen_kwargs or gen_kwargs["top_p"] is None:
gen_kwargs["top_p"] = 1
if "num_beams" not in gen_kwargs:
gen_kwargs["num_beams"] = 1

Expand All @@ -238,11 +245,13 @@ def generate_until(self, requests) -> List[str]:
pbar.update(1)
continue

response_text = message.content[0].text
res.append(message.content[0].text)
pbar.update(1)

###################### CONTINUAL MODE ######################
if self.continual_mode is True: # Cache the response
response_text = message.content[0].text
doc_uuid = f"{task}___{split}___{doc_id}"
self.response_cache[doc_uuid] = response_text
with open(self.response_persistent_file, "w") as f:
Expand Down
4 changes: 3 additions & 1 deletion lmms_eval/models/gemini_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
class GeminiAPI(lmms):
def __init__(
self,
model_version: str = "gemini-1.5-flash-latest",
model_version: str = "gemini-1.5-pro",
modality: str = "image",
timeout: int = 120,
continual_mode: bool = False,
Expand All @@ -46,6 +46,8 @@ def __init__(
if self.continual_mode and response_persistent_folder is None:
raise ValueError("Continual mode requires a persistent path for the response. We will cache the Gemini API response in this path and use it for future requests. Please provide a valid path.")
self.response_persistent_folder = response_persistent_folder
if not os.path.exists(self.response_persistent_folder):
os.makedirs(self.response_persistent_folder)
self.response_persistent_file = os.path.join(self.response_persistent_folder, f"{self.model_version}_response.json")

if os.path.exists(self.response_persistent_file):
Expand Down
Loading

0 comments on commit c65118d

Please sign in to comment.