Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Update VectorStore Base class and Introduce more Integrations #649

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions ragengine/vector_store/azuremongodb_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import List
import os
from ragengine.models import Document

import pymongo
from llama_index.vector_stores.azurecosmosmongo import (
AzureCosmosDBMongoDBVectorSearch,
)

from .base import BaseVectorStore

class AzureCosmosDBMongoDBVectorStoreHandler(BaseVectorStore):
def __init__(self, embedding_manager):
super().__init__(embedding_manager)
self.connection_string = os.environ.get("AZURE_COSMOSDB_MONGODB_URI")
self.mongodb_client = pymongo.MongoClient(self.connection_string)

def _create_new_index(self, index_name: str, documents: List[Document]) -> List[str]:
vector_store = AzureCosmosDBMongoDBVectorSearch(
mongodb_client=self.mongodb_client,
db_name="kaito_ragengine",
collection_name=index_name,
)
return self._create_index_common(index_name, documents, vector_store)
144 changes: 134 additions & 10 deletions ragengine/vector_store/base.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,155 @@
import logging
from abc import ABC, abstractmethod
from typing import Dict, List
import hashlib
import os

from llama_index.core import Document as LlamaDocument
from llama_index.core.storage.index_store import SimpleIndexStore
from llama_index.core import (StorageContext, VectorStoreIndex)

from ragengine.models import Document
import hashlib
from ragengine.embedding.base import BaseEmbeddingModel
from ragengine.inference.inference import Inference
from ragengine.config import PERSIST_DIR

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class BaseVectorStore(ABC):
def __init__(self, embedding_manager: BaseEmbeddingModel):
self.embedding_manager = embedding_manager
self.embed_model = self.embedding_manager.model
self.index_map = {}
self.index_store = SimpleIndexStore()
self.llm = Inference()

@staticmethod
def generate_doc_id(text: str) -> str:
"""Generates a unique document ID based on the hash of the document text."""
return hashlib.sha256(text.encode('utf-8')).hexdigest()

@abstractmethod
def index_documents(self, index_name: str, documents: List[Document]) -> List[str]:
pass
"""Common indexing logic for all vector stores."""
if index_name in self.index_map:
return self._append_documents_to_index(index_name, documents)
else:
return self._create_new_index(index_name, documents)

def _append_documents_to_index(self, index_name: str, documents: List[Document]) -> List[str]:
"""Common logic for appending documents to existing index."""
logger.info(f"Index {index_name} already exists. Appending documents to existing index.")
indexed_doc_ids = set()

for doc in documents:
doc_id = self.generate_doc_id(doc.text)
if not self.document_exists(index_name, doc_id):
self.add_document_to_index(index_name, doc, doc_id)
indexed_doc_ids.add(doc_id)
else:
logger.info(f"Document {doc_id} already exists in index {index_name}. Skipping.")

if indexed_doc_ids:
self._persist(index_name)
return list(indexed_doc_ids)

@abstractmethod
def query(self, index_name: str, query: str, top_k: int, params: dict):
def _create_new_index(self, index_name: str, documents: List[Document]) -> List[str]:
"""Create a new index - implementation specific to each vector store."""
pass

def _create_index_common(self, index_name: str, documents: List[Document], vector_store) -> List[str]:
"""Common logic for creating a new index with documents."""
storage_context = StorageContext.from_defaults(vector_store=vector_store)
llama_docs = []
indexed_doc_ids = set()

for doc in documents:
doc_id = self.generate_doc_id(doc.text)
llama_doc = LlamaDocument(id_=doc_id, text=doc.text, metadata=doc.metadata)
llama_docs.append(llama_doc)
indexed_doc_ids.add(doc_id)

if llama_docs:
index = VectorStoreIndex.from_documents(
llama_docs,
storage_context=storage_context,
embed_model=self.embed_model,
)
index.set_index_id(index_name)
self.index_map[index_name] = index
self.index_store.add_index_struct(index.index_struct)
self._persist(index_name)
Copy link
Member

@zhuangqh zhuangqh Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking about a question: will the persist fails here? or we should rollback something

Copy link
Collaborator Author

@ishaansehgal99 ishaansehgal99 Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its possible, added a try-catch in persist function to help catch this

Copy link
Member

@zhuangqh zhuangqh Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if persist fails, should we rollback the index_store? Or there will be some inconsistencies between index_store and storage

return list(indexed_doc_ids)

def query(self, index_name: str, query: str, top_k: int, llm_params: dict):
"""Common query logic for all vector stores."""
if index_name not in self.index_map:
raise ValueError(f"No such index: '{index_name}' exists.")
self.llm.set_params(llm_params)

query_engine = self.index_map[index_name].as_query_engine(
llm=self.llm,
similarity_top_k=top_k
)
query_result = query_engine.query(query)
return {
"response": query_result.response,
"source_nodes": [
{
"node_id": node.node_id,
"text": node.text,
"score": node.score,
"metadata": node.metadata
}
for node in query_result.source_nodes
],
"metadata": query_result.metadata,
}

