Skip to content

Commit

Permalink
This commit brings a slew of changes:
Browse files Browse the repository at this point in the history
 - New enhanced UI with clear demarcation between KnowledgeGraphRAG and Vanilla RAG.
 - Updated README.
 - Better embedding model.
 - More papers added to the Knowledge Graph.
  • Loading branch information
Ritwik Saha committed Aug 27, 2024
1 parent 0746995 commit 48a56e6
Show file tree
Hide file tree
Showing 29 changed files with 1,034 additions and 388 deletions.
9 changes: 0 additions & 9 deletions .streamlit/config.toml

This file was deleted.

1 change: 1 addition & 0 deletions 0_session-install-dependencies/check-dependencies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys

from utils.check_dependency import check_unauthenticated_access_to_app_enabled

if check_unauthenticated_access_to_app_enabled() == False:
Expand Down
5 changes: 4 additions & 1 deletion 0_session-install-dependencies/install-dependencies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import subprocess

from utils.check_dependency import check_gpu_enabled

print(subprocess.run(["sh 0_session-install-dependencies/setup.sh"], shell=True))
Expand All @@ -14,4 +15,6 @@
)
)

print(subprocess.run(["pip install numpy==1.25.0"], shell=True)) # downgrade numpy to 1.25.0 as a safeguard again.
print(
subprocess.run(["pip install numpy==1.25.0"], shell=True)
) # downgrade numpy to 1.25.0 as a safeguard again.
5 changes: 4 additions & 1 deletion 1_download-models/download.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import utils.constants as const
from utils.huggingface_utils import quantise_and_save_local_model, cache_and_load_embedding_model
from utils.huggingface_utils import (
cache_and_load_embedding_model,
quantise_and_save_local_model,
)

# This just caches the embedding model for future use
cache_and_load_embedding_model()
Expand Down
62 changes: 42 additions & 20 deletions 2_build-knowledge-graph/ingest-arxiv-data.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
import os

from dotenv import load_dotenv
from langchain.docstore.document import Document
from langchain.graphs import Neo4jGraph
from langchain.text_splitter import (
CharacterTextSplitter,
RecursiveCharacterTextSplitter,
)
from langchain.vectorstores.neo4j_vector import Neo4jVector
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
from langchain.docstore.document import Document

import utils.constants as const

from utils.arxiv_utils import create_paper_object_from_arxiv_id
from utils.data_utils import (
create_query_for_category_insertion,
create_indices_queries,
create_cypher_batch_query_to_create_citation_relationship,
create_cypher_batch_query_to_insert_arxiv_papers,
create_cypher_batch_query_to_create_citation_relationship
create_indices_queries,
create_query_for_category_insertion,
)
from utils.neo4j_utils import get_neo4j_credentails, is_neo4j_server_up, reset_neo4j_server, wait_for_neo4j_server
from utils.arxiv_utils import create_paper_object_from_arxiv_id
from utils.huggingface_utils import cache_and_load_embedding_model
from utils.neo4j_utils import (
get_neo4j_credentails,
is_neo4j_server_up,
reset_neo4j_server,
wait_for_neo4j_server,
)

load_dotenv()

Expand All @@ -38,10 +46,10 @@
url=get_neo4j_credentails()["uri"],
username=get_neo4j_credentails()["username"],
password=get_neo4j_credentails()["password"],
index_name='category_embedding_index',
index_name="category_embedding_index",
node_label="Category",
text_node_properties=['title', 'description'],
embedding_node_property='embedding',
text_node_properties=["title", "description"],
embedding_node_property="embedding",
)

for q in create_indices_queries():
Expand All @@ -50,7 +58,11 @@
arxiv_ids_set = set(const.seed_arxiv_paper_ids)
arxiv_ids_set.update(
[
cited_paper for cited_papers in [create_paper_object_from_arxiv_id(seed_paper_id).cited_arxiv_papers for seed_paper_id in const.seed_arxiv_paper_ids]
cited_paper
for cited_papers in [
create_paper_object_from_arxiv_id(seed_paper_id).cited_arxiv_papers
for seed_paper_id in const.seed_arxiv_paper_ids
]
for cited_paper in cited_papers
]
)
Expand Down Expand Up @@ -90,11 +102,15 @@
graph.query(query)
print(f"Created citation relationships for paper {paper.arxiv_id}")

raw_docs = [Document(page_content=p.full_text, metadata={"arxiv_id": p.arxiv_id}) for p in papers_to_insert]
raw_docs = [
Document(page_content=p.full_text, metadata={"arxiv_id": p.arxiv_id})
for p in papers_to_insert
]
# Define chunking strategy
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
encoding_name="cl100k_base",
chunk_size=1000, chunk_overlap=20,
chunk_size=1000,
chunk_overlap=20,
disallowed_special=(),
)
# Chunk the document
Expand Down Expand Up @@ -131,23 +147,29 @@
document_batch.clear()

# Link the chunks to the papers
graph.query("""
graph.query(
"""
MATCH (p:Paper), (c:Chunk)
WHERE p.id = c.arxiv_id
MERGE (p)-[:CONTAINS_TEXT]->(c)
""")
"""
)

