-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathllmrag1.py
64 lines (52 loc) · 1.93 KB
/
llmrag1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import pandas as pd
import faiss
import numpy as np
import openai
from openai import OpenAI # for calling the OpenAI API
import os
# Load clinical data from CSV
data = pd.read_csv('/Users/bolttoday/Downloads/cdata1.csv')
apikey = os.environ.get("OPENAI_API_KEY")
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "<your OpenAI API key if not set as env var>"))
# Initialize OpenAI API key
openai.api_key = apikey
# Function to generate embeddings using OpenAI
def generate_embeddings(texts):
response = openai.Embedding.create(
input=texts,
model="text-embedding-ada-002" # Example embedding model
)
embeddings = [r['embedding'] for r in response['data']]
return np.array(embeddings)
# Generate embeddings for the clinical data
texts = data['text'].tolist()
embeddings = generate_embeddings(texts)
# Initialize FAISS index
dimension = len(embeddings[0])
index = faiss.IndexFlatL2(dimension)
# Add embeddings to the FAISS index
index.add(embeddings)
# Function to perform retrieval
def retrieve(query, k=3):
query_embedding = generate_embeddings([query])
D, I = index.search(query_embedding, k) # D: distances, I: indices
return data.iloc[I[0]].to_dict(orient='records')
# Function to generate a response using OpenAI GPT
def generate_response(retrieved_texts, query):
context = "\n".join([f"Context {i+1}: {text['text']}" for i, text in enumerate(retrieved_texts)])
prompt = f"{context}\n\nQuestion: {query}\nAnswer:"
response = openai.Completion.create(
model="text-davinci-003",
prompt=prompt,
max_tokens=150
)
return response.choices[0].text.strip()
# Main function to handle the RAG process
def rag_model(query):
retrieved_texts = retrieve(query)
response = generate_response(retrieved_texts, query)
return response
# Example usage
query = "What should be monitored in a patient with hypertension?"
response = rag_model(query)
print("Response:", response)