Skip to content

Commit

Permalink
Integrates BiGRU VQA model and best n questions removes based on thre…
Browse files Browse the repository at this point in the history
…shold not exact match
  • Loading branch information
qlamar3 committed Dec 5, 2023
1 parent aea09c5 commit 08272e0
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 43 deletions.
36 changes: 19 additions & 17 deletions models/BLIP.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
BlipForQuestionAnswering,
)
from transformers import AutoProcessor, AutoModelForCausalLM
from models.VQA import generateAnswer

from loguru import logger
# from loguru import logger

log_folder = "logs"
log_file = f"{datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}.log"
if not os.path.exists(log_folder):
os.makedirs(log_folder)
logger.add(os.path.join(log_folder, log_file), enqueue=True)
# logger.add(os.path.join(log_folder, log_file), enqueue=True)

import warnings

Expand All @@ -28,7 +29,7 @@
model_dict = defaultdict(dict)
processor_dict = defaultdict(dict)

logger.info("Loading models...")
# logger.info("Loading models...")
for model_type in ["base", "large"]:
model_name = f"Salesforce/blip-image-captioning-{model_type}"
processor = BlipProcessor.from_pretrained(model_name)
Expand Down Expand Up @@ -59,58 +60,58 @@

processor_dict[f"git_{model_type}"]["vqa"] = processor
model_dict[f"git_{model_type}"]["vqa"] = model
logger.info("Loaded models...")
# logger.info("Loaded models...")


def get_blip_caption(
image, model_type="base", text="a photography of", max_new_tokens=50
):
logger.info(f"{model_type}")
# logger.info(f"{model_type}")
processor = processor_dict[f"blip_{model_type}"]["caption"]
model = model_dict[f"blip_{model_type}"]["caption"]
inputs = processor(image, text, return_tensors="pt")
caption_ids = model.generate(max_new_tokens=max_new_tokens, **inputs)
response = processor.decode(caption_ids[0], skip_special_tokens=True)
if (response is not None) and (len(response) > 0):
logger.info(f"response: {response}")
# logger.info(f"response: {response}")
return response
else:
logger.warning(f"response: {response}")
# logger.warning(f"response: {response}")
return "No response generated from the model."


def get_blip_vqa(image, question, model_type="base"):
logger.info(f"{model_type}")
# logger.info(f"{model_type}")
processor = processor_dict[f"blip_{model_type}"]["vqa"]
model = model_dict[f"blip_{model_type}"]["vqa"]
inputs = processor(image, question, return_tensors="pt")
out = model.generate(**inputs)
response = processor.decode(out[0], skip_special_tokens=True)
if (response is not None) and (len(response) > 0):
logger.info(f"response: {response}")
# logger.info(f"response: {response}")
return response
else:
logger.warning(f"response: {response}")
# logger.warning(f"response: {response}")
return "No response generated from the model."


def get_git_caption(image, model_type="base"):
logger.info(f"{model_type}")
# logger.info(f"{model_type}")
processor = processor_dict[f"git_{model_type}"]["caption"]
model = model_dict[f"git_{model_type}"]["caption"]
pixel_values = processor(images=image, return_tensors="pt").pixel_values
generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
response = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
if (response is not None) and (len(response) > 0):
logger.info(f"response: {response}")
# logger.info(f"response: {response}")
return response
else:
logger.warning(f"response: {response}")
# logger.warning(f"response: {response}")
return "No response generated from the model."


def get_git_vqa(image, question, model_type="base"):
logger.info(f"{model_type}")
# logger.info(f"{model_type}")
processor = processor_dict[f"git_{model_type}"]["vqa"]
model = model_dict[f"git_{model_type}"]["vqa"]
pixel_values = processor(images=image, return_tensors="pt").pixel_values
Expand All @@ -122,10 +123,10 @@ def get_git_vqa(image, question, model_type="base"):
)
response = processor.batch_decode(generated_ids, skip_special_tokens=True)
if (response is not None) and (len(response) > 0):
logger.info(f"response: {response}")
# logger.info(f"response: {response}")
return response[0].replace(question.lower(), "").strip()
else:
logger.warning(f"response: {response}")
# logger.warning(f"response: {response}")
return "No response generated from the model."


Expand All @@ -140,11 +141,12 @@ def get_caption(image, model="blip", model_type="base"):


def get_vqa(image, question, model="blip", model_type="base"):
logger.info(f"Using {model}-{model_type}")
# logger.info(f"Using {model}-{model_type}")
raw_image = np.array(image)
raw_image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB)
if model.lower() == "blip":
answer = get_blip_vqa(raw_image, question, model_type=model_type)
# answer = generateAnswer(raw_image, question)
else:
answer = get_git_vqa(raw_image, question, model_type=model_type)
return answer
12 changes: 6 additions & 6 deletions models/VQA.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
import torchvision.transforms as transforms
import torchvision.models as models

from support.VocabDictionary import VocabDict
from support.VQAModel import ImgEncoder, QstEncoder, VqaModel, ImgAttentionEncoder, Attention, SANModel
from support.VQADataset import VqaDataset, Args, get_loader
from resources.helper_functions import tokenize, load_str_list, resize_image
from models.support.VocabDictionary import VocabDict
from models.support.VQAModel import ImgEncoder, QstEncoder, VqaModel, ImgAttentionEncoder, Attention, SANModel
from models.support.VQADataset import VqaDataset, Args, get_loader
from models.resources.helper_functions import tokenize, load_str_list, resize_image