# Delete orphan chunks with no papers
graph.query("""
graph.query(
"""
MATCH (c:Chunk)
WHERE NOT (c)<-[:CONTAINS_TEXT]-()
DETACH DELETE c
""")
"""
)

# Get the number of chunks finally present in the DB
chunk_count = graph.query("""
chunk_count = graph.query(
"""
MATCH (c:Chunk)
RETURN COUNT(c) as chunk_count
""")[0]["chunk_count"]
"""
)[0]["chunk_count"]

print(f"Number of chunks in the inserted into the knowledge graph: {chunk_count}")
20 changes: 17 additions & 3 deletions 3_launch_application/streamlit_app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os

import streamlit as st

from utils.neo4j_utils import is_neo4j_server_up, reset_neo4j_server
Expand All @@ -9,11 +10,24 @@

cwd = os.getcwd()

st.set_page_config(layout="wide")
st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
pg = st.navigation(
[
st.Page(cwd+"/pgs/llm_selection.py", title="LLM Selection", icon=":material/tv_options_edit_channels:"),
st.Page(cwd+"/pgs/rag_app.py", title="Q/A for AI/ML research papers", icon=":material/description:"),
st.Page(
cwd + "/pgs/rag_app_page.py",
title="Q/A for AI/ML research papers",
icon=":material/description:",
),
st.Page(
cwd + "/pgs/model_selection_page.py",
title="Model Selection",
icon=":material/tv_options_edit_channels:",
),
st.Page(
cwd + "/pgs/knowledge_graph_visualisation_page.py",
title="Knowledge Graph",
icon=":material/hub:",
),
]
)
pg.run()
40 changes: 25 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Knowledge Graph backed RAG AMP
# Knowledge Graph powered RAG AMP

