-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Update VectorStore Base class and Introduce more Integrations #649
Open
ishaansehgal99
wants to merge
4
commits into
main
Choose a base branch
from
Ishaan/abstract-stores
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from typing import List | ||
import os | ||
from ragengine.models import Document | ||
|
||
import pymongo | ||
from llama_index.vector_stores.azurecosmosmongo import ( | ||
AzureCosmosDBMongoDBVectorSearch, | ||
) | ||
|
||
from .base import BaseVectorStore | ||
|
||
class AzureCosmosDBMongoDBVectorStoreHandler(BaseVectorStore): | ||
def __init__(self, embedding_manager): | ||
super().__init__(embedding_manager) | ||
self.connection_string = os.environ.get("AZURE_COSMOSDB_MONGODB_URI") | ||
self.mongodb_client = pymongo.MongoClient(self.connection_string) | ||
|
||
def _create_new_index(self, index_name: str, documents: List[Document]) -> List[str]: | ||
vector_store = AzureCosmosDBMongoDBVectorSearch( | ||
mongodb_client=self.mongodb_client, | ||
db_name="kaito_ragengine", | ||
collection_name=index_name, | ||
) | ||
return self._create_index_common(index_name, documents, vector_store) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,31 +1,155 @@ | ||
import logging | ||
from abc import ABC, abstractmethod | ||
from typing import Dict, List | ||
import hashlib | ||
import os | ||
|
||
from llama_index.core import Document as LlamaDocument | ||
from llama_index.core.storage.index_store import SimpleIndexStore | ||
from llama_index.core import (StorageContext, VectorStoreIndex) | ||
|
||
from ragengine.models import Document | ||
import hashlib | ||
from ragengine.embedding.base import BaseEmbeddingModel | ||
from ragengine.inference.inference import Inference | ||
from ragengine.config import PERSIST_DIR | ||
|
||
# Configure logging | ||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
class BaseVectorStore(ABC): | ||
def __init__(self, embedding_manager: BaseEmbeddingModel): | ||
self.embedding_manager = embedding_manager | ||
self.embed_model = self.embedding_manager.model | ||
self.index_map = {} | ||
self.index_store = SimpleIndexStore() | ||
self.llm = Inference() | ||
|
||
@staticmethod | ||
def generate_doc_id(text: str) -> str: | ||
"""Generates a unique document ID based on the hash of the document text.""" | ||
return hashlib.sha256(text.encode('utf-8')).hexdigest() | ||
|
||
@abstractmethod | ||
def index_documents(self, index_name: str, documents: List[Document]) -> List[str]: | ||
pass | ||
"""Common indexing logic for all vector stores.""" | ||
if index_name in self.index_map: | ||
return self._append_documents_to_index(index_name, documents) | ||
else: | ||
return self._create_new_index(index_name, documents) | ||
|
||
def _append_documents_to_index(self, index_name: str, documents: List[Document]) -> List[str]: | ||
"""Common logic for appending documents to existing index.""" | ||
logger.info(f"Index {index_name} already exists. Appending documents to existing index.") | ||
indexed_doc_ids = set() | ||
|
||
for doc in documents: | ||
doc_id = self.generate_doc_id(doc.text) | ||
if not self.document_exists(index_name, doc_id): | ||
self.add_document_to_index(index_name, doc, doc_id) | ||
indexed_doc_ids.add(doc_id) | ||
else: | ||
logger.info(f"Document {doc_id} already exists in index {index_name}. Skipping.") | ||
|
||
if indexed_doc_ids: | ||
self._persist(index_name) | ||
return list(indexed_doc_ids) | ||
|
||
@abstractmethod | ||
def query(self, index_name: str, query: str, top_k: int, params: dict): | ||
def _create_new_index(self, index_name: str, documents: List[Document]) -> List[str]: | ||
"""Create a new index - implementation specific to each vector store.""" | ||
pass | ||
|
||
def _create_index_common(self, index_name: str, documents: List[Document], vector_store) -> List[str]: | ||
"""Common logic for creating a new index with documents.""" | ||
storage_context = StorageContext.from_defaults(vector_store=vector_store) | ||
llama_docs = [] | ||
indexed_doc_ids = set() | ||
|
||
for doc in documents: | ||
doc_id = self.generate_doc_id(doc.text) | ||
llama_doc = LlamaDocument(id_=doc_id, text=doc.text, metadata=doc.metadata) | ||
llama_docs.append(llama_doc) | ||
indexed_doc_ids.add(doc_id) | ||
|
||
if llama_docs: | ||
index = VectorStoreIndex.from_documents( | ||
llama_docs, | ||
storage_context=storage_context, | ||
embed_model=self.embed_model, | ||
) | ||
index.set_index_id(index_name) | ||
self.index_map[index_name] = index | ||
self.index_store.add_index_struct(index.index_struct) | ||
self._persist(index_name) | ||
return list(indexed_doc_ids) | ||
|
||
def query(self, index_name: str, query: str, top_k: int, llm_params: dict): | ||
"""Common query logic for all vector stores.""" | ||
if index_name not in self.index_map: | ||
raise ValueError(f"No such index: '{index_name}' exists.") | ||
self.llm.set_params(llm_params) | ||
|
||
query_engine = self.index_map[index_name].as_query_engine( | ||
llm=self.llm, | ||
similarity_top_k=top_k | ||
) | ||
query_result = query_engine.query(query) | ||
return { | ||
"response": query_result.response, | ||
"source_nodes": [ | ||
{ | ||
"node_id": node.node_id, | ||
"text": node.text, | ||
"score": node.score, | ||
"metadata": node.metadata | ||
} | ||
for node in query_result.source_nodes | ||
], | ||
"metadata": query_result.metadata, | ||
} | ||
|
||
@abstractmethod | ||
def add_document_to_index(self, index_name: str, document: Document, doc_id: str): | ||
pass | ||
"""Common logic for adding a single document.""" | ||
if index_name not in self.index_map: | ||
raise ValueError(f"No such index: '{index_name}' exists.") | ||
llama_doc = LlamaDocument(text=document.text, metadata=document.metadata, id_=doc_id) | ||
self.index_map[index_name].insert(llama_doc) | ||
|
||
@abstractmethod | ||
def list_all_indexed_documents(self) -> Dict[str, Dict[str, Dict[str, str]]]: | ||
pass | ||
"""Common logic for listing all documents.""" | ||
return { | ||
index_name: { | ||
doc_info.ref_doc_id: { | ||
"text": doc_info.text, | ||
"hash": doc_info.hash | ||
} for _, doc_info in vector_store_index.docstore.docs.items() | ||
} | ||
for index_name, vector_store_index in self.index_map.items() | ||
} | ||
|
||
@abstractmethod | ||
def document_exists(self, index_name: str, doc_id: str) -> bool: | ||
pass | ||
"""Common logic for checking document existence.""" | ||
if index_name not in self.index_map: | ||
logger.warning(f"No such index: '{index_name}' exists in vector store.") | ||
return False | ||
return doc_id in self.index_map[index_name].ref_doc_info | ||
|
||
def _persist_all(self): | ||
"""Common persistence logic.""" | ||
logger.info("Persisting all indexes.") | ||
self.index_store.persist(os.path.join(PERSIST_DIR, "store.json")) | ||
for idx in self.index_store.index_structs(): | ||
self._persist(idx.index_id) | ||
|
||
def _persist(self, index_name: str): | ||
"""Common persistence logic for individual index.""" | ||
try: | ||
logger.info(f"Persisting index {index_name}.") | ||
self.index_store.persist(os.path.join(PERSIST_DIR, "store.json")) | ||
assert index_name in self.index_map, f"No such index: '{index_name}' exists." | ||
storage_context = self.index_map[index_name].storage_context | ||
# Persist the specific index | ||
storage_context.persist(persist_dir=os.path.join(PERSIST_DIR, index_name)) | ||
logger.info(f"Successfully persisted index {index_name}.") | ||
except Exception as e: | ||
logger.error(f"Failed to persist index {index_name}. Error: {str(e)}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from typing import List | ||
from ragengine.models import Document | ||
|
||
import chromadb | ||
from llama_index.vector_stores.chroma import ChromaVectorStore | ||
from .base import BaseVectorStore | ||
|
||
class ChromaDBVectorStoreHandler(BaseVectorStore): | ||
def __init__(self, embedding_manager): | ||
super().__init__(embedding_manager) | ||
self.chroma_client = chromadb.EphemeralClient() | ||
|
||
def _create_new_index(self, index_name: str, documents: List[Document]) -> List[str]: | ||
chroma_collection = self.chroma_client.create_collection(index_name) | ||
vector_store = ChromaVectorStore(chroma_collection=chroma_collection) | ||
return self._create_index_common(index_name, documents, vector_store) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,186 +1,16 @@ | ||
import os | ||
from typing import Dict, List | ||
from typing import List | ||
from ragengine.models import Document | ||
|
||
import faiss | ||
from llama_index.core import Document as LlamaDocument | ||
from llama_index.core import (StorageContext, VectorStoreIndex) | ||
from llama_index.core.storage.index_store import SimpleIndexStore | ||
from llama_index.vector_stores.faiss import FaissVectorStore | ||
|
||
from ragengine.models import Document | ||
from ragengine.inference.inference import Inference | ||
|
||
from ragengine.config import PERSIST_DIR | ||
|
||
from .base import BaseVectorStore | ||
from ragengine.embedding.base import BaseEmbeddingModel | ||
|
||
|
||
class FaissVectorStoreHandler(BaseVectorStore): | ||
def __init__(self, embedding_manager: BaseEmbeddingModel): | ||
self.embedding_manager = embedding_manager | ||
self.embed_model = self.embedding_manager.model | ||
def __init__(self, embedding_manager): | ||
super().__init__(embedding_manager) | ||
self.dimension = self.embedding_manager.get_embedding_dimension() | ||
# TODO: Consider allowing user custom indexing method (would require configmap?) e.g. | ||
""" | ||
# Choose the FAISS index type based on the provided index_method | ||
if index_method == 'FlatL2': | ||
faiss_index = faiss.IndexFlatL2(self.dimension) # L2 (Euclidean distance) index | ||
elif index_method == 'FlatIP': | ||
faiss_index = faiss.IndexFlatIP(self.dimension) # Inner product (cosine similarity) index | ||
elif index_method == 'IVFFlat': | ||
quantizer = faiss.IndexFlatL2(self.dimension) # Quantizer for IVF | ||
faiss_index = faiss.IndexIVFFlat(quantizer, self.dimension, 100) # IVF with flat quantization | ||
elif index_method == 'HNSW': | ||
faiss_index = faiss.IndexHNSWFlat(self.dimension, 32) # HNSW index with 32 neighbors | ||
else: | ||
raise ValueError(f"Unknown index method: {index_method}") | ||
""" | ||
self.index_map = {} # Used to store the in-memory index via namespace (e.g. index_name -> VectorStoreIndex) | ||
self.index_store = SimpleIndexStore() # Use to store global index metadata | ||
self.llm = Inference() | ||
|
||
def index_documents(self, index_name: str, documents: List[Document]) -> List[str]: | ||
""" | ||
Called by the /index endpoint to index documents into the specified index. | ||
|
||
If the index already exists, appends new documents to it. | ||
Otherwise, creates a new index with the provided documents. | ||
|
||
Args: | ||
index_name (str): The name of the index to update or create. | ||
documents (List[Document]): A list of documents to index. | ||
|
||
Returns: | ||
List[str]: A list of document IDs that were successfully indexed. | ||
""" | ||
if index_name in self.index_map: | ||
return self._append_documents_to_index(index_name, documents) | ||
else: | ||
return self._create_new_index(index_name, documents) | ||
|
||
def _append_documents_to_index(self, index_name: str, documents: List[Document]) -> List[str]: | ||
""" | ||
Appends documents to an existing index. | ||
|
||
Args: | ||
index_name (str): The name of the existing index. | ||
documents (List[Document]): A list of documents to append. | ||
|
||
Returns: | ||
List[str]: A list of document IDs that were successfully indexed. | ||
""" | ||
print(f"Index {index_name} already exists. Appending documents to existing index.") | ||
indexed_doc_ids = set() | ||
|
||
for doc in documents: | ||
doc_id = BaseVectorStore.generate_doc_id(doc.text) | ||
if not self.document_exists(index_name, doc_id): | ||
self.add_document_to_index(index_name, doc, doc_id) | ||
indexed_doc_ids.add(doc_id) | ||
else: | ||
print(f"Document {doc_id} already exists in index {index_name}. Skipping.") | ||
|
||
if indexed_doc_ids: | ||
self._persist(index_name) | ||
return list(indexed_doc_ids) | ||
|
||
def _create_new_index(self, index_name: str, documents: List[Document]) -> List[str]: | ||
""" | ||
Creates a new index with the provided documents. | ||
|
||
Args: | ||
index_name (str): The name of the new index to create. | ||
documents (List[Document]): A list of documents to index. | ||
|
||
Returns: | ||
List[str]: A list of document IDs that were successfully indexed. | ||
""" | ||
faiss_index = faiss.IndexFlatL2(self.dimension) | ||
vector_store = FaissVectorStore(faiss_index=faiss_index) | ||
storage_context = StorageContext.from_defaults(vector_store=vector_store) | ||
|
||
llama_docs = [] | ||
indexed_doc_ids = set() | ||
|
||
for doc in documents: | ||
doc_id = BaseVectorStore.generate_doc_id(doc.text) | ||
llama_doc = LlamaDocument(id_=doc_id, text=doc.text, metadata=doc.metadata) | ||
llama_docs.append(llama_doc) | ||
indexed_doc_ids.add(doc_id) | ||
|
||
if llama_docs: | ||
index = VectorStoreIndex.from_documents( | ||
llama_docs, | ||
storage_context=storage_context, | ||
embed_model=self.embed_model, | ||
# use_async=True # TODO: Indexing Process Performed Async | ||
) | ||
index.set_index_id(index_name) | ||
self.index_map[index_name] = index | ||
self.index_store.add_index_struct(index.index_struct) | ||
self._persist(index_name) | ||
return list(indexed_doc_ids) | ||
|
||
def add_document_to_index(self, index_name: str, document: Document, doc_id: str): | ||
"""Inserts a single document into the existing FAISS index.""" | ||
if index_name not in self.index_map: | ||
raise ValueError(f"No such index: '{index_name}' exists.") | ||
llama_doc = LlamaDocument(text=document.text, metadata=document.metadata, id_=doc_id) | ||
self.index_map[index_name].insert(llama_doc) | ||
|
||
def query(self, index_name: str, query: str, top_k: int, llm_params: dict): | ||
"""Queries the FAISS vector store.""" | ||
if index_name not in self.index_map: | ||
raise ValueError(f"No such index: '{index_name}' exists.") | ||
self.llm.set_params(llm_params) | ||
|
||
query_engine = self.index_map[index_name].as_query_engine(llm=self.llm, similarity_top_k=top_k) | ||
query_result = query_engine.query(query) | ||
return { | ||
"response": query_result.response, | ||
"source_nodes": [ | ||
{ | ||
"node_id": node.node_id, | ||
"text": node.text, | ||
"score": node.score, | ||
"metadata": node.metadata | ||
} | ||
for node in query_result.source_nodes | ||
], | ||
"metadata": query_result.metadata, | ||
} | ||
|
||
def list_all_indexed_documents(self) -> Dict[str, Dict[str, Dict[str, str]]]: | ||
"""Lists all documents in the vector store.""" | ||
return { | ||
index_name: { | ||
doc_info.ref_doc_id: { | ||
"text": doc_info.text, "hash": doc_info.hash | ||
} for doc_name, doc_info in vector_store_index.docstore.docs.items() | ||
} | ||
for index_name, vector_store_index in self.index_map.items() | ||
} | ||
|
||
def document_exists(self, index_name: str, doc_id: str) -> bool: | ||
"""Checks if a document exists in the vector store.""" | ||
if index_name not in self.index_map: | ||
print(f"No such index: '{index_name}' exists in vector store.") | ||
return False | ||
return doc_id in self.index_map[index_name].ref_doc_info | ||
|
||
def _persist_all(self): | ||
self.index_store.persist(os.path.join(PERSIST_DIR, "store.json")) # Persist global index store | ||
for idx in self.index_store.index_structs(): | ||
self._persist(idx.index_id) | ||
|
||
def _persist(self, index_name: str): | ||
"""Saves the existing FAISS index to disk.""" | ||
self.index_store.persist(os.path.join(PERSIST_DIR, "store.json")) # Persist global index store | ||
assert index_name in self.index_map, f"No such index: '{index_name}' exists." | ||
|
||
# Persist each index's storage context separately | ||
storage_context = self.index_map[index_name].storage_context | ||
storage_context.persist( | ||
persist_dir=os.path.join(PERSIST_DIR, index_name) | ||
) | ||
return self._create_index_common(index_name, documents, vector_store) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am thinking about a question: will the persist fails here? or we should rollback something
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
its possible, added a try-catch in persist function to help catch this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if persist fails, should we rollback the index_store? Or there will be some inconsistencies between index_store and storage