This repository has been archived by the owner on Oct 20, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
121 lines (96 loc) ยท 4.06 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQA
from pydantic import BaseModel
from typing import List
import os
# FastAPI ์ด๊ธฐํ
app = FastAPI()
# CORS ์ค์
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ๋ฐ์ดํฐ ๋ชจ๋ธ ์ ์
class DiaryEntry(BaseModel):
userId: str
summarizedDiary: str
class Query(BaseModel):
userId: str
question: str
chatHistory: List[str] # ์ฑํ
๋ด์ญ ์ถ๊ฐ
# ํ๊ฒฝ ๋ณ์ ๋๋ ์ง์ ์ค์ ์ผ๋ก๋ถํฐ OpenAI API ํค ๊ฐ์ ธ์ค๊ธฐ
openai_api_key = os.getenv("OPENAI_API_KEY", "")
# Embeddings์ ChromaDB ์ด๊ธฐํ
embedding = OpenAIEmbeddings(openai_api_key=openai_api_key)
persist_directory = 'db'
# ChromaDB ๋ก๋ ๋๋ ์ด๊ธฐํ
vectordb = Chroma(
persist_directory=persist_directory,
embedding_function=embedding
)
# Retriever ์ ์
retriever = vectordb.as_retriever(search_kwargs={"k": 3})
# QA ์ฒด์ธ ์ด๊ธฐํ
qa_chain = RetrievalQA.from_chain_type(
llm=ChatOpenAI(model="gpt-4o-mini", openai_api_key=openai_api_key),
chain_type="stuff",
retriever=retriever,
return_source_documents=True
)
# ์ผ์ ๋ํ ๋ฐ ๊ณต๊ฐ ์ฑ๋ด์ ์ํ ํ๋กฌํํธ ์์ฑ ํจ์
def create_prompt(conversation: list[str], new_question: str) -> str:
system_message = "๋ง์ฝ ์ดํดํ ์ ์๋ ๋ง์ด๋ฉด ์ฌ๊ณผํ๊ณ ๋๋ ๊ฐ์ฑ์ ์ด๊ณ ๊ณต๊ฐ์ ์ํด์ฃผ๋ ๋ฐ๋ปํ ๋ง์์จ๋ฅผ ๊ฐ์ง ์น๊ตฌ์ผ. ์์ฐ์ค๋ฝ๊ฒ ๋ฐ๋ง๋ก ๋ต๋ณํด์ค. ํต์ฌ์ ์ธ ๋ต๋ณ์ด ๋๋๋ฉด ์์ฐ์ค๋ฝ๊ฒ ๊ด๋ จ ์ง๋ฌธํด๋ ์ข๊ณ ๋ฃ๊ณ ๋ง ์์ด๋ ์ข์"
conversation_context = "\n".join(conversation)
prompt = f"{system_message}\n{conversation_context}\n์ฌ์ฉ์: {new_question}\n์น๊ตฌ:"
return prompt
# ์ผ๊ธฐ ํญ๋ชฉ์ ์ถ๊ฐํ๋ ์๋ํฌ์ธํธ
@app.post("/add_diary")
async def add_diary(entry: DiaryEntry):
try:
vectordb.add_texts([entry.summarizedDiary], metadatas=[{"userId": entry.userId}])
vectordb.persist()
return {"message": "Diary embedding is successed."}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ์ฑ๋ด ์ฟผ๋ฆฌ๋ฅผ ์ฒ๋ฆฌํ๋ ์๋ํฌ์ธํธ
@app.post("/query")
async def query_api(query: Query):
try:
user_id = query.userId
# ์ฌ์ฉ์๋ก๋ถํฐ ๋ฐ์ ์ฑํ
๋ด์ญ ์ฌ์ฉ
user_conversation = query.chatHistory
# ์๋ก์ด ์ง๋ฌธ์ ๋ํ ๊ธฐ๋ก์ ์ถ๊ฐ
user_conversation.append(f"์ฌ์ฉ์: {query.question}")
# ํ๋กฌํํธ ์์ฑ
prompt = create_prompt(user_conversation, query.question)
# ํน์ ์ฌ์ฉ์์ ๋ฐ์ดํฐ๋ฅผ ๊ฒ์ํ๋๋ก retriever ์์
retriever_with_user_id = vectordb.as_retriever(search_kwargs={"k": 3, "metadata_filter": {"userId": user_id}})
# gpt ๋ชจ๋ธ๋ก QA ์ฒด์ธ ์ด๊ธฐํ
qa_chain_with_user_id = RetrievalQA.from_chain_type(
llm=ChatOpenAI(model="gpt-4o-mini", openai_api_key=openai_api_key),
chain_type="stuff",
retriever=retriever_with_user_id,
return_source_documents=True
)
# ๋ชจ๋ธ๋ก๋ถํฐ ์๋ต ์ป๊ธฐ
llm_response = qa_chain({"query": prompt})
result = llm_response['result']
# ๋ชจ๋ธ์ ์๋ต์ ๋ํ ๊ธฐ๋ก์ ์ถ๊ฐ
user_conversation.append(f"์น๊ตฌ: {result}")
# ์์ค ๋ฌธ์๊ฐ ์กด์ฌํ๋์ง ํ์ธ
if "source_documents" in llm_response:
sources = [doc.metadata.get('source', 'Unknown source') for doc in llm_response["source_documents"]]
else:
sources = ["No source documents available"]
return {"message": result, "sources": sources}
except Exception as e:
print(f"Error occurred: {e}") # ๋๋ฒ๊น
์ถ๋ ฅ
raise HTTPException(status_code=500, detail=str(e))
# ์๋ฒ ์คํ: uvicorn main:app --reload