Skip to content

Commit

Permalink
Implementing CardImageRetriever. Added card vector generation.
Browse files Browse the repository at this point in the history
  • Loading branch information
yghokim committed Apr 9, 2024
1 parent 8fcf179 commit ba0c8f2
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 9 deletions.
3 changes: 3 additions & 0 deletions libs/py_core/py_core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ class AACessTalkConfig:
parent_example_translation_dictionary_path: str = path.join(dataset_dir_path, "parent_example_translation_dictionary.csv")
card_image_directory_path: str = path.join(dataset_dir_path, "cards")
card_image_table_path: str = path.join(dataset_dir_path, "cards_image_info.csv")
card_image_embeddings_path: str = path.join(dataset_dir_path, "cards_image_desc_embeddings.npz")

embedding_model = 'text-embedding-3-large'
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@
from os import listdir, scandir, path
from time import perf_counter

import numpy
from chatlib.global_config import GlobalConfig
from numpy import array

from py_core.config import AACessTalkConfig
from PIL import ImageOps

from py_core.utils.math import cosine_similarity
from py_core.utils.models import CardImageInfo
from chatlib.utils import env_helper
from chatlib.tool.versatile_mapper import ChatCompletionFewShotMapper, ChatCompletionFewShotMapperParams, MapperInputOutputPair
from chatlib.tool.versatile_mapper import ChatCompletionFewShotMapper, ChatCompletionFewShotMapperParams, \
MapperInputOutputPair
from chatlib.llm.integration import GPTChatCompletionAPI, GeminiAPI, ChatGPTModel
from openai import OpenAI
import google.generativeai as genai
Expand Down Expand Up @@ -68,12 +71,14 @@ def _load_card_descriptions() -> list[CardImageInfo]:
print(f"loaded {len(rows)} card info list.")
return rows


def _save_card_descriptions(rows: list[CardImageInfo]):
with open(AACessTalkConfig.card_image_table_path, "w") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=CardImageInfo.model_fields)
writer.writeheader()
writer.writerows([row.model_dump() for row in rows])


def generate_card_descriptions_all(openai_client: OpenAI):
rows = _load_card_descriptions()

Expand Down Expand Up @@ -190,12 +195,13 @@ def generate_card_description_gpt4(info: CardImageInfo, client: OpenAI) -> str:
return response.choices[0].message.content



def fix_refused_requests(threshold: float, client: OpenAI):
model = 'text-embedding-3-large'
model = AACessTalkConfig.embedding_model

rows = _load_card_descriptions()
suspicious_rows = [row for row in rows if row.description is not None and (
("sorry" in row.description) or ("cannot" in row.description) or ("assistance" in row.description))]
("sorry" in row.description) or ("cannot" in row.description) or ("assistance" in row.description))]

print(f"{len(suspicious_rows)} rows are suspicious.")

Expand Down Expand Up @@ -228,6 +234,7 @@ def fix_refused_requests(threshold: float, client: OpenAI):
print(e)
pass


async def generate_short_descriptions_all():
mapper = ChatCompletionFewShotMapper.make_str_mapper(
GPTChatCompletionAPI(),
Expand Down Expand Up @@ -257,9 +264,11 @@ async def generate_short_descriptions_all():
if row.description is not None and row.description_brief is None:
try:
print(f"Processing {i}/{len(rows)}...")
description = await mapper.run(examples, row.description, ChatCompletionFewShotMapperParams(model=ChatGPTModel.GPT_4_0613, api_params ={}))
#print("[Condensed]", description)
#print("[Original]", row.description)
description = await mapper.run(examples, row.description,
ChatCompletionFewShotMapperParams(model=ChatGPTModel.GPT_4_0613,
api_params={}))
# print("[Condensed]", description)
# print("[Original]", row.description)
row.description_brief = description
_save_card_descriptions(rows)
except Exception as e:
Expand All @@ -270,7 +279,27 @@ async def generate_short_descriptions_all():
desc_words = row.description.split(" ")
desc_brief_words = row.description_brief.split(" ")

print(f"{i}/{len(rows)} description word count reduced from {len(desc_words)} to {len(desc_brief_words)} ({len(desc_brief_words)/len(desc_words) * 100}%)")
print(
f"{i}/{len(rows)} description word count reduced from {len(desc_words)} to {len(desc_brief_words)} ({len(desc_brief_words) / len(desc_words) * 100}%)")


def cache_description_embeddings_all(client: OpenAI):
rows = _load_card_descriptions()
chunk_size = 2048
embeddings = []
for chunk_i in range(0, len(rows), chunk_size):
result = client.embeddings.create(input=[row.description_brief.replace("\n", " ") for row in rows[chunk_i : chunk_i + chunk_size]],
model=AACessTalkConfig.embedding_model,
dimensions=256
)

print("Generated embeddings.")
embeddings.extend([datum.embedding for datum in result.data])

embedding_array = array(embeddings)
with open(AACessTalkConfig.card_image_embeddings_path, 'wb') as f:
numpy.savez_compressed(f, ids=[row.id for row in rows], descriptions=embedding_array)
print("Serialized embeddings to file.")


if __name__ == "__main__":
Expand All @@ -287,4 +316,6 @@ async def generate_short_descriptions_all():

# generate_card_descriptions_all(openai_client)
# fix_refused_requests(threshold=0.4, client=openai_client)
asyncio.run(generate_short_descriptions_all())
# asyncio.run(generate_short_descriptions_all())
cache_description_embeddings_all(openai_client)

Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class CardImageRetriever:
def __init__(self):
pass
3 changes: 2 additions & 1 deletion libs/py_core/py_core/utils/vector_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

class VectorDB:

def __init__(self, dir_name: str = "embeddings", embedding_model: str = "text-embedding-3-small"):
def __init__(self, dir_name: str = "embeddings",
embedding_model: str = "text-embedding-3-small"):
#self.__client = chromadb.PersistentClient(path.join(AACessTalkConfig.dataset_dir_path, dir_name))
self.__client = chromadb.Client()

Expand Down

0 comments on commit ba0c8f2

Please sign in to comment.