Skip to content

Commit

Permalink
Oryx & Vila
Browse files Browse the repository at this point in the history
  • Loading branch information
pufanyi committed Dec 27, 2024
1 parent 792cc5b commit 11ad1b1
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 4 deletions.
23 changes: 22 additions & 1 deletion lmms_eval/models/oryx.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@
import os
import sys

os.environ["LOWRES_RESIZE"] = "384x32"
os.environ["VIDEO_RESIZE"] = "0x64"
os.environ["HIGHRES_BASE"] = "0x32"
os.environ["MAXRES"] = "1536"
os.environ["MINRES"] = "0"
os.environ["VIDEO_MAXRES"] = "480"
os.environ["VIDEO_MINRES"] = "288"

try:
from oryx.constants import (
DEFAULT_IM_END_TOKEN,
Expand All @@ -38,7 +46,6 @@
tokenizer_image_token,
)
from oryx.model.builder import load_pretrained_model
from oryx.model.language_model.oryx_llama import OryxConfig
except ImportError:
eval_logger.debug("Oryx is not installed. Please install Oryx to use this model.")

Expand Down Expand Up @@ -67,6 +74,9 @@ def __init__(
truncate_context=False,
max_frames_num: int = 32,
mm_resampler_type: str = "spatial_pool",
mm_spatial_pool_stride: int = 2,
mm_spatial_pool_out_channels: int = 1024,
mm_spatial_pool_mode: str = "average",
overwrite: bool = True,
video_decode_backend: str = "decord",
**kwargs,
Expand Down Expand Up @@ -98,6 +108,9 @@ def __init__(
overwrite_config["mm_resampler_type"] = self.mm_resampler_type
overwrite_config["patchify_video_feature"] = False
overwrite_config["attn_implementation"] = attn_implementation
overwrite_config["mm_spatial_pool_stride"] = mm_spatial_pool_stride
overwrite_config["mm_spatial_pool_out_channels"] = mm_spatial_pool_out_channels
overwrite_config["mm_spatial_pool_mode"] = mm_spatial_pool_mode

cfg_pretrained = AutoConfig.from_pretrained(self.pretrained)

Expand Down Expand Up @@ -429,6 +442,8 @@ def generate_until(self, requests) -> List[str]:
gen_kwargs["num_beams"] = 1

try:
print("Videos")
print(videos)
with torch.inference_mode():
if task_type == "video":
output_ids = self.model.generate(
Expand Down Expand Up @@ -466,10 +481,16 @@ def generate_until(self, requests) -> List[str]:
res.append(outputs)
pbar.update(1)
except Exception as e:
import traceback

traceback.print_exc()
eval_logger.info(f"{e}")
eval_logger.info(f"Video {visuals} generate failed, check the source")
video_path = "\n".join(visuals)
res.append(f"Video {video_path} generate failed, check the source")
pbar.update(1)
continue
return res

def generate_until_multi_round(self, requests) -> List[str]:
raise NotImplementedError("TODO: Implement multi-round generation")
25 changes: 22 additions & 3 deletions lmms_eval/models/vila.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import argparse
import copy
import json
import logging
import math
import os
import signal
from datetime import timedelta
from typing import List, Optional, Tuple, Union

Expand Down Expand Up @@ -38,7 +38,24 @@
)
from llava.model.builder import load_pretrained_model
except ImportError as e:
eval_logger.debug(f"VILA is not installed. Please install VILA to use this model. Error: {e}")
eval_logger.debug(f"VILA is not installed. Please install VILA to use this model. Error: {e}. You need to make sure the newest repo is for NVILA, so you need to rollback to the commit before the NVILA update.")

CONVERSATION_MODE_MAPPING = {
"vila1.5-3b": "vicuna_v1",
"vila1.5-8b": "llama_3",
"vila1.5-13b": "vicuna_v1",
"vila1.5-40b": "hermes-2",
"llama-3": "llama_3",
"llama3": "llama_3",
}


def auto_set_conversation_mode(model_name_or_path: str) -> str:
for k, v in CONVERSATION_MODE_MAPPING.items():
if k in model_name_or_path.lower():
print(f"Setting conversation mode to `{v}` based on model name/path `{model_name_or_path}`.")
return v
raise ValueError(f"Could not determine conversation mode for model name/path `{model_name_or_path}`.")


@register_model("vila")
Expand All @@ -58,7 +75,7 @@ def __init__(
"sdpa" if torch.__version__ >= "2.1.2" else "eager"
), # inference implementation for attention, can be "sdpa", "eager", "flash_attention_2". Seems FA2 is not effective during inference: https://discuss.huggingface.co/t/flash-attention-has-no-effect-on-inference/73453/5
device_map="cuda:0",
conv_template="hermes-2",
conv_template="auto",
use_cache=True,
truncate_context=False, # whether to truncate the context in generation, set it False for LLaVA-1.6
video_decode_backend="decord",
Expand Down Expand Up @@ -101,6 +118,8 @@ def __init__(
self.truncation = truncation
self.batch_size_per_gpu = int(batch_size)
self.conv_template = conv_template
if self.conv_template == "auto":
self.conv_template = auto_set_conversation_mode(self.model_name)
self.use_cache = use_cache
self.truncate_context = truncate_context
# assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue."
Expand Down

0 comments on commit 11ad1b1

Please sign in to comment.