This repository demonstrates how to power a RAG(Retrieval Augmented Generation) application with a knowledge graph(supported by graph DBs like [Neo4j](https://neo4j.com/)) to capture relationships and contexts not easily accessible if vector databases are being used in a RAG pipeline

Expand All @@ -22,7 +22,7 @@ The AMP is designed to run on and expects [Meta-Llama-3.1-8B-Instruct](https://h

**HF_TOKEN** : The AMP relies on [Huggingface Token](https://huggingface.co/docs/hub/en/security-tokens) to pull [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) model from HuggingFace.

*Note*: Please make sure that the account associated with the HuggingFace token has access to the abovementioned model. It does require filling up a form to obtain access.
***Note***: Please make sure that the account associated with the HuggingFace token has access to the abovementioned model. It does require filling up a form to obtain access.

## AMP Concepts

Expand Down Expand Up @@ -98,7 +98,7 @@ We leverage KG in two ways in order to make this RAG system better than plain(va
1. We aim to enhance the quality of context retreived by choosing chunks from relatively "high-quality" papers.
2. Provide additional information about the papers used to answer a certain question, which could have been more complex in case of traditional vector databases.

### Hybrid RAG
### Knowledge Graph RAG

Since we have a small but related set of AI/ML papers, there would be a lot of "citation" relationships between papers. We define a paper to be of **"Higher Quality"** if it has more number of citations. The number of citations can be computed for a specific paper from the knowledge graph that we have built.

Expand All @@ -121,20 +121,30 @@ We instruct the LLM to provide us the [arXiv IDs](https://info.arxiv.org/help/ar

<span class="caption">The image shows "Top Authors" & "Related Papers" for the "Attention Is All You Need" paper.</span>

## Knowledge Graph Construction

- We start off with some predefined(seed) AI/ML papers as mentioned in [constants.py](./utils/constants.py)
- The arXiv papers cited by the "seed" papers are extracted by converted the paper PDFs to text and matching with regex pattern to extract the arXiv IDs of the papers mentioned. This step results in ~600 uniques papers in our knowledge base.
- For each of these papers, we download the PDF, extract chunks out of the PDF text, calculate vector embedding in save in the graph database.
- The citation information is then captured in relationships between "Paper" nodes. The "Chunk" nodes are also linked to the papers they are coming from.
<img src="./assets/paper_with_chunks_and_authors.png" width="50%" height="50%" />

<span class="caption">The image shows the paper #[2403.20183](https://arxiv.org/abs/2403.20183v3) with citations, authors, and the chunks(represented as <span style="color:yellow">yellow nodes</span>) as part of the Knowledge Graph.</span>

## AMP Flow

1. The AMP and the underlying promts are designed to run on Llama 3 family of models, especially [Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct). The first page gives two choices:
- Use in-session 4-bit quantized version of Llama 3.1 Instruct. The model is already cached in the project volume as part of AMP steps.
- Bring-Your-Own-Llama-3.1: You can provide OpenAI APi comptatible endpoint, along with Model ID and authorization token. There are various providers offering free tier API usage like [OpenRouter](https://openrouter.ai/models/meta-llama/llama-3.1-8b-instruct:free). Alternatively, you can also use Llama 3.1 hosted on **Cloudera AI Inference** service which has been [tech previewed](https://blog.cloudera.com/cloudera-introduces-ai-inference-service-with-nvidia-nim/).
![LLM selection page](./assets/llm_selection_page.gif)
1. In the first page, we can ask the application any AI/ML related questions and it will try to answer from the existing knowledge base that the application has. It will produce answers for the question using Knowledge Graph powered context retrieval and context retrieval just using vector search. The application will output:
- Context used for both Knowledge Graph RAG and Vanilla RAG.
- The answers synthesized by LLM using the context for both RAGs.
- Related papers and top authors for the papers used to construct the answer in case of Knowledge Graph RAG.
- A graphical representation of related papers and top authors.
<img src="./assets/main_page_1.png" width="50%" height="50%" /> <img src="./assets/main_page_2.png" width="49%" height="49%" />

2. Although the application uses an in-session 4-bit quantised flavor of [Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct), we can use a remotely hosted Llama3/3.1 model to power the application. The second page gives option to switch to remote LLM. There are various providers offering free tier API usage like [OpenRouter](https://openrouter.ai/models/meta-llama/llama-3.1-8b-instruct:free). Alternatively, you can also use Llama 3.1 hosted on **Cloudera AI Inference** service which has been [tech previewed](https://blog.cloudera.com/cloudera-introduces-ai-inference-service-with-nvidia-nim/).
![LLM selection page](./assets/llm_selection_page.gif)

2. In the next page, we can ask the RAG pipeline AI/ML related question. It will produce 3 outputs:
- Answer using Vanilla RAG, as if we are running a plain old vector database, with no reranking of chunks.
- Answer using Hybrid RAG, where we rerank the chunks based on the "quality" of the papers.
- A follow-up information about the papers used to generate the answer in case of Hybrid RAG.

3. The page also contains a "graphical" representation of the papers used, related papers and top authors.
![RAG page](./assets/rag_page_gif.gif)
3. The third page gives an list of all papers contained in the knowledge-base. We can select any of these papers and graphically visualize the first and second order "cited by" relationships from other other papers.
![Knowledge Graph page](./assets/knowledge_graph_page_screenshot.png)

## AMP Requirements

Expand All @@ -147,4 +157,4 @@ We instruct the LLM to provide us the [arXiv IDs](https://info.arxiv.org/help/ar
- Tested with Nvidia Tesla T4 GPU (AWS: [g4dn series](https://aws.amazon.com/ec2/instance-types/g4/), Azure: [Standard_NC4as_T4_v3](https://learn.microsoft.com/en-us/azure/virtual-machines/nct4-v3-series))

### CML Runtime
- PBJ-wWrkbench - Python3.10 - Nvidia GPU - 2023.05
- PBJ-Workbench - Python3.10 - Nvidia GPU - 2023.05
Binary file added assets/AMP_thumbnail.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/knowledge_graph_page_screenshot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/main_page_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/main_page_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/paper_with_chunks_and_authors.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed assets/rag_page_gif.gif
Binary file not shown.
14 changes: 10 additions & 4 deletions pgs/commons.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import gc
import time
from enum import Enum

import time
import streamlit as st
import torch
from langchain_core.language_models.llms import BaseLLM

import utils.constants as const
from utils.huggingface_utils import load_local_model, cache_and_load_embedding_model
from utils.huggingface_utils import cache_and_load_embedding_model, load_local_model


@st.cache_resource(show_spinner=False)
def get_cached_local_model() -> BaseLLM:
Expand All @@ -16,13 +17,16 @@ def get_cached_local_model() -> BaseLLM:
progress_bar.progress(30, "Emptying CUDA cache.")
with torch.no_grad():
torch.cuda.empty_cache()
progress_bar.progress(60, f"Loading local {const.local_model_to_be_quantised} model.")
progress_bar.progress(
60, f"Loading local {const.local_model_to_be_quantised} model."
)
local_llm = load_local_model()
progress_bar.progress(99, "Model loaded successfully.")
time.sleep(1.0)
progress_bar.empty()
return local_llm


@st.cache_resource(show_spinner=False)
def get_cached_embedding_model():
progress_bar = st.progress(20, f"Loading {const.embed_model_name} embedding model.")
Expand All @@ -31,11 +35,14 @@ def get_cached_embedding_model():
progress_bar.empty()
return embedding


class StateVariables(Enum):
REMOTE_MODEL_ENDPOINT = "remote_model_endpoint"
REMOTE_MODEL_ID = "remote_model_id"
REMOTE_MODEL_API_KEY = "remote_model_api_key"
IS_REMOTE_LLM = "is_remote_llm"
QUESTION_FROM_DROPDOWN = "question_from_dropdown"


example_questions = [
"What is the difference between GPT-3 and GPT-4?",
Expand All @@ -60,4 +67,3 @@ class StateVariables(Enum):
**Please click on the nodes to redirect to the arXiv link of the paper.**
"""
response_container_height = 750
Loading

0 comments on commit 48a56e6

Please sign in to comment.