Skip to content

Commit

Permalink
Add [a]delete_by_metadata and [a]update_metadata methods to vector st…
Browse files Browse the repository at this point in the history
…ore (#89)

* added [a]delete_by_metadata_filter + tests

* added [a]update_metadata + tests

* [a]delete_by_metadata refuses empty filters + adjusted tests
  • Loading branch information
hemidactylus authored Oct 3, 2024
1 parent 8227c0b commit 3926ace
Show file tree
Hide file tree
Showing 3 changed files with 416 additions and 0 deletions.
162 changes: 162 additions & 0 deletions libs/astradb/langchain_astradb/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,64 @@ async def adelete(
)
return True

def delete_by_metadata_filter(
self,
filter: dict[str, Any], # noqa: A002
) -> int:
"""Delete all documents matching a certain metadata filtering condition.
This operation does not use the vector embeddings in any way, it simply
removes all documents whose metadata match the provided condition.
Args:
filter: Filter on the metadata to apply. The filter cannot be empty.
Returns:
An number expressing the amount of deleted documents.
"""
if not filter:
msg = (
"Method `delete_by_metadata_filter` does not accept an empty "
"filter. Use the `clear()` method if you really want to empty "
"the vector store."
)
raise ValueError(msg)
self.astra_env.ensure_db_setup()
metadata_parameter = self.filter_to_query(filter)
del_result = self.astra_env.collection.delete_many(
filter=metadata_parameter,
)
return del_result.deleted_count or 0

async def adelete_by_metadata_filter(
self,
filter: dict[str, Any], # noqa: A002
) -> int:
"""Delete all documents matching a certain metadata filtering condition.
This operation does not use the vector embeddings in any way, it simply
removes all documents whose metadata match the provided condition.
Args:
filter: Filter on the metadata to apply. The filter cannot be empty.
Returns:
An number expressing the amount of deleted documents.
"""
if not filter:
msg = (
"Method `delete_by_metadata_filter` does not accept an empty "
"filter. Use the `clear()` method if you really want to empty "
"the vector store."
)
raise ValueError(msg)
await self.astra_env.aensure_db_setup()
metadata_parameter = self.filter_to_query(filter)
del_result = await self.astra_env.async_collection.delete_many(
filter=metadata_parameter,
)
return del_result.deleted_count or 0

def delete_collection(self) -> None:
"""Completely delete the collection from the database.
Expand Down Expand Up @@ -1166,6 +1224,110 @@ async def _replace_document(
raise ValueError(msg)
return inserted_ids

def update_metadata(
self,
id_to_metadata: dict[str, dict],
*,
overwrite_concurrency: int | None = None,
) -> int:
"""Add/overwrite the metadata of existing documents.
For each document to update, the new metadata dictionary is appended
to the existing metadata, overwriting individual keys that existed already.
Args:
id_to_metadata: map from the Document IDs to modify to the
new metadata for updating. Keys in this dictionary that
do not correspond to an existing document will be silently ignored.
The values of this map are metadata dictionaries for updating
the documents. Any pre-existing metadata will be merged with
these entries, which take precedence on a key-by-key basis.
overwrite_concurrency: number of threads to process the updates
Defaults to the vector-store overall setting if not provided.
Returns:
the number of documents successfully updated (i.e. found to exist,
since even an update with `{}` as the new metadata counts as successful.)
"""
self.astra_env.ensure_db_setup()

_max_workers = overwrite_concurrency or self.bulk_insert_overwrite_concurrency
with ThreadPoolExecutor(
max_workers=_max_workers,
) as executor:

def _update_document(
id_md_pair: tuple[str, dict],
) -> UpdateResult:
document_id, update_metadata = id_md_pair
encoded_metadata = self.filter_to_query(update_metadata)
return self.astra_env.collection.update_one(
{"_id": document_id},
{"$set": encoded_metadata},
)

update_results = list(
executor.map(
_update_document,
id_to_metadata.items(),
)
)

return sum(u_res.update_info["n"] for u_res in update_results)

async def aupdate_metadata(
self,
id_to_metadata: dict[str, dict],
*,
overwrite_concurrency: int | None = None,
) -> int:
"""Add/overwrite the metadata of existing documents.
For each document to update, the new metadata dictionary is appended
to the existing metadata, overwriting individual keys that existed already.
Args:
id_to_metadata: map from the Document IDs to modify to the
new metadata for updating. Keys in this dictionary that
do not correspond to an existing document will be silently ignored.
The values of this map are metadata dictionaries for updating
the documents. Any pre-existing metadata will be merged with
these entries, which take precedence on a key-by-key basis.
overwrite_concurrency: number of threads to process the updates
Defaults to the vector-store overall setting if not provided.
Returns:
the number of documents successfully updated (i.e. found to exist,
since even an update with `{}` as the new metadata counts as successful.)
"""
await self.astra_env.aensure_db_setup()

sem = asyncio.Semaphore(
overwrite_concurrency or self.bulk_insert_overwrite_concurrency,
)

_async_collection = self.astra_env.async_collection

async def _update_document(
id_md_pair: tuple[str, dict],
) -> UpdateResult:
document_id, update_metadata = id_md_pair
encoded_metadata = self.filter_to_query(update_metadata)
async with sem:
return await _async_collection.update_one(
{"_id": document_id},
{"$set": encoded_metadata},
)

tasks = [
asyncio.create_task(_update_document(id_md_pair))
for id_md_pair in id_to_metadata.items()
]

update_results = await asyncio.gather(*tasks, return_exceptions=False)

return sum(u_res.update_info["n"] for u_res in update_results)

@override
def similarity_search(
self,
Expand Down
174 changes: 174 additions & 0 deletions libs/astradb/tests/integration_tests/test_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,180 @@ async def test_astradb_vectorstore_massive_insert_replace_async(
for doc, _, doc_id in full_results:
assert doc.page_content == expected_text_by_id[doc_id]

def test_astradb_vectorstore_delete_by_metadata_sync(
self,
vector_store_d2: AstraDBVectorStore,
) -> None:
"""Testing delete_by_metadata_filter."""
full_size = 400
# one in ... will be deleted
deletee_ratio = 3

documents = [
Document(
page_content="[1,1]", metadata={"deletee": doc_i % deletee_ratio == 0}
)
for doc_i in range(full_size)
]
num_deletees = len([doc for doc in documents if doc.metadata["deletee"]])

inserted_ids0 = vector_store_d2.add_documents(documents)
assert len(inserted_ids0) == len(documents)

d_result0 = vector_store_d2.delete_by_metadata_filter({"deletee": True})
assert d_result0 == num_deletees
count_on_store0 = len(
vector_store_d2.similarity_search("[1,1]", k=full_size + 1)
)
assert count_on_store0 == full_size - num_deletees

with pytest.raises(ValueError, match="does not accept an empty"):
vector_store_d2.delete_by_metadata_filter({})
count_on_store1 = len(
vector_store_d2.similarity_search("[1,1]", k=full_size + 1)
)
assert count_on_store1 == full_size - num_deletees

async def test_astradb_vectorstore_delete_by_metadata_async(
self,
vector_store_d2: AstraDBVectorStore,
) -> None:
"""Testing delete_by_metadata_filter, async version."""
full_size = 400
# one in ... will be deleted
deletee_ratio = 3

documents = [
Document(
page_content="[1,1]", metadata={"deletee": doc_i % deletee_ratio == 0}
)
for doc_i in range(full_size)
]
num_deletees = len([doc for doc in documents if doc.metadata["deletee"]])

inserted_ids0 = await vector_store_d2.aadd_documents(documents)
assert len(inserted_ids0) == len(documents)

d_result0 = await vector_store_d2.adelete_by_metadata_filter({"deletee": True})
assert d_result0 == num_deletees
count_on_store0 = len(
await vector_store_d2.asimilarity_search("[1,1]", k=full_size + 1)
)
assert count_on_store0 == full_size - num_deletees

with pytest.raises(ValueError, match="does not accept an empty"):
await vector_store_d2.adelete_by_metadata_filter({})
count_on_store1 = len(
await vector_store_d2.asimilarity_search("[1,1]", k=full_size + 1)
)
assert count_on_store1 == full_size - num_deletees

def test_astradb_vectorstore_update_metadata_sync(
self,
vector_store_d2: AstraDBVectorStore,
) -> None:
"""Testing update_metadata."""
# this should not exceed the max number of hits from ANN search
full_size = 20
# one in ... will be updated
updatee_ratio = 2
# set this to lower than full_size // updatee_ratio to test everything.
update_concurrency = 7

def doc_sorter(doc: Document) -> str:
return doc.id or ""

orig_documents0 = [
Document(
page_content="[1,1]",
metadata={
"to_update": doc_i % updatee_ratio == 0,
"inert_field": "I",
"updatee_field": "0",
},
id=f"um_doc_{doc_i}",
)
for doc_i in range(full_size)
]
orig_documents = sorted(orig_documents0, key=doc_sorter)

inserted_ids0 = vector_store_d2.add_documents(orig_documents)
assert len(inserted_ids0) == len(orig_documents)

update_map = {
f"um_doc_{doc_i}": {"updatee_field": "1", "to_update": False}
for doc_i in range(full_size)
if doc_i % updatee_ratio == 0
}
u_result0 = vector_store_d2.update_metadata(
update_map,
overwrite_concurrency=update_concurrency,
)
assert u_result0 == len(update_map)

all_documents = sorted(
vector_store_d2.similarity_search("[1,1]", k=full_size),
key=doc_sorter,
)
assert len(all_documents) == len(orig_documents)
for doc, orig_doc in zip(all_documents, orig_documents):
assert doc.id == orig_doc.id
if doc.id in update_map:
assert doc.metadata == orig_doc.metadata | update_map[doc.id]

async def test_astradb_vectorstore_update_metadata_async(
self,
vector_store_d2: AstraDBVectorStore,
) -> None:
"""Testing update_metadata, async version."""
# this should not exceed the max number of hits from ANN search
full_size = 20
# one in ... will be updated
updatee_ratio = 2
# set this to lower than full_size // updatee_ratio to test everything.
update_concurrency = 7

def doc_sorter(doc: Document) -> str:
return doc.id or ""

orig_documents0 = [
Document(
page_content="[1,1]",
metadata={
"to_update": doc_i % updatee_ratio == 0,
"inert_field": "I",
"updatee_field": "0",
},
id=f"um_doc_{doc_i}",
)
for doc_i in range(full_size)
]
orig_documents = sorted(orig_documents0, key=doc_sorter)

inserted_ids0 = await vector_store_d2.aadd_documents(orig_documents)
assert len(inserted_ids0) == len(orig_documents)

update_map = {
f"um_doc_{doc_i}": {"updatee_field": "1", "to_update": False}
for doc_i in range(full_size)
if doc_i % updatee_ratio == 0
}
u_result0 = await vector_store_d2.aupdate_metadata(
update_map,
overwrite_concurrency=update_concurrency,
)
assert u_result0 == len(update_map)

all_documents = sorted(
await vector_store_d2.asimilarity_search("[1,1]", k=full_size),
key=doc_sorter,
)
assert len(all_documents) == len(orig_documents)
for doc, orig_doc in zip(all_documents, orig_documents):
assert doc.id == orig_doc.id
if doc.id in update_map:
assert doc.metadata == orig_doc.metadata | update_map[doc.id]

def test_astradb_vectorstore_mmr_sync(
self,
vector_store_d2: AstraDBVectorStore,
Expand Down
Loading

0 comments on commit 3926ace

Please sign in to comment.