-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathhybrid_rag.py
77 lines (62 loc) · 2.38 KB
/
hybrid_rag.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import cohere
from dotenv import load_dotenv
from ell import ell
from openai import OpenAI
import prompts
from graph_rag import GraphRAG
from vector_rag import VectorRAG
load_dotenv()
MODEL_NAME = "gpt-4o-mini"
COHERE_API_KEY = os.environ.get("COHERE_API_KEY")
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
SEED = 42
class HybridRAG:
def __init__(
self,
graph_db_path="./test_kuzudb",
vector_db_path="./test_lancedb",
):
self.graph_rag = GraphRAG(graph_db_path)
self.vector_rag = VectorRAG(vector_db_path)
self.co = cohere.ClientV2(COHERE_API_KEY)
@ell.simple(model=MODEL_NAME, temperature=0.3, client=OpenAI(api_key=OPENAI_API_KEY), seed=SEED)
def hybrid_rag(self, question: str, context: str) -> str:
return [
ell.system(prompts.RAG_SYSTEM_PROMPT),
ell.user(prompts.RAG_USER_PROMPT.format(question=question, context=context)),
]
def run(self, question: str) -> str:
question_embedding = self.vector_rag.embed(question)
vector_docs = self.vector_rag.query(question_embedding)
vector_docs = [doc["text"] for doc in vector_docs]
cypher = self.graph_rag.generate_cypher(question)
graph_docs = self.graph_rag.query(question, cypher)
docs = [graph_docs] + vector_docs
# Ensure the doc contents are strings
docs = [str(doc) for doc in docs]
combined_context = self.co.rerank(
model="rerank-english-v3.0",
query=question,
documents=docs,
top_n=20,
return_documents=True,
)
return self.hybrid_rag(question, combined_context)
if __name__ == "__main__":
hybrid_rag = HybridRAG(
graph_db_path="./test_kuzudb",
vector_db_path="./test_lancedb"
)
question = "Who are the founders of BlackRock? Return the names as a numbered list."
response = hybrid_rag.run(question)
print(f"Q1: {question}\n\n{response}")
question = "Where did Larry Fink graduate from?"
response = hybrid_rag.run(question)
print(f"---\nQ2: {question}\n\n{response}")
question = "When was Susan Wagner born?"
response = hybrid_rag.run(question)
print(f"---\nQ3: {question}\n\n{response}")
question = "How did Larry Fink and Rob Kapito meet?"
response = hybrid_rag.run(question)
print(f"---\nQ4: {question}\n\n{response}")