Skip to content

Commit

Permalink
adding provenance factor to KG-RAG
Browse files Browse the repository at this point in the history
  • Loading branch information
karthiksoman committed Dec 4, 2023
1 parent b15479c commit 8570673
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 30 deletions.
11 changes: 10 additions & 1 deletion config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,14 @@ MCQ_PATH : '/data/somank/KG_RAG/data/test_questions_two_hop_mcq_from_monarch_and
TRUE_FALSE_PATH : '/data/somank/KG_RAG/data/test_questions_one_hop_true_false_v2.csv'
ONE_HOP_GRAPH_TRAVERSAL : '/data/somank/KG_RAG/data/one_hop_graph_traversal_questions_v2.csv'
TWO_HOP_GRAPH_TRAVERSAL : '/data/somank/KG_RAG/data/two_hop_graph_traversal_questions.csv'
#

# SPOKE-API params
BASE_URI : 'https://spoke.rbvi.ucsf.edu'
cutoff_Compound_max_phase : 3
cutoff_Protein_source : ['SwissProt']
cutoff_DaG_diseases_sources : ['knowledge', 'experiments']
cutoff_DaG_textmining : 3
cutoff_CtD_phase : 3
cutoff_PiP_confidence : 0.7
cutoff_ACTeG_level : ['Low', 'Medium', 'High']

125 changes: 98 additions & 27 deletions kg_rag/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,90 @@
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"


def get_spoke_api_resp(base_uri, end_point, params=None):
uri = base_uri + end_point
if params:
return requests.get(uri, params=params)
else:
return requests.get(uri)

@retry(wait=wait_random_exponential(min=10, max=30), stop=stop_after_attempt(5))
def get_context_using_spoke_api(node_value):
type_end_point = "/api/v1/types"
result = get_spoke_api_resp(config_data['BASE_URI'], type_end_point)
data_spoke_types = result.json()
node_types = list(data_spoke_types["nodes"].keys())
edge_types = list(data_spoke_types["edges"].keys())
node_types_to_remove = ["DatabaseTimestamp", "Version"]
filtered_node_types = [node_type for node_type in node_types if node_type not in node_types_to_remove]
api_params = {
'node_filters' : filtered_node_types,
'edge_filters': edge_types,
'cutoff_Compound_max_phase': config_data['cutoff_Compound_max_phase'],
'cutoff_Protein_source': config_data['cutoff_Protein_source'],
'cutoff_DaG_diseases_sources': config_data['cutoff_DaG_diseases_sources'],
'cutoff_DaG_textmining': config_data['cutoff_DaG_textmining'],
'cutoff_CtD_phase': config_data['cutoff_CtD_phase'],
'cutoff_PiP_confidence': config_data['cutoff_PiP_confidence'],
'cutoff_ACTeG_level': config_data['cutoff_ACTeG_level']
}
node_type = "Disease"
attribute = "name"
nbr_end_point = "/api/v1/neighborhood/{}/{}/{}".format(node_type, attribute, node_value)
result = get_spoke_api_resp(config_data['BASE_URI'], nbr_end_point, params=api_params)
node_context = result.json()
nbr_nodes = []
nbr_edges = []
for item in node_context:
if "_" not in item["data"]["neo4j_type"]:
try:
if item["data"]["neo4j_type"] == "Protein":
nbr_nodes.append((item["data"]["neo4j_type"], item["data"]["id"], item["data"]["properties"]["description"]))
else:
nbr_nodes.append((item["data"]["neo4j_type"], item["data"]["id"], item["data"]["properties"]["name"]))
except:
nbr_nodes.append((item["data"]["neo4j_type"], item["data"]["id"], item["data"]["properties"]["identifier"]))
elif "_" in item["data"]["neo4j_type"]:
try:
provenance = ", ".join(item["data"]["properties"]["sources"])
except:
try:
provenance = item["data"]["properties"]["source"]
if isinstance(provenance, list):
provenance = ", ".join(provenance)
except:
try:
preprint_list = ast.literal_eval(item["data"]["properties"]["preprint_list"])
if len(preprint_list) > 0:
provenance = ", ".join(preprint_list)
else:
pmid_list = ast.literal_eval(item["data"]["properties"]["pmid_list"])
pmid_list = map(lambda x:"pubmedId:"+x, pmid_list)
if len(pmid_list) > 0:
provenance = ", ".join(pmid_list)
else:
provenance = "Based on data from Institute For Systems Biology (ISB)"
except:
provenance = "SPOKE-KG"
nbr_edges.append((item["data"]["source"], item["data"]["neo4j_type"], item["data"]["target"], provenance))
nbr_nodes_df = pd.DataFrame(nbr_nodes, columns=["node_type", "node_id", "node_name"])
nbr_edges_df = pd.DataFrame(nbr_edges, columns=["source", "edge_type", "target", "provenance"])
merge_1 = pd.merge(nbr_edges_df, nbr_nodes_df, left_on="source", right_on="node_id").drop("node_id", axis=1)
merge_1.loc[:,"node_name"] = merge_1.node_type + " " + merge_1.node_name
merge_1.drop(["source", "node_type"], axis=1, inplace=True)
merge_1 = merge_1.rename(columns={"node_name":"source"})
merge_2 = pd.merge(merge_1, nbr_nodes_df, left_on="target", right_on="node_id").drop("node_id", axis=1)
merge_2.loc[:,"node_name"] = merge_2.node_type + " " + merge_2.node_name
merge_2.drop(["target", "node_type"], axis=1, inplace=True)
merge_2 = merge_2.rename(columns={"node_name":"target"})
merge_2 = merge_2[["source", "edge_type", "target", "provenance"]]
merge_2.loc[:, "predicate"] = merge_2.edge_type.apply(lambda x:x.split("_")[0])
merge_2.loc[:, "context"] = merge_2.source + " " + merge_2.predicate.str.lower() + " " + merge_2.target + " and Provenance of this association is from " + merge_2.provenance + "."
context = merge_2['context'].str.cat(sep=' ')
return context