# Works without function if we want
# args = Args()
Expand Down Expand Up @@ -85,13 +85,13 @@ def generateAnswer(img, question):
num_workers=args.num_workers
)

q_vocab = VocabDict('./resources/vocab_questions.txt')
q_vocab = VocabDict('./models/resources/vocab_questions.txt')
max_q_length = 30
tokens = tokenize(question)
q2idc = np.array([q_vocab.word2idx('<pad>')] * max_q_length)
q2idc[:len(tokens)] = [q_vocab.word2idx(w) for w in tokens]

model = torch.load('./resources/best_model.pt', map_location=torch.device('cpu'))
model = torch.load('./models/resources/best_model.pt', map_location=torch.device('cpu'))
test_q = torch.from_numpy(q2idc)

test_q = torch.from_numpy(q2idc)
Expand Down
10 changes: 5 additions & 5 deletions models/support/VQADataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torchvision.transforms as transforms
import torch.utils.data as data
import numpy as np
from support.VocabDictionary import VocabDict
from models.support.VocabDictionary import VocabDict
import torch
import re

Expand All @@ -11,7 +11,7 @@

class Args:
model_name = 'best_model'
input_dir = './resources'
input_dir = './models/resources'
log_dir = 'logs'
model_dir = 'models'
max_qst_length = 30
Expand All @@ -28,16 +28,16 @@ class Args:
num_workers = 8
save_step = 1
resume_epoch = 15
saved_model = './resources/best_model.pt'
saved_model = './models/resources/best_model.pt'


# Data Loader
class VqaDataset(data.Dataset):
def __init__(self, input_dir, input_vqa, max_qst_length=30, max_num_ans=10, transform=None):
self.input_dir = input_dir
self.vqa = np.load(input_dir+'/'+input_vqa, allow_pickle=True)
self.qst_vocab = VocabDict('./resources/vocab_questions.txt')
self.ans_vocab = VocabDict('./resources/vocab_answers.txt')
self.qst_vocab = VocabDict('./models/resources/vocab_questions.txt')
self.ans_vocab = VocabDict('./models/resources/vocab_answers.txt')
self.max_qst_length = max_qst_length
self.max_num_ans = max_num_ans
self.load_ans = ('valid_answers' in self.vqa[0]) and (self.vqa[0]['valid_answers'] is not None)
Expand Down
2 changes: 1 addition & 1 deletion models/support/VocabDictionary.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from resources.helper_functions import tokenize, load_str_list, resize_image
from models.resources.helper_functions import tokenize, load_str_list, resize_image

class VocabDict:
def __init__(self, vocab_file):
Expand Down
23 changes: 10 additions & 13 deletions pseudocode.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,17 @@ def calculate_scores(scores_arr, sentences_a, sentences_b, model):
return scores_arr


def remove_questions(scores_arr, question_a, question_b):
def remove_questions(scores_arr, question_a, question_b, model):
to_remove = []

for score in scores_arr:
threshold = 0.7
# If we want to implement threshold, we can have the following:
# if cosine(score[0], question_a) < threshold or cosine(score[1], question_b) < threshold
threshold = 0.75

# I think that we should probably compare score[0] to question_b as well as score[1] to question_a
# The reason I am not so far is because I looked for exact matches, but since we probably want threshold then
# it would be necessary to compare the 4 possible pairs between score[0], score[1] and question_a, question_b

if score[0] == question_a or score[1] == question_b:
to_remove.append(score)
if cosine(model.encode(score[0]), model.encode(question_a)) >= threshold and \
cosine(model.encode(score[0]), model.encode(question_b)) >= threshold and \
cosine(model.encode(score[1]), model.encode(question_a)) >= threshold and \
cosine(model.encode(score[1]), model.encode(question_b)) >= threshold:
to_remove.append(score)

# Removes all of the removable scores
for score in to_remove:
Expand All @@ -74,9 +71,9 @@ def chooseBestNQuestions(first, second, n):
chosen_score = new_scores[0]
chosen.append(chosen_score[0])
new_scores.remove(chosen_score)
new_scores = remove_questions(new_scores, chosen_score[0], chosen_score[1])
new_scores = remove_questions(new_scores, chosen_score[0], chosen_score[1], model)

return chosen

chosen = chooseBestNQuestions(first_model, second_model, 3)
print("Chosen: ", chosen)
# chosen = chooseBestNQuestions(first_model, second_model, 3)
# print("Chosen: ", chosen)
6 changes: 5 additions & 1 deletion ui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
from models.Mistral import Mistral

from pseudocode import chooseBestNQuestions
from models.support.VocabDictionary import VocabDict
from models.support.VQAModel import ImgEncoder, QstEncoder, VqaModel, ImgAttentionEncoder, Attention, SANModel
from models.support.VQADataset import VqaDataset, Args, get_loader
from models.resources.helper_functions import tokenize, load_str_list, resize_image

replicate_api_key = ""
replicate_api_key = "r8_TNoIA5wIRHm9v4jUgwkcpNaSqkHURB23PhkK8"
llama = LLaMA(replicate_api_key)
mistral = Mistral(replicate_api_key)

Expand Down

0 comments on commit 08272e0

Please sign in to comment.