@abstractmethod
def add_document_to_index(self, index_name: str, document: Document, doc_id: str):
pass
"""Common logic for adding a single document."""
if index_name not in self.index_map:
raise ValueError(f"No such index: '{index_name}' exists.")
llama_doc = LlamaDocument(text=document.text, metadata=document.metadata, id_=doc_id)
self.index_map[index_name].insert(llama_doc)

@abstractmethod
def list_all_indexed_documents(self) -> Dict[str, Dict[str, Dict[str, str]]]:
pass
"""Common logic for listing all documents."""
return {
index_name: {
doc_info.ref_doc_id: {
"text": doc_info.text,
"hash": doc_info.hash
} for _, doc_info in vector_store_index.docstore.docs.items()
}
for index_name, vector_store_index in self.index_map.items()
}

@abstractmethod
def document_exists(self, index_name: str, doc_id: str) -> bool:
pass
"""Common logic for checking document existence."""
if index_name not in self.index_map:
logger.warning(f"No such index: '{index_name}' exists in vector store.")
return False
return doc_id in self.index_map[index_name].ref_doc_info

def _persist_all(self):
"""Common persistence logic."""
logger.info("Persisting all indexes.")
self.index_store.persist(os.path.join(PERSIST_DIR, "store.json"))
for idx in self.index_store.index_structs():
self._persist(idx.index_id)

def _persist(self, index_name: str):
"""Common persistence logic for individual index."""
try:
logger.info(f"Persisting index {index_name}.")
self.index_store.persist(os.path.join(PERSIST_DIR, "store.json"))
assert index_name in self.index_map, f"No such index: '{index_name}' exists."
storage_context = self.index_map[index_name].storage_context
# Persist the specific index
storage_context.persist(persist_dir=os.path.join(PERSIST_DIR, index_name))
logger.info(f"Successfully persisted index {index_name}.")
except Exception as e:
logger.error(f"Failed to persist index {index_name}. Error: {str(e)}")
16 changes: 16 additions & 0 deletions ragengine/vector_store/chromadb_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import List
from ragengine.models import Document

import chromadb
from llama_index.vector_stores.chroma import ChromaVectorStore
from .base import BaseVectorStore

class ChromaDBVectorStoreHandler(BaseVectorStore):
def __init__(self, embedding_manager):
super().__init__(embedding_manager)
self.chroma_client = chromadb.EphemeralClient()

def _create_new_index(self, index_name: str, documents: List[Document]) -> List[str]:
chroma_collection = self.chroma_client.create_collection(index_name)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
return self._create_index_common(index_name, documents, vector_store)
180 changes: 5 additions & 175 deletions ragengine/vector_store/faiss_store.py
Original file line number Diff line number Diff line change
@@ -1,186 +1,16 @@
import os
from typing import Dict, List
from typing import List
from ragengine.models import Document

import faiss
from llama_index.core import Document as LlamaDocument
from llama_index.core import (StorageContext, VectorStoreIndex)
from llama_index.core.storage.index_store import SimpleIndexStore
from llama_index.vector_stores.faiss import FaissVectorStore

from ragengine.models import Document
from ragengine.inference.inference import Inference

from ragengine.config import PERSIST_DIR

from .base import BaseVectorStore
from ragengine.embedding.base import BaseEmbeddingModel


class FaissVectorStoreHandler(BaseVectorStore):
def __init__(self, embedding_manager: BaseEmbeddingModel):
self.embedding_manager = embedding_manager
self.embed_model = self.embedding_manager.model
def __init__(self, embedding_manager):
super().__init__(embedding_manager)
self.dimension = self.embedding_manager.get_embedding_dimension()
# TODO: Consider allowing user custom indexing method (would require configmap?) e.g.
"""
# Choose the FAISS index type based on the provided index_method
if index_method == 'FlatL2':
faiss_index = faiss.IndexFlatL2(self.dimension) # L2 (Euclidean distance) index
elif index_method == 'FlatIP':
faiss_index = faiss.IndexFlatIP(self.dimension) # Inner product (cosine similarity) index
elif index_method == 'IVFFlat':
quantizer = faiss.IndexFlatL2(self.dimension) # Quantizer for IVF
faiss_index = faiss.IndexIVFFlat(quantizer, self.dimension, 100) # IVF with flat quantization
elif index_method == 'HNSW':
faiss_index = faiss.IndexHNSWFlat(self.dimension, 32) # HNSW index with 32 neighbors
else:
raise ValueError(f"Unknown index method: {index_method}")
"""
self.index_map = {} # Used to store the in-memory index via namespace (e.g. index_name -> VectorStoreIndex)
self.index_store = SimpleIndexStore() # Use to store global index metadata
self.llm = Inference()

