From b9878722b735ce15edf009f7a221cb553521e451 Mon Sep 17 00:00:00 2001 From: Anbraten Date: Thu, 31 Aug 2023 08:48:36 +0200 Subject: [PATCH] feat: add chat history to ai service (#23) --- ai/models.py | 2 +- ai/pyserver.py | 4 +-- ai/question.py | 34 ++++++++++++++++++++----- server/api/repos/[repo_id]/chat.post.ts | 1 + 4 files changed, 31 insertions(+), 10 deletions(-) diff --git a/ai/models.py b/ai/models.py index fd78cbf..a0a4e93 100644 --- a/ai/models.py +++ b/ai/models.py @@ -14,5 +14,5 @@ class IssueIndexInfo(BaseModel): class Conversation(BaseModel): repo_id: int question: str + chat_id: int answer: str = "" - chat_history: Tuple[(str, str)] = [] diff --git a/ai/pyserver.py b/ai/pyserver.py index 2493368..114a8a1 100644 --- a/ai/pyserver.py +++ b/ai/pyserver.py @@ -24,9 +24,9 @@ def perform_index(indexInfo: IndexInfo): @app.post("/ask") def conversation(conversationInput: Conversation): - conversationInput.answer, conversationInput.chat_history = ask( + conversationInput.answer = ask( conversationInput.repo_id, + conversationInput.chat_id, conversationInput.question, - conversationInput.chat_history, ) return conversationInput diff --git a/ai/question.py b/ai/question.py index 28aaf04..99a2d7d 100644 --- a/ai/question.py +++ b/ai/question.py @@ -4,6 +4,7 @@ from langchain.embeddings.openai import OpenAIEmbeddings from langchain.vectorstores import FAISS from langchain.chains import ConversationalRetrievalChain +from langchain.memory import ConversationBufferMemory from langchain.chat_models import ChatOpenAI from dotenv import load_dotenv @@ -11,8 +12,18 @@ data_path = os.getenv("DATA_PATH") +chatMemories = {} -def ask(repo_id: int, question: str, chat_history=[]): + +def _get_chat(chat_id: int): + if chat_id not in chatMemories: + chatMemories[chat_id] = ConversationBufferMemory( + memory_key="chat_history", return_messages=True, output_key="answer" + ) + return chatMemories[chat_id] + + +def ask(repo_id: int, chat_id: int, question: str): embeddings = OpenAIEmbeddings(disallowed_special=()) repo_path = os.path.join(data_path, str(repo_id)) @@ -27,18 +38,27 @@ def ask(repo_id: int, question: str, chat_history=[]): retriever.search_kwargs["maximal_marginal_relevance"] = True retriever.search_kwargs["k"] = 20 - model = ChatOpenAI(model_name="gpt-3.5-turbo-16k") - qa = ConversationalRetrievalChain.from_llm(model, retriever=retriever) + memory = _get_chat(chat_id) + + qa = ConversationalRetrievalChain.from_llm( + llm=ChatOpenAI(temperature=0), + memory=memory, + retriever=retriever, + return_source_documents=True, + ) end = time.time() - chat_history = [] - result = qa({"question": question, "chat_history": chat_history}) - chat_history.append((question, result["answer"])) - return result["answer"], chat_history + result = qa(question) + return result["answer"] end = time.time() +def close_chat(chat_id: int): + if chat_id in chatMemories: + del chatMemories[chat_id] + + if __name__ == "__main__": # print(f">>>{sys.argv}<<<\n") ask(sys.argv[1], sys.argv[2]) diff --git a/server/api/repos/[repo_id]/chat.post.ts b/server/api/repos/[repo_id]/chat.post.ts index 657504a..19a5a9e 100644 --- a/server/api/repos/[repo_id]/chat.post.ts +++ b/server/api/repos/[repo_id]/chat.post.ts @@ -18,6 +18,7 @@ export default defineEventHandler(async (event) => { method: 'POST', body: { repo_id: repo.id, + chat_id: repo.id, // TODO: support opening and closing chats question: message, }, });