Skip to content

Commit

Permalink
update : get detailed responses
Browse files Browse the repository at this point in the history
  • Loading branch information
deepakgouda committed Dec 2, 2023
1 parent e79a5a0 commit 25ac195
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 17 deletions.
29 changes: 20 additions & 9 deletions models/LLaMA.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@ def __init__(self, api_key):
self.min_new_tokens = -1
os.environ["REPLICATE_API_TOKEN"] = api_key

def get_questions(self, caption, num_questions=5):
# caption = "A dog playing with a football in a field"
prompt = f"Given the following caption, generate {num_questions} questions about the picture that you can ask a Visual Question and Answer Model. |{caption}|"

def call_llm(self, prompt):
output = replicate.run(
self.model_name,
input={
Expand All @@ -31,15 +28,29 @@ def get_questions(self, caption, num_questions=5):
"min_new_tokens": self.min_new_tokens,
},
)

res = []
for o in output:
res.append(o)
return "".join(res).splitlines()

question_list = "".join(res).splitlines()
print(question_list)
question_list = [q.strip() for q in question_list if len(q) > 0]
return question_list
def get_questions(self, caption, num_questions=5):
prompt = f"Given the following caption, generate {num_questions} questions about the picture that you can ask a Visual Question and Answer Model. |{caption}|"

question_list = self.call_llm(prompt)
# print(question_list)
res = []
for q in question_list:
if len(q) == 0:
continue
if q[-1] != "?":
continue
res.append(q.strip()[3:])
return res

def get_complete_summary(self, qna_list):
prompt = f"Given the following questions and answers, generate a detailed summary of the image. |{' '.join(qna_list)}|"
respone = self.call_llm(prompt)
return respone


if __name__ == "__main__":
Expand Down
28 changes: 21 additions & 7 deletions models/Mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ def __init__(self, api_key):
self.min_new_tokens = -1
os.environ["REPLICATE_API_TOKEN"] = api_key

def get_questions(self, caption, num_questions=5):
prompt = f"Given the following caption, generate {num_questions} questions about the picture that you can ask a Visual Question and Answer Model. |{caption}|"

def call_llm(self, prompt):
output = replicate.run(
self.model_name,
input={
Expand All @@ -34,10 +32,26 @@ def get_questions(self, caption, num_questions=5):
for o in output:
res.append(o)

question_list = "".join(res).splitlines()
print(question_list)
question_list = [q.strip() for q in question_list if len(q) > 0]
return question_list
return "".join(res).splitlines()

def get_questions(self, caption, num_questions=5):
prompt = f"Given the following caption, generate {num_questions} questions about the picture that you can ask a Visual Question and Answer Model. |{caption}|"

question_list = self.call_llm(prompt)
# print(question_list)
res = []
for q in question_list:
if len(q) == 0:
continue
if q[-1] != "?":
continue
res.append(q.strip()[3:])
return res

def get_complete_summary(self, qna_list):
prompt = f"Given the following questions and answers, generate a detailed summary of the image. |{' '.join(qna_list)}|"
respone = self.call_llm(prompt)
return respone


if __name__ == "__main__":
Expand Down
24 changes: 23 additions & 1 deletion ui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from models.LLaMA import LLaMA
from models.Mistral import Mistral

from pseudocode import chooseBestNQuestions

replicate_api_key = ""
llama = LLaMA(replicate_api_key)
mistral = Mistral(replicate_api_key)
Expand All @@ -32,6 +34,9 @@ def text_to_speech(text):
if "file_name" not in st.session_state.keys():
st.session_state.file_name = ""

if "final_response" not in st.session_state.keys():
st.session_state.final_response = ""

uploaded_file = st.file_uploader("Choose the image you would like described.")

model = st.selectbox("Select the model you would like to use.", ["BLIP", "GIT"])
Expand All @@ -43,7 +48,8 @@ def text_to_speech(text):
model_type = model_type.lower()

llm_type = st.selectbox(
"Select the model size you would like to use.", ["LLaMA-2-70B", "Mistral-7B"]
"Select the model you would like to use for summarization.",
["LLaMA-2-70B", "Mistral-7B"],
)
llm_type = llm_type.lower()

Expand All @@ -69,6 +75,22 @@ def text_to_speech(text):
questions_list_2 = mistral.get_questions(model_output, num_questions=5)
print(questions_list_1)
print(questions_list_2)
questions_list = chooseBestNQuestions(
questions_list_1, questions_list_2, 5
)
print(questions_list)

qna_list = []
for question in questions_list:
vqa_output = get_vqa(image, question, model, model_type)
qna_list.append(f"Question: {question}, Answer: {vqa_output}")
print(qna_list)
if llm_type == "mistral-7b":
response = mistral.get_complete_summary(qna_list)
else:
response = llama.get_complete_summary(qna_list)
print(response)
st.session_state.final_response = response
# placeholder.success("Done...")

if st.session_state.model_output != "":
Expand Down

0 comments on commit 25ac195

Please sign in to comment.