-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathqna.py
90 lines (70 loc) · 2.61 KB
/
qna.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import time
import os
from dotenv import load_dotenv
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.document_loaders import DirectoryLoader, TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores.pgvector import PGVector
from langchain.chains import RetrievalQA
from langchain import OpenAI
def process_data():
print("Getting data...")
loader = DirectoryLoader("documents", glob="**/*.txt")
documents = loader.load()
print("Documents loaded.")
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
print("Text splitted.")
texts = text_splitter.split_documents(documents)
print(texts)
return texts
def create_retriever(initialize=False):
COLLECTION_NAME = "supabase_test"
#Supabase PGVector store
CONNECTION_STRING = PGVector.connection_string_from_db_params(
driver=os.environ.get("PGVECTOR_DRIVER"),
host=os.environ.get("PGVECTOR_HOST"),
port=int(os.environ.get("PGVECTOR_PORT")),
database=os.environ.get("PGVECTOR_DATABASE"),
user=os.environ.get("PGVECTOR_USER"),
password=os.environ.get("PGVECTOR_PASSWORD"),
)
print("Getting base embeddings from HuggingFace...")
embeddings = HuggingFaceEmbeddings()
print("Embeddings loaded.")
if initialize:
texts = process_data()
print("Creating vector store...")
vector_store = PGVector.from_documents(
documents=texts,
embedding=embeddings,
collection_name=COLLECTION_NAME,
connection_string=CONNECTION_STRING
)
print("Vector store created.")
else:
print("Fetching vector store...")
vector_store = PGVector(
connection_string=CONNECTION_STRING,
collection_name=COLLECTION_NAME,
embedding_function=embeddings
)
qna_retriever = RetrievalQA.from_chain_type(
llm=OpenAI(),
chain_type="stuff",
retriever=vector_store.as_retriever()
)
return qna_retriever
def query(prompt, qna_retriever):
print(f"Answer: {qna_retriever.run(prompt)}")
if __name__ == '__main__':
INITIALIZE = False # Change to `True` if Supabase PGVector is already initialized.
load_dotenv()
qna_retriever = create_retriever(initialize=INITIALIZE)
while True:
prompt = input("Prompt: ")
if prompt == "":
break
query(prompt, qna_retriever)
cont = input("Press 'Enter' to prompt again.\n")
if cont != "":
break