diff --git a/libs/py_core/py_core/config.py b/libs/py_core/py_core/config.py index 8cc13ae..08a3f0b 100644 --- a/libs/py_core/py_core/config.py +++ b/libs/py_core/py_core/config.py @@ -7,3 +7,6 @@ class AACessTalkConfig: parent_example_translation_dictionary_path: str = path.join(dataset_dir_path, "parent_example_translation_dictionary.csv") card_image_directory_path: str = path.join(dataset_dir_path, "cards") card_image_table_path: str = path.join(dataset_dir_path, "cards_image_info.csv") + card_image_embeddings_path: str = path.join(dataset_dir_path, "cards_image_desc_embeddings.npz") + + embedding_model = 'text-embedding-3-large' diff --git a/libs/py_core/py_core/processing_tools/generate_image_description.py b/libs/py_core/py_core/processing_tools/generate_image_description.py index f1b5a38..182468e 100644 --- a/libs/py_core/py_core/processing_tools/generate_image_description.py +++ b/libs/py_core/py_core/processing_tools/generate_image_description.py @@ -6,7 +6,9 @@ from os import listdir, scandir, path from time import perf_counter +import numpy from chatlib.global_config import GlobalConfig +from numpy import array from py_core.config import AACessTalkConfig from PIL import ImageOps @@ -14,7 +16,8 @@ from py_core.utils.math import cosine_similarity from py_core.utils.models import CardImageInfo from chatlib.utils import env_helper -from chatlib.tool.versatile_mapper import ChatCompletionFewShotMapper, ChatCompletionFewShotMapperParams, MapperInputOutputPair +from chatlib.tool.versatile_mapper import ChatCompletionFewShotMapper, ChatCompletionFewShotMapperParams, \ + MapperInputOutputPair from chatlib.llm.integration import GPTChatCompletionAPI, GeminiAPI, ChatGPTModel from openai import OpenAI import google.generativeai as genai @@ -68,12 +71,14 @@ def _load_card_descriptions() -> list[CardImageInfo]: print(f"loaded {len(rows)} card info list.") return rows + def _save_card_descriptions(rows: list[CardImageInfo]): with open(AACessTalkConfig.card_image_table_path, "w") as csvfile: writer = csv.DictWriter(csvfile, fieldnames=CardImageInfo.model_fields) writer.writeheader() writer.writerows([row.model_dump() for row in rows]) + def generate_card_descriptions_all(openai_client: OpenAI): rows = _load_card_descriptions() @@ -190,12 +195,13 @@ def generate_card_description_gpt4(info: CardImageInfo, client: OpenAI) -> str: return response.choices[0].message.content + def fix_refused_requests(threshold: float, client: OpenAI): - model = 'text-embedding-3-large' + model = AACessTalkConfig.embedding_model rows = _load_card_descriptions() suspicious_rows = [row for row in rows if row.description is not None and ( - ("sorry" in row.description) or ("cannot" in row.description) or ("assistance" in row.description))] + ("sorry" in row.description) or ("cannot" in row.description) or ("assistance" in row.description))] print(f"{len(suspicious_rows)} rows are suspicious.") @@ -228,6 +234,7 @@ def fix_refused_requests(threshold: float, client: OpenAI): print(e) pass + async def generate_short_descriptions_all(): mapper = ChatCompletionFewShotMapper.make_str_mapper( GPTChatCompletionAPI(), @@ -257,9 +264,11 @@ async def generate_short_descriptions_all(): if row.description is not None and row.description_brief is None: try: print(f"Processing {i}/{len(rows)}...") - description = await mapper.run(examples, row.description, ChatCompletionFewShotMapperParams(model=ChatGPTModel.GPT_4_0613, api_params ={})) - #print("[Condensed]", description) - #print("[Original]", row.description) + description = await mapper.run(examples, row.description, + ChatCompletionFewShotMapperParams(model=ChatGPTModel.GPT_4_0613, + api_params={})) + # print("[Condensed]", description) + # print("[Original]", row.description) row.description_brief = description _save_card_descriptions(rows) except Exception as e: @@ -270,7 +279,27 @@ async def generate_short_descriptions_all(): desc_words = row.description.split(" ") desc_brief_words = row.description_brief.split(" ") - print(f"{i}/{len(rows)} description word count reduced from {len(desc_words)} to {len(desc_brief_words)} ({len(desc_brief_words)/len(desc_words) * 100}%)") + print( + f"{i}/{len(rows)} description word count reduced from {len(desc_words)} to {len(desc_brief_words)} ({len(desc_brief_words) / len(desc_words) * 100}%)") + + +def cache_description_embeddings_all(client: OpenAI): + rows = _load_card_descriptions() + chunk_size = 2048 + embeddings = [] + for chunk_i in range(0, len(rows), chunk_size): + result = client.embeddings.create(input=[row.description_brief.replace("\n", " ") for row in rows[chunk_i : chunk_i + chunk_size]], + model=AACessTalkConfig.embedding_model, + dimensions=256 + ) + + print("Generated embeddings.") + embeddings.extend([datum.embedding for datum in result.data]) + + embedding_array = array(embeddings) + with open(AACessTalkConfig.card_image_embeddings_path, 'wb') as f: + numpy.savez_compressed(f, ids=[row.id for row in rows], descriptions=embedding_array) + print("Serialized embeddings to file.") if __name__ == "__main__": @@ -287,4 +316,6 @@ async def generate_short_descriptions_all(): # generate_card_descriptions_all(openai_client) # fix_refused_requests(threshold=0.4, client=openai_client) - asyncio.run(generate_short_descriptions_all()) + # asyncio.run(generate_short_descriptions_all()) + cache_description_embeddings_all(openai_client) + diff --git a/libs/py_core/py_core/system/task/card_recommendation/card_image_retriever.py b/libs/py_core/py_core/system/task/card_recommendation/card_image_retriever.py new file mode 100644 index 0000000..210ed96 --- /dev/null +++ b/libs/py_core/py_core/system/task/card_recommendation/card_image_retriever.py @@ -0,0 +1,3 @@ +class CardImageRetriever: + def __init__(self): + pass diff --git a/libs/py_core/py_core/utils/vector_db.py b/libs/py_core/py_core/utils/vector_db.py index e636c44..2d03173 100644 --- a/libs/py_core/py_core/utils/vector_db.py +++ b/libs/py_core/py_core/utils/vector_db.py @@ -11,7 +11,8 @@ class VectorDB: - def __init__(self, dir_name: str = "embeddings", embedding_model: str = "text-embedding-3-small"): + def __init__(self, dir_name: str = "embeddings", + embedding_model: str = "text-embedding-3-small"): #self.__client = chromadb.PersistentClient(path.join(AACessTalkConfig.dataset_dir_path, dir_name)) self.__client = chromadb.Client()