Skip to content

Commit

Permalink
updated RAG notebook - Adithya S K
Browse files Browse the repository at this point in the history
  • Loading branch information
adithya-s-k committed Oct 22, 2024
1 parent bfc8814 commit 68973b8
Showing 1 changed file with 167 additions and 10 deletions.
177 changes: 167 additions & 10 deletions RAG/01_Basic_RAG/notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
"<h2>Basic RAG</h2>\n",
"</div>\n",
"\n",
"<a href=\"https://colab.research.google.com/github/adithya-s-k/AI-Engineering.academy/blob/main/RAG/01_Basic_RAG/notebook.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
"\n",
"<div align=\"center\">\n",
" <h3 ><a href=\"https://aiengineering.academy/\" target=\"_blank\">AI Engineering.academy</a></h3>\n",
" \n",
Expand Down Expand Up @@ -160,19 +162,19 @@
"GROK_API_KEY = os.getenv(\"GROQ_API_KEY\")\n",
"\n",
"# Setting up Base LLM\n",
"Settings.llm = OpenAI(\n",
" model=\"gpt-4o-mini\", temperature=0.1, max_tokens=1024, streaming=True\n",
")\n",
"# Settings.llm = OpenAI(\n",
"# model=\"gpt-4o-mini\", temperature=0.1, max_tokens=1024, streaming=True\n",
"# )\n",
"\n",
"# Settings.llm = Groq(model=\"llama3-70b-8192\" , api_key=GROK_API_KEY)\n",
"Settings.llm = Groq(model=\"llama-3.1-70b-versatile\" , api_key=GROK_API_KEY)\n",
"\n",
"# Set the embedding model\n",
"# Option 1: Use FastEmbed with BAAI/bge-base-en-v1.5 model (default)\n",
"# Settings.embed_model = FastEmbedEmbedding(model_name=\"BAAI/bge-base-en-v1.5\")\n",
"Settings.embed_model = FastEmbedEmbedding(model_name=\"BAAI/bge-base-en-v1.5\")\n",
"\n",
"# Option 2: Use OpenAI's embedding model (commented out)\n",
"# If you want to use OpenAI's embedding model, uncomment the following line:\n",
"Settings.embed_model = OpenAIEmbedding(embed_batch_size=10, api_key=OPENAI_API_KEY)\n",
"# Settings.embed_model = OpenAIEmbedding(embed_batch_size=10, api_key=OPENAI_API_KEY)\n",
"\n",
"# Qdrant configuration (commented out)\n",
"# If you're using Qdrant, uncomment and set these variables:\n",
Expand Down Expand Up @@ -220,7 +222,7 @@
"print(\"🔃 Loading Data\")\n",
"\n",
"from llama_index.core import Document\n",
"reader = SimpleDirectoryReader(\"../data/\" , recursive=True)\n",
"reader = SimpleDirectoryReader(\"./content/\" , recursive=True)\n",
"documents = reader.load_data(show_progress=True)"
]
},
Expand Down Expand Up @@ -277,11 +279,11 @@
" # otherwise set Qdrant instance address with:\n",
" # url=QDRANT_CLOUD_ENDPOINT,\n",
" # otherwise set Qdrant instance with host and port:\n",
" host=\"localhost\",\n",
" port=6333\n",
" # host=\"localhost\",\n",
" # port=6333\n",
" # set API KEY for Qdrant Cloud\n",
" # api_key=QDRANT_API_KEY,\n",
" # path=\"./db/\"\n",
" path=\"./db/\"\n",
")\n",
"\n",
"vector_store = QdrantVectorStore(client=client, collection_name=\"01_Basic_RAG\")"
Expand Down Expand Up @@ -552,6 +554,161 @@
"for message in history:\n",
" print(f\"{message.role}: {message.content}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Gradio Applicaiton"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import gradio as gr\n",
"from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Document, Settings\n",
"from llama_index.vector_stores.qdrant import QdrantVectorStore\n",
"from llama_index.embeddings.openai import OpenAIEmbedding\n",
"from llama_index.llms.openai import OpenAI\n",
"import qdrant_client\n",
"import os\n",
"import tempfile\n",
"import shutil\n",
"from typing import List\n",
"from llama_index.core.base.llms.types import ChatMessage, MessageRole\n",
"\n",
"class RAGChatbot:\n",
" def __init__(self):\n",
" self.client = qdrant_client.QdrantClient(host=\"localhost\", port=6333)\n",
" self.vector_store = None\n",
" self.index = None\n",
" self.chat_engine = None\n",
" self.chat_history = []\n",
" \n",
"\n",
" def process_uploaded_files(self, files) -> str:\n",
" try:\n",
" # Create a temporary directory for processing\n",
" with tempfile.TemporaryDirectory() as temp_dir:\n",
" # Save uploaded files to temporary directory\n",
" for file in files:\n",
" shutil.copy(file.name, temp_dir)\n",
" \n",
" # Load documents\n",
" reader = SimpleDirectoryReader(temp_dir)\n",
" documents = reader.load_data()\n",
" \n",
" # Create new collection name based on timestamp\n",
" import time\n",
" collection_name = f\"chat_collection_{int(time.time())}\"\n",
" \n",
" # Initialize vector store and index\n",
" self.vector_store = QdrantVectorStore(\n",
" client=self.client,\n",
" collection_name=collection_name\n",
" )\n",
" \n",
" # Create the index and ingest documents\n",
" self.index = VectorStoreIndex.from_documents(\n",
" documents,\n",
" vector_store=self.vector_store\n",
" )\n",
" \n",
" # Initialize chat engine\n",
" self.chat_engine = self.index.as_chat_engine(\n",
" streaming=True,\n",
" verbose=True\n",
" )\n",
" \n",
" return f\"Successfully processed {len(documents)} documents. Ready to chat!\"\n",
" \n",
" except Exception as e:\n",
" return f\"Error processing files: {str(e)}\"\n",
"\n",
" def chat(self, message: str, history: List[List[str]]) -> str:\n",
" if self.chat_engine is None:\n",
" return \"Please upload documents first before starting the chat.\"\n",
" \n",
" try:\n",
" # Convert history to ChatMessage format\n",
" chat_history = []\n",
" for h in history:\n",
" chat_history.extend([\n",
" ChatMessage(role=MessageRole.USER, content=h[0]),\n",
" ChatMessage(role=MessageRole.ASSISTANT, content=h[1])\n",
" ])\n",
" \n",
" # Add current message to history\n",
" chat_history.append(ChatMessage(role=MessageRole.USER, content=message))\n",
" \n",
" # Get response from chat engine\n",
" response = self.chat_engine.chat(message, chat_history=chat_history)\n",
" \n",
" return str(response)\n",
" \n",
" except Exception as e:\n",
" return f\"Error generating response: {str(e)}\"\n",
"\n",
"def create_demo():\n",
" # Initialize the chatbot\n",
" chatbot = RAGChatbot()\n",
" \n",
" with gr.Blocks(theme=gr.themes.Soft()) as demo:\n",
" gr.Markdown(\"# RAG Chatbot\")\n",
" gr.Markdown(\"Upload your documents and start chatting!\")\n",
" \n",
" with gr.Row():\n",
" with gr.Column(scale=1):\n",
" file_output = gr.File(\n",
" file_count=\"multiple\",\n",
" label=\"Upload Documents\",\n",
" file_types=[\".txt\", \".pdf\", \".docx\", \".md\"]\n",
" )\n",
" upload_button = gr.Button(\"Process Documents\")\n",
" status_box = gr.Textbox(label=\"Status\", interactive=False)\n",
" \n",
" with gr.Column(scale=2):\n",
" chatbot_interface = gr.Chatbot(\n",
" label=\"Chat History\",\n",
" height=400,\n",
" bubble_full_width=False,\n",
" )\n",
" msg = gr.Textbox(\n",
" label=\"Type your message\",\n",
" placeholder=\"Ask me anything about the uploaded documents...\",\n",
" lines=2\n",
" )\n",
" clear = gr.Button(\"Clear\")\n",
" \n",
" # Event handlers\n",
" upload_button.click(\n",
" fn=chatbot.process_uploaded_files,\n",
" inputs=[file_output],\n",
" outputs=[status_box],\n",
" )\n",
" \n",
" msg.submit(\n",
" fn=chatbot.chat,\n",
" inputs=[msg, chatbot_interface],\n",
" outputs=[chatbot_interface],\n",
" )\n",
" \n",
" clear.click(\n",
" lambda: None,\n",
" None,\n",
" chatbot_interface,\n",
" queue=False\n",
" )\n",
" \n",
" return demo\n",
"\n",
"if __name__ == \"__main__\":\n",
" demo = create_demo()\n",
" demo.launch(share=True)"
]
}
],
"metadata": {
Expand Down

0 comments on commit 68973b8

Please sign in to comment.