Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto Hashing ID for VectorDB Classes (#4746) #4789

Merged
merged 1 commit into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion autogen/agentchat/contrib/vectordb/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import hashlib
import os
from typing import (
Any,
Callable,
Expand All @@ -16,6 +18,8 @@
Vector = Union[Sequence[float], Sequence[int]]
ItemID = Union[str, int] # chromadb doesn't support int ids, VikingDB does

HASH_LENGTH = int(os.environ.get("HASH_LENGTH", 8))


class Document(TypedDict):
"""A Document is a record in the vector database.
Expand All @@ -26,7 +30,7 @@ class Document(TypedDict):
embedding: Vector, Optional | the vector representation of the content.
"""

id: ItemID
id: Optional[ItemID]
content: str
metadata: Optional[Metadata]
embedding: Optional[Vector]
Expand Down Expand Up @@ -108,6 +112,19 @@ def delete_collection(self, collection_name: str) -> Any:
"""
...

def generate_chunk_ids(chunks: List[str], hash_length: int = HASH_LENGTH) -> List[ItemID]:
"""
Generate chunk IDs to ensure non-duplicate uploads.

Args:
chunks (list): A list of chunks (strings) to hash.
hash_length (int): The desired length of the hash.

Returns:
list: A list of generated chunk IDs.
"""
return [hashlib.blake2b(chunk.encode("utf-8")).hexdigest()[:hash_length] for chunk in chunks]

def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False, **kwargs) -> None:
"""
Insert documents into the collection of the vector database.
Expand Down
71 changes: 49 additions & 22 deletions autogen/agentchat/contrib/vectordb/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,17 @@ def _wait_for_document(self, collection: Collection, index_name: str, doc: Docum
if query_result and query_result[0][0]["_id"] == doc["id"]:
return
sleep(_DELAY)

raise TimeoutError(f"Document {self.index_name} is not ready!")
if (
query_result
and float(query_result[0][1]) == 1.0
and query_result[0][0].get("metadata") == doc.get("metadata")
):
# Handles edge case where document is uploaded with a specific user-generated ID, then the identical content is uploaded with a hash generated ID.
logger.warning(
f"""Documents may be ready, the search has found identical content with a different ID and {"identical" if query_result[0][0].get("metadata") == doc.get("metadata") else "different"} metadata. Duplicate ID: {str(query_result[0][0]["_id"])}"""
)
else:
raise TimeoutError(f"Document {self.index_name} is not ready!")

def _get_embedding_size(self):
return len(self.embedding_function(_SAMPLE_SENTENCE)[0])
Expand Down Expand Up @@ -275,33 +284,49 @@ def insert_docs(

For large numbers of Documents, insertion is performed in batches.

Documents are recommended to not have an ID field, as the method will generate Hashed ID's for them.

Args:
docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
docs: List[Document] | A list of documents. Each document is a TypedDict `Document`, which may contain an ID. Documents without ID's will have them generated.
collection_name: str | The name of the collection. Default is None.
upsert: bool | Whether to update the document if it exists. Default is False.
batch_size: Number of documents to be inserted in each batch
kwargs: Additional keyword arguments. Use `hash_length` to set the length of the hash generated ID's, use `overwrite_ids` to overwrite existing ID's with Hashed Values.
"""
hash_length = kwargs.get("hash_length")
overwrite_ids = kwargs.get("overwrite_ids", False)

if any(doc.get("content") is None for doc in docs):
raise ValueError("The document content is required.")

if not docs:
logger.info("No documents to insert.")
return

docs = deepcopy(docs)
collection = self.get_collection(collection_name)

assert (
len({doc.get("id") is None for doc in docs}) == 1
), "Documents provided must all have ID's or all not have ID's"

if docs[0].get("id") is None or overwrite_ids:
logger.info("No id field in the documents. The documents will be inserted with Hash generated IDs.")
content = [doc["content"] for doc in docs]
ids = (
self.generate_chunk_ids(content, hash_length=hash_length)
if hash_length
else self.generate_chunk_ids(content)
)
docs = [{**doc, "id": id} for doc, id in zip(docs, ids)]

if upsert:
self.update_docs(docs, collection.name, upsert=True)

else:
# Sanity checking the first document
if docs[0].get("content") is None:
raise ValueError("The document content is required.")
if docs[0].get("id") is None:
raise ValueError("The document id is required.")

input_ids = set()
result_ids = set()
id_batch = []
text_batch = []
metadata_batch = []
size = 0
i = 0
input_ids, result_ids = set(), set()
id_batch, text_batch, metadata_batch = [], [], []
size, i = 0, 0
for doc in docs:
id = doc["id"]
text = doc["content"]
Expand All @@ -314,9 +339,7 @@ def insert_docs(
if (i + 1) % batch_size == 0 or size >= 47_000_000:
result_ids.update(self._insert_batch(collection, text_batch, metadata_batch, id_batch))
input_ids.update(id_batch)
id_batch = []
text_batch = []
metadata_batch = []
id_batch, text_batch, metadata_batch = [], [], []
size = 0
i += 1
if text_batch:
Expand Down Expand Up @@ -365,7 +388,8 @@ def _insert_batch(
]
# insert the documents in MongoDB Atlas
insert_result = collection.insert_many(to_insert) # type: ignore
return insert_result.inserted_ids # TODO Remove this. Replace by log like update_docs
# TODO Remove this. Replace by log like update_docs
return insert_result.inserted_ids

def update_docs(self, docs: List[Document], collection_name: str = None, **kwargs: Any) -> None:
"""Update documents, including their embeddings, in the Collection.
Expand All @@ -375,11 +399,14 @@ def update_docs(self, docs: List[Document], collection_name: str = None, **kwarg
Uses deepcopy to avoid changing docs.

Args:
docs: List[Document] | A list of documents.
docs: List[Document] | A list of documents, with ID, to ensure the correct document is updated.
collection_name: str | The name of the collection. Default is None.
kwargs: Any | Use upsert=True` to insert documents whose ids are not present in collection.
"""

provided_doc_count = len(docs)
docs = [doc for doc in docs if doc.get("id") is not None]
if len(docs) != provided_doc_count:
logger.info(f"{provided_doc_count - len(docs)} will not be updated, as they did not contain an ID")
n_docs = len(docs)
logger.info(f"Preparing to embed and update {n_docs=}")
# Compute the embeddings
Expand Down
15 changes: 14 additions & 1 deletion autogen/agentchat/contrib/vectordb/qdrant.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import hashlib
import logging
import os
import uuid
from typing import Callable, List, Optional, Sequence, Tuple, Union

from .base import Document, ItemID, QueryResults, VectorDB
Expand Down Expand Up @@ -155,6 +156,18 @@ def delete_collection(self, collection_name: str) -> None:
"""
return self.client.delete_collection(collection_name)

def generate_chunk_ids(chunks: List[str]) -> List[ItemID]:
"""
Generate chunk IDs to ensure non-duplicate uploads.

Args:
chunks (list): A list of chunks (strings) to hash.

Returns:
list: A list of generated chunk IDs.
"""
return [str(uuid.UUID(hex=hashlib.md5(chunk.encode("utf-8")).hexdigest())) for chunk in chunks]

def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False) -> None:
"""
Insert documents into the collection of the vector database.
Expand Down
55 changes: 55 additions & 0 deletions test/agentchat/contrib/vectordb/test_mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,28 @@ def example_documents() -> List[Document]:
]


@pytest.fixture
def id_less_example_documents() -> List[Document]:
"""No ID for Hashing Input Test"""
return [
Document(content="Stars are Big.", metadata={"a": 1}),
Document(content="Atoms are Small.", metadata={"b": 1}),
Document(content="Clouds are White.", metadata={"c": 1}),
Document(content="Grass is Green.", metadata={"d": 1, "e": 2}),
]


@pytest.fixture
def id_mix_example_documents() -> List[Document]:
"""No ID for Hashing Input Test"""
return [
Document(id="123", content="Stars are Big.", metadata={"a": 1}),
Document(content="Atoms are Small.", metadata={"b": 1}),
Document(id="321", content="Clouds are White.", metadata={"c": 1}),
Document(content="Grass is Green.", metadata={"d": 1, "e": 2}),
]


@pytest.fixture
def db_with_indexed_clxn(collection_name):
"""VectorDB with a collection created immediately"""
Expand Down Expand Up @@ -212,6 +234,39 @@ def test_insert_docs(db, collection_name, example_documents):
assert len(found[0]["embedding"]) == 384


def test_insert_docs_no_id(db, collection_name, id_less_example_documents):
# Test that there's an active collection
with pytest.raises(ValueError) as exc:
db.insert_docs(id_less_example_documents)
assert "No collection is specified" in str(exc.value)

# Create a collection
db.delete_collection(collection_name)
collection = db.create_collection(collection_name)

# Insert example documents
db.insert_docs(id_less_example_documents, collection_name=collection_name)
found = list(collection.find({}))
assert len(found) == len(id_less_example_documents)
# Check that documents have correct fields, including "_id" and "embedding" but not "id"
assert all([set(doc.keys()) == {"_id", "content", "metadata", "embedding"} for doc in found])
# Check ids
hash_values = set(db.generate_chunk_ids([content.get("content") for content in id_less_example_documents]))
assert {doc["_id"] for doc in found} == hash_values
# Check embedding lengths
assert len(found[0]["embedding"]) == 384


def test_insert_docs_mix_id(db, collection_name, id_mix_example_documents):
# Test that there's an active collection
with pytest.raises(ValueError) as exc:
db.insert_docs(id_mix_example_documents)
assert "No collection is specified" in str(exc.value)
# Test that insert_docs does not accept mixed ID inserts
with pytest.raises(AssertionError, match="Documents provided must all have ID's or all not have ID's"):
db.insert_docs(id_mix_example_documents, collection_name, upsert=True)


def test_update_docs(db_with_indexed_clxn, example_documents):
db, collection = db_with_indexed_clxn
# Use update_docs to insert new documents
Expand Down
Loading