diff --git a/LICENSE b/LICENSE
index 206be3eb..37e142af 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,21 +1,14 @@
-MIT License
+MIT License for Non-Commercial Use
-Copyright (c) 2023 Anas Awadalla, Irena Gao, Joshua Gardner, Jack Hessel, Yusuf Hanafy, Wanrong Zhu, Kalyani Marathe, Yonatan Bitton, Samir Gadre, Jenia Jitsev, Simon Kornblith, Pang Wei Koh, Gabriel Ilharco, Mitchell Wortsman, Ludwig Schmidt.
+Copyright (c) 2023 Bo Li, Yuanhan Zhang, Liangyu Chen, Jinghao Wang, Fanyi Pu, Jingkang Yang, Chunyuan Li, Ziwei Liu
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
+S-Lab, Nanyang Technological University
+Microsoft Research, Redmond
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
+Permission is hereby granted, free of charge, to any person obtaining a copy of the Otter model and MIMIC-IT Dataset (the "Software"), to use, copy, modify, merge, and distribute copies of the Software, subject to the following conditions:
+1. The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
+2. The Software may not be used for commercial purposes. For the purposes of this license, commercial use includes, but is not limited to: integration of the Software into a product or service that generates revenue, incorporation of the Software into a commercial offering, or using the Software in the course of performing services for which payment is received.
+3. Redistributions of the Software must retain the above copyright notice, this list of conditions, and the following disclaimer.
+4. Neither the names of the copyright holders nor the names of its contributors may be used to endorse or promote products derived from this Software without specific prior written permission.
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT, OR OTHERWISE, ARISING FROM, OUT OF, OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
\ No newline at end of file
diff --git a/README.md b/README.md
index a2e7d44c..b0b84f5f 100644
--- a/README.md
+++ b/README.md
@@ -5,20 +5,25 @@
-
Bo Li*1
-
Yuanhan Zhang*,1
-
Liangyu Chen*,1
-
Jinghao Wang*,1
-
Fanyi Pu*,1
+
Bo Li*,β ,1
+
Yuanhan Zhang*,β ,1
+
Liangyu Chen*,1
+
Jinghao Wang*,1
+
Fanyi Pu*,1
Jingkang Yang1
Chunyuan Li2
-
Ziwei Liu1
+
Ziwei Liu✉,1
1S-Lab, Nanyang Technological University
2Microsoft Research, Redmond
+
+ β Co-Project Lead
+ * Equal Contribution
+ ✉ Corresponding Author
+
-----------------
@@ -44,37 +49,40 @@
- [Checkpoints v0.2 (video version, trained on MIMIC-IT all videos, upcoming)]()
- [Checkpoints v0.3 (Otter-E, visual assistant version, upcoming)]()
-Otter v0.1 supports multiple images inputs as in-context examples, which is **the first multi-modal instruction tuned model** that supports to organize inputs this way.
-
-Otter v0.2 supports videos inputs (frames are arranged as original Flamingo's implementation) and multiple images inputs (they serve as in-context examples for each other).
+Otter v0.1 supports multiple images inputs as in-context examples, which is **the first multi-modal instruction tuned model** that supports to organize inputs this way.
-Huge accolades to [Flamingo](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model) and [OpenFlamingo](https://github.com/mlfoundations/open_flamingo) team for the work on this great architecture.
+Otter v0.2 supports videos inputs (frames are arranged as original Flamingo's implementation) and multiple images inputs (they serve as in-context examples for each other).
-**Eval Results:** [Multi-Modal Arena](http://vlarena.opengvlab.com/) | Multi-Modal AGI Benchmark (Upcoming)
-
-
-
-
-

-
+**Eval Results:** [Multi-Modal Arena](http://vlarena.opengvlab.com/) | Multi-Modal AGI Benchmark (upcoming)
## π¦Ύ Update
+**[2023-06-23]**
+1. 𧨠[Download MIMIC-IT Dataset](https://entuedu-my.sharepoint.com/:f:/g/personal/libo0013_e_ntu_edu_sg/Eo9bgNV5cjtEswfA-HfjNNABiKsjDzSWAl5QYAlRZPiuZA?e=M9isDT). For more details on navigating the dataset, please refer to [MIMIC-IT Dataset README](mimic-it/README.md).
+2. ποΈ [Run Otter Locally](./pipeline/demo). You can run our model locally with at least 16G GPU mem for tasks like image/video tagging and captioning and identifying harmful content. We fix a bug related to video inference where `frame tensors` were mistakenly unsqueezed to a wrong `vision_x`. You can now try running it again with the updated version.
+ > Make sure to adjust the `sys.path.append("../..")` correctly to access `otter.modeling_otter` in order to launch the model.
+
**[2023-06-08]**
1. Introducing Project Otter's brand new homepage: https://otter-ntu.github.io/. Check it out now!
-2. Check our [paper](https://arxiv.org/abs/2306.05425) introducing MIMIC-IT in details. Meet MIMIC-IT, the first multimodal in-context instruction tuning dataset with 2.8M instructions! Designed to create diverse vision-language instructions that align with real-world visual content, MIMIC-IT spans across seven image and video datasets covering a vast array of scenes. From general scene understanding to spotting subtle differences and enhancing egocentric view comprehension for AR headsets, our MIMIC-IT dataset has it all. Discover more about the MIMIC-IT dataset now!
+2. Check our [paper](https://arxiv.org/abs/2306.05425) introducing MIMIC-IT in details. Meet MIMIC-IT, the first multimodal in-context instruction tuning dataset with 2.8M instructions! From general scene understanding to spotting subtle differences and enhancing egocentric view comprehension for AR headsets, our MIMIC-IT dataset has it all.
3. Stay tuned for our upcoming Otter Model v0.2, trained on the MIMIC-IT dataset! With the ability to understand daily scenes, reason in context, spot differences in observations, and act as an egocentric assistant. Checkout conceptual demo video at [Youtube](https://www.youtube.com/watch?v=K8o_LKGQJhs) or [Bilibili](https://www.bilibili.com/video/BV1Bo4y1T7SN/?share_source=copy_web&vd_source=477facaaaa60694f67a784f5eaa905ad)!
-**[2023-05-14]**
+
-## 𦦠Overview
+
+

+
+
+## 𦦠Why In-Context Instruction Tuning?
+
+Large Language Models (LLMs) have demonstrated exceptional universal aptitude as few/zero-shot learners for numerous tasks, owing to their pre-training on extensive text data. Among these LLMs, GPT-3 stands out as a prominent model with significant capabilities. Additionally, variants of GPT-3, namely InstrctGPT and ChatGPT, have proven effective in interpreting natural language instructions to perform complex real-world tasks, thanks to instruction tuning.
-Large Language Models (LLMs) have exhibited exceptional universal aptitude as few/zero-shot learners for numerous tasks, thanks to their pre-training on large-scale text data. GPT-3 is a prominent LLM that has showcased significant capabilities in this regard. Furthermore, variants of GPT-3, namely InstrctGPT and ChatGPT, equipped with instruction tuning, have proven effective in interpreting natural language instructions to perform complex real-world tasks. In this paper, we propose to introduce instruction tuning into multi-modal models, motivated by the Flamingo model's upstream interleaved format pretraining dataset. We adopt a similar approach to construct our **MI**-**M**odal **I**n-**C**ontext **I**nstruction **T**uning (**MIMIC-IT**) dataset. We then introduce 𦦠Otter, a multi-modal model based on OpenFlamingo (open-sourced version of DeepMind's Flamingo), trained on MIMIC-IT and showcasing improved instruction-following ability and in-context learning. We integrate both OpenFlamingo and Otter into Hugging Face Transformers for more researchers to incorporate the models into their customized training and inference pipelines.
+Motivated by the upstream interleaved format pretraining of the Flamingo model, we present 𦦠Otter, a multi-modal model based on OpenFlamingo (the open-sourced version of DeepMind's Flamingo). We train our Otter in an in-context instruction tuning way on our proposed **MI**-**M**odal **I**n-**C**ontext **I**nstruction **T**uning (**MIMIC-IT**) dataset. Otter showcases improved instruction-following and in-context learning ability in both images and videos.
## π MIMIC-IT Dataset Details
@@ -82,7 +90,7 @@ Large Language Models (LLMs) have exhibited exceptional universal aptitude as fe

-MIMIC-IT covers a vast array of real-life scenarios that empower Vision-Language Models (VLMs) to not only comprehend general scenes, but also to reason about context and astutely differentiate between observations. MIMIC-IT also enables the application of egocentric visual assistant model that can serve that can answer your questions like **Hey, Do you think I left my keys on the table?**. In addition to English, MIMIC-IT is also multilingual, supporting Chinese, Korean, Japanese, German, French, Spanish, and Arabic, thereby allowing a larger global audience to altogether enjoy from the convenience brought about by advancements in artificial intelligence.
+MIMIC-IT enables the application of egocentric visual assistant model that can serve that can answer your questions like **Hey, Do you think I left my keys on the table?**. Harness the power of MIMIC-IT to unlock the full potential of your AI-driven visual assistant and elevate your interactive vision-language tasks to new heights.
@@ -90,8 +98,6 @@ MIMIC-IT covers a vast array of real-life scenarios that empower Vision-Language
We also introduce **Syphus**, an automated pipeline for generating high-quality instruction-response pairs in multiple languages. Building upon the framework proposed by LLaVA, we utilize ChatGPT to generate instruction-response pairs based on visual content. To ensure the quality of the generated instruction-response pairs, our pipeline incorporates system messages, visual annotations, and in-context examples as prompts for ChatGPT.
-
-
For more details, please check the [MIMIC-IT dataset](mimic-it/README.md).
@@ -101,9 +107,9 @@ For more details, please check the [MIMIC-IT dataset](mimic-it/README.md).
-Otter is designed to support multi-modal in-context instruction tuning based on the OpenFlamingo model, which involves conditioning the language model on the corresponding media, such as an image that corresponds to a caption or an instruction-response pair.
+Otter is designed to support multi-modal in-context instruction tuning based on the OpenFlamingo model, which involves conditioning the language model on the corresponding media, such as an image that corresponds to a caption or an instruction-response pair.
-We train Otter on MIMIC-IT dataset with approximately 2.8 million in-context instruction-response pairs, which are structured into a cohesive template to facilitate various tasks.
+We train Otter on MIMIC-IT dataset with approximately 2.8 million in-context instruction-response pairs, which are structured into a cohesive template to facilitate various tasks. Otter supports videos inputs (frames are arranged as original Flamingo's implementation) and multiple images inputs as in-context examples, which is **the first multi-modal instruction tuned model**.
The following template encompasses images, user instructions, and model-generated responses, utilizing the `User` and `GPT` role labels to enable seamless user-assistant interactions.
@@ -129,35 +135,12 @@ For more details, please refer to our [paper](https://arxiv.org/abs/2306.05425)'
## ποΈ Environments
1. Compare cuda version returned by nvidia-smi and nvcc --version. They need to match. Or at least, the version get by nvcc --version should be <= the version get by nvidia-smi.
-2. Install the pytorch that matches your cuda version. (e.g. cuda 11.7 torch 2.0.0). We have successfully run this code on cuda 11.1 torch 1.10.1 and cuda 11.7 torch 2.0.0. Version compatible reference:[here](https://pytorch.org/) or [here](https://pytorch.org/get-started/previous-versions/).
+2. Install the pytorch that matches your cuda version. (e.g. cuda 11.7 torch 2.0.0). We have successfully run this code on cuda 11.1 torch 1.10.1 and cuda 11.7 torch 2.0.0. You can refer to PyTorch's documentation, [Latest](https://pytorch.org/) or [Previous](https://pytorch.org/get-started/previous-versions/).
3. You may install via `conda env create -f environment.yml`. Especially to make sure the `transformers>=4.28.0`, `accelerate>=0.18.0`.
## π€ Hugging Face Model
-You can use the 𦩠Flamingo model / 𦦠Otter model as a π€ Hugging Face model with only a few lines! One-click and then model configs/weights are downloaded automatically.
-
-``` python
-from flamingo import FlamingoModel
-flamingo_model = FlamingoModel.from_pretrained("luodian/openflamingo-9b-hf", device_map=auto)
-
-from otter import OtterModel
-otter_model = OtterModel.from_pretrained("luodian/otter-9b-hf", device_map=auto)
-```
-
-Previous [OpenFlamingo](https://github.com/mlfoundations/open_flamingo) was developed with [DistributedDataParallel](https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel) (DDP) on A100 cluster. Loading OpenFlamingo-9B to GPU requires **at least 33G GPU memory**, which is only available on A100 GPUs.
-
-In order to allow more researchers without access to A100 machines to try training OpenFlamingo, we wrap the OpenFlamingo model into a π€ hugging Face model ([Jinghao](https://king159.github.io/) has submitted a [PR](https://github.com/huggingface/transformers/pull/23063) to the /huggingface/transformers!). Via `device_map=auto`, the large model is sharded across multiple GPUs when loading and training. This can help researchers who do not have access to A100-80G GPUs to achieve similar throughput in training, testing on 4x RTX-3090-24G GPUs, and model deployment on 2x RTX-3090-24G GPUs. Specific details are below (may vary depending on the CPU and disk performance, as we conducted training on different machines).
-
-
-

-
-
-
-
-Our Otter model is also developed in this way and it's deployed on the π€ Hugging Face model hub. Our model can be hosted on two RTX-3090-24G GPUs and achieve a similar speed to one A100-80G machine.
+After configuring environment, you can use the 𦩠Flamingo model / 𦦠Otter model as a π€ Hugging Face model with only a few lines! One-click and then model configs/weights are downloaded automatically. Please refer to [Huggingface Otter/Flamingo](./docs/huggingface_compatible.md) for details.
## βοΈ Training
@@ -193,21 +176,6 @@ pipeline/train/instruction_following.py \
--warmup_steps_ratio=0.01 \
```
-## π Checkpoints
-
-For details, you may refer to the [model card](docs/model_card.md).
-
-## πͺ© Web Demo
-
-We host our [Otter-9B Demo](https://otter.cliangyu.com/) via dual RTX-3090-24G GPUs. Launch your own demo by following the [demo instructions](docs/demo.md).
-
-## π Incoming Features
-
-We are working towards offering these features to our users. However, we have encountered some issues in the process. If you have the solutions to these issues, we would be grateful if you could submit a pull request with your code. Your contribution would be highly appreciated.
-
-- [x] `xformers` support: for saving GPU memory and training speedup. issue [#35](https://github.com/Luodian/PET-VLM/issues/35)
-- [ ] `load_in_8bit` support: for saving GPU memory and training speedup.
-
## π Citation
If you found this repository useful, please consider citing:
@@ -233,6 +201,8 @@ If you found this repository useful, please consider citing:
We thank [Jack Hessel](https://jmhessel.com/) for the advise and support, as well as the [OpenFlamingo](https://github.com/mlfoundations/open_flamingo) team for their great contribution to the open source community.
+Huge accolades to [Flamingo](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model) and [OpenFlamingo](https://github.com/mlfoundations/open_flamingo) team for the work on this great architecture.
+
### π Related Projects
- [LLaVA: Visual Instruction Tuning](https://github.com/haotian-liu/LLaVA)
diff --git a/docs/demo.md b/docs/demo.md
deleted file mode 100644
index 4cccdf48..00000000
--- a/docs/demo.md
+++ /dev/null
@@ -1,250 +0,0 @@
-## πͺ© Serving Demo
-
-We will show you how to host a demo on your own computer using gradio.
-
-## Preparation
-
-### Download the checkpoints
-
-The 𦦠Otter checkpoint and the 𦩠Open Flamingo checkpoint can be auto-downloaded with the code below.
-
-## Start Demo
-
-### Launch a controller
-
-```Shell
-python -m pipeline.serve.controller --host 0.0.0.0 --port 10000
-```
-
-### Launch a model worker
-
-```Shell
-# Init our 𦦠Otter model on GPU
-CUDA_VISIBLE_DEVICES=0,1 python -m pipeline.serve.model_worker --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model_name otter --checkpoint_path luodian/otter-9b-hf --num_gpus 2 --limit_model_concurrency 200
-# Init our 𦦠Otter video model on CPU
-CUDA_VISIBLE_DEVICES=0,1 python -m pipeline.serve.model_worker --controller http://localhost:10000 --port 40002 --worker http://localhost:40002 --model_name otter_video --checkpoint_path checkpoint/otter9B_DC_fullset_16frames/ --num_gpus 2 --limit_model_concurrency 200 --load_bit 16
-# Init original open flamingo model on GPU
-CUDA_VISIBLE_DEVICES=2,3 python -m pipeline.serve.model_worker --controller http://localhost:10000 --port 40001 --worker http://localhost:40001 --model_name open_flamingo --checkpoint_path luodian/openflamingo-9b-hf --num_gpus 2 --limit_model_concurrency 200
-
-# Init original open flamingo model on CPU
-python -m pipeline.serve.model_worker --controller http://localhost:10000 --port 40001 --worker http://localhost:40001 --model_name open_flamingo_original --checkpoint_path luodian/openflamingo-9b-hf --num_gpus 0
-```
-
-Wait until the process finishes loading the model and you see "Uvicorn running on ...".
-
-### Launch a gradio web server
-
-```Shell
-# Image demo
-python -m pipeline.serve.gradio_web_server --controller http://localhost:10000 --port 7861
-# Video demo
-python -m pipeline.serve.gradio_web_server_video --controller http://localhost:10000 --port 7862
-```
-
-Now, you can open your browser and chat with the model!
-
-## Mini Demo
-
-Here is an example of multi-modal ICL (in-context learning) with 𦦠Otter. We provide two demo images with corresponding instructions and answers, then we ask the model to generate an answer given our instruct. You may change your instruction and see how the model responds.
-
-``` python
-import requests
-import torch
-import transformers
-from PIL import Image
-from otter.modeling_otter import OtterForConditionalGeneration
-
-model = OtterForConditionalGeneration.from_pretrained(
- "luodian/otter-9b-hf", device_map="auto"
-)
-tokenizer = model.text_tokenizer
-image_processor = transformers.CLIPImageProcessor()
-demo_image_one = Image.open(
- requests.get(
- "http://images.cocodataset.org/val2017/000000039769.jpg", stream=True
- ).raw
-)
-demo_image_two = Image.open(
- requests.get(
- "http://images.cocodataset.org/test-stuff2017/000000028137.jpg", stream=True
- ).raw
-)
-query_image = Image.open(
- requests.get(
- "http://images.cocodataset.org/test-stuff2017/000000028352.jpg", stream=True
- ).raw
-)
-vision_x = (
- image_processor.preprocess(
- [demo_image_one, demo_image_two, query_image], return_tensors="pt"
- )["pixel_values"]
- .unsqueeze(1)
- .unsqueeze(0)
-)
-model.text_tokenizer.padding_side = "left"
-lang_x = model.text_tokenizer(
- [
- "
User: what does the image describe? GPT: two cats sleeping.<|endofchunk|>User: what does the image describe? GPT: a bathroom sink.<|endofchunk|>User: what does the image describe? GPT:"
- ],
- return_tensors="pt",
-)
-generated_text = model.generate(
- vision_x=vision_x.to(model.device),
- lang_x=lang_x["input_ids"].to(model.device),
- attention_mask=lang_x["attention_mask"].to(model.device),
- max_new_tokens=256,
- num_beams=3,
- no_repeat_ngram_size=3,
-)
-
-print("Generated text: ", model.text_tokenizer.decode(generated_text[0]))
-```
-
-An example for video.
-``` python
-import mimetypes
-import os
-from io import BytesIO
-from typing import Union
-import cv2
-import requests
-import torch
-import transformers
-from PIL import Image
-from torchvision.transforms import Compose, Resize, ToTensor
-from tqdm import tqdm
-import sys
-
-from otter.modeling_otter import OtterForConditionalGeneration
-
-# Disable warnings
-requests.packages.urllib3.disable_warnings()
-
-# ------------------- Utility Functions -------------------
-
-
-def get_content_type(file_path):
- content_type, _ = mimetypes.guess_type(file_path)
- return content_type
-
-
-# ------------------- Image and Video Handling Functions -------------------
-
-
-def extract_frames(video_path, num_frames=128):
- video = cv2.VideoCapture(video_path)
- total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
- frame_step = total_frames // num_frames
- frames = []
-
- for i in range(num_frames):
- video.set(cv2.CAP_PROP_POS_FRAMES, i * frame_step)
- ret, frame = video.read()
- if ret:
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
- frame = Image.fromarray(frame).convert("RGB")
- frames.append(frame)
-
- video.release()
- return frames
-
-
-def get_image(url: str) -> Union[Image.Image, list]:
- if "://" not in url: # Local file
- content_type = get_content_type(url)
- else: # Remote URL
- content_type = requests.head(url, stream=True, verify=False).headers.get("Content-Type")
-
- if "image" in content_type:
- if "://" not in url: # Local file
- return Image.open(url)
- else: # Remote URL
- return Image.open(requests.get(url, stream=True, verify=False).raw)
- elif "video" in content_type:
- video_path = "temp_video.mp4"
- if "://" not in url: # Local file
- video_path = url
- else: # Remote URL
- with open(video_path, "wb") as f:
- f.write(requests.get(url, stream=True, verify=False).content)
- frames = extract_frames(video_path)
- if "://" in url: # Only remove the temporary video file if it was downloaded
- os.remove(video_path)
- return frames
- else:
- raise ValueError("Invalid content type. Expected image or video.")
-
-
-# ------------------- OTTER Prompt and Response Functions -------------------
-
-
-def get_formatted_prompt(prompt: str) -> str:
- return f"User: {prompt} GPT:"
-
-
-def get_response(input_data, prompt: str, model=None, image_processor=None) -> str:
- if isinstance(input_data, Image.Image):
- vision_x = (
- image_processor.preprocess([input_data], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
- )
- elif isinstance(input_data, list): # list of video frames
- vision_x = image_processor.preprocess(input_data, return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
- else:
- raise ValueError("Invalid input data. Expected PIL Image or list of video frames.")
-
- lang_x = model.text_tokenizer(
- [
- get_formatted_prompt(prompt),
- ],
- return_tensors="pt",
- )
-
- generated_text = model.generate(
- vision_x=vision_x.to(model.device),
- lang_x=lang_x["input_ids"].to(model.device),
- attention_mask=lang_x["attention_mask"].to(model.device),
- max_new_tokens=512,
- num_beams=3,
- no_repeat_ngram_size=3,
- )
- parsed_output = (
- model.text_tokenizer.decode(generated_text[0])
- .split("")[-1]
- .lstrip()
- .rstrip()
- .split("<|endofchunk|>")[0]
- .lstrip()
- .rstrip()
- .lstrip('"')
- .rstrip('"')
- )
- return parsed_output
-
-
-# ------------------- Main Function -------------------
-
-if __name__ == "__main__":
- model = OtterForConditionalGeneration.from_pretrained(
- "luodian/otter-9b-dc-hf",
- )
- model.text_tokenizer.padding_side = "left"
- tokenizer = model.text_tokenizer
- image_processor = transformers.CLIPImageProcessor()
- model.eval()
-
- while True:
- video_url = "demo.mp4" # Replace with the path to your video file
-
- frames_list = get_image(video_url)
-
- prompts_input = input("Enter prompts (comma-separated): ")
- prompts = [prompt.strip() for prompt in prompts_input.split(",")]
-
- for prompt in prompts:
- print(f"\nPrompt: {prompt}")
- response = get_response(frames_list, prompt, model, image_processor)
- print(f"Response: {response}")
-
- if prompts_input.lower() == "quit":
- break
-```
diff --git a/docs/huggingface_compatible.md b/docs/huggingface_compatible.md
new file mode 100644
index 00000000..22fbfc1c
--- /dev/null
+++ b/docs/huggingface_compatible.md
@@ -0,0 +1,26 @@
+## π€ Hugging Face Model
+
+You can use the 𦩠Flamingo model / 𦦠Otter model as a π€ Hugging Face model with only a few lines! One-click and then model configs/weights are downloaded automatically.
+
+``` python
+from flamingo import FlamingoModel
+flamingo_model = FlamingoModel.from_pretrained("luodian/openflamingo-9b-hf", device_map=auto)
+
+from otter import OtterModel
+otter_model = OtterModel.from_pretrained("luodian/otter-9b-hf", device_map=auto)
+```
+
+Previous [OpenFlamingo](https://github.com/mlfoundations/open_flamingo) was developed with [DistributedDataParallel](https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel) (DDP) on A100 cluster. Loading OpenFlamingo-9B to GPU requires **at least 33G GPU memory**, which is only available on A100 GPUs.
+
+In order to allow more researchers without access to A100 machines to try training OpenFlamingo, we wrap the OpenFlamingo model into a π€ hugging Face model ([Jinghao](https://king159.github.io/) has submitted a [PR](https://github.com/huggingface/transformers/pull/23063) to the /huggingface/transformers!). Via `device_map=auto`, the large model is sharded across multiple GPUs when loading and training. This can help researchers who do not have access to A100-80G GPUs to achieve similar throughput in training, testing on 4x RTX-3090-24G GPUs, and model deployment on 2x RTX-3090-24G GPUs. Specific details are below (may vary depending on the CPU and disk performance, as we conducted training on different machines).
+
+
+

+
+
+
+
+Our Otter model is also developed in this way and it's deployed on the π€ Hugging Face model hub. Our model can be hosted on two RTX-3090-24G GPUs and achieve a similar speed to one A100-80G machine.
\ No newline at end of file
diff --git a/docs/model_card.md b/docs/model_card.md
deleted file mode 100644
index 80055741..00000000
--- a/docs/model_card.md
+++ /dev/null
@@ -1,25 +0,0 @@
----
-language: en
-datasets:
-- multi-instruct
----
-
-# Otter-9B
-
-[Code](https://github.com/Luodian/PET-VLM) | [Demo](https://otter.cliangyu.com/)
-
-Otter is an instruction-following large multi-modal model built upon Open-Flamingo-9B. Through in-context instruction following, Otter is able to perform tasks more aligned with human preferences and more accurate.
-
-## Model Details
-
-Following the same setting as Flamingo, we freeze the pretrained vision encoder and language model, and only train Perceiver modules and cross-attention layers. We add one special token `` and resize the input and output embedding of the language model. This special token is used to separate the instruction and answer when calculating the causal loss ans used as a beginning token when generating the answer.
-
-Our training data will be released soon.
-
-## Uses
-
-Otter-9B is intended to be used **for academic research purposes only.** Commercial use is prohibited, in line with LLaMA's non-commercial license.
-
-### Bias, Risks, and Limitations
-
-This model may generate inaccurate or offensive outputs, reflecting biases in its training data and pretrained priors.
diff --git a/docs/server_host.md b/docs/server_host.md
new file mode 100644
index 00000000..acee0868
--- /dev/null
+++ b/docs/server_host.md
@@ -0,0 +1,44 @@
+## πͺ© Serving Demo
+
+We will show you how to host a demo on your own computer using gradio.
+
+## Preparation
+
+### Download the checkpoints
+
+The 𦦠Otter checkpoint and the 𦩠Open Flamingo checkpoint can be auto-downloaded with the code below.
+
+## Start Demo
+
+### Launch a controller
+
+```Shell
+python -m pipeline.serve.controller --host 0.0.0.0 --port 10000
+```
+
+### Launch a model worker
+
+```Shell
+# Init our 𦦠Otter model on GPU
+CUDA_VISIBLE_DEVICES=0,1 python -m pipeline.serve.model_worker --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model_name otter --checkpoint_path luodian/otter-9b-hf --num_gpus 2 --limit_model_concurrency 200
+# Init our 𦦠Otter video model on CPU
+CUDA_VISIBLE_DEVICES=0,1 python -m pipeline.serve.model_worker --controller http://localhost:10000 --port 40002 --worker http://localhost:40002 --model_name otter_video --checkpoint_path checkpoint/otter9B_DC_fullset_16frames/ --num_gpus 2 --limit_model_concurrency 200 --load_bit 16
+# Init original open flamingo model on GPU
+CUDA_VISIBLE_DEVICES=2,3 python -m pipeline.serve.model_worker --controller http://localhost:10000 --port 40001 --worker http://localhost:40001 --model_name open_flamingo --checkpoint_path luodian/openflamingo-9b-hf --num_gpus 2 --limit_model_concurrency 200
+
+# Init original open flamingo model on CPU
+python -m pipeline.serve.model_worker --controller http://localhost:10000 --port 40001 --worker http://localhost:40001 --model_name open_flamingo_original --checkpoint_path luodian/openflamingo-9b-hf --num_gpus 0
+```
+
+Wait until the process finishes loading the model and you see "Uvicorn running on ...".
+
+### Launch a gradio web server
+
+```Shell
+# Image demo
+python -m pipeline.serve.gradio_web_server --controller http://localhost:10000 --port 7861
+# Video demo
+python -m pipeline.serve.gradio_web_server_video --controller http://localhost:10000 --port 7862
+```
+
+Now, you can open your browser and chat with the model!
diff --git a/mimic-it/README.md b/mimic-it/README.md
index df52e17c..1aaca4d8 100644
--- a/mimic-it/README.md
+++ b/mimic-it/README.md
@@ -1,110 +1,118 @@
-
+
+- [π³ MIMIC-IT Overview](#-mimic-it-overview)
+- [Using MIMIC-IT Dataset](#using-mimic-it-dataset)
+ - [Convert It](#convert-it)
+ - [Download It](#download-it)
+ - [Eggs (Coming Soon)](#eggs-coming-soon)
+- [Syphus: the hero behind MIMIC-IT](#syphus-the-hero-behind-mimic-it)
+ - [Syphus on your own dataset](#syphus-on-your-own-dataset)
+- [Multilingual Instruction-Response Pairs](#multilingual-instruction-response-pairs)
+
## π³ MIMIC-IT Overview
-High-quality instructions are essential for the zero-shot performance of large language models on interactive natural language tasks. For interactive vision-language tasks involving intricate visual scenes, a large quantity of diverse and creative instructions should be imperative to tune vision-language models (VLMs). Nevertheless, the current availability of vision-language instructions in terms of quantity, diversity, and creativity remains limited, posing challenges to the generalization of interactive VLMs. Here we present **MIMIC-IT**, a dataset comprising 2.8M multi-modal instructions-response pairs based on images and videos. Each instruction-response pair is accompanied by multi-modal in-context information, forming conversational contexts aimed at empowering VLMs in perception, reasoning, and planning. The instruction-response collection process, dubbed as **Syphus**, is scaled using an automatic annotation pipeline that combines human expertise with GPT's capabilities.
+MIMIC-IT offers a diverse and extensive dataset of 2.8M multimodal instruction-response pairs, designed to enhance the performance of Vision-Language Models (VLMs) in real-life scenarios, enabling VLMs to excel in perception, reasoning, and planning while also catering to a multilingual audience.
+
+MIMIC-IT enables the application of egocentric visual assistant model that can serve that can answer your questions like **Hey, Do you think I left my keys on the table?**. Harness the power of MIMIC-IT to unlock the full potential of your AI-driven visual assistant and elevate your interactive vision-language tasks to new heights.
-MIMIC-IT covers a vast array of real-life scenarios that empower Vision-Language Models (VLMs) to not only comprehend general scenes, but also to reason about context and astutely differentiate between observations. MIMIC-IT also enables the application of egocentric visual assistant model that can serve that can answer your questions like **Hey, Do you think I left my keys on the table?**. In addition to English, MIMIC-IT is also multilingual, supporting Chinese, Korean, Japanese, German, French, Spanish, and Arabic, thereby allowing a larger global audience to altogether enjoy from the convenience brought about by advancements in artificial intelligence.
+MIMIC-IT provides multilingual instructions, supporting English, Chinese, Korean, Japanese, German, French, Spanish, and Arabic, thereby allowing a larger global audience to altogether enjoy from the convenience brought about by advancements in artificial intelligence.
-
+
-## Dataset Statistics
-
-| **Visual Sources (Scenes)** | **In-context** | **#Clips/Images** | **#Uni. Inst.** | **#Instances** |
-|---------------------------|----------------|------------------|-----------------|----------------|
-| COCO (General) | lang./vis. | - / 81K | 261K | 345K |
-| SD~(Surveillance) | lang./vis. | - / 9K | 10K | 15K |
-| SN~(Indoor Ego.) | lang./vis. | - / 0.5K | 4.8K | 6K |
-| DC~(General) | lang./vis. | 16K / 1M | 40K | 62K |
-| VIST~(Story) | lang./vis. | - / 16K | 32K | 33K |
-| TVC~(TV) | lang./vis. | 86K / 577K | 86K | 92K |
-| E4D~(General Ego.) | lang./vis. | 400K / 6.4M | 1.8M | 2.4M |
-| Total | lang./vis. | 502K / 8.1M | 2.2M | 2.8M |
-
-## Download Links
-
-The initial release includes LA and DC instruction-response pairs for the MIMIC-IT dataset. We plan to release additional datasets with a larger number of instruction pairs and more information after further examination.
+## Using MIMIC-IT Dataset
-We are contacting the image sources (those public datasets we used) to ask if we can directly release their image/video data in our Otter training format (base64 format within a large JSON file), we will put these data in following link if there would not be any legal/license issue.
+You can following the steps to obtain the MIMIC-IT dataset. Each task (e.g. `DC`, `LA`) in MIMIC-IT is composed of three parts, including:
+1. `xx.json` file: the images in base64 format.
+2. `xx_instructions.json` file: the instruction-response pairs (also includes image ids and related instructions ids for each instruction-response pair) for each task.
+3. `xx_train.json` file: the customized related instruction-response pairs for each instruction.
-This process may take some time. If you are interested in using this data, please leave an issue in this repository or email drluodian@gmail.com, and we will keep you updated.
+The following steps will introduce you how to gather them together.
-Additionally, we are in the process of providing the scripts used to convert public dataset images and extract specific frames from corresponding videos into the MIMIC-IT input format. This will help map the original dataset to our annotations UUIDs (e.g. from COCO's `000000215677.jpg` -> ours `LA_00_IMG_000000215677`).
+### Convert It
+You may need to refer to the [Convert-It](./convert-it/README.md) to convert the image sources from public dataset to the format of `xx.json`. If you find it is hard to download from the original image sources, you can refer to the [Eggs](#eggs) section to seek help there.
-| Scenes | Images/Videos | Size | Annotations | Size |
-| :--- | :---: | :---: | :---: | :---: |
-| **LA In-context** | Processing | 5.2GB |[link](https://entuedu-my.sharepoint.com/:u:/r/personal/libo0013_e_ntu_edu_sg/Documents/MIMIC-IT-Release/LA_instructions.json.zip?csf=1&web=1&e=SvaKh3) | 269.3MB |
-| **Dense Caption** | Processing | 86.4GB |[link](https://entuedu-my.sharepoint.com/:u:/r/personal/libo0013_e_ntu_edu_sg/Documents/MIMIC-IT-Release/DC_instructions.json.zip?csf=1&web=1&e=jM4gGB) | 269.1MB |
-| **TV Caption** | Processing | 17.0GB | Cleaning | 55.6MB |
-| **Visual Story Telling** | Processing | 16.2GB |Cleaning | 33.4MB |
-| **Scene Navigation (Indoor Event Planning)** | Processing | 2.3GB |Cleaning | 7.6MB |
-| **Spot The Difference (COCO's General Difference)** | Processing | 5.2GB |Cleaning | 80.5MB |
-| **Spot The Difference (Subtle Difference)** | Processing | 3.1GB |Cleaning | 5.0MB |
-| **EGO4D** | Processing | ~500GB |Cleaning | 3.2GB |
+### Download It
-The data is available on [NTU-Onedrive](https://entuedu-my.sharepoint.com/:f:/g/personal/libo0013_e_ntu_edu_sg/Eo9bgNV5cjtEswfA-HfjNNABiKsjDzSWAl5QYAlRZPiuZA?e=M9isDT). The JSON files are compressed into ZIP files to save space. After downloading, unzip the files and verify the MD5 checksums to ensure their integrity. The MD5 Checksums for the released annotations are:
+You can download the `instructions.json` and `train.json` files, from our provided [OneDrive folder](https://entuedu-my.sharepoint.com/:f:/g/personal/libo0013_e_ntu_edu_sg/Eo9bgNV5cjtEswfA-HfjNNABiKsjDzSWAl5QYAlRZPiuZA?e=M9isDT).
+| Tasks/Scenes | Zip File MD5 Checksum | Unzipped File Size |
+| :--- | :---: | :---: |
+| **LA-In-Context** | fdc2427451bcfd8a04bab7a1c2305259 | 338 MB |
+| **DC** | 0d373a5f511fd1d9f03ac81bb12e04fe | 171 MB |
+| **TVC** | 122b5cb0bd51c658625b7ea8c7d8c04c | 230 MB |
+| **VST** | 988569e39aaa24da0df547644514b0d4 | 32 MB |
+| **SN** | 1c4751c5b2c0bcaaeb94dbc5fb39e7a6 | 8 MB |
+| **SD (General Diff)** | TBD | 81 MB |
+| **SD (Subtle Diff)** | 5175198daebb997672a21307e8b18a96 | 5 MB |
+| **E4D (1st Part)** | 504b779dbc852c943adbe7862d6924d7 | 710 MB/3.2 GB |
+After downloading, unzip the files and place them in the `mimicit_data` folder. The folder structure should be as follows:
-1. LA_instructions.json (json not the zip file) -> f9bc559391d15727b35f3df306b12e31
-2. DC_instructions.json -> bb0d1f9f7d100c99869f79d13b3a3beb
-
-The MIMIC-IT dataset is stored in the following format:
+```bash
+mimicit_data/DC/DC_instructions.json
+mimicit_data/DC/DC_train.json
+```
+The `DC_instructions.json` includes a meta object with version, time, and author information. The data object contains instruction-response pairs, each with a unique identifier (e.g., "DC_INS_00001"). Each pair consists of an instruction, an answer, an array of associated image IDs, and an array of related instruction IDs (which can be arranged as in-context examples).
```json
-{
- "meta": {
- "version": "0.0.1",
- "time": "2023-06",
- "author": "ntu"
- },
- "data": {
- "DC_04_INS_00001": {
- "instruction": "Who is the main focus of the video?",
- "answer": "The main focus of the video is a police officer riding a horse down the street.",
- "image_ids": [
- "DC_04_IMG_v_N1c3C_Npr-E_0000",
- "DC_04_IMG_v_N1c3C_Npr-E_0001",
- ...
- "DC_04_IMG_v_N1c3C_Npr-E_0067"
- ],
- "rel_ins_ids": [
- "DC_04_INS_00002",
- "DC_04_INS_00003",
- ...
- "DC_04_INS_00008"
- ]
- },
- ...
- }
+{
+ "meta":{"virson":"0.0.1","time":"2023-06","author":"ntu"},
+ "data": {
+ "DC_INS_00001": {
+ "instruction":"Who is the main focus of the video?",
+ "answer":"The main focus of the video is a police officer riding a horse down the street.",
+ "image_ids":["DC_IMG_v_N1c3C_Npr-E_0000","DC_IMG_v_N1c3C_Npr-E_0001","DC_IMG_v_N1c3C_Npr-E_0002",..."],
+ "rel_ins_ids":["DC_INS_00002","DC_INS_00003","DC_INS_00004","DC_INS_00005","DC_INS_00006","DC_INS_00007","DC_INS_00008"]
+ },
+ }
+ ...
}
```
-This JSON file includes a meta object with version, time, and author information. The data object contains instruction-response pairs, each with a unique identifier (e.g., "DC_04_INS_00001"). Each pair consists of an instruction, an answer, an array of associated image IDs, and an array of related instruction IDs (which can be arranged as in-context examples).
+The `DC_train.json` contains instructions IDs and their associated related instruction IDs. Each instruction is associated with its related instructions. We provide it for more flexibly define each instruction's related instructions. It serves for different in-context learning objectives. In default, the related instructions ids are from `rel_ins_ids` at `DC_instructions.json`. But you can define your own related instructions ids for each instruction by just creating your own `DC_train.json` file.
-## Multilingual Instruction-Response Pairs
+```json
+{
+ "DC_INS_00001": ["DC_INS_00002", "DC_INS_00003", "DC_INS_00004", "DC_INS_00005", "DC_INS_00006", "DC_INS_00007", "DC_INS_00008"],
+ ...
+}
+```
-We will release multilingual instruction-response pairs in the following languages:
+### Eggs (Coming Soon)
-
-
-
+Things could be tricky since some image/video sources are not easy to get the access to download them. We also provide the converted `xx.json` files for you to download directly. You need to agree the same terms and conditions as the original dataset, as well as recognize and appreciate the contributions made by these data sources. Please refer to [Google form]() to apply for the access to download the converted `xx.json` files.
-## Syphus Overview
+
+## Syphus: the hero behind MIMIC-IT
-Syphus, an automated pipeline for generating high-quality instruction-response pairs in multiple languages. Building upon the framework proposed by LLaVA, we utilize ChatGPT to generate instruction-response pairs based on visual content. To ensure the quality of the generated instruction-response pairs, our pipeline incorporates system messages, visual annotations, and in-context examples as prompts for ChatGPT. System messages define the desired tone and style of the generated instruction-response pairs, while visual annotations provide essential image information such as bounding boxes and image descriptions. In-context examples assist ChatGPT in learning within the context. During cold-start stage, in-context examples are collected by prompting ChatGPT solely through system messages and visual annotations, employing a heuristic approach. This stage concludes only when a satisfactory in-context examples are identified. In step 4, once the instruction-response pairs are obtained, the pipeline expands them into Chinese (zh), Japanese (ja), Spanish (es), German (de), French (fr), Korean (ko), and Arabic (ar).
+Embracing Syphus, an automated pipeline that generates top-tier instruction-response pairs in various languages.
+
+Syphus builds on the LLaVA framework and uses ChatGPT to produce pairs based on visual content. It ensures quality by incorporating system messages for tone and style, visual annotations for essential image information, and in-context examples to assist ChatGPT in contextual learning. During the cold-start stage, in-context examples are collected using a heuristic approach with system messages and visual annotations. This stage concludes only when a satisfactory in-context examples are identified.
+
+Finally, the pipeline expands the instruction-response pairs into languages like Chinese, Japanese, Spanish, German, French, Korean, and Arabic.
-## Syphus on your own dataset
+### Syphus on your own dataset
-We provide source code of the framework in [syphus](Syphus) folder. You can use it to generate instruction-response pairs on your own dataset following the steps below.
+We provide source code of the framework in [syphus](./syphus/) folder. You can use it to generate instruction-response pairs on your own dataset following the steps below.
1. Configure openai key. Create the following environment variables in your system.
@@ -129,7 +137,13 @@ export OPENAI_API_ENGINE="chatgpt0301"
6. You are done! Run the following command to generate instruction-response pairs on your own dataset.
``` bash
-python main.py --name YourDataset.your_dataset --num_threads 4
+python main.py --name YourDataset.your_dataset --num_threads 64
```
-## π Citation
+## Multilingual Instruction-Response Pairs
+
+We will release multilingual instruction-response pairs in the following languages:
+
+
+
+
diff --git a/mimic-it/convert-it/README.md b/mimic-it/convert-it/README.md
new file mode 100644
index 00000000..ec9d2b66
--- /dev/null
+++ b/mimic-it/convert-it/README.md
@@ -0,0 +1,153 @@
+# Convert It
+
+This guide provides detailed instructions on how to convert various datasets from their public sources to our required format, including LLaVA-In-Context, Dense Captions, Visual Storytelling, TV Captions, Scene Navigation, Spot The Difference, and EGO4D. By following the specified steps, users can easily set up on these datasets. The output for each dataset will be saved in a corresponding JSON file named `.json` in the `output` folder.
+
+## LLaVA-In-Context
+
+Download the [coco2017](https://cocodataset.org/#download) images (coco2014 might also be work), put the images in a folder with the path ``. Download the [meta](XXX) for the training image ids, put the meta file at the path ``.
+
+The folder structure should be like this:
+
+```plain
+/
+ annotations/
+ val2017/
+ train2017/
+ 000000498792.jpg
+ XXXXXXXXXXXX.jpg
+ ...
+```
+
+Run the following command (the `--num_threads` is optional):
+
+```bash
+python main.py --name=2d.Llava --image_path= --image_root=/train2017 [--num_threads=]
+```
+
+The output will be saved in `output/LA.json`.
+
+
+## Dense Captions
+
+Download the [Dense Captions](https://cs.stanford.edu/people/ranjaykrishna/densevid/) videos in [ActivityNet](http://activity-net.org/challenges/2016/download.html#c3d), put the videos in a folder with the path ``
+
+The folder structure should be like this:
+
+```plain
+/
+ .mp4
+ ...
+```
+
+Run the following command:
+
+```bash
+python main.py --name=video.DenseCaptions --image_path= [--num_threads=]
+```
+
+The output will be saved in `output/DC.json`.
+
+## Visual Storytelling
+
+Download the [Visual Storytelling Dataset](https://visionandlanguage.net/VIST/dataset.html) and extract the `train.story-in-sequence.json` to a path, let `` be the path of the json file, and run the following command:
+
+```bash
+python main.py --name=video.VisualStorytelling --image_path= [--num_threads=]
+```
+
+The output will be saved in `output/VST.json`.
+
+## TV Captions
+
+Download the [TV Captions video frames (3FPS)](https://tvqa.cs.unc.edu/download_tvqa.html#tvqa-download-4) and extract the `zip` to a path, let `` be the path of the extracted folder.
+
+The folder structure should be like this:
+
+```plain
+/
+ bbt_frames/
+ ...
+ castle_frames/
+ ...
+ house_frames/
+ ...
+ met_frames/
+ ...
+```
+
+Run the following command:
+
+```bash
+python main.py --name=video.TVCaptions --image_path= [--num_threads=]
+```
+
+The output will be saved in `output/TV.json`.
+
+## Scene Navigation
+
+Download the ScanNet v2 dataset from the [official website](http://www.scan-net.org/), let `` be the path of the dataset
+
+The folder structure should be like this:
+
+```plain
+/
+ scene0000_00/
+ color/
+ 000000.jpg
+ ...
+ ...
+ ...
+```
+
+
+Run the following command:
+
+```bash
+python main.py --name=3d.SceneNavigation --image_path= [--num_threads=]
+```
+
+The output will be saved in `output/SN.json`.
+
+## Spot The Difference (Subtle Difference Version)
+
+Download the Spot The Difference Dataset from [Google Drive](https://drive.google.com/file/d/1OVb4_3Uec_xbyUk90aWC6LFpKsIOtR7v/view?usp=sharing), let `` be the path of the dataset.
+
+The folder structure should be like this:
+
+```plain
+/
+ .jpg
+ ...
+```
+
+Run the following command:
+
+```bash
+python main.py --name=change.SpotTheDifference --image_path= [--num_threads=]
+```
+
+The output will be saved in `output/SD.json`.
+
+## Spot The Difference (COCO General Difference Version)
+
+TBD
+
+## EGO4D
+
+Download the [EGO4D dataset](https://ego4d-data.org/#download), let `` be the path of the dataset.
+
+The folder structure should be like this:
+
+```plain
+/
+ .mp4
+ ...
+```
+
+Run the following command:
+
+```bash
+python main.py --name=fpv.EGO4D --image_path= [--num_threads=]
+```
+
+The output will be saved in `output/E4D.json`.
diff --git a/mimic-it/convert-it/abstract_dataset.py b/mimic-it/convert-it/abstract_dataset.py
new file mode 100644
index 00000000..389790f4
--- /dev/null
+++ b/mimic-it/convert-it/abstract_dataset.py
@@ -0,0 +1,138 @@
+from abc import ABC, abstractmethod
+from typing import List, Dict, Any, Tuple
+from PIL import Image
+import importlib
+
+AVAILABLE_DATASETS: List[str] = [
+ "change.SpotTheDifference",
+ "video.DenseCaptions",
+ "video.TVCaptions",
+ "video.VisualStoryTelling",
+ "3d.SceneNavigation",
+ "fpv.EGO4D",
+ "2d.Llava",
+]
+
+
+class AbstractDataset(ABC):
+ def __init__(self, name: str, short_name: str, image_path: str, num_threads: int):
+ """
+ Constructor.
+
+ Args:
+ name (str): The name of the dataset.
+ short_name (str): The short name of the dataset.
+ image_path (str): The path to the images of the dataset.
+ num_threads (int): The number of threads to use for processing the images.
+ """
+ self.name: str = name
+ self.short_name: str = short_name
+ self.images: Dict[str, Image.Image] = self._load_images(image_path, num_threads)
+
+ @abstractmethod
+ def _load_images(self, image_path: str, num_thread: int) -> dict[str, Image.Image]:
+ """
+ Load the images from the videos or albums.
+
+ Args:
+ image_path (str): The path storing the videos or albums.
+ num_thread (int): The number of threads to use for loading the images.
+
+ Returns:
+ Dict[str, Image.Image]: A dictionary of images, where the keys are the IDs of the images.
+ """
+ pass
+
+ def __getitem__(self, key: str) -> Dict[str, Any]:
+ """
+ Return the item at the given index as a dictionary.
+
+ Args:
+ key (str): The key of the item to retrieve.
+
+ Returns:
+ Dict[str, Any]: The item at the given key.
+ """
+ return self.images[key]
+
+ def __iter__(self) -> "AbstractDataset":
+ """
+ Return the iterator object for the dataset.
+
+ Returns:
+ AbstractDataset: The iterator object.
+ """
+ self.keys = iter(self.images.keys())
+ return self
+
+ def __next__(self) -> Tuple[str, Image.Image]:
+ """
+ Return the next item in the iteration.
+
+ Returns:
+ Tuple[str, Image.Image]: The next item as a tuple of key and image.
+
+ Raises:
+ StopIteration: If there are no more items in the iteration.
+ """
+ try:
+ key = next(self.keys)
+ image = self.images[key]
+ return key, image
+ except StopIteration:
+ raise StopIteration
+
+ def __len__(self) -> int:
+ """
+ Return the length of the dataset.
+
+ Returns:
+ int: The length of the dataset.
+ """
+ return len(self.query_inputs)
+
+ def __str__(self) -> str:
+ """
+ Return a string representation of the dataset.
+
+ Returns:
+ str: The string representation of the dataset.
+ """
+ return f"{self.name} dataset"
+
+
+def get_dataset_by_path(path: str, dataset_args: dict[str, str]) -> AbstractDataset:
+ """
+ Get an instance of a dataset class based on the given path.
+
+ Args:
+ path (str): The path to the dataset class in the format ".".
+ dataset_args (Dict[str, str]): Additional arguments to pass to the dataset class constructor.
+
+ Returns:
+ AbstractDataset: An instance of the dataset class.
+
+ Raises:
+ AssertionError: If the given path is not an available dataset.
+ """
+ assert path in AVAILABLE_DATASETS, f"{path} is not an available dataset."
+ module_path, dataset_name = path.split(".")
+ module_path = "datasets." + module_path
+
+ # Import the module and load the class
+
+ imported_module = importlib.import_module(module_path)
+ dataset_class = getattr(imported_module, dataset_name)
+
+ # Instantiate the class and return the instance
+ return dataset_class(**dataset_args)
+
+
+def get_available_datasets() -> List[str]:
+ """
+ Get a list of available dataset paths.
+
+ Returns:
+ List[str]: A list of available dataset paths.
+ """
+ return AVAILABLE_DATASETS
diff --git a/mimic-it/convert-it/datasets/2d.py b/mimic-it/convert-it/datasets/2d.py
new file mode 100644
index 00000000..dd71612e
--- /dev/null
+++ b/mimic-it/convert-it/datasets/2d.py
@@ -0,0 +1,56 @@
+import os
+import json
+
+from abstract_dataset import AbstractDataset
+from PIL import Image
+from tqdm import tqdm
+from glob import glob
+from image_utils import create_folder
+
+
+class Llava(AbstractDataset):
+ def __init__(
+ self,
+ name: str = "Llava",
+ short_name="LA",
+ *,
+ image_root: str,
+ image_path: str,
+ num_threads: int,
+ ):
+ """
+ Initializes a Llava in-context dataset.
+
+ Args:
+ name (str): The name of the dataset. Defaults to "Llava".
+ short_name (str): The short name of the dataset. Defaults to "LA".
+ image_path (str): The path containing the dataset images, downloaded from https://drive.google.com/file/d/1OVb4_3Uec_xbyUk90aWC6LFpKsIOtR7v/view?usp=sharing.
+ image_root (str): The path to the coco image train split
+ num_threads (int): The number of threads to use for processing the images.
+ """
+ self.image_root = image_root
+ super().__init__(name, short_name, image_path, num_threads)
+
+ def _load_images(self, image_path: str, num_thread: int) -> dict[str, Image.Image]:
+ """
+ Loads the images from the dataset.
+
+ Args:
+ image_path (str): The path to the dictionary containing the dataset images.
+ num_threads (int): The number of threads to use for processing the images.
+
+ Returns:
+ dict[str, Image.Image]: A dictionary where the keys are image identifiers and the values are PIL.Image.Image objects.
+ """
+
+ def read_image(file_name) -> Image.Image:
+ return Image.open(file_name)
+
+ images = {}
+ with open(image_path) as f:
+ image_ids = json.load(f).keys()
+
+ for cur_image_id in image_ids:
+ images[cur_image_id] = read_image(f"{self.image_root}/{cur_image_id}.jpg")
+
+ return images
diff --git a/mimic-it/convert-it/datasets/3d.py b/mimic-it/convert-it/datasets/3d.py
new file mode 100644
index 00000000..85b7206d
--- /dev/null
+++ b/mimic-it/convert-it/datasets/3d.py
@@ -0,0 +1,39 @@
+from PIL import Image
+
+from abstract_dataset import AbstractDataset
+
+
+class SceneNavigation(AbstractDataset):
+ def __init__(
+ self,
+ name: str = "SceneNavigation",
+ short_name="SN",
+ *,
+ image_path: str,
+ num_threads: int,
+ ):
+ """
+ Initializes a SceneNavigation dataset.
+
+ Args:
+ name (str): The name of the dataset. Defaults to "SceneNavigation".
+ short_name (str): The short name of the dataset. Defaults to "SN".
+ image_path (str): The directory path of the folder named "scannet_frames_25k" obtained by downloading a compressed file from http://www.scan-net.org/ and extracting it.
+ num_threads (int): The number of threads to use for processing the images.
+ """
+ super().__init__(name, short_name, image_path, num_threads)
+
+ def _load_images(self, image_path: str, num_thread: int) -> dict[str, Image.Image]:
+ """
+ Loads the images from the dataset.
+
+ Args:
+ image_path (str): The path to the directory containing the images downloaded from http://www.scan-net.org/.
+ num_threads (int): The number of threads to use for processing the images.
+
+ Returns:
+ dict[str, Image.Image]: A dictionary where the keys are image identifiers and the values are PIL.Image.Image objects.
+ """
+ from datasets.utils.scene_navigation_utils import process_data
+
+ return process_data(image_path, num_thread)
diff --git a/mimic-it/convert-it/datasets/__init__.py b/mimic-it/convert-it/datasets/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/mimic-it/convert-it/datasets/change.py b/mimic-it/convert-it/datasets/change.py
new file mode 100644
index 00000000..d942fda0
--- /dev/null
+++ b/mimic-it/convert-it/datasets/change.py
@@ -0,0 +1,95 @@
+import os
+import json
+
+from abstract_dataset import AbstractDataset
+from PIL import Image
+from tqdm import tqdm
+from glob import glob
+from image_utils import create_folder
+
+
+class SpotTheDifference(AbstractDataset):
+ def __init__(
+ self,
+ name: str = "SpotTheDifference",
+ short_name="SD",
+ *,
+ image_path: str,
+ num_threads: int,
+ ):
+ """
+ Initializes a SpotTheDifference dataset.
+
+ Args:
+ name (str): The name of the dataset. Defaults to "SpotTheDifference".
+ short_name (str): The short name of the dataset. Defaults to "SD".
+ image_path (str): The path containing the dataset images, downloaded from https://drive.google.com/file/d/1OVb4_3Uec_xbyUk90aWC6LFpKsIOtR7v/view?usp=sharing.
+ num_threads (int): The number of threads to use for processing the images.
+ """
+ super().__init__(name, short_name, image_path, num_threads)
+
+ def _load_images(self, image_path: str, num_thread: int) -> dict[str, Image.Image]:
+ """
+ Loads the images from the dataset.
+
+ Args:
+ image_path (str): The path to the dictionary containing the dataset images.
+ num_threads (int): The number of threads to use for processing the images.
+
+ Returns:
+ dict[str, Image.Image]: A dictionary where the keys are image identifiers and the values are PIL.Image.Image objects.
+ """
+ file_names = glob(os.path.join(image_path, "*"))
+ names = set()
+ for file_name in file_names:
+ image_name = file_name.split("/")[-1].split(".")[0]
+ id = image_name.split("_")[0]
+ names.add(id)
+ ids = list(sorted(list(names)))[:5]
+
+ jpgs_path = glob(os.path.join(image_path, "*.jpg"))
+ jpegs_path = glob(os.path.join(image_path, "*.jpeg"))
+ pngs_path = glob(os.path.join(image_path, "*.png"))
+ jpgs = set()
+ pngs = set()
+ for path in jpgs_path:
+ jpgs.add(path.split("/")[-1].split(".")[0])
+ for path in pngs_path:
+ pngs.add(path.split("/")[-1].split(".")[0])
+
+ def get_path(file_name):
+ if file_name in jpgs:
+ return os.path.join(image_path, file_name + ".jpg")
+ elif file_name in pngs:
+ # print("file_name", file_name, os.path.join(image_path, file_name + ".png"))
+ return os.path.join(image_path, file_name + ".png")
+ elif file_name in jpegs_path:
+ return os.path.join(image_path, file_name + ".jpeg")
+ else:
+ # print("===================================", file_name)
+ raise Exception("File not found")
+
+ def read_image(file_name) -> Image.Image:
+ return Image.open(file_name)
+
+ file_not_found = []
+
+ images = {}
+
+ for id in tqdm(ids, desc="Reading images"):
+ try:
+ file_1 = get_path(id)
+ file_2 = get_path(id + "_2")
+ # print(file_1, file_2)
+ images[id.zfill(5) + "_1"] = read_image(file_1)
+ images[id.zfill(5) + "_2"] = read_image(file_2)
+ except Exception as e:
+ file_not_found.append(id)
+ print(f"File not found: {id}")
+ # print(f"Error: {e}")
+
+ create_folder("log")
+ with open("log/file_not_found.log", "w") as f:
+ json.dump(file_not_found, f, indent=4)
+
+ return images
diff --git a/mimic-it/convert-it/datasets/fpv.py b/mimic-it/convert-it/datasets/fpv.py
new file mode 100644
index 00000000..a0d3a24c
--- /dev/null
+++ b/mimic-it/convert-it/datasets/fpv.py
@@ -0,0 +1,61 @@
+import os
+
+from PIL import Image
+from glob import glob
+
+from abstract_dataset import AbstractDataset
+from image_utils import frame_video
+
+from tqdm import tqdm
+from concurrent.futures import ThreadPoolExecutor
+
+
+class EGO4D(AbstractDataset):
+ def __init__(
+ self,
+ name: str = "EGO4D",
+ short_name="E4D",
+ *,
+ image_path: str,
+ num_threads: int,
+ ):
+ """
+ Initializes a EGO4D dataset.
+
+ Args:
+ name (str): The name of the dataset. Defaults to "EGO4D".
+ short_name (str): The short name of the dataset. Defaults to "E4D".
+ image_path (str): The directory path of the folder downloaded from https://ego4d-data.org/#download.
+ num_threads (int): The number of threads to use for processing the images.
+ """
+ super().__init__(name, short_name, image_path, num_threads)
+
+ def _load_images(self, image_path: str, num_thread: int) -> dict[str, Image.Image]:
+ """
+ Loads the images from the dataset.
+
+ Args:
+ image_path (str): The path to the directory containing the images downloaded from https://ego4d-data.org/#download.
+ num_threads (int): The number of threads to use for processing the images.
+
+ Returns:
+ dict[str, Image.Image]: A dictionary where the keys are image identifiers and the values are PIL.Image.Image objects.
+ """
+ video_paths = glob(os.path.join(image_path, "*"))
+
+ def get_image(video_path):
+ images = frame_video(video_path)
+ images_dict = {}
+ video_name = os.path.basename(video_path).split(".")[0]
+ for index, image in enumerate(images):
+ images_dict[f"{video_name}_{index:08d}"] = image
+ return images_dict
+
+ final_images_dict = {}
+
+ with ThreadPoolExecutor(max_workers=num_thread) as executor:
+ futures = [executor.submit(get_image, video_path) for video_path in video_paths]
+ for images_dict in tqdm(futures, desc="Processing videos into images", unit="video"):
+ final_images_dict.update(images_dict.result())
+
+ return final_images_dict
diff --git a/mimic-it/convert-it/datasets/utils/scene_navigation_utils.py b/mimic-it/convert-it/datasets/utils/scene_navigation_utils.py
new file mode 100644
index 00000000..e40a4fc1
--- /dev/null
+++ b/mimic-it/convert-it/datasets/utils/scene_navigation_utils.py
@@ -0,0 +1,58 @@
+import os
+
+from PIL import Image
+from tqdm import tqdm
+from concurrent.futures import ThreadPoolExecutor
+
+from PIL import Image
+from image_utils import resize_image
+
+
+def process(cur_dir, img_root):
+ """
+ Process images in a directory.
+
+ Parameters:
+ - cur_dir (str): The current directory name.
+ - img_root (str): The root directory of the images.
+
+ Returns:
+ - dict: A dictionary containing processed images. The keys are unique identifiers
+ for each image, and the values are the processed images.
+
+ """
+ root = os.path.join(img_root, cur_dir, "color")
+ file_list = os.listdir(root)
+ images = {}
+ for cur_file in file_list:
+ file_name = os.path.join(img_root, cur_dir, "color", cur_file)
+ img = Image.open(file_name) # path to file
+ image_id = f"{cur_dir}_color_{cur_file[:-4]}"
+ images[image_id] = resize_image(img)
+ return images
+
+
+def process_data(img_root: str, num_threads: int):
+ """
+ Process images in parallel using multiple threads.
+
+ Parameters:
+ - img_root (str): The root directory of the images.
+ - num_threads (int): The number of threads to use for parallel processing.
+
+ Returns:
+ - dict: A dictionary containing processed images. The keys are unique identifiers
+ for each image, and the values are the processed images.
+
+ """
+ keys = os.listdir(img_root)
+ all_images = {}
+ process_bar = tqdm(total=len(keys), unit="image", desc="Loading images")
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
+ futures = [executor.submit(process, cur_dir, img_root) for cur_dir in keys]
+ for future in futures:
+ images = future.result()
+ all_images.update(images)
+ process_bar.update(1)
+ process_bar.close()
+ return all_images
diff --git a/mimic-it/convert-it/datasets/utils/visual_story_telling_utils.py b/mimic-it/convert-it/datasets/utils/visual_story_telling_utils.py
new file mode 100644
index 00000000..2a3a6b96
--- /dev/null
+++ b/mimic-it/convert-it/datasets/utils/visual_story_telling_utils.py
@@ -0,0 +1,35 @@
+import requests
+
+from tqdm import tqdm
+from concurrent.futures import ThreadPoolExecutor
+from PIL import Image
+from io import BytesIO
+from image_utils import resize_image
+
+
+def get_url(image: dict[str]):
+ if "url_o" in image:
+ return image["url_o"]
+ else:
+ return image["url_m"]
+
+
+def download_single_image(image: dict[str]):
+ url = get_url(image)
+ id = image["id"]
+ pic = requests.get(url)
+ return (
+ id,
+ resize_image(Image.open(BytesIO(pic.content))),
+ )
+
+
+def download(images: list[dict[str]], num_threads: int):
+ output = {}
+ process_bar = tqdm(total=len(images), unit="image", desc="Downloading images")
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
+ for id, image in executor.map(download_single_image, images):
+ output[id] = image
+ process_bar.update(1)
+ process_bar.close()
+ return output
diff --git a/mimic-it/convert-it/datasets/video.py b/mimic-it/convert-it/datasets/video.py
new file mode 100644
index 00000000..82340266
--- /dev/null
+++ b/mimic-it/convert-it/datasets/video.py
@@ -0,0 +1,175 @@
+import json
+import os
+
+from abstract_dataset import AbstractDataset
+from PIL import Image
+from tqdm import tqdm
+from concurrent.futures import ThreadPoolExecutor
+from glob import glob
+from image_utils import frame_video, get_image_name, resize_image
+from natsort import natsorted
+
+
+class DenseCaptions(AbstractDataset):
+ def __init__(
+ self,
+ name: str = "DenseCaptions",
+ short_name="DC",
+ *,
+ image_path: str,
+ num_threads: int,
+ ):
+ """
+ Initializes a DenseCaptions dataset.
+
+ Args:
+ name (str): The name of the dataset. Defaults to "DenseCaptions".
+ short_name (str): The short name of the dataset. Defaults to "DC".
+ image_path (str): The path to the directory containing the dataset images.
+ num_threads (int): The number of threads to use for processing the images.
+ """
+ super().__init__(name, short_name, image_path, num_threads)
+
+ def _load_images(self, image_path: str, num_thread: int) -> dict[str, Image.Image]:
+ """
+ Loads the images from the dataset.
+
+ Args:
+ image_path (str): The path to the directory containing the all the videos downloaded from ActivityNet (in mp4 format).
+ num_threads (int): The number of threads to use for processing the images.
+
+ Returns:
+ dict[str, Image.Image]: A dictionary where the keys are image identifiers and the values are PIL.Image.Image objects.
+ """
+ videos = glob(f"{image_path}/*.mp4")
+ if len(videos) <= 100:
+ raise ValueError("Not enough videos in the dataset, please check the path.")
+ with ThreadPoolExecutor(max_workers=num_thread) as executor:
+ results = {}
+ process_bar = tqdm(total=len(videos), desc="Processing videos into images", unit="video")
+ for video, framed_results in executor.map(lambda x: (get_image_name(x), frame_video(x)), videos):
+ for index, result in enumerate(framed_results):
+ # print("video", video)
+ name = video + "_" + str(index).zfill(4)
+ results[name] = result
+ process_bar.update(1)
+ process_bar.close()
+ return results
+
+
+class VisualStoryTelling(AbstractDataset):
+ def __init__(
+ self,
+ name: str = "VisualStroryTelling",
+ short_name="VST",
+ *,
+ image_path: str,
+ num_threads: int,
+ ):
+ """
+ Initializes a VisualStoryTelling dataset.
+
+ Args:
+ name (str): The name of the dataset. Defaults to "VisualStroryTelling".
+ short_name (str): The short name of the dataset. Defaults to "VST".
+ image_path (str): The json file (train.story-in-sequence.json) containing the dataset images, downloaded from https://visionandlanguage.net/VIST.
+ num_threads (int): The number of threads to use for processing the images.
+ """
+ super().__init__(name, short_name, image_path, num_threads)
+
+ def _load_images(self, image_path: str, num_thread: int) -> dict[str, Image.Image]:
+ """
+ Load the images from the VisualStoryTelling dataset.
+
+ Args:
+ image_path (str): The path to the JSON file containing the dataset images.
+ num_thread (int): The number of threads to use for loading the images.
+
+ Returns:
+ Dict[str, Image.Image]: A dictionary of images, where the keys are the IDs of the images.
+ """
+ from datasets.utils.visual_story_telling_utils import download
+
+ with open(image_path, "r") as f:
+ data = json.load(f)
+ return download(data["images"], num_thread)
+
+
+class TVCaptions(AbstractDataset):
+ def __init__(
+ self,
+ name: str = "TVCaptions",
+ short_name="TVC",
+ *,
+ image_path: str,
+ num_threads: int,
+ ):
+ """
+ Initializes a TVCaptions dataset.
+
+ Args:
+ name (str): The name of the dataset. Defaults to "TVCaptions".
+ short_name (str): The short name of the dataset. Defaults to "TVC".
+ image_path (str): The path to the directory containing the dataset images, downloaded from https://tvqa.cs.unc.edu/download_tvqa.html#tvqa-download-4
+ num_threads (int): The number of threads to use for processing the images.
+ """
+ super().__init__(name, short_name, image_path, num_threads)
+
+ def _load_images(self, image_path: str, num_thread: int) -> dict[str, Image.Image]:
+ """
+ Load the images from the TVCaptions dataset.
+
+ Args:
+ image_path (str): The path to the directory containing the dataset images, downloaded from https://tvqa.cs.unc.edu/download_tvqa.html#tvqa-download-4.
+ num_thread (int): The number of threads to use for loading the images.
+
+ Returns:
+ Dict[str, Image.Image]: A dictionary of images, where the keys are the IDs of the images.
+
+ """
+
+ def get_frames(directory, frames=16):
+ # Generate a list of image filenames
+ image_filenames = natsorted(glob(os.path.join(directory, "*")))
+
+ # Calculate the stride length to achieve an average sample
+ stride = max(1, len(image_filenames) // frames)
+
+ # Initialize the starting index for sampling
+ start_index = stride // 2
+
+ # Sample 16 images evenly
+ sampled_images = [image_filenames[i] for i in range(start_index, len(image_filenames), stride)]
+
+ return sampled_images
+
+ def get_images(frames, frame_name, clip_name):
+ images = {}
+ for frame in frames:
+ image_name = os.path.basename(frame).split(".")[0]
+ if clip_name.startswith(frame_name):
+ image_id = f"{clip_name}_{image_name}"
+ else:
+ image_id = f"{frame_name}_{clip_name}_{image_name}"
+ images[image_id] = resize_image(Image.open(frame))
+ return images
+
+ frames = glob(os.path.join(image_path, "*"))
+ all_images = {}
+ for frame in frames:
+ frame_name = os.path.basename(frame).split("_")[0]
+ clips = glob(os.path.join(frame, "*"))
+ progress_bar = tqdm(total=len(clips), desc=f"Processing clips in {frame_name}", unit="clip")
+ with ThreadPoolExecutor(max_workers=num_thread) as executor:
+
+ def get_images_dict(clip):
+ clip_name = os.path.basename(clip)
+ frames = get_frames(clip)
+ return get_images(frames, frame_name, clip_name)
+
+ for images in executor.map(get_images_dict, clips):
+ all_images.update(images)
+ progress_bar.update(1)
+ progress_bar.close()
+
+ return all_images
diff --git a/mimic-it/convert-it/image_utils.py b/mimic-it/convert-it/image_utils.py
new file mode 100644
index 00000000..6f1b5fa9
--- /dev/null
+++ b/mimic-it/convert-it/image_utils.py
@@ -0,0 +1,156 @@
+import base64
+import os
+import cv2
+
+from io import BytesIO
+from PIL import Image
+from tqdm import tqdm
+from concurrent.futures import ThreadPoolExecutor
+
+
+def get_image_id(image_name: str, dataset_name: str) -> str:
+ """
+ Extracts the image identifier from a given image name.
+
+ Args:
+ image_name (str): The name of the image.
+ dataset_name (str): The name of the dataset.
+
+ Returns:
+ str: The image identifier.
+ """
+ return f"{dataset_name}_IMG_{get_image_name(image_name)}"
+
+
+def resize_image(image, target_size=(224, 224)):
+ """
+ Resizes the given image to the target size using the Lanczos algorithm.
+
+ Args:
+ image (PIL.Image.Image): The input image to be resized.
+ target_size (tuple[int, int]): The target size to which the image should be resized.
+ Defaults to (224, 224).
+
+ Returns:
+ PIL.Image.Image: The resized image.
+ """
+ if image.size != target_size:
+ return image.resize(target_size, Image.LANCZOS)
+ return image
+
+
+def process_image(img: Image.Image):
+ """
+ Processes the input image by resizing it, converting it to RGB mode, and encoding it as base64.
+
+ Args:
+ image (PIL.Image.Image): The input image to be processed.
+
+ Returns:
+ str: The base64 encoded string representation of the processed image.
+ """
+ resized_img = resize_image(img)
+ if resized_img.mode != "RGB":
+ resized_img = resized_img.convert("RGB")
+ buffer = BytesIO()
+ resized_img.save(buffer, format="PNG")
+ img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
+ return img_base64
+
+
+def get_json_data(images: dict[str, Image.Image], dataset_name: str, num_thread: int) -> dict[str, str]:
+ """
+ Converts a dictionary of images to a JSON-compatible dictionary with base64 encoded strings.
+
+ Args:
+ images (Dict[str, Image.Image]): A dictionary of images, where the keys are image identifiers and the values are PIL.Image.Image objects.
+ dataset_name (str): The name of the dataset.
+ num_threads (int): The number of threads to use for processing the images.
+
+ Returns:
+ Dict[str, str]: A dictionary where the keys are formatted as "{dataset_name}_IMG_{key}" and the values are base64 encoded string representations of the processed images.
+ """
+ with ThreadPoolExecutor(max_workers=num_thread) as executor:
+ process_bar = tqdm(total=len(images), desc="Processing images", unit="image")
+ results = {}
+
+ def process_image_wrapper(args):
+ key, img = args
+ new_key = get_image_id(key, dataset_name)
+ result = process_image(img)
+
+ process_bar.update(1)
+ return new_key, result
+
+ processed_images = executor.map(process_image_wrapper, images.items())
+
+ for key, result in processed_images:
+ results[key] = result
+
+ process_bar.close()
+
+ return results
+
+
+def frame_video(video_file: str, fps=1):
+ """
+ Extracts frames from a video file at a specified frame rate and returns them as base64 encoded strings.
+
+ Args:
+ video_file (str): The path to the video file.
+ fps (int): The frame rate at which frames should be extracted. Defaults to 1 frame per second.
+
+ Returns:
+ List[Image]: A list of PIL.Image.Image objects representing the extracted frames.
+ """
+ if not os.path.exists(video_file):
+ raise FileNotFoundError(f"Video file {video_file} does not exist.")
+
+ cap = cv2.VideoCapture(video_file)
+ video_fps = int(cap.get(cv2.CAP_PROP_FPS))
+
+ frame_count = 0
+ saved_frame_count = 0
+ frames = []
+
+ while cap.isOpened():
+ ret, frame = cap.read()
+
+ if not ret:
+ break
+
+ if frame_count % (video_fps // fps) == 0:
+ # convert frame to base64
+ _, buffer = cv2.imencode(".jpg", frame)
+ frames.append(resize_image(Image.open(BytesIO(buffer))))
+ saved_frame_count += 1
+
+ frame_count += 1
+
+ cap.release()
+
+ return frames
+
+
+def get_image_name(image_path: str) -> str:
+ """
+ Extracts the image name from a given image path.
+
+ Args:
+ image_path (str): The path to the image.
+
+ Returns:
+ str: The image name.
+ """
+ return image_path.split("/")[-1].split(".")[0]
+
+
+def create_folder(folder_name: str):
+ """
+ Creates a folder if it does not already exist.
+
+ Args:
+ folder_name (str): The name of the folder to create.
+ """
+ if not os.path.exists(folder_name):
+ os.makedirs(folder_name)
diff --git a/mimic-it/convert-it/main.py b/mimic-it/convert-it/main.py
new file mode 100644
index 00000000..ff105833
--- /dev/null
+++ b/mimic-it/convert-it/main.py
@@ -0,0 +1,29 @@
+import argparse
+import json
+
+from abstract_dataset import get_dataset_by_path
+from image_utils import get_json_data, create_folder
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--name", type=str, required=True, help="Path to the dataset class.")
+ parser.add_argument("--num_threads", type=int, default=8, help="Number of threads.")
+ parser.add_argument("--image_path", help="Path to the prompt file.")
+ parser.add_argument("--image_root", default=None, help="Path to the image root.")
+
+ args = parser.parse_args()
+ dataset_args = {}
+ if args.image_path is not None:
+ dataset_args["image_path"] = args.image_path
+ if args.num_threads is not None:
+ dataset_args["num_threads"] = args.num_threads
+ if args.image_root is not None:
+ dataset_args["image_root"] = args.image_root
+ dataset = get_dataset_by_path(args.name, dataset_args)
+ dataset_short_name = dataset.short_name
+ dataset = dict(dataset)
+ json_data = get_json_data(dataset, dataset_short_name, args.num_threads)
+ create_folder("output")
+ with open(f"output/{dataset_short_name}.json", "w") as f:
+ json.dump(json_data, f)
diff --git a/mimic-it/docs/mimicit_logo.png b/mimic-it/docs/mimicit_logo.png
deleted file mode 100644
index c3033e50..00000000
Binary files a/mimic-it/docs/mimicit_logo.png and /dev/null differ
diff --git a/mimic-it/syphus/abstract_dataset.py b/mimic-it/syphus/abstract_dataset.py
index 8f95c3cd..fe3c6db2 100644
--- a/mimic-it/syphus/abstract_dataset.py
+++ b/mimic-it/syphus/abstract_dataset.py
@@ -92,10 +92,6 @@ def get_dataset_by_path(path: str, dataset_args: dict[str, str]) -> AbstractData
module_path, dataset_name = path.split(".")
module_path = "datasets." + module_path
- # TODO:remove later, Print module and class names for debugging
- print(f"Loading module: {module_path}")
- print(f"Loading class: {dataset_name}")
-
# Import the module and load the class
imported_module = importlib.import_module(module_path)
dataset_class = getattr(imported_module, dataset_name)
diff --git a/pipeline/eval/__init__.py b/pipeline/demo/__init__.py
similarity index 100%
rename from pipeline/eval/__init__.py
rename to pipeline/demo/__init__.py
diff --git a/pipeline/demo/otter_image.ipynb b/pipeline/demo/otter_image.ipynb
new file mode 100644
index 00000000..5aafb0d5
--- /dev/null
+++ b/pipeline/demo/otter_image.ipynb
@@ -0,0 +1,140 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Otter Image Demo (In-context Learning)\n",
+ "\n",
+ "Here is an example of multi-modal ICL (in-context learning) with 𦦠Otter. We provide two demo images with corresponding instructions and answers, then we ask the model to generate an answer given our instruct. You may change your instruction and see how the model responds.\n",
+ "\n",
+ "You can also try our [online demo](https://otter.cliangyu.com/) to see more in-context learning demonstrations."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/opt/conda/envs/otter/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Successfully imported xformers version 0.0.20\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Downloading (β¦)lve/main/config.json: 4.45kB [00:00, 8.91MB/s]\n",
+ "Downloading (β¦)model.bin.index.json: 93.2kB [00:00, 106MB/s]\n",
+ "Downloading shards: 0%| | 0/4 [01:32, ?it/s]\n"
+ ]
+ },
+ {
+ "ename": "KeyboardInterrupt",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[1], line 10\u001b[0m\n\u001b[1;32m 7\u001b[0m sys\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mappend(\u001b[39m\"\u001b[39m\u001b[39m../..\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 8\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39motter\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mmodeling_otter\u001b[39;00m \u001b[39mimport\u001b[39;00m OtterForConditionalGeneration\n\u001b[0;32m---> 10\u001b[0m model \u001b[39m=\u001b[39m OtterForConditionalGeneration\u001b[39m.\u001b[39;49mfrom_pretrained(\u001b[39m\"\u001b[39;49m\u001b[39mluodian/OTTER-9B-LA-InContext\u001b[39;49m\u001b[39m\"\u001b[39;49m, device_map\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mauto\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n\u001b[1;32m 11\u001b[0m tokenizer \u001b[39m=\u001b[39m model\u001b[39m.\u001b[39mtext_tokenizer\n\u001b[1;32m 12\u001b[0m image_processor \u001b[39m=\u001b[39m transformers\u001b[39m.\u001b[39mCLIPImageProcessor()\n",
+ "File \u001b[0;32m/opt/conda/envs/otter/lib/python3.9/site-packages/transformers/modeling_utils.py:2531\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 2528\u001b[0m \u001b[39m# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.\u001b[39;00m\n\u001b[1;32m 2529\u001b[0m \u001b[39mif\u001b[39;00m is_sharded:\n\u001b[1;32m 2530\u001b[0m \u001b[39m# rsolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.\u001b[39;00m\n\u001b[0;32m-> 2531\u001b[0m resolved_archive_file, sharded_metadata \u001b[39m=\u001b[39m get_checkpoint_shard_files(\n\u001b[1;32m 2532\u001b[0m pretrained_model_name_or_path,\n\u001b[1;32m 2533\u001b[0m resolved_archive_file,\n\u001b[1;32m 2534\u001b[0m cache_dir\u001b[39m=\u001b[39;49mcache_dir,\n\u001b[1;32m 2535\u001b[0m force_download\u001b[39m=\u001b[39;49mforce_download,\n\u001b[1;32m 2536\u001b[0m proxies\u001b[39m=\u001b[39;49mproxies,\n\u001b[1;32m 2537\u001b[0m resume_download\u001b[39m=\u001b[39;49mresume_download,\n\u001b[1;32m 2538\u001b[0m local_files_only\u001b[39m=\u001b[39;49mlocal_files_only,\n\u001b[1;32m 2539\u001b[0m use_auth_token\u001b[39m=\u001b[39;49muse_auth_token,\n\u001b[1;32m 2540\u001b[0m user_agent\u001b[39m=\u001b[39;49muser_agent,\n\u001b[1;32m 2541\u001b[0m revision\u001b[39m=\u001b[39;49mrevision,\n\u001b[1;32m 2542\u001b[0m subfolder\u001b[39m=\u001b[39;49msubfolder,\n\u001b[1;32m 2543\u001b[0m _commit_hash\u001b[39m=\u001b[39;49mcommit_hash,\n\u001b[1;32m 2544\u001b[0m )\n\u001b[1;32m 2546\u001b[0m \u001b[39m# load pt weights early so that we know which dtype to init the model under\u001b[39;00m\n\u001b[1;32m 2547\u001b[0m \u001b[39mif\u001b[39;00m from_pt:\n",
+ "File \u001b[0;32m/opt/conda/envs/otter/lib/python3.9/site-packages/transformers/utils/hub.py:934\u001b[0m, in \u001b[0;36mget_checkpoint_shard_files\u001b[0;34m(pretrained_model_name_or_path, index_filename, cache_dir, force_download, proxies, resume_download, local_files_only, use_auth_token, user_agent, revision, subfolder, _commit_hash)\u001b[0m\n\u001b[1;32m 931\u001b[0m \u001b[39mfor\u001b[39;00m shard_filename \u001b[39min\u001b[39;00m tqdm(shard_filenames, desc\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mDownloading shards\u001b[39m\u001b[39m\"\u001b[39m, disable\u001b[39m=\u001b[39m\u001b[39mnot\u001b[39;00m show_progress_bar):\n\u001b[1;32m 932\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 933\u001b[0m \u001b[39m# Load from URL\u001b[39;00m\n\u001b[0;32m--> 934\u001b[0m cached_filename \u001b[39m=\u001b[39m cached_file(\n\u001b[1;32m 935\u001b[0m pretrained_model_name_or_path,\n\u001b[1;32m 936\u001b[0m shard_filename,\n\u001b[1;32m 937\u001b[0m cache_dir\u001b[39m=\u001b[39;49mcache_dir,\n\u001b[1;32m 938\u001b[0m force_download\u001b[39m=\u001b[39;49mforce_download,\n\u001b[1;32m 939\u001b[0m proxies\u001b[39m=\u001b[39;49mproxies,\n\u001b[1;32m 940\u001b[0m resume_download\u001b[39m=\u001b[39;49mresume_download,\n\u001b[1;32m 941\u001b[0m local_files_only\u001b[39m=\u001b[39;49mlocal_files_only,\n\u001b[1;32m 942\u001b[0m use_auth_token\u001b[39m=\u001b[39;49muse_auth_token,\n\u001b[1;32m 943\u001b[0m user_agent\u001b[39m=\u001b[39;49muser_agent,\n\u001b[1;32m 944\u001b[0m revision\u001b[39m=\u001b[39;49mrevision,\n\u001b[1;32m 945\u001b[0m subfolder\u001b[39m=\u001b[39;49msubfolder,\n\u001b[1;32m 946\u001b[0m _commit_hash\u001b[39m=\u001b[39;49m_commit_hash,\n\u001b[1;32m 947\u001b[0m )\n\u001b[1;32m 948\u001b[0m \u001b[39m# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so\u001b[39;00m\n\u001b[1;32m 949\u001b[0m \u001b[39m# we don't have to catch them here.\u001b[39;00m\n\u001b[1;32m 950\u001b[0m \u001b[39mexcept\u001b[39;00m EntryNotFoundError:\n",
+ "File \u001b[0;32m/opt/conda/envs/otter/lib/python3.9/site-packages/transformers/utils/hub.py:417\u001b[0m, in \u001b[0;36mcached_file\u001b[0;34m(path_or_repo_id, filename, cache_dir, force_download, resume_download, proxies, use_auth_token, revision, local_files_only, subfolder, repo_type, user_agent, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash)\u001b[0m\n\u001b[1;32m 414\u001b[0m user_agent \u001b[39m=\u001b[39m http_user_agent(user_agent)\n\u001b[1;32m 415\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 416\u001b[0m \u001b[39m# Load from URL or cache if already cached\u001b[39;00m\n\u001b[0;32m--> 417\u001b[0m resolved_file \u001b[39m=\u001b[39m hf_hub_download(\n\u001b[1;32m 418\u001b[0m path_or_repo_id,\n\u001b[1;32m 419\u001b[0m filename,\n\u001b[1;32m 420\u001b[0m subfolder\u001b[39m=\u001b[39;49m\u001b[39mNone\u001b[39;49;00m \u001b[39mif\u001b[39;49;00m \u001b[39mlen\u001b[39;49m(subfolder) \u001b[39m==\u001b[39;49m \u001b[39m0\u001b[39;49m \u001b[39melse\u001b[39;49;00m subfolder,\n\u001b[1;32m 421\u001b[0m repo_type\u001b[39m=\u001b[39;49mrepo_type,\n\u001b[1;32m 422\u001b[0m revision\u001b[39m=\u001b[39;49mrevision,\n\u001b[1;32m 423\u001b[0m cache_dir\u001b[39m=\u001b[39;49mcache_dir,\n\u001b[1;32m 424\u001b[0m user_agent\u001b[39m=\u001b[39;49muser_agent,\n\u001b[1;32m 425\u001b[0m force_download\u001b[39m=\u001b[39;49mforce_download,\n\u001b[1;32m 426\u001b[0m proxies\u001b[39m=\u001b[39;49mproxies,\n\u001b[1;32m 427\u001b[0m resume_download\u001b[39m=\u001b[39;49mresume_download,\n\u001b[1;32m 428\u001b[0m use_auth_token\u001b[39m=\u001b[39;49muse_auth_token,\n\u001b[1;32m 429\u001b[0m local_files_only\u001b[39m=\u001b[39;49mlocal_files_only,\n\u001b[1;32m 430\u001b[0m )\n\u001b[1;32m 432\u001b[0m \u001b[39mexcept\u001b[39;00m RepositoryNotFoundError:\n\u001b[1;32m 433\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mEnvironmentError\u001b[39;00m(\n\u001b[1;32m 434\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00mpath_or_repo_id\u001b[39m}\u001b[39;00m\u001b[39m is not a local folder and is not a valid model identifier \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 435\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mlisted on \u001b[39m\u001b[39m'\u001b[39m\u001b[39mhttps://huggingface.co/models\u001b[39m\u001b[39m'\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39mIf this is a private repository, make sure to \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 436\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mpass a token having permission to this repo with `use_auth_token` or log in with \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 437\u001b[0m \u001b[39m\"\u001b[39m\u001b[39m`huggingface-cli login` and pass `use_auth_token=True`.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 438\u001b[0m )\n",
+ "File \u001b[0;32m/opt/conda/envs/otter/lib/python3.9/site-packages/huggingface_hub/utils/_validators.py:118\u001b[0m, in \u001b[0;36mvalidate_hf_hub_args.._inner_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[39mif\u001b[39;00m check_use_auth_token:\n\u001b[1;32m 116\u001b[0m kwargs \u001b[39m=\u001b[39m smoothly_deprecate_use_auth_token(fn_name\u001b[39m=\u001b[39mfn\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m, has_token\u001b[39m=\u001b[39mhas_token, kwargs\u001b[39m=\u001b[39mkwargs)\n\u001b[0;32m--> 118\u001b[0m \u001b[39mreturn\u001b[39;00m fn(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
+ "File \u001b[0;32m/opt/conda/envs/otter/lib/python3.9/site-packages/huggingface_hub/file_download.py:1364\u001b[0m, in \u001b[0;36mhf_hub_download\u001b[0;34m(repo_id, filename, subfolder, repo_type, revision, library_name, library_version, cache_dir, local_dir, local_dir_use_symlinks, user_agent, force_download, force_filename, proxies, etag_timeout, resume_download, token, local_files_only, legacy_cache_layout)\u001b[0m\n\u001b[1;32m 1361\u001b[0m \u001b[39mwith\u001b[39;00m temp_file_manager() \u001b[39mas\u001b[39;00m temp_file:\n\u001b[1;32m 1362\u001b[0m logger\u001b[39m.\u001b[39minfo(\u001b[39m\"\u001b[39m\u001b[39mdownloading \u001b[39m\u001b[39m%s\u001b[39;00m\u001b[39m to \u001b[39m\u001b[39m%s\u001b[39;00m\u001b[39m\"\u001b[39m, url, temp_file\u001b[39m.\u001b[39mname)\n\u001b[0;32m-> 1364\u001b[0m http_get(\n\u001b[1;32m 1365\u001b[0m url_to_download,\n\u001b[1;32m 1366\u001b[0m temp_file,\n\u001b[1;32m 1367\u001b[0m proxies\u001b[39m=\u001b[39;49mproxies,\n\u001b[1;32m 1368\u001b[0m resume_size\u001b[39m=\u001b[39;49mresume_size,\n\u001b[1;32m 1369\u001b[0m headers\u001b[39m=\u001b[39;49mheaders,\n\u001b[1;32m 1370\u001b[0m expected_size\u001b[39m=\u001b[39;49mexpected_size,\n\u001b[1;32m 1371\u001b[0m )\n\u001b[1;32m 1373\u001b[0m \u001b[39mif\u001b[39;00m local_dir \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 1374\u001b[0m logger\u001b[39m.\u001b[39minfo(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mStoring \u001b[39m\u001b[39m{\u001b[39;00murl\u001b[39m}\u001b[39;00m\u001b[39m in cache at \u001b[39m\u001b[39m{\u001b[39;00mblob_path\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n",
+ "File \u001b[0;32m/opt/conda/envs/otter/lib/python3.9/site-packages/huggingface_hub/file_download.py:541\u001b[0m, in \u001b[0;36mhttp_get\u001b[0;34m(url, temp_file, proxies, resume_size, headers, timeout, max_retries, expected_size)\u001b[0m\n\u001b[1;32m 531\u001b[0m displayed_name \u001b[39m=\u001b[39m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m(β¦)\u001b[39m\u001b[39m{\u001b[39;00mdisplayed_name[\u001b[39m-\u001b[39m\u001b[39m20\u001b[39m:]\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m\n\u001b[1;32m 533\u001b[0m progress \u001b[39m=\u001b[39m tqdm(\n\u001b[1;32m 534\u001b[0m unit\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mB\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 535\u001b[0m unit_scale\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 539\u001b[0m disable\u001b[39m=\u001b[39m\u001b[39mbool\u001b[39m(logger\u001b[39m.\u001b[39mgetEffectiveLevel() \u001b[39m==\u001b[39m logging\u001b[39m.\u001b[39mNOTSET),\n\u001b[1;32m 540\u001b[0m )\n\u001b[0;32m--> 541\u001b[0m \u001b[39mfor\u001b[39;00m chunk \u001b[39min\u001b[39;00m r\u001b[39m.\u001b[39miter_content(chunk_size\u001b[39m=\u001b[39m\u001b[39m10\u001b[39m \u001b[39m*\u001b[39m \u001b[39m1024\u001b[39m \u001b[39m*\u001b[39m \u001b[39m1024\u001b[39m):\n\u001b[1;32m 542\u001b[0m \u001b[39mif\u001b[39;00m chunk: \u001b[39m# filter out keep-alive new chunks\u001b[39;00m\n\u001b[1;32m 543\u001b[0m progress\u001b[39m.\u001b[39mupdate(\u001b[39mlen\u001b[39m(chunk))\n",
+ "File \u001b[0;32m/opt/conda/envs/otter/lib/python3.9/site-packages/requests/models.py:816\u001b[0m, in \u001b[0;36mResponse.iter_content..generate\u001b[0;34m()\u001b[0m\n\u001b[1;32m 814\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mhasattr\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mraw, \u001b[39m\"\u001b[39m\u001b[39mstream\u001b[39m\u001b[39m\"\u001b[39m):\n\u001b[1;32m 815\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m--> 816\u001b[0m \u001b[39myield from\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mraw\u001b[39m.\u001b[39mstream(chunk_size, decode_content\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[1;32m 817\u001b[0m \u001b[39mexcept\u001b[39;00m ProtocolError \u001b[39mas\u001b[39;00m e:\n\u001b[1;32m 818\u001b[0m \u001b[39mraise\u001b[39;00m ChunkedEncodingError(e)\n",
+ "File \u001b[0;32m/opt/conda/envs/otter/lib/python3.9/site-packages/urllib3/response.py:628\u001b[0m, in \u001b[0;36mHTTPResponse.stream\u001b[0;34m(self, amt, decode_content)\u001b[0m\n\u001b[1;32m 626\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 627\u001b[0m \u001b[39mwhile\u001b[39;00m \u001b[39mnot\u001b[39;00m is_fp_closed(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_fp):\n\u001b[0;32m--> 628\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mread(amt\u001b[39m=\u001b[39;49mamt, decode_content\u001b[39m=\u001b[39;49mdecode_content)\n\u001b[1;32m 630\u001b[0m \u001b[39mif\u001b[39;00m data:\n\u001b[1;32m 631\u001b[0m \u001b[39myield\u001b[39;00m data\n",
+ "File \u001b[0;32m/opt/conda/envs/otter/lib/python3.9/site-packages/urllib3/response.py:567\u001b[0m, in \u001b[0;36mHTTPResponse.read\u001b[0;34m(self, amt, decode_content, cache_content)\u001b[0m\n\u001b[1;32m 564\u001b[0m fp_closed \u001b[39m=\u001b[39m \u001b[39mgetattr\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_fp, \u001b[39m\"\u001b[39m\u001b[39mclosed\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39mFalse\u001b[39;00m)\n\u001b[1;32m 566\u001b[0m \u001b[39mwith\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_error_catcher():\n\u001b[0;32m--> 567\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_fp_read(amt) \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m fp_closed \u001b[39melse\u001b[39;00m \u001b[39mb\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 568\u001b[0m \u001b[39mif\u001b[39;00m amt \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 569\u001b[0m flush_decoder \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n",
+ "File \u001b[0;32m/opt/conda/envs/otter/lib/python3.9/site-packages/urllib3/response.py:525\u001b[0m, in \u001b[0;36mHTTPResponse._fp_read\u001b[0;34m(self, amt)\u001b[0m\n\u001b[1;32m 523\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 524\u001b[0m chunk_amt \u001b[39m=\u001b[39m max_chunk_amt\n\u001b[0;32m--> 525\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_fp\u001b[39m.\u001b[39;49mread(chunk_amt)\n\u001b[1;32m 526\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m data:\n\u001b[1;32m 527\u001b[0m \u001b[39mbreak\u001b[39;00m\n",
+ "File \u001b[0;32m/opt/conda/envs/otter/lib/python3.9/http/client.py:463\u001b[0m, in \u001b[0;36mHTTPResponse.read\u001b[0;34m(self, amt)\u001b[0m\n\u001b[1;32m 460\u001b[0m \u001b[39mif\u001b[39;00m amt \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 461\u001b[0m \u001b[39m# Amount is given, implement using readinto\u001b[39;00m\n\u001b[1;32m 462\u001b[0m b \u001b[39m=\u001b[39m \u001b[39mbytearray\u001b[39m(amt)\n\u001b[0;32m--> 463\u001b[0m n \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mreadinto(b)\n\u001b[1;32m 464\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mmemoryview\u001b[39m(b)[:n]\u001b[39m.\u001b[39mtobytes()\n\u001b[1;32m 465\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 466\u001b[0m \u001b[39m# Amount is not given (unbounded read) so we must check self.length\u001b[39;00m\n\u001b[1;32m 467\u001b[0m \u001b[39m# and self.chunked\u001b[39;00m\n",
+ "File \u001b[0;32m/opt/conda/envs/otter/lib/python3.9/http/client.py:507\u001b[0m, in \u001b[0;36mHTTPResponse.readinto\u001b[0;34m(self, b)\u001b[0m\n\u001b[1;32m 502\u001b[0m b \u001b[39m=\u001b[39m \u001b[39mmemoryview\u001b[39m(b)[\u001b[39m0\u001b[39m:\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlength]\n\u001b[1;32m 504\u001b[0m \u001b[39m# we do not use _safe_read() here because this may be a .will_close\u001b[39;00m\n\u001b[1;32m 505\u001b[0m \u001b[39m# connection, and the user is reading more bytes than will be provided\u001b[39;00m\n\u001b[1;32m 506\u001b[0m \u001b[39m# (for example, reading in 1k chunks)\u001b[39;00m\n\u001b[0;32m--> 507\u001b[0m n \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mfp\u001b[39m.\u001b[39;49mreadinto(b)\n\u001b[1;32m 508\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m n \u001b[39mand\u001b[39;00m b:\n\u001b[1;32m 509\u001b[0m \u001b[39m# Ideally, we would raise IncompleteRead if the content-length\u001b[39;00m\n\u001b[1;32m 510\u001b[0m \u001b[39m# wasn't satisfied, but it might break compatibility.\u001b[39;00m\n\u001b[1;32m 511\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_close_conn()\n",
+ "File \u001b[0;32m/opt/conda/envs/otter/lib/python3.9/socket.py:704\u001b[0m, in \u001b[0;36mSocketIO.readinto\u001b[0;34m(self, b)\u001b[0m\n\u001b[1;32m 702\u001b[0m \u001b[39mwhile\u001b[39;00m \u001b[39mTrue\u001b[39;00m:\n\u001b[1;32m 703\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m--> 704\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_sock\u001b[39m.\u001b[39;49mrecv_into(b)\n\u001b[1;32m 705\u001b[0m \u001b[39mexcept\u001b[39;00m timeout:\n\u001b[1;32m 706\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_timeout_occurred \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n",
+ "File \u001b[0;32m/opt/conda/envs/otter/lib/python3.9/ssl.py:1242\u001b[0m, in \u001b[0;36mSSLSocket.recv_into\u001b[0;34m(self, buffer, nbytes, flags)\u001b[0m\n\u001b[1;32m 1238\u001b[0m \u001b[39mif\u001b[39;00m flags \u001b[39m!=\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[1;32m 1239\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m 1240\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mnon-zero flags not allowed in calls to recv_into() on \u001b[39m\u001b[39m%s\u001b[39;00m\u001b[39m\"\u001b[39m \u001b[39m%\u001b[39m\n\u001b[1;32m 1241\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__class__\u001b[39m)\n\u001b[0;32m-> 1242\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mread(nbytes, buffer)\n\u001b[1;32m 1243\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 1244\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39msuper\u001b[39m()\u001b[39m.\u001b[39mrecv_into(buffer, nbytes, flags)\n",
+ "File \u001b[0;32m/opt/conda/envs/otter/lib/python3.9/ssl.py:1100\u001b[0m, in \u001b[0;36mSSLSocket.read\u001b[0;34m(self, len, buffer)\u001b[0m\n\u001b[1;32m 1098\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 1099\u001b[0m \u001b[39mif\u001b[39;00m buffer \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m-> 1100\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_sslobj\u001b[39m.\u001b[39;49mread(\u001b[39mlen\u001b[39;49m, buffer)\n\u001b[1;32m 1101\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 1102\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_sslobj\u001b[39m.\u001b[39mread(\u001b[39mlen\u001b[39m)\n",
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
+ ]
+ }
+ ],
+ "source": [
+ "import requests\n",
+ "import torch\n",
+ "import transformers\n",
+ "from PIL import Image\n",
+ "import sys\n",
+ "\n",
+ "sys.path.append(\"../..\")\n",
+ "from otter.modeling_otter import OtterForConditionalGeneration\n",
+ "\n",
+ "model = OtterForConditionalGeneration.from_pretrained(\"luodian/OTTER-9B-LA-InContext\", device_map=\"auto\")\n",
+ "tokenizer = model.text_tokenizer\n",
+ "image_processor = transformers.CLIPImageProcessor()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "demo_image_one = Image.open(requests.get(\"http://images.cocodataset.org/val2017/000000039769.jpg\", stream=True).raw)\n",
+ "demo_image_two = Image.open(requests.get(\"http://images.cocodataset.org/test-stuff2017/000000028137.jpg\", stream=True).raw)\n",
+ "query_image = Image.open(requests.get(\"http://images.cocodataset.org/test-stuff2017/000000028352.jpg\", stream=True).raw)\n",
+ "vision_x = image_processor.preprocess([demo_image_one, demo_image_two, query_image], return_tensors=\"pt\")[\"pixel_values\"].unsqueeze(1).unsqueeze(0)\n",
+ "model.text_tokenizer.padding_side = \"left\"\n",
+ "lang_x = model.text_tokenizer(\n",
+ " [\n",
+ " \"User: a photo of GPT: two cats sleeping.<|endofchunk|>User: a photo of GPT: a bathroom sink.<|endofchunk|>User: a photo of GPT:\"\n",
+ " ],\n",
+ " return_tensors=\"pt\",\n",
+ ")\n",
+ "\n",
+ "bad_words_id = tokenizer([\"User:\", \"GPT1:\", \"GFT:\", \"GPT:\"], add_special_tokens=False).input_ids\n",
+ "generated_text = model.generate(\n",
+ " vision_x=vision_x.to(model.device),\n",
+ " lang_x=lang_x[\"input_ids\"].to(model.device),\n",
+ " attention_mask=lang_x[\"attention_mask\"].to(model.device),\n",
+ " max_new_tokens=512,\n",
+ " num_beams=3,\n",
+ " no_repeat_ngram_size=3,\n",
+ " bad_words_ids=bad_words_id,\n",
+ ")\n",
+ "\n",
+ "print(\"Generated text: \", model.text_tokenizer.decode(generated_text[0]))"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "otter",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.16"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/pipeline/demo/otter_video.ipynb b/pipeline/demo/otter_video.ipynb
new file mode 100644
index 00000000..5dc69b66
--- /dev/null
+++ b/pipeline/demo/otter_video.ipynb
@@ -0,0 +1,214 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Otter Video Demo\n",
+ "\n",
+ "Current Otter Video is Otter-v0.2-DC (0612), means itβs trianed on MIMIC-IT-DC at June 12th. The code reads a video and uniformly extracts 16 frames, so avoid using excessively long videos if you want the model to generate specific descriptions.\n",
+ "\n",
+ "If your machine has over 16G GPU memory, you can run our model locally in fp16 mode for tasks like video labeling and identifying harmful content. For machines with over 36G GPU memory (by combining multiple cards with [device_map='auto'](https://huggingface.co/docs/accelerate/usage_guides/big_modeling) to one model different cards), you can run our model in the more accurate fp32 mode."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import mimetypes\n",
+ "import os\n",
+ "from typing import Union\n",
+ "import cv2\n",
+ "import requests\n",
+ "import torch\n",
+ "import transformers\n",
+ "from PIL import Image\n",
+ "import sys\n",
+ "\n",
+ "sys.path.append(\"../..\")\n",
+ "from otter.modeling_otter import OtterForConditionalGeneration\n",
+ "\n",
+ "# Disable warnings\n",
+ "requests.packages.urllib3.disable_warnings()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# ------------------- Utility Functions -------------------\n",
+ "\n",
+ "\n",
+ "def get_content_type(file_path):\n",
+ " content_type, _ = mimetypes.guess_type(file_path)\n",
+ " return content_type\n",
+ "\n",
+ "\n",
+ "# ------------------- Image and Video Handling Functions -------------------\n",
+ "\n",
+ "\n",
+ "def extract_frames(video_path, num_frames=16):\n",
+ " video = cv2.VideoCapture(video_path)\n",
+ " total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))\n",
+ " frame_step = total_frames // num_frames\n",
+ " frames = []\n",
+ "\n",
+ " for i in range(num_frames):\n",
+ " video.set(cv2.CAP_PROP_POS_FRAMES, i * frame_step)\n",
+ " ret, frame = video.read()\n",
+ " if ret:\n",
+ " frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n",
+ " frame = Image.fromarray(frame).convert(\"RGB\")\n",
+ " frames.append(frame)\n",
+ "\n",
+ " video.release()\n",
+ " return frames\n",
+ "\n",
+ "\n",
+ "def get_image(url: str) -> Union[Image.Image, list]:\n",
+ " if \"://\" not in url: # Local file\n",
+ " content_type = get_content_type(url)\n",
+ " else: # Remote URL\n",
+ " content_type = requests.head(url, stream=True, verify=False).headers.get(\"Content-Type\")\n",
+ "\n",
+ " if \"image\" in content_type:\n",
+ " if \"://\" not in url: # Local file\n",
+ " return Image.open(url)\n",
+ " else: # Remote URL\n",
+ " return Image.open(requests.get(url, stream=True, verify=False).raw)\n",
+ " elif \"video\" in content_type:\n",
+ " video_path = \"temp_video.mp4\"\n",
+ " if \"://\" not in url: # Local file\n",
+ " video_path = url\n",
+ " else: # Remote URL\n",
+ " with open(video_path, \"wb\") as f:\n",
+ " f.write(requests.get(url, stream=True, verify=False).content)\n",
+ " frames = extract_frames(video_path)\n",
+ " if \"://\" in url: # Only remove the temporary video file if it was downloaded\n",
+ " os.remove(video_path)\n",
+ " return frames\n",
+ " else:\n",
+ " raise ValueError(\"Invalid content type. Expected image or video.\")\n",
+ "\n",
+ "\n",
+ "# ------------------- OTTER Prompt and Response Functions -------------------\n",
+ "\n",
+ "\n",
+ "def get_formatted_prompt(prompt: str) -> str:\n",
+ " return f\"User: {prompt} GPT:\"\n",
+ "\n",
+ "\n",
+ "def get_response(input_data, prompt: str, model=None, image_processor=None, tensor_dtype=None) -> str:\n",
+ " if isinstance(input_data, Image.Image):\n",
+ " vision_x = image_processor.preprocess([input_data], return_tensors=\"pt\")[\"pixel_values\"].unsqueeze(1).unsqueeze(0)\n",
+ " elif isinstance(input_data, list): # list of video frames\n",
+ " vision_x = image_processor.preprocess(input_data, return_tensors=\"pt\")[\"pixel_values\"].unsqueeze(0).unsqueeze(0)\n",
+ " else:\n",
+ " raise ValueError(\"Invalid input data. Expected PIL Image or list of video frames.\")\n",
+ "\n",
+ " lang_x = model.text_tokenizer(\n",
+ " [\n",
+ " get_formatted_prompt(prompt),\n",
+ " ],\n",
+ " return_tensors=\"pt\",\n",
+ " )\n",
+ "\n",
+ " bad_words_id = model.text_tokenizer([\"User:\", \"GPT1:\", \"GFT:\", \"GPT:\"], add_special_tokens=False).input_ids\n",
+ " generated_text = model.generate(\n",
+ " vision_x=vision_x.to(model.device, dtype=tensor_dtype),\n",
+ " lang_x=lang_x[\"input_ids\"].to(model.device),\n",
+ " attention_mask=lang_x[\"attention_mask\"].to(model.device),\n",
+ " max_new_tokens=512,\n",
+ " num_beams=3,\n",
+ " no_repeat_ngram_size=3,\n",
+ " bad_words_ids=bad_words_id,\n",
+ " )\n",
+ " parsed_output = (\n",
+ " model.text_tokenizer.decode(generated_text[0])\n",
+ " .split(\"\")[-1]\n",
+ " .lstrip()\n",
+ " .rstrip()\n",
+ " .split(\"<|endofchunk|>\")[0]\n",
+ " .lstrip()\n",
+ " .rstrip()\n",
+ " .lstrip('\"')\n",
+ " .rstrip('\"')\n",
+ " )\n",
+ " return parsed_output"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# ------------------- Main Function -------------------\n",
+ "load_bit = \"fp32\"\n",
+ "if load_bit == \"fp16\":\n",
+ " precision = {\"torch_dtype\": torch.float16}\n",
+ "elif load_bit == \"bf16\":\n",
+ " precision = {\"torch_dtype\": torch.bfloat16}\n",
+ "elif load_bit == \"fp32\":\n",
+ " precision = {\"torch_dtype\": torch.float32}\n",
+ "\n",
+ "# This model version is trained on MIMIC-IT DC dataset.\n",
+ "model = OtterForConditionalGeneration.from_pretrained(\"luodian/OTTER-9B-DenseCaption\", device_map=\"auto\", **precision)\n",
+ "tensor_dtype = {\"fp16\": torch.float16, \"bf16\": torch.bfloat16, \"fp32\": torch.float32}[load_bit]\n",
+ "\n",
+ "model.text_tokenizer.padding_side = \"left\"\n",
+ "tokenizer = model.text_tokenizer\n",
+ "image_processor = transformers.CLIPImageProcessor()\n",
+ "model.eval()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "while True:\n",
+ " video_url = input(\"Enter video path: \") # Replace with the path to your video file, could be any common format.\n",
+ "\n",
+ " frames_list = get_image(video_url)\n",
+ "\n",
+ " while True:\n",
+ " prompts_input = input(\"Enter prompts: \")\n",
+ "\n",
+ " if prompts_input.lower() == \"quit\":\n",
+ " break\n",
+ "\n",
+ " print(f\"\\nPrompt: {prompts_input}\")\n",
+ " response = get_response(frames_list, prompts_input, model, image_processor, tensor_dtype)\n",
+ " print(f\"Response: {response}\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "otter",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.16"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/pipeline/demo/otter_video.py b/pipeline/demo/otter_video.py
new file mode 100644
index 00000000..b45402fc
--- /dev/null
+++ b/pipeline/demo/otter_video.py
@@ -0,0 +1,150 @@
+import mimetypes
+import os
+from typing import Union
+import cv2
+import requests
+import torch
+import transformers
+from PIL import Image
+import sys
+
+sys.path.append("../..")
+# make sure you can properly access the otter folder
+from otter.modeling_otter import OtterForConditionalGeneration
+
+# Disable warnings
+requests.packages.urllib3.disable_warnings()
+
+# ------------------- Utility Functions -------------------
+
+
+def get_content_type(file_path):
+ content_type, _ = mimetypes.guess_type(file_path)
+ return content_type
+
+
+# ------------------- Image and Video Handling Functions -------------------
+
+
+def extract_frames(video_path, num_frames=16):
+ video = cv2.VideoCapture(video_path)
+ total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
+ frame_step = total_frames // num_frames
+ frames = []
+
+ for i in range(num_frames):
+ video.set(cv2.CAP_PROP_POS_FRAMES, i * frame_step)
+ ret, frame = video.read()
+ if ret:
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ frame = Image.fromarray(frame).convert("RGB")
+ frames.append(frame)
+
+ video.release()
+ return frames
+
+
+def get_image(url: str) -> Union[Image.Image, list]:
+ if "://" not in url: # Local file
+ content_type = get_content_type(url)
+ else: # Remote URL
+ content_type = requests.head(url, stream=True, verify=False).headers.get("Content-Type")
+
+ if "image" in content_type:
+ if "://" not in url: # Local file
+ return Image.open(url)
+ else: # Remote URL
+ return Image.open(requests.get(url, stream=True, verify=False).raw)
+ elif "video" in content_type:
+ video_path = "temp_video.mp4"
+ if "://" not in url: # Local file
+ video_path = url
+ else: # Remote URL
+ with open(video_path, "wb") as f:
+ f.write(requests.get(url, stream=True, verify=False).content)
+ frames = extract_frames(video_path)
+ if "://" in url: # Only remove the temporary video file if it was downloaded
+ os.remove(video_path)
+ return frames
+ else:
+ raise ValueError("Invalid content type. Expected image or video.")
+
+
+# ------------------- OTTER Prompt and Response Functions -------------------
+
+
+def get_formatted_prompt(prompt: str) -> str:
+ return f"User: {prompt} GPT:"
+
+
+def get_response(input_data, prompt: str, model=None, image_processor=None, tensor_dtype=None) -> str:
+ if isinstance(input_data, Image.Image):
+ vision_x = image_processor.preprocess([input_data], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
+ elif isinstance(input_data, list): # list of video frames
+ vision_x = image_processor.preprocess(input_data, return_tensors="pt")["pixel_values"].unsqueeze(0).unsqueeze(0)
+ else:
+ raise ValueError("Invalid input data. Expected PIL Image or list of video frames.")
+
+ lang_x = model.text_tokenizer(
+ [
+ get_formatted_prompt(prompt),
+ ],
+ return_tensors="pt",
+ )
+
+ bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
+ generated_text = model.generate(
+ vision_x=vision_x.to(model.device, dtype=tensor_dtype),
+ lang_x=lang_x["input_ids"].to(model.device),
+ attention_mask=lang_x["attention_mask"].to(model.device),
+ max_new_tokens=512,
+ num_beams=3,
+ no_repeat_ngram_size=3,
+ bad_words_ids=bad_words_id,
+ )
+ parsed_output = (
+ model.text_tokenizer.decode(generated_text[0])
+ .split("")[-1]
+ .lstrip()
+ .rstrip()
+ .split("<|endofchunk|>")[0]
+ .lstrip()
+ .rstrip()
+ .lstrip('"')
+ .rstrip('"')
+ )
+ return parsed_output
+
+
+# ------------------- Main Function -------------------
+load_bit = "fp32"
+if load_bit == "fp16":
+ precision = {"torch_dtype": torch.float16}
+elif load_bit == "bf16":
+ precision = {"torch_dtype": torch.bfloat16}
+elif load_bit == "fp32":
+ precision = {"torch_dtype": torch.float32}
+
+# This model version is trained on MIMIC-IT DC dataset.
+model = OtterForConditionalGeneration.from_pretrained("luodian/OTTER-9B-DenseCaption", device_map="auto", **precision)
+tensor_dtype = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}[load_bit]
+
+model.text_tokenizer.padding_side = "left"
+tokenizer = model.text_tokenizer
+image_processor = transformers.CLIPImageProcessor()
+model.eval()
+
+while True:
+ video_url = input("Enter video path: ") # Replace with the path to your video file, could be any common format.
+
+ frames_list = get_image(video_url)
+
+ while True:
+ prompts_input = input("Enter prompts: ")
+
+ if prompts_input.lower() == "quit":
+ break
+
+ print(f"\nPrompt: {prompts_input}")
+ response = get_response(frames_list, prompts_input, model, image_processor, tensor_dtype)
+ print(f"Response: {response}")
diff --git a/pipeline/eval/benchmark_otter.py b/pipeline/eval/benchmark_otter.py
deleted file mode 100644
index 2462599a..00000000
--- a/pipeline/eval/benchmark_otter.py
+++ /dev/null
@@ -1,215 +0,0 @@
-import requests
-import torch
-import transformers
-import json
-from PIL import Image
-from otter.modeling_otter import OtterForConditionalGeneration
-import argparse
-from tqdm import tqdm
-
-requests.packages.urllib3.disable_warnings()
-
-
-def get_image(url: str) -> Image.Image:
- """
- Get image from url
-
- Args:
- url (str): url of the image
-
- Returns:
- Image.Image: PIL Image
- """
- return Image.open(requests.get(url, stream=True, verify=False).raw)
-
-
-def get_formatted_prompt(prompt: str) -> str:
- """
- Format prompt for GPT
-
- Args:
- prompt (str): prompt to be formatted
-
- Returns:
- str: formatted prompt
- """
- return f" User: {prompt} GPT:"
-
-
-def get_response(url: str, prompt: str, model=None, image_processor=None) -> str:
- """
- Get the response of single image and prompt from the model
-
- Args:
- url (str): url of the image
- prompt (str): the prompt (no need to be formatted)
-
- Returns:
- str: response of the model
- """
- query_image = get_image(url)
- vision_x = image_processor.preprocess([query_image], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
- lang_x = model.text_tokenizer(
- [
- get_formatted_prompt(prompt),
- ],
- return_tensors="pt",
- )
- generated_text = model.generate(
- vision_x=vision_x.to(model.device),
- lang_x=lang_x["input_ids"].to(model.device),
- attention_mask=lang_x["attention_mask"].to(model.device),
- max_new_tokens=256,
- num_beams=3,
- no_repeat_ngram_size=3,
- )
- parsed_output = (
- model.text_tokenizer.decode(generated_text[0])
- .split("")[1]
- .lstrip()
- .rstrip()
- .split("<|endofchunk|>")[0]
- .lstrip()
- .rstrip()
- .lstrip('"')
- .rstrip('"')
- )
- return parsed_output
-
-
-def generate_html(output_file, model_version_or_tag):
- import json
-
- # Load the data from the JSON file
- with open(output_file, "r") as f:
- data = json.load(f)
-
- # Start the HTML file
- html = """
-
-
-
- Benchmarking various ver. of Otter
-
-
-
- {}
- """
-
- html = html.format(model_version_or_tag)
-
- # Add headers
- html += """
-
-
-
Image
-
-
-
Instruction
-
-
-
Response
-
-
- """
-
- # Add the data to the HTML
- for item in data:
- html += """
-
-
-

-
-
- {instruction}
-
-
- {response}
-
-
- """.format(
- **item
- )
-
- # Close the HTML tags
- html += """
-
-
- """
-
- # Write the HTML string to a file
- output_html_path = output_file.replace(".json", ".html")
- with open(output_html_path, "w") as f:
- f.write(html)
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--model_path_or_name",
- type=str,
- default="luodian/otter-9b-hf",
- help="Path or name of the model (HF format)",
- )
- parser.add_argument(
- "--model_version_or_tag",
- type=str,
- default="apr25_otter",
- help="Version or tag of the model",
- )
- parser.add_argument(
- "--input_file",
- type=str,
- default="evaluation/sample_questions.json",
- help="Path of the input file",
- )
- args = parser.parse_args()
-
- model = OtterForConditionalGeneration.from_pretrained(args.model_path_or_name, device_map="auto")
- model.text_tokenizer.padding_side = "left"
- tokenizer = model.text_tokenizer
- image_processor = transformers.CLIPImageProcessor()
-
- responses = []
- with open(args.input_file) as f:
- data = json.load(f)
- progress_bar = tqdm(total=len(data["input"]))
- for item in data["input"]:
- print("=" * 50)
- print(f"Processing {item['image']} with prompt {item['instruction']}")
- response = get_response(item["image"], item["instruction"], model, image_processor)
- print(f"Response: {response}")
- responses.append(
- {
- "image": item["image"],
- "instruction": item["instruction"],
- "response": response,
- }
- )
- progress_bar.update(1)
- json.dump(
- responses,
- open(f"./evaluation/{args.model_version_or_tag}_outputs.json", "w"),
- indent=4,
- )
-
- generate_html(
- f"./evaluation/{args.model_version_or_tag}_outputs.json",
- args.model_version_or_tag,
- )
diff --git a/pipeline/eval/classification.py b/pipeline/eval/classification.py
deleted file mode 100644
index 07a9a82d..00000000
--- a/pipeline/eval/classification.py
+++ /dev/null
@@ -1,134 +0,0 @@
-from typing import Dict, Sequence, Tuple
-import re
-import numpy as np
-import torch
-
-
-def postprocess_classification_generation(predictions) -> str:
- return re.split("Prompt|Completion", predictions, 1)[0]
-
-
-def compute_classification_accuracy(predictions: Sequence[Dict[str, str]]) -> float:
- """Compute the accuracy of a sequence of predictions."""
-
- def _preprocess_fn(s):
- """Function to preprocess both targets and predictions."""
- return s.lower()
-
- is_correct = [_preprocess_fn(x["prediction"]) == _preprocess_fn(x["class_label"]) for x in predictions]
-
- return np.mean(is_correct).item()
-
-
-def compute_shifted_logits_and_labels(logits: torch.Tensor, encodings, tokenizer, eoc_token_id) -> Tuple[torch.Tensor, torch.Tensor]:
- """Helper function to compute shifted logits and labels.
-
- This allows for straightforward computation of the loss on shift_logits
- and shift_labels such that the nth element of logits computes the n-1th
- element of the original labels (in the outputs, the nth element of logits
- corresponds to the nth element of the labels).
-
- Elements in shift_labels that correspond to inputs are masked with values
- of -100 (by default in hf, loss is only computed on token IDs >= 0).
-
- Returns: tuple containing two elements:
- shift_logits: a float Tensor of shape [batch_size, seq_len - 1].
- shift_labels: an integer Tensor of shape [batch_size, seq_len - 1]
- """
-
- labels = encodings["input_ids"].clone()
-
- # convert padding and EOC tokens to -100 so they are ignored in loss
- labels[labels == tokenizer.pad_token_id] = -100
- labels[labels == eoc_token_id] = -100
-
- # Convert all tokens in prefix until separator to -100 so they are
- # ignored in loss
- for idx in range(len(labels)):
- # Find the location of the last token of prefix *from right*,
- # since the first non-padding token of the sequence will also be
- # eos_token (because bos_token and eos_token are the same for
- # the tokenizer).
- end_of_prefix = -labels[idx].tolist()[::-1].index(tokenizer.eos_token_id) - 1
- labels[idx, : end_of_prefix + 1] = -100
-
- # Shift so that tokens < n predict n. The shifted tensors both have
- # shape [batch_size, seq_len - 1].
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
-
- return shift_logits, shift_labels
-
-
-def compute_per_sample_probs(encodings, tokenizer, logits: torch.Tensor, eoc_token_id) -> torch.Tensor:
- """Helper function to compute per-sample probability of the input sequence.
-
- Assumes is used to separate inputs from targets in the
- prompt text
- """
- shift_logits, shift_labels = compute_shifted_logits_and_labels(logits, encodings, tokenizer, eoc_token_id)
-
- # Tuple of tensors for unmasked label tokens. The first element of the
- # tuple contains the batch indices; the second element contains the
- # sequence indices.
- unmasked_indices = torch.nonzero(shift_labels != -100, as_tuple=True)
- # Tensor where the i^th element is the token_id corresponding to the i^th
- # element of unmasked_indices
- unmasked_token_ids = shift_labels[unmasked_indices]
-
- # 3d tensor of [batch_idx, sequence_position, token_id] for unmasked tokens.
- target_idxs = torch.column_stack([*unmasked_indices, unmasked_token_ids])
- target_idxs = target_idxs.to(shift_logits.device)
-
- # Sanity check that every element in batch has at least one unmasked
- # target token
- assert torch.all(torch.bincount(target_idxs[:, 0]) != 0), "At least one element in batch has no unmasked target tokens."
-
- # Renormalize over tokens to make sure they are proper probabilities via
- # softmax over the token dimension.
- shift_probs = torch.nn.functional.softmax(shift_logits, 2)
-
- # Compute the probability of the target sequence (as the product of the
- # probability of the individual tokens in the sequence).
- target_probs = torch.ones(len(shift_labels), device=shift_logits.device)
- for i, j, k in target_idxs:
- target_probs[i] *= shift_probs[i, j, k]
-
- return target_probs
-
-
-def compute_per_sample_loss(encodings, tokenizer, logits, eoc_token_id) -> torch.Tensor:
- """Helper function to compute per-sample classification loss.
-
- Assumes is used to separate inputs from targets in the
- prompt text
- """
- shift_logits, shift_labels = compute_shifted_logits_and_labels(logits, encodings, tokenizer, eoc_token_id)
-
- device = shift_logits.device
-
- # Loss is computed token-wise, on Tensors of shape
- # [batch_size * (seq_len - 1), vocab_size]
- # and returns a loss tensor of shape
- # [batch_size * (seq_len - 1)]. Most of the tokens will be masked
- # in this computation.
- loss = torch.nn.functional.cross_entropy(
- shift_logits.view(-1, shift_logits.size(-1)),
- shift_labels.view(-1).to(device),
- reduction="none",
- )
-
- # Reshape to [batch_size, seq_len - 1]
- loss = loss.view(shift_logits.size(0), shift_logits.size(1)).cpu()
-
- # loss_mask is 1 for tokens we want included in the loss, and 0 for tokens
- # that should be ignored in the loss.
- loss_mask = (shift_labels != -100).int().cpu()
-
- loss *= loss_mask
-
- # Compute per-element loss : sum loss over all (unmasked) tokens and
- # divide by number of variable tokens to obtain tensor of
- # shape [batch_size,]
- loss = loss.sum(dim=1) / (shift_labels != -100).sum(dim=1).float()
- return loss
diff --git a/pipeline/eval/coco_metric.py b/pipeline/eval/coco_metric.py
deleted file mode 100644
index 2a1c8842..00000000
--- a/pipeline/eval/coco_metric.py
+++ /dev/null
@@ -1,22 +0,0 @@
-from pycocoevalcap.eval import COCOEvalCap
-from pycocotools.coco import COCO
-
-
-def compute_cider(
- result_path,
- annotations_path="/data/yfcc-tmp/data/mscoco/annotations/captions_train2017.json",
-):
- # create coco object and coco_result object
- coco = COCO(annotations_path)
- coco_result = coco.loadRes(result_path)
-
- # create coco_eval object by taking coco and coco_result
- coco_eval = COCOEvalCap(coco, coco_result)
- coco_eval.params["image_id"] = coco_result.getImgIds()
- coco_eval.evaluate()
-
- return coco_eval.eval
-
-
-def postprocess_captioning_generation(predictions):
- return predictions.split("Output", 1)[0]
diff --git a/pipeline/eval/eval_datasets.py b/pipeline/eval/eval_datasets.py
deleted file mode 100644
index b251d03f..00000000
--- a/pipeline/eval/eval_datasets.py
+++ /dev/null
@@ -1,91 +0,0 @@
-import json
-import os
-
-from PIL import Image
-from torch.utils.data import Dataset
-from torchvision.datasets import ImageFolder
-
-from .imagenet_utils import IMAGENET_1K_CLASS_ID_TO_LABEL
-
-
-class COCOFlickrDataset(Dataset):
- def __init__(
- self,
- image_dir_path="/mmfs1/gscratch/efml/anasa2/data/coco/train2017/",
- annotations_path="/mmfs1/gscratch/efml/anasa2/data/coco/annotations/captions_train2017.json",
- is_flickr=False,
- ):
- self.image_dir_path = image_dir_path
- self.annotations = json.load(open(annotations_path))["annotations"]
- self.is_flickr = is_flickr
-
- def __len__(self):
- return len(self.annotations)
-
- def get_img_path(self, idx):
- if self.is_flickr:
- return f"{self.image_dir_path}/{self.annotations[idx]['image_id']}.jpg"
- else:
- return f"{self.image_dir_path}/COCO_train2014_{self.annotations[idx]['image_id']:012d}.jpg"
-
- def __getitem__(self, idx):
- image = Image.open(self.get_img_path(idx))
- caption = self.annotations[idx]["caption"]
- return {
- "image": image,
- "caption": caption,
- "image_id": self.annotations[idx]["image_id"],
- }
-
-
-class VQADataset(Dataset):
- def __init__(
- self,
- image_dir_path="/mmfs1/gscratch/efml/anasa2/data/vqav2/train2014/",
- question_path="/mmfs1/gscratch/efml/anasa2/data/vqav2/v2_OpenEnded_mscoco_train2014_questions.json",
- annotations_path="/mmfs1/gscratch/efml/anasa2/data/vqav2/v2_mscoco_train2014_annotations.json",
- vqa_dataset="vqa",
- ):
- self.questions = json.load(open(question_path, "r"))["questions"]
- self.answers = json.load(open(annotations_path, "r"))["annotations"]
- self.image_dir_path = image_dir_path
- self.vqa_dataset = vqa_dataset
-
- def __len__(self):
- return len(self.questions)
-
- def get_img_path(self, question):
- if self.vqa_dataset == "vqa":
- return os.path.join(self.image_dir_path, f"COCO_train2014_{question['image_id']:012d}.jpg")
- elif self.vqa_dataset == "ok_vqa":
- return os.path.join(self.image_dir_path, f"COCO_train2014_{question['image_id']:012d}.jpg")
- else:
- raise Exception(f"Unknown VQA dataset {self.vqa_dataset}")
-
- def __getitem__(self, idx):
- question = self.questions[idx]
- answers = self.answers[idx]
- img_path = self.get_img_path(question)
- image = Image.open(img_path)
- return {
- "image": image,
- "question": question["question"],
- "answers": [a["answer"] for a in answers["answers"]],
- "question_id": question["question_id"],
- }
-
-
-class ImageNetDataset(ImageFolder):
- """Class to represent the ImageNet1k dataset."""
-
- def __init__(self, root, **kwargs):
- super().__init__(root=root, **kwargs)
-
- def __getitem__(self, idx):
- sample, target = super().__getitem__(idx)
- target_label = IMAGENET_1K_CLASS_ID_TO_LABEL[target]
- return {
- "image": sample,
- "class_id": target, # numeric ID of the ImageNet class
- "class_name": target_label, # human-readable name of ImageNet class
- }
diff --git a/pipeline/eval/evaluate.py b/pipeline/eval/evaluate.py
deleted file mode 100644
index 88bced03..00000000
--- a/pipeline/eval/evaluate.py
+++ /dev/null
@@ -1,873 +0,0 @@
-import argparse
-import json
-from math import ceil
-import os
-import random
-import uuid
-from collections import defaultdict
-from typing import Callable
-
-import more_itertools
-import numpy as np
-import torch
-from coco_metric import compute_cider, postprocess_captioning_generation
-from eval_datasets import COCOFlickrDataset, VQADataset, ImageNetDataset
-from tqdm import tqdm
-
-from .ok_vqa_utils import postprocess_ok_vqa_generation
-from vqa_metric import compute_vqa_accuracy, postprocess_vqa_generation
-from .classification import (
- compute_per_sample_probs,
- compute_per_sample_loss,
-)
-from .imagenet_utils import (
- openai_imagenet_classnames,
- IMAGENET_1K_CLASS_ID_TO_LABEL,
-)
-
-parser = argparse.ArgumentParser()
-parser.add_argument("--lm_path", type=str, default="facebook/opt-1.3b")
-parser.add_argument("--lm_tokenizer_path", type=str, default="facebook/opt-30b")
-parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str)
-parser.add_argument("--vision_encoder_pretrained", default="openai", type=str)
-parser.add_argument("--checkpoint_path", type=str, required=True)
-parser.add_argument(
- "--cross_attn_every_n_layers",
- type=int,
- default=1,
- help="how often to add a cross-attention layer after each transformer layer",
-)
-parser.add_argument("--results_file", type=str, default=None, help="JSON file to save results")
-
-# Trial arguments
-parser.add_argument("--shots", nargs="+", default=[0, 4, 8, 16, 32], type=int)
-parser.add_argument(
- "--num_trials",
- type=int,
- default=1,
- help="Number of trials to run for each shot using different demonstrations",
-)
-parser.add_argument(
- "--trial_seeds",
- nargs="+",
- default=[0],
- help="Seeds to use for each trial for picking demonstrations and eval sets",
-)
-parser.add_argument("--num_samples", type=int, default=5000, help="Number of samples to evaluate on")
-
-parser.add_argument("--batch_size", type=int, default=8)
-parser.add_argument("--device", type=int, default=0)
-
-# Per-dataset evaluation flags
-parser.add_argument(
- "--eval_coco",
- action="store_true",
- default=False,
- help="Whether to evaluate on COCO.",
-)
-parser.add_argument(
- "--eval_vqav2",
- action="store_true",
- default=False,
- help="Whether to evaluate on VQAV2.",
-)
-parser.add_argument(
- "--eval_ok_vqa",
- action="store_true",
- default=False,
- help="Whether to evaluate on OK-VQA.",
-)
-parser.add_argument(
- "--eval_imagenet",
- action="store_true",
- default=False,
- help="Whether to evaluate on ImageNet.",
-)
-
-parser.add_argument(
- "--eval_flickr30",
- action="store_true",
- default=False,
- help="Whether to evaluate on Flickr30.",
-)
-
-# Dataset arguments
-
-## Flickr30 Dataset
-parser.add_argument(
- "--flickr_image_dir_path",
- type=str,
- help="Path to the flickr30/flickr30k_images directory.",
- default=None,
-)
-parser.add_argument(
- "--flickr_annotations_json_path",
- type=str,
- help="Path to the dataset_flickr30k_coco_style.json file.",
- default=None,
-)
-
-## COCO Dataset
-parser.add_argument(
- "--coco_image_dir_path",
- type=str,
- help="Path to the flickr30/flickr30k_images directory.",
- default=None,
-)
-parser.add_argument(
- "--coco_annotations_json_path",
- type=str,
- default=None,
-)
-
-## VQAV2 Dataset
-parser.add_argument(
- "--vqav2_image_dir_path",
- type=str,
- default=None,
-)
-parser.add_argument(
- "--vqav2_questions_json_path",
- type=str,
- default=None,
-)
-parser.add_argument(
- "--vqav2_annotations_json_path",
- type=str,
- default=None,
-)
-
-## OK-VQA Dataset
-parser.add_argument(
- "--ok_vqa_image_dir_path",
- type=str,
- help="Path to the vqav2/train2014 directory.",
- default=None,
-)
-parser.add_argument(
- "--ok_vqa_questions_json_path",
- type=str,
- help="Path to the v2_OpenEnded_mscoco_train2014_questions.json file.",
- default=None,
-)
-parser.add_argument(
- "--ok_vqa_annotations_json_path",
- type=str,
- help="Path to the v2_mscoco_train2014_annotations.json file.",
- default=None,
-)
-
-## Imagenet dataset
-parser.add_argument("--imagenet_root", type=str, default="/tmp")
-
-
-def main():
- args = parser.parse_args()
-
- # TODO: load hf model
- flamingo = None
- tokenizer = None
- image_processor = None
-
- checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
- flamingo.load_state_dict(checkpoint, strict=False)
- flamingo.to(args.device if args.device >= 0 else "cpu")
-
- results = defaultdict(list)
-
- if args.eval_flickr30:
- print("Evaluating on Flickr30...")
- for shot in args.shots:
- scores = []
- for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
- cider_score = evaluate_coco_flickr(
- model=flamingo,
- tokenizer=tokenizer,
- image_processor=image_processor,
- batch_size=args.batch_size,
- image_dir_path=args.flickr_image_dir_path,
- annotations_json_path=args.flickr_annotations_json_path,
- num_samples=args.num_samples,
- num_shots=shot,
- device=args.device,
- seed=seed,
- is_flickr=True,
- )
- print(f"Shots {shot} Trial {trial} CIDEr score: {cider_score}")
- scores.append(cider_score)
- print(f"Shots {shot} Mean CIDEr score: {np.mean(scores)}")
- results["flickr30"].append({"shots": shot, "trials": scores, "mean": np.mean(scores)})
- results = defaultdict(list)
-
- if args.eval_coco:
- print("Evaluating on COCO...")
- for shot in args.shots:
- scores = []
- for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
- cider_score = evaluate_coco_flickr(
- model=flamingo,
- tokenizer=tokenizer,
- image_processor=image_processor,
- batch_size=args.batch_size,
- image_dir_path=args.coco_image_dir_path,
- annotations_json_path=args.coco_annotations_json_path,
- num_samples=args.num_samples,
- num_shots=shot,
- device=args.device,
- seed=seed,
- )
- print(f"Shots {shot} Trial {trial} CIDEr score: {cider_score}")
- scores.append(cider_score)
- print(f"Shots {shot} Mean CIDEr score: {np.mean(scores)}")
- results["coco"].append({"shots": shot, "trials": scores, "mean": np.mean(scores)})
-
- if args.eval_ok_vqa:
- print("Evaluating on OK-VQA...")
- for shot in args.shots:
- scores = []
- for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
- ok_vqa_score = evaluate_vqa(
- model=flamingo,
- tokenizer=tokenizer,
- image_processor=image_processor,
- batch_size=args.batch_size,
- num_samples=args.num_samples,
- num_shots=shot,
- device=args.device,
- seed=seed,
- image_dir_path=args.ok_vqa_image_dir_path,
- questions_json_path=args.ok_vqa_questions_json_path,
- annotations_json_path=args.ok_vqa_annotations_json_path,
- vqa_dataset="ok_vqa",
- )
- print(f"Shots {shot} Trial {trial} OK-VQA score: {ok_vqa_score}")
- scores.append(ok_vqa_score)
- print(f"Shots {shot} Mean OK-VQA score: {np.mean(scores)}")
- results["ok_vqa"].append({"shots": shot, "trials": scores, "mean": np.mean(scores)})
-
- if args.eval_vqav2:
- print("Evaluating on VQAv2...")
- for shot in args.shots:
- scores = []
- for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
- vqa_score = evaluate_vqa(
- model=flamingo,
- tokenizer=tokenizer,
- image_processor=image_processor,
- batch_size=args.batch_size,
- num_samples=args.num_samples,
- num_shots=shot,
- device=args.device,
- seed=seed,
- image_dir_path=args.vqav2_image_dir_path,
- questions_json_path=args.vqav2_questions_json_path,
- annotations_json_path=args.vqav2_annotations_json_path,
- vqa_dataset="vqa",
- )
- print(f"Shots {shot} Trial {trial} VQA score: {vqa_score}")
- scores.append(vqa_score)
- print(f"Shots {shot} Mean VQA score: {np.mean(scores)}")
- results["vqav2"].append({"shots": shot, "trials": scores, "mean": np.mean(scores)})
-
- if args.eval_imagenet:
- print("Evaluating on ImageNet...")
- for shot in args.shots:
- scores = []
- for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
- imagenet_score = evaluate_imagenet(
- model=flamingo,
- tokenizer=tokenizer,
- image_processor=image_processor,
- batch_size=args.batch_size,
- num_samples=args.num_samples,
- num_shots=shot,
- device=args.device,
- seed=seed,
- imagenet_root=args.imagenet_root,
- )
- print(f"Shots {shot} Trial {trial} " f"ImageNet score: {imagenet_score}")
- scores.append(imagenet_score)
- print(f"Shots {shot} Mean ImageNet score: {np.mean(scores)}")
- results["imagenet"].append({"shots": shot, "trials": scores, "mean": np.mean(scores)})
-
- if args.results_file is not None:
- with open(args.results_file, "w") as f:
- json.dump(results, f)
-
-
-def get_random_indices(num_samples, query_set_size, full_dataset, seed):
- if num_samples + query_set_size > len(full_dataset):
- raise ValueError(f"num_samples + num_shots must be less than {len(full_dataset)}")
-
- # get a random subset of the dataset
- np.random.seed(seed)
- random_indices = np.random.choice(len(full_dataset), num_samples + query_set_size, replace=False)
- return random_indices
-
-
-def prepare_eval_samples_and_dataset(full_dataset, random_indices, query_set_size):
- # get in context samples
- in_context_samples = [full_dataset[i] for i in random_indices[:query_set_size]]
- eval_dataset = torch.utils.data.Subset(full_dataset, random_indices[query_set_size:])
- return in_context_samples, eval_dataset
-
-
-def get_context_images(image_processor, in_context_samples, num_shots):
- if num_shots > 0:
- context_images = [image_processor(s["image"]).unsqueeze(0) for s in in_context_samples]
- context_images = torch.cat(context_images, dim=0)
- context_images = context_images.unsqueeze(1).unsqueeze(0)
- else:
- context_images = None
- return context_images
-
-
-def get_context_text(
- get_prompt: Callable[[dict], str],
- in_context_samples,
- effective_num_shots,
- num_shots,
-) -> str:
- context_text = "".join([get_prompt(s) for s in in_context_samples]) if effective_num_shots > 0 else ""
-
- if num_shots == 0:
- context_text = context_text.replace("", "")
- return context_text
-
-
-def prepare_batch_images(batch, image_processor, context_images, num_shots):
- batch_images = None
- for b, sample_imgs in zip(batch, context_images):
- b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0)
- b_image = torch.cat([sample_imgs, b_image], dim=1) if num_shots > 0 else b_image
-
- if batch_images is None:
- batch_images = b_image
- else:
- batch_images = torch.cat([batch_images, b_image], dim=0)
- return batch_images
-
-
-def sample_batch_demos_from_query_set(query_set, num_samples, batch_size):
- return [random.sample(query_set, num_samples) for _ in range(batch_size)]
-
-
-def get_outputs(
- model,
- batch_images,
- device,
- attention_mask,
- max_generation_length,
- num_beams,
- length_penalty,
- input_ids,
-):
- with torch.inference_mode():
- outputs = model.generate(
- batch_images.to(device if device >= 0 else "cpu"),
- input_ids.to(device if device >= 0 else "cpu"),
- attention_mask=attention_mask.to(device if device >= 0 else "cpu"),
- max_new_tokens=max_generation_length,
- num_beams=num_beams,
- length_penalty=length_penalty,
- )
-
- outputs = outputs[:, len(input_ids[0]) :]
- return outputs
-
-
-def evaluate_coco_flickr(
- model,
- tokenizer,
- image_processor,
- batch_size,
- image_dir_path,
- annotations_json_path,
- seed=42,
- max_generation_length=20,
- num_beams=3,
- length_penalty=-2.0,
- num_samples=5000,
- query_set_size=2048,
- num_shots=8,
- device=-1,
- is_flickr=False,
-):
- """Evaluate a model on COCO dataset.
-
- Args:
- model (nn.Module): model to evaluate
- tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
- image_processor : image processor for the model
- batch_size (int): batch size
- image_dir_path (str, optional): path to the directory containing the images.
- annotations_json_path (str, optional): path to the json file containing the annotations.
- seed (int, optional): seed for random number generator. Defaults to 42.
- max_generation_length (int, optional): maximum length of the generated caption. Defaults to 10.
- num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
- length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
- num_samples (int, optional): number of samples to evaluate on. Defaults to 5000.
- query_set_size (int, optional): number of samples to use for query set. Defaults to 2048.
- num_shots (int, optional): number of in-context samples to use. Defaults to 8.
- device (int, optional): device to use. Defaults to -1.
- num_workers (int, optional): number of workers to use for dataloader. Defaults to 4.
- is_flickr (bool): defines if that data is COCO or Flickr. Defaults to False (COCO).
-
- Returns:
- float: CIDEr score
-
- """
-
- full_dataset = COCOFlickrDataset(
- image_dir_path=image_dir_path,
- annotations_path=annotations_json_path,
- is_flickr=is_flickr,
- )
- effective_num_shots = num_shots if num_shots > 0 else 2
- random_indices = get_random_indices(num_samples, query_set_size, full_dataset, seed)
-
- in_context_samples, eval_dataset = prepare_eval_samples_and_dataset(
- full_dataset=full_dataset,
- random_indices=random_indices,
- query_set_size=query_set_size,
- )
-
- model.eval()
-
- def get_prompt(sample):
- return f"Output:{sample['caption'].strip()}<|endofchunk|>"
-
- predictions = defaultdict()
-
- desc = "Running inference Flickr30" if is_flickr else "Running inference COCO"
-
- for batch in more_itertools.chunked(tqdm(eval_dataset, desc=desc), batch_size):
- batch_demo_samples = sample_batch_demos_from_query_set(in_context_samples, effective_num_shots, len(batch))
-
- context_images = [
- get_context_images(
- image_processor=image_processor,
- in_context_samples=batch_demo_samples[i],
- num_shots=num_shots,
- )
- for i in range(len(batch))
- ]
-
- context_text = [
- get_context_text(
- get_prompt,
- in_context_samples=batch_demo_samples[i],
- effective_num_shots=effective_num_shots,
- num_shots=num_shots,
- )
- for i in range(len(batch))
- ]
-
- batch_images = prepare_batch_images(
- batch=batch,
- image_processor=image_processor,
- context_images=context_images,
- num_shots=num_shots,
- )
-
- batch_text = [f"{context_text[i]}Output:" for i in range(len(batch))]
-
- tokenizer.padding_side = "left"
- encodings = tokenizer(
- batch_text,
- padding="longest",
- truncation=True,
- return_tensors="pt",
- max_length=2000,
- )
- input_ids = encodings["input_ids"]
- attention_mask = encodings["attention_mask"]
-
- outputs = get_outputs(
- model=model,
- batch_images=batch_images,
- device=device,
- attention_mask=attention_mask,
- max_generation_length=max_generation_length,
- num_beams=num_beams,
- length_penalty=length_penalty,
- input_ids=input_ids,
- )
- new_predictions = [postprocess_captioning_generation(out).replace('"', "") for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)]
-
- for i, sample in enumerate(batch):
- predictions[sample["image_id"]] = {
- "caption": new_predictions[i],
- }
-
- # save the predictions to a temporary file
- random_uuid = str(uuid.uuid4())
- results_path = f"flickrresults_{random_uuid}.json" if is_flickr else f"cocoresults_{random_uuid}.json"
- with open(results_path, "w") as f:
- f.write(
- json.dumps(
- [{"image_id": k, "caption": predictions[k]["caption"]} for k in predictions],
- indent=4,
- )
- )
-
- metrics = compute_cider(
- result_path=results_path,
- annotations_path=annotations_json_path,
- )
-
- # delete the temporary file
- os.remove(results_path)
-
- return metrics["CIDEr"] * 100.0
-
-
-def evaluate_vqa(
- model,
- tokenizer,
- image_processor,
- batch_size,
- image_dir_path,
- questions_json_path,
- annotations_json_path,
- seed=42,
- max_generation_length=5,
- num_beams=3,
- length_penalty=-2.0,
- num_samples=5000,
- query_set_size=2048,
- num_shots=8,
- device=-1,
- vqa_dataset="vqa",
-):
- """
- Evaluate a model on VQA datasets. Currently supports VQA v2.0.
-
- Args:
- model (nn.Module): model to evaluate
- tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
- image_processor : image processor for the model
- batch_size (int): batch size
- image_dir_path (str): path to image directory
- questions_json_path (str): path to questions json file
- annotations_json_path (str): path to annotations json file
- seed (int, optional): random seed. Defaults to 42.
- max_generation_length (int, optional): max generation length. Defaults to 5.
- num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
- length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
- num_samples (int, optional): number of samples to evaluate on. Defaults to 5000 samples.
- query_set_size (int, optional): size of the query set. Defaults to 2048.
- num_shots (int, optional): number of shots to use. Defaults to 8.
- device (int, optional): device to use. Defaults to -1 (cpu).
- num_workers (int, optional): number of workers to use. Defaults to 4.
- vqa_dataset (string): type of vqa dataset: currently supports vqa, ok_vqa. Defaults to vqa.
- Returns:
- float: accuracy score
- """
-
- full_dataset = VQADataset(
- image_dir_path=image_dir_path,
- question_path=questions_json_path,
- annotations_path=annotations_json_path,
- vqa_dataset=vqa_dataset,
- )
-
- effective_num_shots = num_shots if num_shots > 0 else 2
-
- if num_samples + effective_num_shots > len(full_dataset):
- raise ValueError(f"num_samples + num_shots must be less than or equal to {len(full_dataset)}")
-
- random_indices = get_random_indices(num_samples, query_set_size, full_dataset, seed)
-
- def get_prompt(sample, train=True):
- return f"Question:{sample['question'].strip()} Short Answer:{sample['answers'][0].strip() if train else ''}{'<|endofchunk|>' if train else ''}"
-
- in_context_samples, eval_dataset = prepare_eval_samples_and_dataset(
- full_dataset=full_dataset,
- random_indices=random_indices,
- query_set_size=query_set_size,
- )
-
- model.eval()
- predictions = []
-
- for batch in more_itertools.chunked(tqdm(eval_dataset, desc="Running inference"), batch_size):
- batch_demo_samples = sample_batch_demos_from_query_set(in_context_samples, effective_num_shots, len(batch))
-
- context_images = [
- get_context_images(
- image_processor=image_processor,
- in_context_samples=batch_demo_samples[i],
- num_shots=num_shots,
- )
- for i in range(len(batch))
- ]
-
- context_text = [
- get_context_text(
- get_prompt,
- in_context_samples=batch_demo_samples[i],
- effective_num_shots=effective_num_shots,
- num_shots=num_shots,
- )
- for i in range(len(batch))
- ]
-
- batch_images = prepare_batch_images(
- batch=batch,
- image_processor=image_processor,
- context_images=context_images,
- num_shots=num_shots,
- )
-
- batch_text = [context_text[i] + get_prompt(s, train=False) for i, s in enumerate(batch)]
-
- tokenizer.padding_side = "left"
- encodings = tokenizer(
- batch_text,
- return_tensors="pt",
- padding="longest",
- truncation=True,
- max_length=2000,
- )
- input_ids = encodings["input_ids"].to(device if device >= 0 else "cpu")
- attention_mask = encodings["attention_mask"].to(device if device >= 0 else "cpu")
-
- outputs = get_outputs(
- model=model,
- batch_images=batch_images,
- device=device,
- attention_mask=attention_mask,
- max_generation_length=max_generation_length,
- num_beams=num_beams,
- length_penalty=length_penalty,
- input_ids=input_ids,
- )
-
- process_function = postprocess_vqa_generation if vqa_dataset == "vqa" else postprocess_ok_vqa_generation
-
- new_predictions = [process_function(out) for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)]
-
- predictions.extend([{"answer": p, "question_id": sample["question_id"]} for p, sample in zip(new_predictions, batch)])
- # save the predictions to a temporary file
- random_uuid = str(uuid.uuid4())
- with open(f"{vqa_dataset}results_{random_uuid}.json", "w") as f:
- f.write(json.dumps(predictions, indent=4))
-
- acc = compute_vqa_accuracy(
- f"{vqa_dataset}results_{random_uuid}.json",
- questions_json_path,
- annotations_json_path,
- )
-
- # delete the temporary file
- os.remove(f"{vqa_dataset}results_{random_uuid}.json")
-
- return acc
-
-
-def evaluate_imagenet(
- model,
- tokenizer,
- image_processor,
- batch_size: int,
- imagenet_root: str,
- seed: int = 42,
- num_samples: int = 5000,
- num_shots: int = 8,
- device: int = -1,
-):
- """
- Evaluate a model on ImageNet dataset.
-
- Args:
- model: model to evaluate
- tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
- image_processor : image processor for the model
- batch_size (int): batch size
- imagenet_root (str): path to imagenet root for the specified split.
- seed (int, optional): random seed. Defaults to 42.
- num_samples (int, optional): number of samples to evaluate on. Defaults to 5000 samples.
- num_shots (int, optional): number of shots to use. Defaults to 8.
- device (int, optional): device to use. Defaults to -1 (cpu).
-
- Returns:
- float: accuracy score
- """
-
- full_dataset = ImageNetDataset(root=imagenet_root)
-
- effective_num_shots = num_shots if num_shots > 0 else 2
-
- if num_samples + effective_num_shots > len(full_dataset):
- raise ValueError(f"num_samples + num_shots must be less than or equal to " f"{len(full_dataset)} ")
-
- random_indices = get_random_indices(num_samples, effective_num_shots, full_dataset, seed)
-
- eoc_token = "<|endofchunk|>"
- eoc_token_id = tokenizer.additional_special_tokens_ids[tokenizer.additional_special_tokens.index(eoc_token)]
-
- # Padding from right allows efficient precomputing of context activations.
- tokenizer.padding_side = "right"
-
- def _imagenet_prompt(class_name, is_context: bool = True):
- """Construct an imagenet prompt for a given label."""
- prefix = "A photo of a "
- if is_context:
- return prefix + class_name.strip()
- else:
- # Not a context example; insert EOS token before the class name
- # so that we can compute the loss on the class name tokens only.
- return prefix + tokenizer.eos_token + class_name.strip()
-
- def get_imagenet_prompt(x: dict, is_context: bool = True) -> str:
- """Construct an ImageNet prompt for an example, using its label."""
- return _imagenet_prompt(x["class_name"], is_context=is_context)
-
- in_context_samples, eval_dataset = prepare_eval_samples_and_dataset(
- full_dataset=full_dataset,
- random_indices=random_indices,
- query_set_size=effective_num_shots, # NOTE: here we replace query_set_size with effective_num_shots but this is not the ideal evaluation setting.
- # TODO: We should add a query_set_size argument to the function and use it to randomly sample the context for each example.
- # This will be more consistent with the evaluation setting in the paper but will require some reworking of the caching.
- )
-
- device = device if device >= 0 else "cpu"
-
- model.eval()
- # Predictions based on the class target sequence with the maximal
- # predicted probability
- predictions_max_prob = []
- # Predictions based on the class target sequence with the minimal loss on
- # the model logits
- predictions_min_loss = []
- labels = []
-
- context_images = [
- get_context_images(
- image_processor=image_processor,
- in_context_samples=in_context_samples,
- num_shots=num_shots,
- )
- for _ in range(batch_size)
- ]
-
- context_text = get_context_text(
- get_imagenet_prompt,
- in_context_samples=in_context_samples,
- effective_num_shots=effective_num_shots,
- num_shots=num_shots,
- )
-
- # kwargs to use when calling tokenizer
- tokenizer_kwargs = {
- "return_tensors": "pt",
- "padding": True,
- "truncation": True,
- "max_length": 256,
- }
-
- for i, batch in enumerate(more_itertools.chunked(eval_dataset, batch_size)):
- print(f"processing batch {i} of {ceil(len(eval_dataset) / batch_size)}")
- batch_per_class_probs = []
- batch_per_class_losses = []
- batch_images = prepare_batch_images(
- batch=batch,
- image_processor=image_processor,
- context_images=context_images,
- num_shots=num_shots,
- )
-
- # Process the images only once.
- batch_images = batch_images.to(device)
- model._encode_vision_x(vision_x=batch_images)
-
- # Process the context text only once.
- context_encodings = tokenizer([context_text] * batch_size, **tokenizer_kwargs)
- context_ids = context_encodings["input_ids"].to(device)
- context_len = context_ids.shape[-1]
- context_precomputed = model(
- None,
- context_ids,
- use_cached_vision_x=True,
- clear_conditioned_layers=False,
- use_cache=True,
- )
-
- # For each ImageNet class, construct the output prompt, compute a
- # forward pass, and store the results.
- for imagenet_class_name in tqdm(openai_imagenet_classnames):
- batch_text = [context_text + _imagenet_prompt(imagenet_class_name, False) + eoc_token] * batch_size
-
- full_batch_encodings = tokenizer(batch_text, **tokenizer_kwargs)
-
- # full_batch_input_ids has shape [batch_size, seq_len], but we
- # only need to run inference on the [batch_size,
- # context_len:] inputs that have not been precomputed and
- # vary per class.
- full_batch_input_ids = full_batch_encodings["input_ids"].to(device)
- full_batch_attention_mask = full_batch_encodings["attention_mask"].to(device)
-
- # Sanity check that the encoded inputs with context are the same
- # as the encoded context alone, for every example in the batch
- assert torch.all(context_ids[0, :] == full_batch_input_ids[:, :context_len]).item()
-
- # Clone the nested structure of the past key values
- past_key_values = tuple([tuple([x.clone() for x in inner]) for inner in context_precomputed.past_key_values])
-
- # Compute the outputs without recomputing context representations.
- outputs = model(
- vision_x=None,
- lang_x=full_batch_input_ids[:, context_len:],
- attention_mask=full_batch_attention_mask,
- use_cached_vision_x=True,
- clear_conditioned_layers=False,
- past_key_values=past_key_values,
- use_cache=True,
- )
-
- logits = torch.concat((context_precomputed.logits, outputs.logits), 1)
-
- per_sample_probs = compute_per_sample_probs(
- encodings=full_batch_encodings,
- tokenizer=tokenizer,
- logits=logits,
- eoc_token_id=eoc_token_id,
- )
- per_sample_loss = compute_per_sample_loss(
- encodings=full_batch_encodings,
- tokenizer=tokenizer,
- logits=logits,
- eoc_token_id=eoc_token_id,
- )
- batch_per_class_probs.append(per_sample_probs.detach())
- batch_per_class_losses.append(per_sample_loss.detach())
-
- # Tensor of shape [batch_size, 1000] where the [i,j]th element is
- # the (probability or loss) for batch element i on imagenet class j.
- batch_probs = torch.stack(batch_per_class_probs, 1)
- batch_losses = torch.stack(batch_per_class_losses, 1)
-
- predictions_max_prob.extend(torch.argmax(batch_probs, 1).detach().tolist())
- predictions_min_loss.extend(torch.argmin(batch_losses, 1).detach().tolist())
- labels.extend(x["class_id"] for x in batch)
-
- acc_max_prob = (np.array(predictions_max_prob) == np.array(labels)).mean()
- acc_min_loss = (np.array(predictions_min_loss) == np.array(labels)).mean()
- print(f"[DEBUG] ImageNet accuracy with max prob method is {acc_max_prob}")
- print(f"[DEBUG] ImageNet accuracy with min loss method is {acc_min_loss}")
- print(f"[DEBUG] printing ImageNet predictions and labels:")
- for yhat_prob, yhat_loss, y in zip(predictions_max_prob, predictions_min_loss, labels):
- print(
- " " * 30 + f"label: {IMAGENET_1K_CLASS_ID_TO_LABEL[y]}"
- f"\nprediction (max prob method): "
- f"{IMAGENET_1K_CLASS_ID_TO_LABEL[yhat_prob]}"
- f"\nprediction (min loss method): "
- f"{IMAGENET_1K_CLASS_ID_TO_LABEL[yhat_loss]}\n"
- "#" * 25
- )
- return acc_max_prob
-
-
-if __name__ == "__main__":
- main()
diff --git a/pipeline/eval/imagenet_utils.py b/pipeline/eval/imagenet_utils.py
deleted file mode 100644
index 41ec90fd..00000000
--- a/pipeline/eval/imagenet_utils.py
+++ /dev/null
@@ -1,1005 +0,0 @@
-# classnames via https://github.com/mlfoundations/wise-ft/blob/master/src/datasets/imagenet_classnames.py#L1
-openai_imagenet_classnames = [
- "tench",
- "goldfish",
- "great white shark",
- "tiger shark",
- "hammerhead shark",
- "electric ray",
- "stingray",
- "rooster",
- "hen",
- "ostrich",
- "brambling",
- "goldfinch",
- "house finch",
- "junco",
- "indigo bunting",
- "American robin",
- "bulbul",
- "jay",
- "magpie",
- "chickadee",
- "American dipper",
- "kite (bird of prey)",
- "bald eagle",
- "vulture",
- "great grey owl",
- "fire salamander",
- "smooth newt",
- "newt",
- "spotted salamander",
- "axolotl",
- "American bullfrog",
- "tree frog",
- "tailed frog",
- "loggerhead sea turtle",
- "leatherback sea turtle",
- "mud turtle",
- "terrapin",
- "box turtle",
- "banded gecko",
- "green iguana",
- "Carolina anole",
- "desert grassland whiptail lizard",
- "agama",
- "frilled-necked lizard",
- "alligator lizard",
- "Gila monster",
- "European green lizard",
- "chameleon",
- "Komodo dragon",
- "Nile crocodile",
- "American alligator",
- "triceratops",
- "worm snake",
- "ring-necked snake",
- "eastern hog-nosed snake",
- "smooth green snake",
- "kingsnake",
- "garter snake",
- "water snake",
- "vine snake",
- "night snake",
- "boa constrictor",
- "African rock python",
- "Indian cobra",
- "green mamba",
- "sea snake",
- "Saharan horned viper",
- "eastern diamondback rattlesnake",
- "sidewinder rattlesnake",
- "trilobite",
- "harvestman",
- "scorpion",
- "yellow garden spider",
- "barn spider",
- "European garden spider",
- "southern black widow",
- "tarantula",
- "wolf spider",
- "tick",
- "centipede",
- "black grouse",
- "ptarmigan",
- "ruffed grouse",
- "prairie grouse",
- "peafowl",
- "quail",
- "partridge",
- "african grey parrot",
- "macaw",
- "sulphur-crested cockatoo",
- "lorikeet",
- "coucal",
- "bee eater",
- "hornbill",
- "hummingbird",
- "jacamar",
- "toucan",
- "duck",
- "red-breasted merganser",
- "goose",
- "black swan",
- "tusker",
- "echidna",
- "platypus",
- "wallaby",
- "koala",
- "wombat",
- "jellyfish",
- "sea anemone",
- "brain coral",
- "flatworm",
- "nematode",
- "conch",
- "snail",
- "slug",
- "sea slug",
- "chiton",
- "chambered nautilus",
- "Dungeness crab",
- "rock crab",
- "fiddler crab",
- "red king crab",
- "American lobster",
- "spiny lobster",
- "crayfish",
- "hermit crab",
- "isopod",
- "white stork",
- "black stork",
- "spoonbill",
- "flamingo",
- "little blue heron",
- "great egret",
- "bittern bird",
- "crane bird",
- "limpkin",
- "common gallinule",
- "American coot",
- "bustard",
- "ruddy turnstone",
- "dunlin",
- "common redshank",
- "dowitcher",
- "oystercatcher",
- "pelican",
- "king penguin",
- "albatross",
- "grey whale",
- "killer whale",
- "dugong",
- "sea lion",
- "Chihuahua",
- "Japanese Chin",
- "Maltese",
- "Pekingese",
- "Shih Tzu",
- "King Charles Spaniel",
- "Papillon",
- "toy terrier",
- "Rhodesian Ridgeback",
- "Afghan Hound",
- "Basset Hound",
- "Beagle",
- "Bloodhound",
- "Bluetick Coonhound",
- "Black and Tan Coonhound",
- "Treeing Walker Coonhound",
- "English foxhound",
- "Redbone Coonhound",
- "borzoi",
- "Irish Wolfhound",
- "Italian Greyhound",
- "Whippet",
- "Ibizan Hound",
- "Norwegian Elkhound",
- "Otterhound",
- "Saluki",
- "Scottish Deerhound",
- "Weimaraner",
- "Staffordshire Bull Terrier",
- "American Staffordshire Terrier",
- "Bedlington Terrier",
- "Border Terrier",
- "Kerry Blue Terrier",
- "Irish Terrier",
- "Norfolk Terrier",
- "Norwich Terrier",
- "Yorkshire Terrier",
- "Wire Fox Terrier",
- "Lakeland Terrier",
- "Sealyham Terrier",
- "Airedale Terrier",
- "Cairn Terrier",
- "Australian Terrier",
- "Dandie Dinmont Terrier",
- "Boston Terrier",
- "Miniature Schnauzer",
- "Giant Schnauzer",
- "Standard Schnauzer",
- "Scottish Terrier",
- "Tibetan Terrier",
- "Australian Silky Terrier",
- "Soft-coated Wheaten Terrier",
- "West Highland White Terrier",
- "Lhasa Apso",
- "Flat-Coated Retriever",
- "Curly-coated Retriever",
- "Golden Retriever",
- "Labrador Retriever",
- "Chesapeake Bay Retriever",
- "German Shorthaired Pointer",
- "Vizsla",
- "English Setter",
- "Irish Setter",
- "Gordon Setter",
- "Brittany dog",
- "Clumber Spaniel",
- "English Springer Spaniel",
- "Welsh Springer Spaniel",
- "Cocker Spaniel",
- "Sussex Spaniel",
- "Irish Water Spaniel",
- "Kuvasz",
- "Schipperke",
- "Groenendael dog",
- "Malinois",
- "Briard",
- "Australian Kelpie",
- "Komondor",
- "Old English Sheepdog",
- "Shetland Sheepdog",
- "collie",
- "Border Collie",
- "Bouvier des Flandres dog",
- "Rottweiler",
- "German Shepherd Dog",
- "Dobermann",
- "Miniature Pinscher",
- "Greater Swiss Mountain Dog",
- "Bernese Mountain Dog",
- "Appenzeller Sennenhund",
- "Entlebucher Sennenhund",
- "Boxer",
- "Bullmastiff",
- "Tibetan Mastiff",
- "French Bulldog",
- "Great Dane",
- "St. Bernard",
- "husky",
- "Alaskan Malamute",
- "Siberian Husky",
- "Dalmatian",
- "Affenpinscher",
- "Basenji",
- "pug",
- "Leonberger",
- "Newfoundland dog",
- "Great Pyrenees dog",
- "Samoyed",
- "Pomeranian",
- "Chow Chow",
- "Keeshond",
- "brussels griffon",
- "Pembroke Welsh Corgi",
- "Cardigan Welsh Corgi",
- "Toy Poodle",
- "Miniature Poodle",
- "Standard Poodle",
- "Mexican hairless dog (xoloitzcuintli)",
- "grey wolf",
- "Alaskan tundra wolf",
- "red wolf or maned wolf",
- "coyote",
- "dingo",
- "dhole",
- "African wild dog",
- "hyena",
- "red fox",
- "kit fox",
- "Arctic fox",
- "grey fox",
- "tabby cat",
- "tiger cat",
- "Persian cat",
- "Siamese cat",
- "Egyptian Mau",
- "cougar",
- "lynx",
- "leopard",
- "snow leopard",
- "jaguar",
- "lion",
- "tiger",
- "cheetah",
- "brown bear",
- "American black bear",
- "polar bear",
- "sloth bear",
- "mongoose",
- "meerkat",
- "tiger beetle",
- "ladybug",
- "ground beetle",
- "longhorn beetle",
- "leaf beetle",
- "dung beetle",
- "rhinoceros beetle",
- "weevil",
- "fly",
- "bee",
- "ant",
- "grasshopper",
- "cricket insect",
- "stick insect",
- "cockroach",
- "praying mantis",
- "cicada",
- "leafhopper",
- "lacewing",
- "dragonfly",
- "damselfly",
- "red admiral butterfly",
- "ringlet butterfly",
- "monarch butterfly",
- "small white butterfly",
- "sulphur butterfly",
- "gossamer-winged butterfly",
- "starfish",
- "sea urchin",
- "sea cucumber",
- "cottontail rabbit",
- "hare",
- "Angora rabbit",
- "hamster",
- "porcupine",
- "fox squirrel",
- "marmot",
- "beaver",
- "guinea pig",
- "common sorrel horse",
- "zebra",
- "pig",
- "wild boar",
- "warthog",
- "hippopotamus",
- "ox",
- "water buffalo",
- "bison",
- "ram (adult male sheep)",
- "bighorn sheep",
- "Alpine ibex",
- "hartebeest",
- "impala (antelope)",
- "gazelle",
- "arabian camel",
- "llama",
- "weasel",
- "mink",
- "European polecat",
- "black-footed ferret",
- "otter",
- "skunk",
- "badger",
- "armadillo",
- "three-toed sloth",
- "orangutan",
- "gorilla",
- "chimpanzee",
- "gibbon",
- "siamang",
- "guenon",
- "patas monkey",
- "baboon",
- "macaque",
- "langur",
- "black-and-white colobus",
- "proboscis monkey",
- "marmoset",
- "white-headed capuchin",
- "howler monkey",
- "titi monkey",
- "Geoffroy's spider monkey",
- "common squirrel monkey",
- "ring-tailed lemur",
- "indri",
- "Asian elephant",
- "African bush elephant",
- "red panda",
- "giant panda",
- "snoek fish",
- "eel",
- "silver salmon",
- "rock beauty fish",
- "clownfish",
- "sturgeon",
- "gar fish",
- "lionfish",
- "pufferfish",
- "abacus",
- "abaya",
- "academic gown",
- "accordion",
- "acoustic guitar",
- "aircraft carrier",
- "airliner",
- "airship",
- "altar",
- "ambulance",
- "amphibious vehicle",
- "analog clock",
- "apiary",
- "apron",
- "trash can",
- "assault rifle",
- "backpack",
- "bakery",
- "balance beam",
- "balloon",
- "ballpoint pen",
- "Band-Aid",
- "banjo",
- "baluster / handrail",
- "barbell",
- "barber chair",
- "barbershop",
- "barn",
- "barometer",
- "barrel",
- "wheelbarrow",
- "baseball",
- "basketball",
- "bassinet",
- "bassoon",
- "swimming cap",
- "bath towel",
- "bathtub",
- "station wagon",
- "lighthouse",
- "beaker",
- "military hat (bearskin or shako)",
- "beer bottle",
- "beer glass",
- "bell tower",
- "baby bib",
- "tandem bicycle",
- "bikini",
- "ring binder",
- "binoculars",
- "birdhouse",
- "boathouse",
- "bobsleigh",
- "bolo tie",
- "poke bonnet",
- "bookcase",
- "bookstore",
- "bottle cap",
- "hunting bow",
- "bow tie",
- "brass memorial plaque",
- "bra",
- "breakwater",
- "breastplate",
- "broom",
- "bucket",
- "buckle",
- "bulletproof vest",
- "high-speed train",
- "butcher shop",
- "taxicab",
- "cauldron",
- "candle",
- "cannon",
- "canoe",
- "can opener",
- "cardigan",
- "car mirror",
- "carousel",
- "tool kit",
- "cardboard box / carton",
- "car wheel",
- "automated teller machine",
- "cassette",
- "cassette player",
- "castle",
- "catamaran",
- "CD player",
- "cello",
- "mobile phone",
- "chain",
- "chain-link fence",
- "chain mail",
- "chainsaw",
- "storage chest",
- "chiffonier",
- "bell or wind chime",
- "china cabinet",
- "Christmas stocking",
- "church",
- "movie theater",
- "cleaver",
- "cliff dwelling",
- "cloak",
- "clogs",
- "cocktail shaker",
- "coffee mug",
- "coffeemaker",
- "spiral or coil",
- "combination lock",
- "computer keyboard",
- "candy store",
- "container ship",
- "convertible",
- "corkscrew",
- "cornet",
- "cowboy boot",
- "cowboy hat",
- "cradle",
- "construction crane",
- "crash helmet",
- "crate",
- "infant bed",
- "Crock Pot",
- "croquet ball",
- "crutch",
- "cuirass",
- "dam",
- "desk",
- "desktop computer",
- "rotary dial telephone",
- "diaper",
- "digital clock",
- "digital watch",
- "dining table",
- "dishcloth",
- "dishwasher",
- "disc brake",
- "dock",
- "dog sled",
- "dome",
- "doormat",
- "drilling rig",
- "drum",
- "drumstick",
- "dumbbell",
- "Dutch oven",
- "electric fan",
- "electric guitar",
- "electric locomotive",
- "entertainment center",
- "envelope",
- "espresso machine",
- "face powder",
- "feather boa",
- "filing cabinet",
- "fireboat",
- "fire truck",
- "fire screen",
- "flagpole",
- "flute",
- "folding chair",
- "football helmet",
- "forklift",
- "fountain",
- "fountain pen",
- "four-poster bed",
- "freight car",
- "French horn",
- "frying pan",
- "fur coat",
- "garbage truck",
- "gas mask or respirator",
- "gas pump",
- "goblet",
- "go-kart",
- "golf ball",
- "golf cart",
- "gondola",
- "gong",
- "gown",
- "grand piano",
- "greenhouse",
- "radiator grille",
- "grocery store",
- "guillotine",
- "hair clip",
- "hair spray",
- "half-track",
- "hammer",
- "hamper",
- "hair dryer",
- "hand-held computer",
- "handkerchief",
- "hard disk drive",
- "harmonica",
- "harp",
- "combine harvester",
- "hatchet",
- "holster",
- "home theater",
- "honeycomb",
- "hook",
- "hoop skirt",
- "gymnastic horizontal bar",
- "horse-drawn vehicle",
- "hourglass",
- "iPod",
- "clothes iron",
- "carved pumpkin",
- "jeans",
- "jeep",
- "T-shirt",
- "jigsaw puzzle",
- "rickshaw",
- "joystick",
- "kimono",
- "knee pad",
- "knot",
- "lab coat",
- "ladle",
- "lampshade",
- "laptop computer",
- "lawn mower",
- "lens cap",
- "letter opener",
- "library",
- "lifeboat",
- "lighter",
- "limousine",
- "ocean liner",
- "lipstick",
- "slip-on shoe",
- "lotion",
- "music speaker",
- "loupe magnifying glass",
- "sawmill",
- "magnetic compass",
- "messenger bag",
- "mailbox",
- "tights",
- "one-piece bathing suit",
- "manhole cover",
- "maraca",
- "marimba",
- "mask",
- "matchstick",
- "maypole",
- "maze",
- "measuring cup",
- "medicine cabinet",
- "megalith",
- "microphone",
- "microwave oven",
- "military uniform",
- "milk can",
- "minibus",
- "miniskirt",
- "minivan",
- "missile",
- "mitten",
- "mixing bowl",
- "mobile home",
- "ford model t",
- "modem",
- "monastery",
- "monitor",
- "moped",
- "mortar and pestle",
- "graduation cap",
- "mosque",
- "mosquito net",
- "vespa",
- "mountain bike",
- "tent",
- "computer mouse",
- "mousetrap",
- "moving van",
- "muzzle",
- "metal nail",
- "neck brace",
- "necklace",
- "baby pacifier",
- "notebook computer",
- "obelisk",
- "oboe",
- "ocarina",
- "odometer",
- "oil filter",
- "pipe organ",
- "oscilloscope",
- "overskirt",
- "bullock cart",
- "oxygen mask",
- "product packet / packaging",
- "paddle",
- "paddle wheel",
- "padlock",
- "paintbrush",
- "pajamas",
- "palace",
- "pan flute",
- "paper towel",
- "parachute",
- "parallel bars",
- "park bench",
- "parking meter",
- "railroad car",
- "patio",
- "payphone",
- "pedestal",
- "pencil case",
- "pencil sharpener",
- "perfume",
- "Petri dish",
- "photocopier",
- "plectrum",
- "Pickelhaube",
- "picket fence",
- "pickup truck",
- "pier",
- "piggy bank",
- "pill bottle",
- "pillow",
- "ping-pong ball",
- "pinwheel",
- "pirate ship",
- "drink pitcher",
- "block plane",
- "planetarium",
- "plastic bag",
- "plate rack",
- "farm plow",
- "plunger",
- "Polaroid camera",
- "pole",
- "police van",
- "poncho",
- "pool table",
- "soda bottle",
- "plant pot",
- "potter's wheel",
- "power drill",
- "prayer rug",
- "printer",
- "prison",
- "missile",
- "projector",
- "hockey puck",
- "punching bag",
- "purse",
- "quill",
- "quilt",
- "race car",
- "racket",
- "radiator",
- "radio",
- "radio telescope",
- "rain barrel",
- "recreational vehicle",
- "fishing casting reel",
- "reflex camera",
- "refrigerator",
- "remote control",
- "restaurant",
- "revolver",
- "rifle",
- "rocking chair",
- "rotisserie",
- "eraser",
- "rugby ball",
- "ruler measuring stick",
- "sneaker",
- "safe",
- "safety pin",
- "salt shaker",
- "sandal",
- "sarong",
- "saxophone",
- "scabbard",
- "weighing scale",
- "school bus",
- "schooner",
- "scoreboard",
- "CRT monitor",
- "screw",
- "screwdriver",
- "seat belt",
- "sewing machine",
- "shield",
- "shoe store",
- "shoji screen / room divider",
- "shopping basket",
- "shopping cart",
- "shovel",
- "shower cap",
- "shower curtain",
- "ski",
- "balaclava ski mask",
- "sleeping bag",
- "slide rule",
- "sliding door",
- "slot machine",
- "snorkel",
- "snowmobile",
- "snowplow",
- "soap dispenser",
- "soccer ball",
- "sock",
- "solar thermal collector",
- "sombrero",
- "soup bowl",
- "keyboard space bar",
- "space heater",
- "space shuttle",
- "spatula",
- "motorboat",
- "spider web",
- "spindle",
- "sports car",
- "spotlight",
- "stage",
- "steam locomotive",
- "through arch bridge",
- "steel drum",
- "stethoscope",
- "scarf",
- "stone wall",
- "stopwatch",
- "stove",
- "strainer",
- "tram",
- "stretcher",
- "couch",
- "stupa",
- "submarine",
- "suit",
- "sundial",
- "sunglasses",
- "sunglasses",
- "sunscreen",
- "suspension bridge",
- "mop",
- "sweatshirt",
- "swim trunks / shorts",
- "swing",
- "electrical switch",
- "syringe",
- "table lamp",
- "tank",
- "tape player",
- "teapot",
- "teddy bear",
- "television",
- "tennis ball",
- "thatched roof",
- "front curtain",
- "thimble",
- "threshing machine",
- "throne",
- "tile roof",
- "toaster",
- "tobacco shop",
- "toilet seat",
- "torch",
- "totem pole",
- "tow truck",
- "toy store",
- "tractor",
- "semi-trailer truck",
- "tray",
- "trench coat",
- "tricycle",
- "trimaran",
- "tripod",
- "triumphal arch",
- "trolleybus",
- "trombone",
- "hot tub",
- "turnstile",
- "typewriter keyboard",
- "umbrella",
- "unicycle",
- "upright piano",
- "vacuum cleaner",
- "vase",
- "vaulted or arched ceiling",
- "velvet fabric",
- "vending machine",
- "vestment",
- "viaduct",
- "violin",
- "volleyball",
- "waffle iron",
- "wall clock",
- "wallet",
- "wardrobe",
- "military aircraft",
- "sink",
- "washing machine",
- "water bottle",
- "water jug",
- "water tower",
- "whiskey jug",
- "whistle",
- "hair wig",
- "window screen",
- "window shade",
- "Windsor tie",
- "wine bottle",
- "airplane wing",
- "wok",
- "wooden spoon",
- "wool",
- "split-rail fence",
- "shipwreck",
- "sailboat",
- "yurt",
- "website",
- "comic book",
- "crossword",
- "traffic or street sign",
- "traffic light",
- "dust jacket",
- "menu",
- "plate",
- "guacamole",
- "consomme",
- "hot pot",
- "trifle",
- "ice cream",
- "popsicle",
- "baguette",
- "bagel",
- "pretzel",
- "cheeseburger",
- "hot dog",
- "mashed potatoes",
- "cabbage",
- "broccoli",
- "cauliflower",
- "zucchini",
- "spaghetti squash",
- "acorn squash",
- "butternut squash",
- "cucumber",
- "artichoke",
- "bell pepper",
- "cardoon",
- "mushroom",
- "Granny Smith apple",
- "strawberry",
- "orange",
- "lemon",
- "fig",
- "pineapple",
- "banana",
- "jackfruit",
- "cherimoya (custard apple)",
- "pomegranate",
- "hay",
- "carbonara",
- "chocolate syrup",
- "dough",
- "meatloaf",
- "pizza",
- "pot pie",
- "burrito",
- "red wine",
- "espresso",
- "tea cup",
- "eggnog",
- "mountain",
- "bubble",
- "cliff",
- "coral reef",
- "geyser",
- "lakeshore",
- "promontory",
- "sandbar",
- "beach",
- "valley",
- "volcano",
- "baseball player",
- "bridegroom",
- "scuba diver",
- "rapeseed",
- "daisy",
- "yellow lady's slipper",
- "corn",
- "acorn",
- "rose hip",
- "horse chestnut seed",
- "coral fungus",
- "agaric",
- "gyromitra",
- "stinkhorn mushroom",
- "earth star fungus",
- "hen of the woods mushroom",
- "bolete",
- "corn cob",
- "toilet paper",
-]
-# Maps numeric class ids to labels
-IMAGENET_1K_CLASS_ID_TO_LABEL = dict(zip(range(len(openai_imagenet_classnames)), openai_imagenet_classnames))
diff --git a/pipeline/eval/ok_vqa_utils.py b/pipeline/eval/ok_vqa_utils.py
deleted file mode 100644
index 51497991..00000000
--- a/pipeline/eval/ok_vqa_utils.py
+++ /dev/null
@@ -1,214 +0,0 @@
-# Those are manual mapping that are not caught by our stemming rules or would
-# would be done incorrectly by our automatic stemming rule. In details,
-# the keys of the _MANUAL_MATCHES dict contains the original word and the value
-# contains the transformation of the word expected by the OKVQA stemming rule.
-# These manual rules were found by checking the `raw_answers` and the `answers`
-# fields of the released OKVQA dataset and checking all things that were not
-# properly mapped by our automatic rules. In particular some of the mapping
-# are sometimes constant, e.g. christmas -> christmas which was incorrectly
-# singularized by our inflection.singularize.
-import re
-import nltk
-from nltk.corpus.reader import VERB
-import inflection
-
-_MANUAL_MATCHES = {
- "police": "police",
- "las": "las",
- "vegas": "vegas",
- "yes": "yes",
- "jeans": "jean",
- "hell's": "hell",
- "domino's": "domino",
- "morning": "morn",
- "clothes": "cloth",
- "are": "are",
- "riding": "ride",
- "leaves": "leaf",
- "dangerous": "danger",
- "clothing": "cloth",
- "texting": "text",
- "kiting": "kite",
- "firefighters": "firefight",
- "ties": "tie",
- "married": "married",
- "teething": "teeth",
- "gloves": "glove",
- "tennis": "tennis",
- "dining": "dine",
- "directions": "direct",
- "waves": "wave",
- "christmas": "christmas",
- "drives": "drive",
- "pudding": "pud",
- "coding": "code",
- "plating": "plate",
- "quantas": "quanta",
- "hornes": "horn",
- "graves": "grave",
- "mating": "mate",
- "paned": "pane",
- "alertness": "alert",
- "sunbathing": "sunbath",
- "tenning": "ten",
- "wetness": "wet",
- "urinating": "urine",
- "sickness": "sick",
- "braves": "brave",
- "firefighting": "firefight",
- "lenses": "lens",
- "reflections": "reflect",
- "backpackers": "backpack",
- "eatting": "eat",
- "designers": "design",
- "curiousity": "curious",
- "playfulness": "play",
- "blindness": "blind",
- "hawke": "hawk",
- "tomatoe": "tomato",
- "rodeoing": "rodeo",
- "brightness": "bright",
- "circuses": "circus",
- "skateboarders": "skateboard",
- "staring": "stare",
- "electronics": "electron",
- "electicity": "elect",
- "mountainous": "mountain",
- "socializing": "social",
- "hamburgers": "hamburg",
- "caves": "cave",
- "transitions": "transit",
- "wading": "wade",
- "creame": "cream",
- "toileting": "toilet",
- "sautee": "saute",
- "buildings": "build",
- "belongings": "belong",
- "stockings": "stock",
- "walle": "wall",
- "cumulis": "cumuli",
- "travelers": "travel",
- "conducter": "conduct",
- "browsing": "brows",
- "pooping": "poop",
- "haircutting": "haircut",
- "toppings": "top",
- "hearding": "heard",
- "sunblocker": "sunblock",
- "bases": "base",
- "markings": "mark",
- "mopeds": "mope",
- "kindergartener": "kindergarten",
- "pies": "pie",
- "scrapbooking": "scrapbook",
- "couponing": "coupon",
- "meetings": "meet",
- "elevators": "elev",
- "lowes": "low",
- "men's": "men",
- "childrens": "children",
- "shelves": "shelve",
- "paintings": "paint",
- "raines": "rain",
- "paring": "pare",
- "expressions": "express",
- "routes": "rout",
- "pease": "peas",
- "vastness": "vast",
- "awning": "awn",
- "boy's": "boy",
- "drunkenness": "drunken",
- "teasing": "teas",
- "conferences": "confer",
- "ripeness": "ripe",
- "suspenders": "suspend",
- "earnings": "earn",
- "reporters": "report",
- "kid's": "kid",
- "containers": "contain",
- "corgie": "corgi",
- "porche": "porch",
- "microwaves": "microwave",
- "batter's": "batter",
- "sadness": "sad",
- "apartments": "apart",
- "oxygenize": "oxygen",
- "striping": "stripe",
- "purring": "pure",
- "professionals": "profession",
- "piping": "pipe",
- "farmer's": "farmer",
- "potatoe": "potato",
- "emirates": "emir",
- "womens": "women",
- "veteran's": "veteran",
- "wilderness": "wilder",
- "propellers": "propel",
- "alpes": "alp",
- "charioteering": "chariot",
- "swining": "swine",
- "illness": "ill",
- "crepte": "crept",
- "adhesives": "adhesive",
- "regent's": "regent",
- "decorations": "decor",
- "rabbies": "rabbi",
- "overseas": "oversea",
- "travellers": "travel",
- "casings": "case",
- "smugness": "smug",
- "doves": "dove",
- "nationals": "nation",
- "mustange": "mustang",
- "ringe": "ring",
- "gondoliere": "gondolier",
- "vacationing": "vacate",
- "reminders": "remind",
- "baldness": "bald",
- "settings": "set",
- "glaced": "glace",
- "coniferous": "conifer",
- "revelations": "revel",
- "personals": "person",
- "daughter's": "daughter",
- "badness": "bad",
- "projections": "project",
- "polarizing": "polar",
- "vandalizers": "vandal",
- "minerals": "miner",
- "protesters": "protest",
- "controllers": "control",
- "weddings": "wed",
- "sometimes": "sometime",
- "earing": "ear",
-}
-
-
-class OKVQAStemmer:
- """Stemmer to match OKVQA v1.1 procedure."""
-
- def __init__(self):
- self._wordnet_lemmatizer = nltk.stem.WordNetLemmatizer()
-
- def stem(self, input_string):
- """Apply stemming."""
- word_and_pos = nltk.pos_tag(nltk.tokenize.word_tokenize(input_string))
- stemmed_words = []
- for w, p in word_and_pos:
- if w in _MANUAL_MATCHES:
- w = _MANUAL_MATCHES[w]
- elif w.endswith("ing"):
- w = self._wordnet_lemmatizer.lemmatize(w, VERB)
- elif p.startswith("NNS") or p.startswith("NNPS"):
- w = inflection.singularize(w)
- stemmed_words.append(w)
- return " ".join(stemmed_words)
-
-
-stemmer = OKVQAStemmer()
-
-
-def postprocess_ok_vqa_generation(predictions) -> str:
- prediction = re.split("Question|Answer", predictions, 1)[0]
- prediction_stem = stemmer.stem(prediction)
- return prediction_stem
diff --git a/pipeline/eval/vqa_metric.py b/pipeline/eval/vqa_metric.py
deleted file mode 100644
index fb1e364d..00000000
--- a/pipeline/eval/vqa_metric.py
+++ /dev/null
@@ -1,543 +0,0 @@
-import copy
-import datetime
-import json
-import re
-import sys
-
-# Interface for accessing the VQA dataset.
-
-# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
-# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
-
-# The following functions are defined:
-# VQA - VQA class that loads VQA annotation file and prepares data structures.
-# getQuesIds - Get question ids that satisfy given filter conditions.
-# getImgIds - Get image ids that satisfy given filter conditions.
-# loadQA - Load questions and answers with the specified question ids.
-# showQA - Display the specified questions and answers.
-# loadRes - Load result file and create result object.
-
-# Help on each function can be accessed by: "help(COCO.function)"
-
-
-class VQA:
- def __init__(self, annotation_file=None, question_file=None):
- """
- Constructor of VQA helper class for reading and visualizing questions and answers.
- :param annotation_file (str): location of VQA annotation file
- :return:
- """
- # load dataset
- self.dataset = {}
- self.questions = {}
- self.qa = {}
- self.qqa = {}
- self.imgToQA = {}
- if not annotation_file == None and not question_file == None:
- print("loading VQA annotations and questions into memory...")
- time_t = datetime.datetime.utcnow()
- dataset = json.load(open(annotation_file, "r"))
- questions = json.load(open(question_file, "r"))
- print(datetime.datetime.utcnow() - time_t)
- self.dataset = dataset
- self.questions = questions
- self.createIndex()
-
- def createIndex(self):
- # create index
- print("creating index...")
- imgToQA = {ann["image_id"]: [] for ann in self.dataset["annotations"]}
- qa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
- qqa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
- for ann in self.dataset["annotations"]:
- imgToQA[ann["image_id"]] += [ann]
- qa[ann["question_id"]] = ann
- for ques in self.questions["questions"]:
- qqa[ques["question_id"]] = ques
- print("index created!")
-
- # create class members
- self.qa = qa
- self.qqa = qqa
- self.imgToQA = imgToQA
-
- def info(self):
- """
- Print information about the VQA annotation file.
- :return:
- """
- for key, value in self.dataset["info"].items():
- print("%s: %s" % (key, value))
-
- def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
- """
- Get question ids that satisfy given filter conditions. default skips that filter
- :param imgIds (int array) : get question ids for given imgs
- quesTypes (str array) : get question ids for given question types
- ansTypes (str array) : get question ids for given answer types
- :return: ids (int array) : integer array of question ids
- """
- imgIds = imgIds if type(imgIds) == list else [imgIds]
- quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
- ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
-
- if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
- anns = self.dataset["annotations"]
- else:
- if not len(imgIds) == 0:
- anns = sum(
- [self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA],
- [],
- )
- else:
- anns = self.dataset["annotations"]
- anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann["question_type"] in quesTypes]
- anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann["answer_type"] in ansTypes]
- ids = [ann["question_id"] for ann in anns]
- return ids
-
- def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
- """
- Get image ids that satisfy given filter conditions. default skips that filter
- :param quesIds (int array) : get image ids for given question ids
- quesTypes (str array) : get image ids for given question types
- ansTypes (str array) : get image ids for given answer types
- :return: ids (int array) : integer array of image ids
- """
- quesIds = quesIds if type(quesIds) == list else [quesIds]
- quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
- ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
-
- if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
- anns = self.dataset["annotations"]
- else:
- if not len(quesIds) == 0:
- anns = sum([self.qa[quesId] for quesId in quesIds if quesId in self.qa], [])
- else:
- anns = self.dataset["annotations"]
- anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann["question_type"] in quesTypes]
- anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann["answer_type"] in ansTypes]
- ids = [ann["image_id"] for ann in anns]
- return ids
-
- def loadQA(self, ids=[]):
- """
- Load questions and answers with the specified question ids.
- :param ids (int array) : integer ids specifying question ids
- :return: qa (object array) : loaded qa objects
- """
- if type(ids) == list:
- return [self.qa[id] for id in ids]
- elif type(ids) == int:
- return [self.qa[ids]]
-
- def showQA(self, anns):
- """
- Display the specified annotations.
- :param anns (array of object): annotations to display
- :return: None
- """
- if len(anns) == 0:
- return 0
- for ann in anns:
- quesId = ann["question_id"]
- print("Question: %s" % (self.qqa[quesId]["question"]))
- for ans in ann["answers"]:
- print("Answer %d: %s" % (ans["answer_id"], ans["answer"]))
-
- def loadRes(self, resFile, quesFile):
- """
- Load result file and return a result object.
- :param resFile (str) : file name of result file
- :return: res (obj) : result api object
- """
- res = VQA()
- res.questions = json.load(open(quesFile))
- res.dataset["info"] = copy.deepcopy(self.questions["info"])
- res.dataset["task_type"] = copy.deepcopy(self.questions["task_type"])
- res.dataset["data_type"] = copy.deepcopy(self.questions["data_type"])
- res.dataset["data_subtype"] = copy.deepcopy(self.questions["data_subtype"])
- res.dataset["license"] = copy.deepcopy(self.questions["license"])
-
- print("Loading and preparing results... ")
- time_t = datetime.datetime.utcnow()
- anns = json.load(open(resFile))
- assert type(anns) == list, "results is not an array of objects"
- annsQuesIds = [ann["question_id"] for ann in anns]
- # print set of question ids that do not have corresponding annotations
-
- # assert set(annsQuesIds) == set(self.getQuesIds()), \
- # 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.'
- for ann in anns:
- quesId = ann["question_id"]
- if res.dataset["task_type"] == "Multiple Choice":
- assert ann["answer"] in self.qqa[quesId]["multiple_choices"], "predicted answer is not one of the multiple choices"
- qaAnn = self.qa[quesId]
- ann["image_id"] = qaAnn["image_id"]
- ann["question_type"] = qaAnn["question_type"]
- ann["answer_type"] = qaAnn["answer_type"]
- print("DONE (t=%0.2fs)" % ((datetime.datetime.utcnow() - time_t).total_seconds()))
-
- res.dataset["annotations"] = anns
- res.createIndex()
- return res
-
-
-class VQAEval:
- def __init__(self, vqa, vqaRes, n=2):
- self.n = n
- self.accuracy = {}
- self.evalQA = {}
- self.evalQuesType = {}
- self.evalAnsType = {}
- self.vqa = vqa
- self.vqaRes = vqaRes
- self.params = {"question_id": vqaRes.getQuesIds()}
- self.contractions = {
- "aint": "ain't",
- "arent": "aren't",
- "cant": "can't",
- "couldve": "could've",
- "couldnt": "couldn't",
- "couldn'tve": "couldn't've",
- "couldnt've": "couldn't've",
- "didnt": "didn't",
- "doesnt": "doesn't",
- "dont": "don't",
- "hadnt": "hadn't",
- "hadnt've": "hadn't've",
- "hadn'tve": "hadn't've",
- "hasnt": "hasn't",
- "havent": "haven't",
- "hed": "he'd",
- "hed've": "he'd've",
- "he'dve": "he'd've",
- "hes": "he's",
- "howd": "how'd",
- "howll": "how'll",
- "hows": "how's",
- "Id've": "I'd've",
- "I'dve": "I'd've",
- "Im": "I'm",
- "Ive": "I've",
- "isnt": "isn't",
- "itd": "it'd",
- "itd've": "it'd've",
- "it'dve": "it'd've",
- "itll": "it'll",
- "let's": "let's",
- "maam": "ma'am",
- "mightnt": "mightn't",
- "mightnt've": "mightn't've",
- "mightn'tve": "mightn't've",
- "mightve": "might've",
- "mustnt": "mustn't",
- "mustve": "must've",
- "neednt": "needn't",
- "notve": "not've",
- "oclock": "o'clock",
- "oughtnt": "oughtn't",
- "ow's'at": "'ow's'at",
- "'ows'at": "'ow's'at",
- "'ow'sat": "'ow's'at",
- "shant": "shan't",
- "shed've": "she'd've",
- "she'dve": "she'd've",
- "she's": "she's",
- "shouldve": "should've",
- "shouldnt": "shouldn't",
- "shouldnt've": "shouldn't've",
- "shouldn'tve": "shouldn't've",
- "somebody'd": "somebodyd",
- "somebodyd've": "somebody'd've",
- "somebody'dve": "somebody'd've",
- "somebodyll": "somebody'll",
- "somebodys": "somebody's",
- "someoned": "someone'd",
- "someoned've": "someone'd've",
- "someone'dve": "someone'd've",
- "someonell": "someone'll",
- "someones": "someone's",
- "somethingd": "something'd",
- "somethingd've": "something'd've",
- "something'dve": "something'd've",
- "somethingll": "something'll",
- "thats": "that's",
- "thered": "there'd",
- "thered've": "there'd've",
- "there'dve": "there'd've",
- "therere": "there're",
- "theres": "there's",
- "theyd": "they'd",
- "theyd've": "they'd've",
- "they'dve": "they'd've",
- "theyll": "they'll",
- "theyre": "they're",
- "theyve": "they've",
- "twas": "'twas",
- "wasnt": "wasn't",
- "wed've": "we'd've",
- "we'dve": "we'd've",
- "weve": "we've",
- "werent": "weren't",
- "whatll": "what'll",
- "whatre": "what're",
- "whats": "what's",
- "whatve": "what've",
- "whens": "when's",
- "whered": "where'd",
- "wheres": "where's",
- "whereve": "where've",
- "whod": "who'd",
- "whod've": "who'd've",
- "who'dve": "who'd've",
- "wholl": "who'll",
- "whos": "who's",
- "whove": "who've",
- "whyll": "why'll",
- "whyre": "why're",
- "whys": "why's",
- "wont": "won't",
- "wouldve": "would've",
- "wouldnt": "wouldn't",
- "wouldnt've": "wouldn't've",
- "wouldn'tve": "wouldn't've",
- "yall": "y'all",
- "yall'll": "y'all'll",
- "y'allll": "y'all'll",
- "yall'd've": "y'all'd've",
- "y'alld've": "y'all'd've",
- "y'all'dve": "y'all'd've",
- "youd": "you'd",
- "youd've": "you'd've",
- "you'dve": "you'd've",
- "youll": "you'll",
- "youre": "you're",
- "youve": "you've",
- }
- self.manualMap = {
- "none": "0",
- "zero": "0",
- "one": "1",
- "two": "2",
- "three": "3",
- "four": "4",
- "five": "5",
- "six": "6",
- "seven": "7",
- "eight": "8",
- "nine": "9",
- "ten": "10",
- }
- self.articles = ["a", "an", "the"]
-
- self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
- self.commaStrip = re.compile("(\d)(\,)(\d)")
- self.punct = [
- ";",
- r"/",
- "[",
- "]",
- '"',
- "{",
- "}",
- "(",
- ")",
- "=",
- "+",
- "\\",
- "_",
- "-",
- ">",
- "<",
- "@",
- "`",
- ",",
- "?",
- "!",
- ]
-
- def evaluate(self, quesIds=None):
- if quesIds == None:
- quesIds = [quesId for quesId in self.params["question_id"]]
- gts = {}
- res = {}
- for quesId in quesIds:
- gts[quesId] = self.vqa.qa[quesId]
- res[quesId] = self.vqaRes.qa[quesId]
-
- # =================================================
- # Compute accuracy
- # =================================================
- accQA = []
- accQuesType = {}
- accAnsType = {}
- print("computing accuracy")
- step = 0
- for quesId in quesIds:
- for ansDic in gts[quesId]["answers"]:
- ansDic["answer"] = ansDic["answer"].replace("\n", " ")
- ansDic["answer"] = ansDic["answer"].replace("\t", " ")
- ansDic["answer"] = ansDic["answer"].strip()
- resAns = res[quesId]["answer"]
- resAns = resAns.replace("\n", " ")
- resAns = resAns.replace("\t", " ")
- resAns = resAns.strip()
- gtAcc = []
- gtAnswers = [ans["answer"] for ans in gts[quesId]["answers"]]
-
- if len(set(gtAnswers)) > 1:
- for ansDic in gts[quesId]["answers"]:
- ansDic["answer"] = self.processPunctuation(ansDic["answer"])
- ansDic["answer"] = self.processDigitArticle(ansDic["answer"])
- resAns = self.processPunctuation(resAns)
- resAns = self.processDigitArticle(resAns)
-
- for gtAnsDatum in gts[quesId]["answers"]:
- otherGTAns = [item for item in gts[quesId]["answers"] if item != gtAnsDatum]
- matchingAns = [item for item in otherGTAns if item["answer"] == resAns]
- acc = min(1, float(len(matchingAns)) / 3)
- gtAcc.append(acc)
- quesType = gts[quesId]["question_type"]
- ansType = gts[quesId]["answer_type"]
- avgGTAcc = float(sum(gtAcc)) / len(gtAcc)
- accQA.append(avgGTAcc)
- if quesType not in accQuesType:
- accQuesType[quesType] = []
- accQuesType[quesType].append(avgGTAcc)
- if ansType not in accAnsType:
- accAnsType[ansType] = []
- accAnsType[ansType].append(avgGTAcc)
- self.setEvalQA(quesId, avgGTAcc)
- self.setEvalQuesType(quesId, quesType, avgGTAcc)
- self.setEvalAnsType(quesId, ansType, avgGTAcc)
- if step % 100 == 0:
- self.updateProgress(step / float(len(quesIds)))
- step = step + 1
-
- self.setAccuracy(accQA, accQuesType, accAnsType)
- print("Done computing accuracy")
-
- def processPunctuation(self, inText):
- outText = inText
- for p in self.punct:
- if (p + " " in inText or " " + p in inText) or (re.search(self.commaStrip, inText) != None):
- outText = outText.replace(p, "")
- else:
- outText = outText.replace(p, " ")
- outText = self.periodStrip.sub("", outText, re.UNICODE)
- return outText
-
- def processDigitArticle(self, inText):
- outText = []
- tempText = inText.lower().split()
- for word in tempText:
- word = self.manualMap.setdefault(word, word)
- if word not in self.articles:
- outText.append(word)
- else:
- pass
- for wordId, word in enumerate(outText):
- if word in self.contractions:
- outText[wordId] = self.contractions[word]
- outText = " ".join(outText)
- return outText
-
- def setAccuracy(self, accQA, accQuesType, accAnsType):
- self.accuracy["overall"] = round(100 * float(sum(accQA)) / len(accQA), self.n)
- self.accuracy["perQuestionType"] = {
- quesType: round(
- 100 * float(sum(accQuesType[quesType])) / len(accQuesType[quesType]),
- self.n,
- )
- for quesType in accQuesType
- }
- self.accuracy["perAnswerType"] = {ansType: round(100 * float(sum(accAnsType[ansType])) / len(accAnsType[ansType]), self.n) for ansType in accAnsType}
-
- def setEvalQA(self, quesId, acc):
- self.evalQA[quesId] = round(100 * acc, self.n)
-
- def setEvalQuesType(self, quesId, quesType, acc):
- if quesType not in self.evalQuesType:
- self.evalQuesType[quesType] = {}
- self.evalQuesType[quesType][quesId] = round(100 * acc, self.n)
-
- def setEvalAnsType(self, quesId, ansType, acc):
- if ansType not in self.evalAnsType:
- self.evalAnsType[ansType] = {}
- self.evalAnsType[ansType][quesId] = round(100 * acc, self.n)
-
- def updateProgress(self, progress):
- barLength = 20
- status = ""
- if isinstance(progress, int):
- progress = float(progress)
- if not isinstance(progress, float):
- progress = 0
- status = "error: progress var must be float\r\n"
- if progress < 0:
- progress = 0
- status = "Halt...\r\n"
- if progress >= 1:
- progress = 1
- status = "Done...\r\n"
- block = int(round(barLength * progress))
- text = "\rFinshed Percent: [{0}] {1}% {2}".format("#" * block + "-" * (barLength - block), int(progress * 100), status)
- sys.stdout.write(text)
- sys.stdout.flush()
-
-
-def compute_vqa_accuracy(result_json_path, question_json_path, annotation_json_path):
- """Compute the VQA accuracy metric.
-
- Args:
- predictions (List): list of predictions
- ground_truth (List[List]): list of all possible ground truth answers
-
- Returns:
- float: VQA accuracy
- """
- # coding: utf-8
- # dataDir = data_dir
-
- # set up file names and paths
- # versionType = 'v2_' # this should be '' when using VQA v2.0 dataset
- # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
- # taskType = 'OpenEnded'
- # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
- # dataType = 'mscoco'
- # dataSubType = 'train2014'
- # annFile = '%s/%s%s_%s_annotations.json' % (
- # dataDir, versionType, dataType, dataSubType)
- # quesFile = '%s/%s%s_%s_%s_questions.json' % (
- # dataDir, versionType, taskType, dataType, dataSubType)
- # imgDir = '%s/%s/%s/' % (dataDir, dataType, dataSubType)
- # resultType = res_file_name
- # fileTypes = ['results', 'accuracy',
- # 'evalQA', 'evalQuesType', 'evalAnsType']
-
- # An example result json file has been provided in './Results' folder.
-
- # [resFile, accuracyFile, evalQAFile, evalQuesTypeFile, evalAnsTypeFile] = ['%s/%s%s_%s_%s_%s_%s.json' % (dataDir, versionType, taskType, dataType, dataSubType,
- # resultType, fileType) for fileType in fileTypes]
-
- # create vqa object and vqaRes object
- vqa = VQA(annotation_json_path, question_json_path)
- vqaRes = vqa.loadRes(result_json_path, question_json_path)
-
- # create vqaEval object by taking vqa and vqaRes
- # n is precision of accuracy (number of places after decimal), default is 2
- vqaEval = VQAEval(vqa, vqaRes, n=2)
-
- # evaluate results
- """
- If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function
- By default it uses all the question ids in annotation file
- """
- vqaEval.evaluate()
-
- return vqaEval.accuracy["overall"]
-
-
-def postprocess_vqa_generation(predictions):
- return re.split("Question|Answer", predictions, 1)[0]