Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RAG-GUI #90

Merged
merged 48 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
4ae503c
Create app.py
Kunal-1669 Mar 25, 2024
47b6e44
Merge branch 'project-BasicRAG' into project-RAG-GUI
arjbingly Mar 25, 2024
32c377b
Merge branch 'project-RAG-GUI' of https://github.com/arjbingly/Capsto…
arjbingly Mar 25, 2024
a5b01cc
linting and formatting cookbooks
arjbingly Mar 25, 2024
f4c61a9
Revert "linting and formatting cookbooks"
arjbingly Mar 25, 2024
127dfaf
Update app.py
Kunal-1669 Mar 26, 2024
56913bd
Merge branch 'project-RAG-GUI' of https://github.com/arjbingly/Capsto…
Kunal-1669 Mar 26, 2024
e727f71
Update app.py
Kunal-1669 Mar 30, 2024
7003b7b
Create ~$Capstone5 Presentation v1.pptx
Kunal-1669 Mar 30, 2024
e9d7a5a
Merge remote-tracking branch 'origin/main' into project-RAG-GUI
Kunal-1669 Apr 2, 2024
0d64abe
Merge branch 'project-RAG-GUI' of https://github.com/arjbingly/Capsto…
arjbingly Apr 5, 2024
a336195
Merge remote-tracking branch 'origin/main' into project-RAG-GUI
arjbingly Apr 5, 2024
64ebf68
Add parser dependencies
arjbingly Apr 5, 2024
f8a8c52
Config.ini changes
arjbingly Apr 5, 2024
a9a9551
Ingest Cookbook async and docs
arjbingly Apr 5, 2024
d58d1e5
Async bug retriever
arjbingly Apr 5, 2024
b7117f1
Relocate RAG-GUI to cookbooks
arjbingly Apr 5, 2024
9f7e178
Update app.py
Kunal-1669 Apr 9, 2024
d27eba4
Merge remote-tracking branch 'origin/project-RAG-GUI' into project-RA…
Kunal-1669 Apr 9, 2024
a508768
Basic RAG read_only defaults
arjbingly Apr 9, 2024
9424f45
Update app.py
Kunal-1669 Apr 9, 2024
bee53d7
Update app.py
Kunal-1669 Apr 10, 2024
84d6211
Merge remote-tracking branch 'origin/project-RAG-GUI' into project-RA…
Kunal-1669 Apr 11, 2024
9bcc8af
Update app
Kunal-1669 Apr 15, 2024
3e3a3cc
Update app to show sources
arjbingly Apr 15, 2024
c82d516
Update app
arjbingly Apr 15, 2024
352f849
Update App
arjbingly Apr 15, 2024
b21f8ad
changed UI, added content retrieval
sanchitvj Apr 15, 2024
a744b23
Update app.py
Kunal-1669 Apr 16, 2024
fb27ea7
Update app.py
Kunal-1669 Apr 16, 2024
789d253
Update app.py
Kunal-1669 Apr 16, 2024
8d76ae0
added streaming for LLMs
sanchitvj Apr 17, 2024
ff66172
Update app
arjbingly Apr 17, 2024
df78856
Add spinner decorator
arjbingly Apr 17, 2024
57eda19
Update app - Add chat style UI
arjbingly Apr 17, 2024
39a489a
Update app: cosmetic
arjbingly Apr 17, 2024
054b372
loading retriever in Class
sanchitvj Apr 17, 2024
d6470dd
added gemma
sanchitvj Apr 17, 2024
0e2c17f
Update App: show sources
arjbingly Apr 17, 2024
f114115
Merge branch 'project-RAG-GUI' of https://github.com/arjbingly/Capsto…
arjbingly Apr 17, 2024
b5d56eb
ruff check and formatted
sanchitvj Apr 17, 2024
55a9ce7
Create branch_Jenkinsfile
sanchitvj Apr 17, 2024
944ede5
Update branch_Jenkinsfile
sanchitvj Apr 18, 2024
a0c88f4
modifying tests
sanchitvj Apr 18, 2024
1309463
fixed failing tests.
sanchitvj Apr 18, 2024
376575d
updated data
sanchitvj Apr 18, 2024
85a9ef6
fixed typing
sanchitvj Apr 18, 2024
1593f0d
Merge branch 'main' into project-RAG-GUI
sanchitvj Apr 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/branch_Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pipeline {
withPythonEnv(PYTHONPATH){
sh 'pip install ruff'
catchError(buildResult: 'SUCCESS', stageResult: 'FAILURE'){
sh 'ruff check . --exclude .pyenv-var-lib-jenkins-workspace-capstone_5-.venv-bin --output-format junit -o ruff-report.xml'
sh 'ruff check . --exclude *venv* --output-format junit -o ruff-report.xml'
sh 'ruff format .'
}
}
Expand Down
178 changes: 178 additions & 0 deletions cookbook/RAG-GUI/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""A cookbook demonstrating how to run RAG app on streamlit."""

import os
import sys
from pathlib import Path

import streamlit as st
from grag.components.multivec_retriever import Retriever
from grag.components.utils import get_config
from grag.components.vectordb.deeplake_client import DeepLakeClient
from grag.rag.basic_rag import BasicRAG

sys.path.insert(1, str(Path(os.getcwd()).parents[1]))

st.set_page_config(page_title="GRAG",
menu_items={
"Get Help": "https://github.com/arjbingly/Capstone_5",
"About": "This is a simple GUI for GRAG"
})


def spinner(text):
"""Decorator that displays a loading spinner with a custom text message during the execution of a function.

This decorator wraps any function to show a spinner using Streamlit's st.spinner during the function call,
indicating that an operation is in progress. The spinner is displayed with a user-defined text message.

Args:
text (str): The message to display next to the spinner.

Returns:
function: A decorator that takes a function and wraps it in a spinner context.
"""

def _spinner(func):
"""A decorator function that takes another function and wraps it to show a spinner during its execution.

Args:
func (function): The function to wrap.

Returns:
function: The wrapped function with a spinner displayed during its execution.
"""

def wrapper_func(*args, **kwargs):
"""The wrapper function that actually executes the wrapped function within the spinner context.

Args:
*args: Positional arguments passed to the wrapped function.
**kwargs: Keyword arguments passed to the wrapped function.
"""
with st.spinner(text=text):
func(*args, **kwargs)

return wrapper_func

return _spinner


@st.cache_data
def load_config():
"""Loads config."""
return get_config()


conf = load_config()


class RAGApp:
"""Application class to manage a Retrieval-Augmented Generation (RAG) model interface.

Attributes:
app: The main application or server instance hosting the RAG model.
conf: Configuration settings or parameters for the application.
"""

def __init__(self, app, conf):
"""Initializes the RAGApp with a given application and configuration.

Args:
app: The main application or framework instance that this class will interact with.
conf: A configuration object or dictionary containing settings for the application.
"""
self.app = app
self.conf = conf

def render_sidebar(self):
"""Renders the sidebar in the application interface with model selection and parameters."""
with st.sidebar:
st.title('GRAG')
st.subheader('Models and parameters')
st.sidebar.selectbox('Choose a model',
['Llama-2-13b-chat', 'Llama-2-7b-chat',
'Mixtral-8x7B-Instruct-v0.1', 'gemma-7b-it'],
key='selected_model')
st.sidebar.slider('Temperature',
min_value=0.1,
max_value=1.0,
value=0.1,
step=0.1,
key='temperature')
st.sidebar.slider('Top-k',
min_value=1,
max_value=5,
value=3,
step=1,
key='top_k')
st.button('Load Model', on_click=self.load_rag)
st.checkbox('Show sources', key='show_sources')

@spinner(text='Loading model...')
def load_rag(self):
"""Loads the specified RAG model based on the user's selection and settings in the sidebar."""
if 'rag' in st.session_state:
del st.session_state['rag']

llm_kwargs = {"temperature": st.session_state['temperature'], }
if st.session_state['selected_model'] == "Mixtral-8x7B-Instruct-v0.1":
llm_kwargs['n_gpu_layers'] = 16
llm_kwargs['quantization'] = 'Q4_K_M'
elif st.session_state['selected_model'] == "gemma-7b-it":
llm_kwargs['n_gpu_layers'] = 18
llm_kwargs['quantization'] = 'f16'

retriever_kwargs = {
"client_kwargs": {"read_only": True, },
"top_k": st.session_state['top_k']
}
client = DeepLakeClient(collection_name="usc", read_only=True)
retriever = Retriever(vectordb=client)

st.session_state['rag'] = BasicRAG(model_name=st.session_state['selected_model'], stream=True,
llm_kwargs=llm_kwargs, retriever=retriever,
retriever_kwargs=retriever_kwargs)
st.success(
f"""Model Loaded !!!

Model Name: {st.session_state['selected_model']}
Temperature: {st.session_state['temperature']}
Top-k : {st.session_state['top_k']}"""
)

def clear_cache(self):
"""Clears the cached data within the application."""
st.cache_data.clear()

def render_main(self):
"""Renders the main chat interface for user interaction with the loaded RAG model."""
st.title(":us: US Constitution Expert! :mortar_board:")
if 'rag' not in st.session_state:
st.warning("You have not loaded any model")
else:
user_input = st.chat_input("Ask me anything about the US Constitution.")

if user_input:
with st.chat_message("user"):
st.write(user_input)
with st.chat_message("assistant"):
_ = st.write_stream(
st.session_state['rag'](user_input)[0]
)
if st.session_state['show_sources']:
retrieved_docs = st.session_state['rag'].retriever.get_chunk(user_input)
for index, doc in enumerate(retrieved_docs):
with st.expander(f"Source {index + 1}"):
st.markdown(f"**{index + 1}. {doc.metadata['source']}**")
# if st.session_state['show_content']:
st.text(f"**{doc.page_content}**")

def render(self):
"""Orchestrates the rendering of both main and sidebar components of the application."""
self.render_main()
self.render_sidebar()


if __name__ == "__main__":
app = RAGApp(st, conf)
app.render()
Binary file added full_report/~$Capstone5 Presentation v1.pptx
Binary file not shown.
16 changes: 16 additions & 0 deletions projects/Basic-RAG/BasicRAG_ingest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""A cookbook demonstrating how to ingest pdf files for use with BasicRAG."""

from pathlib import Path

from grag.components.multivec_retriever import Retriever
from grag.components.vectordb.deeplake_client import DeepLakeClient

# from grag.rag.basic_rag import BasicRAG

client = DeepLakeClient(collection_name="test")
retriever = Retriever(vectordb=client)

dir_path = Path(__file__).parent / "some_dir"

retriever.ingest(dir_path)
# rag = BasicRAG(doc_chain="refine")
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ dependencies = [
"rouge-score>=0.1.2",
"deeplake>=3.8.27",
"bitsandbytes>=0.43.0",
"accelerate>=0.28.0"
"accelerate>=0.28.0",
"poppler-utils>=0.1.0",
"tesseract>=0.1.3"
]

[project.optional-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion src/config.ini
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ chunk_overlap : 400

[multivec_retriever]
# store_path: data/docs
store_path : ${data:data_path}/docs
store_path : ${data:data_path}/doc_store
# namespace: UUID(8c9040b0-b5cd-4d7c-bc2e-737da1b24ebf)
namespace : 8c9040b0b5cd4d7cbc2e737da1b24ebf
id_key : doc_id
Expand Down
4 changes: 3 additions & 1 deletion src/grag/components/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
from typing import Optional, Union


import torch
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
Expand Down Expand Up @@ -51,6 +52,7 @@ def __init__(
base_dir: str = llm_conf["base_dir"],
quantization: str = llm_conf["quantization"],
pipeline: str = llm_conf["pipeline"],
callbacks=None,
):
"""Initialize the LLM class using the given parameters."""
self.base_dir = Path(base_dir)
Expand All @@ -67,7 +69,7 @@ def __init__(
if std_out:
self.callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
else:
self.callback_manager = None # type: ignore
self.callback_manager = callbacks # type: ignore

@property
def model_name(self):
Expand Down
41 changes: 20 additions & 21 deletions src/grag/components/multivec_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
- Retriever
"""

import asyncio
import uuid
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
Expand Down Expand Up @@ -44,13 +43,13 @@ class Retriever:
"""

def __init__(
self,
vectordb: Optional[VectorDB] = None,
store_path: str = multivec_retriever_conf["store_path"],
id_key: str = multivec_retriever_conf["id_key"],
namespace: str = multivec_retriever_conf["namespace"],
top_k=int(multivec_retriever_conf["top_k"]),
client_kwargs: Optional[Dict[str, Any]] = None,
self,
vectordb: Optional[VectorDB] = None,
store_path: str = multivec_retriever_conf["store_path"],
id_key: str = multivec_retriever_conf["id_key"],
namespace: str = multivec_retriever_conf["namespace"],
top_k=int(multivec_retriever_conf["top_k"]),
client_kwargs: Optional[Dict[str, Any]] = None,
):
"""Initialize the Retriever.

Expand Down Expand Up @@ -157,7 +156,7 @@ async def aadd_docs(self, docs: List[Document]):
"""
chunks = self.split_docs(docs)
doc_ids = self.gen_doc_ids(docs)
await asyncio.run(self.vectordb.aadd_docs(chunks))
await self.vectordb.aadd_docs(chunks)
self.retriever.docstore.mset(list(zip(doc_ids, docs)))

def get_chunk(self, query: str, with_score=False, top_k=None):
Expand Down Expand Up @@ -237,12 +236,12 @@ def get_docs_from_chunks(self, chunks: List[Document], one_to_one=False):
return [d for d in docs if d is not None]

def ingest(
self,
dir_path: Union[str, Path],
glob_pattern: str = "**/*.pdf",
dry_run: bool = False,
verbose: bool = True,
parser_kwargs: Optional[Dict[str, Any]] = None,
self,
dir_path: Union[str, Path],
glob_pattern: str = "**/*.pdf",
dry_run: bool = False,
verbose: bool = True,
parser_kwargs: Optional[Dict[str, Any]] = None,
):
"""Ingests the files in directory.

Expand Down Expand Up @@ -279,12 +278,12 @@ def ingest(
print(f"DRY RUN: found - {filepath.relative_to(dir_path)}")

async def aingest(
self,
dir_path: Union[str, Path],
glob_pattern: str = "**/*.pdf",
dry_run: bool = False,
verbose: bool = True,
parser_kwargs: Optional[Dict[str, Any]] = None,
self,
dir_path: Union[str, Path],
glob_pattern: str = "**/*.pdf",
dry_run: bool = False,
verbose: bool = True,
parser_kwargs: Optional[Dict[str, Any]] = None,
):
"""Asynchronously ingests the files in directory.

Expand Down
3 changes: 2 additions & 1 deletion src/grag/prompts/matcher.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@
"Llama-2-7b": "Llama-2",
"Llama-2-13b": "Llama-2",
"Llama-2-70b": "Llama-2",
"Mixtral-8x7B-Instruct-v0.1": "Mixtral"
"Mixtral-8x7B-Instruct-v0.1": "Mixtral",
"gemma-7b-it": "Llama-2"
}
Loading
Loading