From 68973b8e20b6ab5cb381af169186bae155ac986d Mon Sep 17 00:00:00 2001 From: Adithya S K Date: Tue, 22 Oct 2024 15:54:00 +0530 Subject: [PATCH] updated RAG notebook - Adithya S K --- RAG/01_Basic_RAG/notebook.ipynb | 177 ++++++++++++++++++++++++++++++-- 1 file changed, 167 insertions(+), 10 deletions(-) diff --git a/RAG/01_Basic_RAG/notebook.ipynb b/RAG/01_Basic_RAG/notebook.ipynb index a6d029b..058fc53 100644 --- a/RAG/01_Basic_RAG/notebook.ipynb +++ b/RAG/01_Basic_RAG/notebook.ipynb @@ -8,6 +8,8 @@ "

Basic RAG

\n", "\n", "\n", + "\"Open\n", + "\n", "
\n", "

AI Engineering.academy

\n", " \n", @@ -160,19 +162,19 @@ "GROK_API_KEY = os.getenv(\"GROQ_API_KEY\")\n", "\n", "# Setting up Base LLM\n", - "Settings.llm = OpenAI(\n", - " model=\"gpt-4o-mini\", temperature=0.1, max_tokens=1024, streaming=True\n", - ")\n", + "# Settings.llm = OpenAI(\n", + "# model=\"gpt-4o-mini\", temperature=0.1, max_tokens=1024, streaming=True\n", + "# )\n", "\n", - "# Settings.llm = Groq(model=\"llama3-70b-8192\" , api_key=GROK_API_KEY)\n", + "Settings.llm = Groq(model=\"llama-3.1-70b-versatile\" , api_key=GROK_API_KEY)\n", "\n", "# Set the embedding model\n", "# Option 1: Use FastEmbed with BAAI/bge-base-en-v1.5 model (default)\n", - "# Settings.embed_model = FastEmbedEmbedding(model_name=\"BAAI/bge-base-en-v1.5\")\n", + "Settings.embed_model = FastEmbedEmbedding(model_name=\"BAAI/bge-base-en-v1.5\")\n", "\n", "# Option 2: Use OpenAI's embedding model (commented out)\n", "# If you want to use OpenAI's embedding model, uncomment the following line:\n", - "Settings.embed_model = OpenAIEmbedding(embed_batch_size=10, api_key=OPENAI_API_KEY)\n", + "# Settings.embed_model = OpenAIEmbedding(embed_batch_size=10, api_key=OPENAI_API_KEY)\n", "\n", "# Qdrant configuration (commented out)\n", "# If you're using Qdrant, uncomment and set these variables:\n", @@ -220,7 +222,7 @@ "print(\"🔃 Loading Data\")\n", "\n", "from llama_index.core import Document\n", - "reader = SimpleDirectoryReader(\"../data/\" , recursive=True)\n", + "reader = SimpleDirectoryReader(\"./content/\" , recursive=True)\n", "documents = reader.load_data(show_progress=True)" ] }, @@ -277,11 +279,11 @@ " # otherwise set Qdrant instance address with:\n", " # url=QDRANT_CLOUD_ENDPOINT,\n", " # otherwise set Qdrant instance with host and port:\n", - " host=\"localhost\",\n", - " port=6333\n", + " # host=\"localhost\",\n", + " # port=6333\n", " # set API KEY for Qdrant Cloud\n", " # api_key=QDRANT_API_KEY,\n", - " # path=\"./db/\"\n", + " path=\"./db/\"\n", ")\n", "\n", "vector_store = QdrantVectorStore(client=client, collection_name=\"01_Basic_RAG\")" @@ -552,6 +554,161 @@ "for message in history:\n", " print(f\"{message.role}: {message.content}\")" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Gradio Applicaiton" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import gradio as gr\n", + "from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Document, Settings\n", + "from llama_index.vector_stores.qdrant import QdrantVectorStore\n", + "from llama_index.embeddings.openai import OpenAIEmbedding\n", + "from llama_index.llms.openai import OpenAI\n", + "import qdrant_client\n", + "import os\n", + "import tempfile\n", + "import shutil\n", + "from typing import List\n", + "from llama_index.core.base.llms.types import ChatMessage, MessageRole\n", + "\n", + "class RAGChatbot:\n", + " def __init__(self):\n", + " self.client = qdrant_client.QdrantClient(host=\"localhost\", port=6333)\n", + " self.vector_store = None\n", + " self.index = None\n", + " self.chat_engine = None\n", + " self.chat_history = []\n", + " \n", + "\n", + " def process_uploaded_files(self, files) -> str:\n", + " try:\n", + " # Create a temporary directory for processing\n", + " with tempfile.TemporaryDirectory() as temp_dir:\n", + " # Save uploaded files to temporary directory\n", + " for file in files:\n", + " shutil.copy(file.name, temp_dir)\n", + " \n", + " # Load documents\n", + " reader = SimpleDirectoryReader(temp_dir)\n", + " documents = reader.load_data()\n", + " \n", + " # Create new collection name based on timestamp\n", + " import time\n", + " collection_name = f\"chat_collection_{int(time.time())}\"\n", + " \n", + " # Initialize vector store and index\n", + " self.vector_store = QdrantVectorStore(\n", + " client=self.client,\n", + " collection_name=collection_name\n", + " )\n", + " \n", + " # Create the index and ingest documents\n", + " self.index = VectorStoreIndex.from_documents(\n", + " documents,\n", + " vector_store=self.vector_store\n", + " )\n", + " \n", + " # Initialize chat engine\n", + " self.chat_engine = self.index.as_chat_engine(\n", + " streaming=True,\n", + " verbose=True\n", + " )\n", + " \n", + " return f\"Successfully processed {len(documents)} documents. Ready to chat!\"\n", + " \n", + " except Exception as e:\n", + " return f\"Error processing files: {str(e)}\"\n", + "\n", + " def chat(self, message: str, history: List[List[str]]) -> str:\n", + " if self.chat_engine is None:\n", + " return \"Please upload documents first before starting the chat.\"\n", + " \n", + " try:\n", + " # Convert history to ChatMessage format\n", + " chat_history = []\n", + " for h in history:\n", + " chat_history.extend([\n", + " ChatMessage(role=MessageRole.USER, content=h[0]),\n", + " ChatMessage(role=MessageRole.ASSISTANT, content=h[1])\n", + " ])\n", + " \n", + " # Add current message to history\n", + " chat_history.append(ChatMessage(role=MessageRole.USER, content=message))\n", + " \n", + " # Get response from chat engine\n", + " response = self.chat_engine.chat(message, chat_history=chat_history)\n", + " \n", + " return str(response)\n", + " \n", + " except Exception as e:\n", + " return f\"Error generating response: {str(e)}\"\n", + "\n", + "def create_demo():\n", + " # Initialize the chatbot\n", + " chatbot = RAGChatbot()\n", + " \n", + " with gr.Blocks(theme=gr.themes.Soft()) as demo:\n", + " gr.Markdown(\"# RAG Chatbot\")\n", + " gr.Markdown(\"Upload your documents and start chatting!\")\n", + " \n", + " with gr.Row():\n", + " with gr.Column(scale=1):\n", + " file_output = gr.File(\n", + " file_count=\"multiple\",\n", + " label=\"Upload Documents\",\n", + " file_types=[\".txt\", \".pdf\", \".docx\", \".md\"]\n", + " )\n", + " upload_button = gr.Button(\"Process Documents\")\n", + " status_box = gr.Textbox(label=\"Status\", interactive=False)\n", + " \n", + " with gr.Column(scale=2):\n", + " chatbot_interface = gr.Chatbot(\n", + " label=\"Chat History\",\n", + " height=400,\n", + " bubble_full_width=False,\n", + " )\n", + " msg = gr.Textbox(\n", + " label=\"Type your message\",\n", + " placeholder=\"Ask me anything about the uploaded documents...\",\n", + " lines=2\n", + " )\n", + " clear = gr.Button(\"Clear\")\n", + " \n", + " # Event handlers\n", + " upload_button.click(\n", + " fn=chatbot.process_uploaded_files,\n", + " inputs=[file_output],\n", + " outputs=[status_box],\n", + " )\n", + " \n", + " msg.submit(\n", + " fn=chatbot.chat,\n", + " inputs=[msg, chatbot_interface],\n", + " outputs=[chatbot_interface],\n", + " )\n", + " \n", + " clear.click(\n", + " lambda: None,\n", + " None,\n", + " chatbot_interface,\n", + " queue=False\n", + " )\n", + " \n", + " return demo\n", + "\n", + "if __name__ == \"__main__\":\n", + " demo = create_demo()\n", + " demo.launch(share=True)" + ] } ], "metadata": {