def index_documents(self, index_name: str, documents: List[Document]) -> List[str]:
"""
Called by the /index endpoint to index documents into the specified index.

If the index already exists, appends new documents to it.
Otherwise, creates a new index with the provided documents.

Args:
index_name (str): The name of the index to update or create.
documents (List[Document]): A list of documents to index.

Returns:
List[str]: A list of document IDs that were successfully indexed.
"""
if index_name in self.index_map:
return self._append_documents_to_index(index_name, documents)
else:
return self._create_new_index(index_name, documents)

def _append_documents_to_index(self, index_name: str, documents: List[Document]) -> List[str]:
"""
Appends documents to an existing index.

Args:
index_name (str): The name of the existing index.
documents (List[Document]): A list of documents to append.

Returns:
List[str]: A list of document IDs that were successfully indexed.
"""
print(f"Index {index_name} already exists. Appending documents to existing index.")
indexed_doc_ids = set()

for doc in documents:
doc_id = BaseVectorStore.generate_doc_id(doc.text)
if not self.document_exists(index_name, doc_id):
self.add_document_to_index(index_name, doc, doc_id)
indexed_doc_ids.add(doc_id)
else:
print(f"Document {doc_id} already exists in index {index_name}. Skipping.")

if indexed_doc_ids:
self._persist(index_name)
return list(indexed_doc_ids)

def _create_new_index(self, index_name: str, documents: List[Document]) -> List[str]:
"""
Creates a new index with the provided documents.

Args:
index_name (str): The name of the new index to create.
documents (List[Document]): A list of documents to index.

Returns:
List[str]: A list of document IDs that were successfully indexed.
"""
faiss_index = faiss.IndexFlatL2(self.dimension)
vector_store = FaissVectorStore(faiss_index=faiss_index)
storage_context = StorageContext.from_defaults(vector_store=vector_store)

llama_docs = []
indexed_doc_ids = set()

for doc in documents:
doc_id = BaseVectorStore.generate_doc_id(doc.text)
llama_doc = LlamaDocument(id_=doc_id, text=doc.text, metadata=doc.metadata)
llama_docs.append(llama_doc)
indexed_doc_ids.add(doc_id)

if llama_docs:
index = VectorStoreIndex.from_documents(
llama_docs,
storage_context=storage_context,
embed_model=self.embed_model,
# use_async=True # TODO: Indexing Process Performed Async
)
index.set_index_id(index_name)
self.index_map[index_name] = index
self.index_store.add_index_struct(index.index_struct)
self._persist(index_name)
return list(indexed_doc_ids)

def add_document_to_index(self, index_name: str, document: Document, doc_id: str):
"""Inserts a single document into the existing FAISS index."""
if index_name not in self.index_map:
raise ValueError(f"No such index: '{index_name}' exists.")
llama_doc = LlamaDocument(text=document.text, metadata=document.metadata, id_=doc_id)
self.index_map[index_name].insert(llama_doc)

def query(self, index_name: str, query: str, top_k: int, llm_params: dict):
"""Queries the FAISS vector store."""
if index_name not in self.index_map:
raise ValueError(f"No such index: '{index_name}' exists.")
self.llm.set_params(llm_params)

query_engine = self.index_map[index_name].as_query_engine(llm=self.llm, similarity_top_k=top_k)
query_result = query_engine.query(query)
return {
"response": query_result.response,
"source_nodes": [
{
"node_id": node.node_id,
"text": node.text,
"score": node.score,
"metadata": node.metadata
}
for node in query_result.source_nodes
],
"metadata": query_result.metadata,
}

def list_all_indexed_documents(self) -> Dict[str, Dict[str, Dict[str, str]]]:
"""Lists all documents in the vector store."""
return {
index_name: {
doc_info.ref_doc_id: {
"text": doc_info.text, "hash": doc_info.hash
} for doc_name, doc_info in vector_store_index.docstore.docs.items()
}
for index_name, vector_store_index in self.index_map.items()
}

def document_exists(self, index_name: str, doc_id: str) -> bool:
"""Checks if a document exists in the vector store."""
if index_name not in self.index_map:
print(f"No such index: '{index_name}' exists in vector store.")
return False
return doc_id in self.index_map[index_name].ref_doc_info

def _persist_all(self):
self.index_store.persist(os.path.join(PERSIST_DIR, "store.json")) # Persist global index store
for idx in self.index_store.index_structs():
self._persist(idx.index_id)

def _persist(self, index_name: str):
"""Saves the existing FAISS index to disk."""
self.index_store.persist(os.path.join(PERSIST_DIR, "store.json")) # Persist global index store
assert index_name in self.index_map, f"No such index: '{index_name}' exists."

# Persist each index's storage context separately
storage_context = self.index_map[index_name].storage_context
storage_context.persist(
persist_dir=os.path.join(PERSIST_DIR, index_name)
)
return self._create_index_common(index_name, documents, vector_store)