diff --git a/autogen/agentchat/contrib/vectordb/base.py b/autogen/agentchat/contrib/vectordb/base.py index c0519b5717f1..05cb74505680 100644 --- a/autogen/agentchat/contrib/vectordb/base.py +++ b/autogen/agentchat/contrib/vectordb/base.py @@ -1,3 +1,5 @@ +import hashlib +import os from typing import ( Any, Callable, @@ -16,6 +18,8 @@ Vector = Union[Sequence[float], Sequence[int]] ItemID = Union[str, int] # chromadb doesn't support int ids, VikingDB does +HASH_LENGTH = int(os.environ.get("HASH_LENGTH", 8)) + class Document(TypedDict): """A Document is a record in the vector database. @@ -26,7 +30,7 @@ class Document(TypedDict): embedding: Vector, Optional | the vector representation of the content. """ - id: ItemID + id: Optional[ItemID] content: str metadata: Optional[Metadata] embedding: Optional[Vector] @@ -108,6 +112,19 @@ def delete_collection(self, collection_name: str) -> Any: """ ... + def generate_chunk_ids(chunks: List[str], hash_length: int = HASH_LENGTH) -> List[ItemID]: + """ + Generate chunk IDs to ensure non-duplicate uploads. + + Args: + chunks (list): A list of chunks (strings) to hash. + hash_length (int): The desired length of the hash. + + Returns: + list: A list of generated chunk IDs. + """ + return [hashlib.blake2b(chunk.encode("utf-8")).hexdigest()[:hash_length] for chunk in chunks] + def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False, **kwargs) -> None: """ Insert documents into the collection of the vector database. diff --git a/autogen/agentchat/contrib/vectordb/mongodb.py b/autogen/agentchat/contrib/vectordb/mongodb.py index 2e0580fe826b..290bbbe1f958 100644 --- a/autogen/agentchat/contrib/vectordb/mongodb.py +++ b/autogen/agentchat/contrib/vectordb/mongodb.py @@ -123,8 +123,17 @@ def _wait_for_document(self, collection: Collection, index_name: str, doc: Docum if query_result and query_result[0][0]["_id"] == doc["id"]: return sleep(_DELAY) - - raise TimeoutError(f"Document {self.index_name} is not ready!") + if ( + query_result + and float(query_result[0][1]) == 1.0 + and query_result[0][0].get("metadata") == doc.get("metadata") + ): + # Handles edge case where document is uploaded with a specific user-generated ID, then the identical content is uploaded with a hash generated ID. + logger.warning( + f"""Documents may be ready, the search has found identical content with a different ID and {"identical" if query_result[0][0].get("metadata") == doc.get("metadata") else "different"} metadata. Duplicate ID: {str(query_result[0][0]["_id"])}""" + ) + else: + raise TimeoutError(f"Document {self.index_name} is not ready!") def _get_embedding_size(self): return len(self.embedding_function(_SAMPLE_SENTENCE)[0]) @@ -275,33 +284,49 @@ def insert_docs( For large numbers of Documents, insertion is performed in batches. + Documents are recommended to not have an ID field, as the method will generate Hashed ID's for them. + Args: - docs: List[Document] | A list of documents. Each document is a TypedDict `Document`. + docs: List[Document] | A list of documents. Each document is a TypedDict `Document`, which may contain an ID. Documents without ID's will have them generated. collection_name: str | The name of the collection. Default is None. upsert: bool | Whether to update the document if it exists. Default is False. batch_size: Number of documents to be inserted in each batch + kwargs: Additional keyword arguments. Use `hash_length` to set the length of the hash generated ID's, use `overwrite_ids` to overwrite existing ID's with Hashed Values. """ + hash_length = kwargs.get("hash_length") + overwrite_ids = kwargs.get("overwrite_ids", False) + + if any(doc.get("content") is None for doc in docs): + raise ValueError("The document content is required.") + if not docs: logger.info("No documents to insert.") return + docs = deepcopy(docs) collection = self.get_collection(collection_name) + + assert ( + len({doc.get("id") is None for doc in docs}) == 1 + ), "Documents provided must all have ID's or all not have ID's" + + if docs[0].get("id") is None or overwrite_ids: + logger.info("No id field in the documents. The documents will be inserted with Hash generated IDs.") + content = [doc["content"] for doc in docs] + ids = ( + self.generate_chunk_ids(content, hash_length=hash_length) + if hash_length + else self.generate_chunk_ids(content) + ) + docs = [{**doc, "id": id} for doc, id in zip(docs, ids)] + if upsert: self.update_docs(docs, collection.name, upsert=True) + else: - # Sanity checking the first document - if docs[0].get("content") is None: - raise ValueError("The document content is required.") - if docs[0].get("id") is None: - raise ValueError("The document id is required.") - - input_ids = set() - result_ids = set() - id_batch = [] - text_batch = [] - metadata_batch = [] - size = 0 - i = 0 + input_ids, result_ids = set(), set() + id_batch, text_batch, metadata_batch = [], [], [] + size, i = 0, 0 for doc in docs: id = doc["id"] text = doc["content"] @@ -314,9 +339,7 @@ def insert_docs( if (i + 1) % batch_size == 0 or size >= 47_000_000: result_ids.update(self._insert_batch(collection, text_batch, metadata_batch, id_batch)) input_ids.update(id_batch) - id_batch = [] - text_batch = [] - metadata_batch = [] + id_batch, text_batch, metadata_batch = [], [], [] size = 0 i += 1 if text_batch: @@ -365,7 +388,8 @@ def _insert_batch( ] # insert the documents in MongoDB Atlas insert_result = collection.insert_many(to_insert) # type: ignore - return insert_result.inserted_ids # TODO Remove this. Replace by log like update_docs + # TODO Remove this. Replace by log like update_docs + return insert_result.inserted_ids def update_docs(self, docs: List[Document], collection_name: str = None, **kwargs: Any) -> None: """Update documents, including their embeddings, in the Collection. @@ -375,11 +399,14 @@ def update_docs(self, docs: List[Document], collection_name: str = None, **kwarg Uses deepcopy to avoid changing docs. Args: - docs: List[Document] | A list of documents. + docs: List[Document] | A list of documents, with ID, to ensure the correct document is updated. collection_name: str | The name of the collection. Default is None. kwargs: Any | Use upsert=True` to insert documents whose ids are not present in collection. """ - + provided_doc_count = len(docs) + docs = [doc for doc in docs if doc.get("id") is not None] + if len(docs) != provided_doc_count: + logger.info(f"{provided_doc_count - len(docs)} will not be updated, as they did not contain an ID") n_docs = len(docs) logger.info(f"Preparing to embed and update {n_docs=}") # Compute the embeddings diff --git a/autogen/agentchat/contrib/vectordb/qdrant.py b/autogen/agentchat/contrib/vectordb/qdrant.py index 2c5194a9f73f..dbbbb93dbdd7 100644 --- a/autogen/agentchat/contrib/vectordb/qdrant.py +++ b/autogen/agentchat/contrib/vectordb/qdrant.py @@ -1,6 +1,7 @@ import abc +import hashlib import logging -import os +import uuid from typing import Callable, List, Optional, Sequence, Tuple, Union from .base import Document, ItemID, QueryResults, VectorDB @@ -155,6 +156,18 @@ def delete_collection(self, collection_name: str) -> None: """ return self.client.delete_collection(collection_name) + def generate_chunk_ids(chunks: List[str]) -> List[ItemID]: + """ + Generate chunk IDs to ensure non-duplicate uploads. + + Args: + chunks (list): A list of chunks (strings) to hash. + + Returns: + list: A list of generated chunk IDs. + """ + return [str(uuid.UUID(hex=hashlib.md5(chunk.encode("utf-8")).hexdigest())) for chunk in chunks] + def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False) -> None: """ Insert documents into the collection of the vector database. diff --git a/test/agentchat/contrib/vectordb/test_mongodb.py b/test/agentchat/contrib/vectordb/test_mongodb.py index 3ae1ed572591..a09f45295c56 100644 --- a/test/agentchat/contrib/vectordb/test_mongodb.py +++ b/test/agentchat/contrib/vectordb/test_mongodb.py @@ -107,6 +107,28 @@ def example_documents() -> List[Document]: ] +@pytest.fixture +def id_less_example_documents() -> List[Document]: + """No ID for Hashing Input Test""" + return [ + Document(content="Stars are Big.", metadata={"a": 1}), + Document(content="Atoms are Small.", metadata={"b": 1}), + Document(content="Clouds are White.", metadata={"c": 1}), + Document(content="Grass is Green.", metadata={"d": 1, "e": 2}), + ] + + +@pytest.fixture +def id_mix_example_documents() -> List[Document]: + """No ID for Hashing Input Test""" + return [ + Document(id="123", content="Stars are Big.", metadata={"a": 1}), + Document(content="Atoms are Small.", metadata={"b": 1}), + Document(id="321", content="Clouds are White.", metadata={"c": 1}), + Document(content="Grass is Green.", metadata={"d": 1, "e": 2}), + ] + + @pytest.fixture def db_with_indexed_clxn(collection_name): """VectorDB with a collection created immediately""" @@ -212,6 +234,39 @@ def test_insert_docs(db, collection_name, example_documents): assert len(found[0]["embedding"]) == 384 +def test_insert_docs_no_id(db, collection_name, id_less_example_documents): + # Test that there's an active collection + with pytest.raises(ValueError) as exc: + db.insert_docs(id_less_example_documents) + assert "No collection is specified" in str(exc.value) + + # Create a collection + db.delete_collection(collection_name) + collection = db.create_collection(collection_name) + + # Insert example documents + db.insert_docs(id_less_example_documents, collection_name=collection_name) + found = list(collection.find({})) + assert len(found) == len(id_less_example_documents) + # Check that documents have correct fields, including "_id" and "embedding" but not "id" + assert all([set(doc.keys()) == {"_id", "content", "metadata", "embedding"} for doc in found]) + # Check ids + hash_values = set(db.generate_chunk_ids([content.get("content") for content in id_less_example_documents])) + assert {doc["_id"] for doc in found} == hash_values + # Check embedding lengths + assert len(found[0]["embedding"]) == 384 + + +def test_insert_docs_mix_id(db, collection_name, id_mix_example_documents): + # Test that there's an active collection + with pytest.raises(ValueError) as exc: + db.insert_docs(id_mix_example_documents) + assert "No collection is specified" in str(exc.value) + # Test that insert_docs does not accept mixed ID inserts + with pytest.raises(AssertionError, match="Documents provided must all have ID's or all not have ID's"): + db.insert_docs(id_mix_example_documents, collection_name, upsert=True) + + def test_update_docs(db_with_indexed_clxn, example_documents): db, collection = db_with_indexed_clxn # Use update_docs to insert new documents