Skip to content

Commit

Permalink
feat: add chat history to ai service (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
anbraten authored Aug 31, 2023
1 parent 98bb2de commit b987872
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 10 deletions.
2 changes: 1 addition & 1 deletion ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ class IssueIndexInfo(BaseModel):
class Conversation(BaseModel):
repo_id: int
question: str
chat_id: int
answer: str = ""
chat_history: Tuple[(str, str)] = []
4 changes: 2 additions & 2 deletions ai/pyserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def perform_index(indexInfo: IndexInfo):

@app.post("/ask")
def conversation(conversationInput: Conversation):
conversationInput.answer, conversationInput.chat_history = ask(
conversationInput.answer = ask(
conversationInput.repo_id,
conversationInput.chat_id,
conversationInput.question,
conversationInput.chat_history,
)
return conversationInput
34 changes: 27 additions & 7 deletions ai/question.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,26 @@
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.chat_models import ChatOpenAI
from dotenv import load_dotenv

load_dotenv(os.path.join(os.path.dirname(__file__), "..", ".env"))

data_path = os.getenv("DATA_PATH")

chatMemories = {}

def ask(repo_id: int, question: str, chat_history=[]):

def _get_chat(chat_id: int):
if chat_id not in chatMemories:
chatMemories[chat_id] = ConversationBufferMemory(
memory_key="chat_history", return_messages=True, output_key="answer"
)
return chatMemories[chat_id]


def ask(repo_id: int, chat_id: int, question: str):
embeddings = OpenAIEmbeddings(disallowed_special=())

repo_path = os.path.join(data_path, str(repo_id))
Expand All @@ -27,18 +38,27 @@ def ask(repo_id: int, question: str, chat_history=[]):
retriever.search_kwargs["maximal_marginal_relevance"] = True
retriever.search_kwargs["k"] = 20

model = ChatOpenAI(model_name="gpt-3.5-turbo-16k")
qa = ConversationalRetrievalChain.from_llm(model, retriever=retriever)
memory = _get_chat(chat_id)

qa = ConversationalRetrievalChain.from_llm(
llm=ChatOpenAI(temperature=0),
memory=memory,
retriever=retriever,
return_source_documents=True,
)
end = time.time()

chat_history = []
result = qa({"question": question, "chat_history": chat_history})
chat_history.append((question, result["answer"]))
return result["answer"], chat_history
result = qa(question)
return result["answer"]

end = time.time()


def close_chat(chat_id: int):
if chat_id in chatMemories:
del chatMemories[chat_id]


if __name__ == "__main__":
# print(f">>>{sys.argv}<<<\n")
ask(sys.argv[1], sys.argv[2])
1 change: 1 addition & 0 deletions server/api/repos/[repo_id]/chat.post.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ export default defineEventHandler(async (event) => {
method: 'POST',
body: {
repo_id: repo.id,
chat_id: repo.id, // TODO: support opening and closing chats
question: message,
},
});
Expand Down

0 comments on commit b987872

Please sign in to comment.