def get_prompt(instruction, new_system_prompt):
system_prompt = B_SYS + new_system_prompt + E_SYS
prompt_template = B_INST + system_prompt + instruction + E_INST
Expand Down Expand Up @@ -78,28 +162,6 @@ def llama_model(model_name, branch_name, cache_dir, temperature=0, top_p=1, max_



def create_mcq(df, source_column, target_column, node_type, predicate):
disease_pairs = df[source_column].unique()
disease_pairs = [(disease1, disease2) for disease1 in disease_pairs for disease2 in disease_pairs if disease1 != disease2]

new_data = []

#For each source pair, find a common target and 4 negative samples
for disease1, disease2 in disease_pairs:
common_gene = set(df[df[source_column] == disease1][target_column]).intersection(set(df[df[source_column] == disease2][target_column]))
common_gene = list(common_gene)[0] if common_gene else None
# Get 4 random negative samples
negative_samples = df[(df[source_column] != disease1) & (df[source_column] != disease2)][target_column].sample(4).tolist()
new_data.append(((disease1, disease2), common_gene, negative_samples))

new_df = pd.DataFrame(new_data, columns=["disease_pair", "correct_node", "negative_samples"])
new_df.dropna(subset = ["correct_node"], inplace=True)
new_df.loc[:, "disease_1"] = new_df["disease_pair"].apply(lambda x: x[0])
new_df.loc[:, "disease_2"] = new_df["disease_pair"].apply(lambda x: x[1])
new_df.negative_samples = new_df.negative_samples.apply(lambda x:", ".join(x[0:4]))
new_df.loc[:, "text"] = "Out of the given list, which " + node_type + " " + predicate + " " + new_df.disease_1 + " and " + new_df.disease_2 + ". Given list is: " + new_df.correct_node + ", " + new_df.negative_samples
return new_df

@retry(wait=wait_random_exponential(min=10, max=30), stop=stop_after_attempt(5))
def fetch_GPT_response(instruction, system_prompt, chat_model_id, chat_deployment_id, temperature=0):
print('Calling OpenAI...')
Expand Down Expand Up @@ -167,7 +229,7 @@ def load_chroma(vector_db_path, sentence_embedding_model):
embedding_function = load_sentence_transformer(sentence_embedding_model)
return Chroma(persist_directory=vector_db_path, embedding_function=embedding_function)

def retrieve_context(question, vectorstore, embedding_function, node_context_df, context_volume, context_sim_threshold, context_sim_min_threshold):
def retrieve_context(question, vectorstore, embedding_function, node_context_df, context_volume, context_sim_threshold, context_sim_min_threshold, api=True):
entities = disease_entity_extractor_v2(question)
node_hits = []
if entities:
Expand All @@ -178,7 +240,10 @@ def retrieve_context(question, vectorstore, embedding_function, node_context_df,
question_embedding = embedding_function.embed_query(question)
node_context_extracted = ""
for node_name in node_hits:
node_context = node_context_df[node_context_df.node_name == node_name].node_context.values[0]
if not api:
node_context = node_context_df[node_context_df.node_name == node_name].node_context.values[0]
else:
node_context = get_context_using_spoke_api(node_name)
node_context_list = node_context.split(". ")
node_context_embeddings = embedding_function.embed_documents(node_context_list)
similarities = [cosine_similarity(np.array(question_embedding).reshape(1, -1), np.array(node_context_embedding).reshape(1, -1)) for node_context_embedding in node_context_embeddings]
Expand All @@ -198,7 +263,10 @@ def retrieve_context(question, vectorstore, embedding_function, node_context_df,
node_context_extracted = ""
for node in node_hits:
node_name = node[0].page_content
node_context = node_context_df[node_context_df.node_name == node_name].node_context.values[0]
if not api:
node_context = node_context_df[node_context_df.node_name == node_name].node_context.values[0]
else:
node_context = get_context_using_spoke_api(node_name)
node_context_list = node_context.split(". ")
node_context_embeddings = embedding_function.embed_documents(node_context_list)
similarities = [cosine_similarity(np.array(question_embedding).reshape(1, -1), np.array(node_context_embedding).reshape(1, -1)) for node_context_embedding in node_context_embeddings]
Expand All @@ -213,7 +281,7 @@ def retrieve_context(question, vectorstore, embedding_function, node_context_df,
return node_context_extracted


def interactive(question, vectorstore, node_context_df, embedding_function_for_context_retrieval, llm_type):
def interactive(question, vectorstore, node_context_df, embedding_function_for_context_retrieval, llm_type, api=True):
input("Press enter for Step 1 - Disease entity extraction using GPT-3.5-Turbo")
print("Processing ...")
entities = disease_entity_extractor_v2(question)
Expand All @@ -233,7 +301,10 @@ def interactive(question, vectorstore, node_context_df, embedding_function_for_c
input("Press enter for Step 3 - Context extraction from SPOKE")
node_context = []
for node_name in node_hits:
node_context.append(node_context_df[node_context_df.node_name == node_name].node_context.values[0])
if not api:
node_context.append(node_context_df[node_context_df.node_name == node_name].node_context.values[0])
else:
node_context.append(get_context_using_spoke_api(node_name))
print("Extracted Context is : ")
print(". ".join(node_context))
print(" ")
Expand Down
7 changes: 5 additions & 2 deletions system_prompts.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@ PROMPT_BASED_TEXT_GENERATION: |
# KG-RAG Based Text Generation
KG_RAG_BASED_TEXT_GENERATION: |
You are an expert biomedical researcher. For answering the Question at the end, you need to first read the Context provided.
Then give your final answer by considering the context.
You are an expert biomedical researcher. For answering the Question at the end, you need to first read the Context provided. Then give your final answer with the Provenance information by considering the context. Refer the Example 1 and Example 2 to know how to cite the Provenance.
Example 1:
The genes associated with DiseaseX are Gene1, Gene2 [Provenance: <cite the Provenance>].
Example 2:
The genes associated with DiseaseX are Gene1, Gene2 [Provenance: <cite the Provenance>]. The disease localizes in the organ1, organ2, organ3 [Provenance : <cite the Provenance>].
# True/False Question
TRUE_FALSE_QUESTION: |
Expand Down

0 comments on commit 8570673

Please sign in to comment.