Multimodal-CoT incorporates vision features in a decoupled training framework. The framework consists of two training stages: (i) rationale generation and (ii) answer inference. Both stages share the same model architecture but differ in the input and output.
Install all required python dependencies:
pip install -r requirements.txt
Download the dataset from the following repository in images folder:
https://github.com/lupantech/ScienceQA/tree/main/data
The following instructions show how we obtain those features.
Download the image files from Google Drive and unzip all the images (train, dev, test) in the same folder (). The structure should be:
Run extract_features.py --data_root images --output_dir vision_features --img_type vit
The processed captions for ScienceQA are available in data/instruct_captions
folder.
The following instructions show how we obtain those features.
Intall lavis and prepare Vicuna weights to use InstructBLIP for caption extraction.
Assume that the images are stored in the images
folder.
python extract_caption.py
# rationale generation
CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py \
--data_root data/ScienceQA/data \
--caption_file data/instruct_captions.json \
--model declare-lab/flan-alpaca-large \
--user_msg rationale --img_type vit \
--bs 2 --eval_bs 4 --epoch 50 --lr 5e-5 --output_len 512 \
--use_caption --use_generate --prompt_format QCM-E \
--output_dir experiments
# answer inference
CUDA_VISIBLE_DEVICES=0,1,2,3 python main_central.py \
--data_root data/ScienceQA/data \
--caption_file data/instruct_captions.json \
--model declare-lab/flan-alpaca-large \
--user_msg answer --img_type vit \
--bs 4 --eval_bs 8 --epoch 50 --lr 5e-5 --output_len 64 \
--use_caption --use_generate --prompt_format QCMG-A \
--output_dir experiments \
--eval_le experiments/rationale_declare-lab-flan-alpaca-large_vit_QCM-E_lr5e-05_bs8_op512_ep50/predictions_ans_eval.json \
--test_le experiments/rationale_declare-lab-flan-alpaca-large_vit_QCM-E_lr5e-05_bs8_op512_ep50/predictions_ans_test.json
Part of our codes are adapted from mm-cot, ScienceQA, Transformers, pytorch-image-models.