diff --git a/libs/astradb/langchain_astradb/vectorstores.py b/libs/astradb/langchain_astradb/vectorstores.py index 8b1576a..84b2083 100644 --- a/libs/astradb/langchain_astradb/vectorstores.py +++ b/libs/astradb/langchain_astradb/vectorstores.py @@ -827,6 +827,64 @@ async def adelete( ) return True + def delete_by_metadata_filter( + self, + filter: dict[str, Any], # noqa: A002 + ) -> int: + """Delete all documents matching a certain metadata filtering condition. + + This operation does not use the vector embeddings in any way, it simply + removes all documents whose metadata match the provided condition. + + Args: + filter: Filter on the metadata to apply. The filter cannot be empty. + + Returns: + An number expressing the amount of deleted documents. + """ + if not filter: + msg = ( + "Method `delete_by_metadata_filter` does not accept an empty " + "filter. Use the `clear()` method if you really want to empty " + "the vector store." + ) + raise ValueError(msg) + self.astra_env.ensure_db_setup() + metadata_parameter = self.filter_to_query(filter) + del_result = self.astra_env.collection.delete_many( + filter=metadata_parameter, + ) + return del_result.deleted_count or 0 + + async def adelete_by_metadata_filter( + self, + filter: dict[str, Any], # noqa: A002 + ) -> int: + """Delete all documents matching a certain metadata filtering condition. + + This operation does not use the vector embeddings in any way, it simply + removes all documents whose metadata match the provided condition. + + Args: + filter: Filter on the metadata to apply. The filter cannot be empty. + + Returns: + An number expressing the amount of deleted documents. + """ + if not filter: + msg = ( + "Method `delete_by_metadata_filter` does not accept an empty " + "filter. Use the `clear()` method if you really want to empty " + "the vector store." + ) + raise ValueError(msg) + await self.astra_env.aensure_db_setup() + metadata_parameter = self.filter_to_query(filter) + del_result = await self.astra_env.async_collection.delete_many( + filter=metadata_parameter, + ) + return del_result.deleted_count or 0 + def delete_collection(self) -> None: """Completely delete the collection from the database. @@ -1166,6 +1224,110 @@ async def _replace_document( raise ValueError(msg) return inserted_ids + def update_metadata( + self, + id_to_metadata: dict[str, dict], + *, + overwrite_concurrency: int | None = None, + ) -> int: + """Add/overwrite the metadata of existing documents. + + For each document to update, the new metadata dictionary is appended + to the existing metadata, overwriting individual keys that existed already. + + Args: + id_to_metadata: map from the Document IDs to modify to the + new metadata for updating. Keys in this dictionary that + do not correspond to an existing document will be silently ignored. + The values of this map are metadata dictionaries for updating + the documents. Any pre-existing metadata will be merged with + these entries, which take precedence on a key-by-key basis. + overwrite_concurrency: number of threads to process the updates + Defaults to the vector-store overall setting if not provided. + + Returns: + the number of documents successfully updated (i.e. found to exist, + since even an update with `{}` as the new metadata counts as successful.) + """ + self.astra_env.ensure_db_setup() + + _max_workers = overwrite_concurrency or self.bulk_insert_overwrite_concurrency + with ThreadPoolExecutor( + max_workers=_max_workers, + ) as executor: + + def _update_document( + id_md_pair: tuple[str, dict], + ) -> UpdateResult: + document_id, update_metadata = id_md_pair + encoded_metadata = self.filter_to_query(update_metadata) + return self.astra_env.collection.update_one( + {"_id": document_id}, + {"$set": encoded_metadata}, + ) + + update_results = list( + executor.map( + _update_document, + id_to_metadata.items(), + ) + ) + + return sum(u_res.update_info["n"] for u_res in update_results) + + async def aupdate_metadata( + self, + id_to_metadata: dict[str, dict], + *, + overwrite_concurrency: int | None = None, + ) -> int: + """Add/overwrite the metadata of existing documents. + + For each document to update, the new metadata dictionary is appended + to the existing metadata, overwriting individual keys that existed already. + + Args: + id_to_metadata: map from the Document IDs to modify to the + new metadata for updating. Keys in this dictionary that + do not correspond to an existing document will be silently ignored. + The values of this map are metadata dictionaries for updating + the documents. Any pre-existing metadata will be merged with + these entries, which take precedence on a key-by-key basis. + overwrite_concurrency: number of threads to process the updates + Defaults to the vector-store overall setting if not provided. + + Returns: + the number of documents successfully updated (i.e. found to exist, + since even an update with `{}` as the new metadata counts as successful.) + """ + await self.astra_env.aensure_db_setup() + + sem = asyncio.Semaphore( + overwrite_concurrency or self.bulk_insert_overwrite_concurrency, + ) + + _async_collection = self.astra_env.async_collection + + async def _update_document( + id_md_pair: tuple[str, dict], + ) -> UpdateResult: + document_id, update_metadata = id_md_pair + encoded_metadata = self.filter_to_query(update_metadata) + async with sem: + return await _async_collection.update_one( + {"_id": document_id}, + {"$set": encoded_metadata}, + ) + + tasks = [ + asyncio.create_task(_update_document(id_md_pair)) + for id_md_pair in id_to_metadata.items() + ] + + update_results = await asyncio.gather(*tasks, return_exceptions=False) + + return sum(u_res.update_info["n"] for u_res in update_results) + @override def similarity_search( self, diff --git a/libs/astradb/tests/integration_tests/test_vectorstore.py b/libs/astradb/tests/integration_tests/test_vectorstore.py index 721d033..9b5b2b2 100644 --- a/libs/astradb/tests/integration_tests/test_vectorstore.py +++ b/libs/astradb/tests/integration_tests/test_vectorstore.py @@ -819,6 +819,180 @@ async def test_astradb_vectorstore_massive_insert_replace_async( for doc, _, doc_id in full_results: assert doc.page_content == expected_text_by_id[doc_id] + def test_astradb_vectorstore_delete_by_metadata_sync( + self, + vector_store_d2: AstraDBVectorStore, + ) -> None: + """Testing delete_by_metadata_filter.""" + full_size = 400 + # one in ... will be deleted + deletee_ratio = 3 + + documents = [ + Document( + page_content="[1,1]", metadata={"deletee": doc_i % deletee_ratio == 0} + ) + for doc_i in range(full_size) + ] + num_deletees = len([doc for doc in documents if doc.metadata["deletee"]]) + + inserted_ids0 = vector_store_d2.add_documents(documents) + assert len(inserted_ids0) == len(documents) + + d_result0 = vector_store_d2.delete_by_metadata_filter({"deletee": True}) + assert d_result0 == num_deletees + count_on_store0 = len( + vector_store_d2.similarity_search("[1,1]", k=full_size + 1) + ) + assert count_on_store0 == full_size - num_deletees + + with pytest.raises(ValueError, match="does not accept an empty"): + vector_store_d2.delete_by_metadata_filter({}) + count_on_store1 = len( + vector_store_d2.similarity_search("[1,1]", k=full_size + 1) + ) + assert count_on_store1 == full_size - num_deletees + + async def test_astradb_vectorstore_delete_by_metadata_async( + self, + vector_store_d2: AstraDBVectorStore, + ) -> None: + """Testing delete_by_metadata_filter, async version.""" + full_size = 400 + # one in ... will be deleted + deletee_ratio = 3 + + documents = [ + Document( + page_content="[1,1]", metadata={"deletee": doc_i % deletee_ratio == 0} + ) + for doc_i in range(full_size) + ] + num_deletees = len([doc for doc in documents if doc.metadata["deletee"]]) + + inserted_ids0 = await vector_store_d2.aadd_documents(documents) + assert len(inserted_ids0) == len(documents) + + d_result0 = await vector_store_d2.adelete_by_metadata_filter({"deletee": True}) + assert d_result0 == num_deletees + count_on_store0 = len( + await vector_store_d2.asimilarity_search("[1,1]", k=full_size + 1) + ) + assert count_on_store0 == full_size - num_deletees + + with pytest.raises(ValueError, match="does not accept an empty"): + await vector_store_d2.adelete_by_metadata_filter({}) + count_on_store1 = len( + await vector_store_d2.asimilarity_search("[1,1]", k=full_size + 1) + ) + assert count_on_store1 == full_size - num_deletees + + def test_astradb_vectorstore_update_metadata_sync( + self, + vector_store_d2: AstraDBVectorStore, + ) -> None: + """Testing update_metadata.""" + # this should not exceed the max number of hits from ANN search + full_size = 20 + # one in ... will be updated + updatee_ratio = 2 + # set this to lower than full_size // updatee_ratio to test everything. + update_concurrency = 7 + + def doc_sorter(doc: Document) -> str: + return doc.id or "" + + orig_documents0 = [ + Document( + page_content="[1,1]", + metadata={ + "to_update": doc_i % updatee_ratio == 0, + "inert_field": "I", + "updatee_field": "0", + }, + id=f"um_doc_{doc_i}", + ) + for doc_i in range(full_size) + ] + orig_documents = sorted(orig_documents0, key=doc_sorter) + + inserted_ids0 = vector_store_d2.add_documents(orig_documents) + assert len(inserted_ids0) == len(orig_documents) + + update_map = { + f"um_doc_{doc_i}": {"updatee_field": "1", "to_update": False} + for doc_i in range(full_size) + if doc_i % updatee_ratio == 0 + } + u_result0 = vector_store_d2.update_metadata( + update_map, + overwrite_concurrency=update_concurrency, + ) + assert u_result0 == len(update_map) + + all_documents = sorted( + vector_store_d2.similarity_search("[1,1]", k=full_size), + key=doc_sorter, + ) + assert len(all_documents) == len(orig_documents) + for doc, orig_doc in zip(all_documents, orig_documents): + assert doc.id == orig_doc.id + if doc.id in update_map: + assert doc.metadata == orig_doc.metadata | update_map[doc.id] + + async def test_astradb_vectorstore_update_metadata_async( + self, + vector_store_d2: AstraDBVectorStore, + ) -> None: + """Testing update_metadata, async version.""" + # this should not exceed the max number of hits from ANN search + full_size = 20 + # one in ... will be updated + updatee_ratio = 2 + # set this to lower than full_size // updatee_ratio to test everything. + update_concurrency = 7 + + def doc_sorter(doc: Document) -> str: + return doc.id or "" + + orig_documents0 = [ + Document( + page_content="[1,1]", + metadata={ + "to_update": doc_i % updatee_ratio == 0, + "inert_field": "I", + "updatee_field": "0", + }, + id=f"um_doc_{doc_i}", + ) + for doc_i in range(full_size) + ] + orig_documents = sorted(orig_documents0, key=doc_sorter) + + inserted_ids0 = await vector_store_d2.aadd_documents(orig_documents) + assert len(inserted_ids0) == len(orig_documents) + + update_map = { + f"um_doc_{doc_i}": {"updatee_field": "1", "to_update": False} + for doc_i in range(full_size) + if doc_i % updatee_ratio == 0 + } + u_result0 = await vector_store_d2.aupdate_metadata( + update_map, + overwrite_concurrency=update_concurrency, + ) + assert u_result0 == len(update_map) + + all_documents = sorted( + await vector_store_d2.asimilarity_search("[1,1]", k=full_size), + key=doc_sorter, + ) + assert len(all_documents) == len(orig_documents) + for doc, orig_doc in zip(all_documents, orig_documents): + assert doc.id == orig_doc.id + if doc.id in update_map: + assert doc.metadata == orig_doc.metadata | update_map[doc.id] + def test_astradb_vectorstore_mmr_sync( self, vector_store_d2: AstraDBVectorStore, diff --git a/libs/astradb/tests/integration_tests/test_vectorstore_autodetect.py b/libs/astradb/tests/integration_tests/test_vectorstore_autodetect.py index 60caf47..94f84f8 100644 --- a/libs/astradb/tests/integration_tests/test_vectorstore_autodetect.py +++ b/libs/astradb/tests/integration_tests/test_vectorstore_autodetect.py @@ -92,6 +92,26 @@ def test_autodetect_flat_novectorize_crud( results2 = ad_store.similarity_search("[-1,-1]", k=3, filter={"q2": "Q2"}) assert results2 == [Document(id=id4, page_content=pc4, metadata=md4)] + # delete by metadata + del_by_md = ad_store.delete_by_metadata_filter(filter={"q2": "Q2"}) + assert del_by_md is not None + assert del_by_md == 1 + results2n = ad_store.similarity_search("[-1,-1]", k=3, filter={"q2": "Q2"}) + assert results2n == [] + + def doc_sorter(doc: Document) -> str: + return doc.id or "" + + # update metadata + ad_store.update_metadata( + {"1": {"m1": "A", "mZ": "Z"}, "2": {"m1": "B", "mZ": "Z"}} + ) + matches_z = ad_store.similarity_search("[-1,-1]", k=3, filter={"mZ": "Z"}) + assert len(matches_z) == 2 + s_matches_z = sorted(matches_z, key=doc_sorter) + assert s_matches_z[0].metadata == {"m1": "A", "m2": "x", "mZ": "Z"} + assert s_matches_z[1].metadata == {"m1": "B", "m2": "y", "mZ": "Z"} + def test_autodetect_default_novectorize_crud( self, astra_db_credentials: AstraDBCredentials, @@ -148,6 +168,26 @@ def test_autodetect_default_novectorize_crud( results2 = ad_store.similarity_search("[9,10]", k=3, filter={"q2": "Q2"}) assert results2 == [Document(id=id4, page_content=pc4, metadata=md4)] + # delete by metadata + del_by_md = ad_store.delete_by_metadata_filter(filter={"q2": "Q2"}) + assert del_by_md is not None + assert del_by_md == 1 + results2n = ad_store.similarity_search("[-1,-1]", k=3, filter={"q2": "Q2"}) + assert results2n == [] + + def doc_sorter(doc: Document) -> str: + return doc.id or "" + + # update metadata + ad_store.update_metadata( + {"1": {"m1": "A", "mZ": "Z"}, "2": {"m1": "B", "mZ": "Z"}} + ) + matches_z = ad_store.similarity_search("[-1,-1]", k=3, filter={"mZ": "Z"}) + assert len(matches_z) == 2 + s_matches_z = sorted(matches_z, key=doc_sorter) + assert s_matches_z[0].metadata == {"m1": "A", "m2": "x", "mZ": "Z"} + assert s_matches_z[1].metadata == {"m1": "B", "m2": "y", "mZ": "Z"} + def test_autodetect_flat_vectorize_crud( self, astra_db_credentials: AstraDBCredentials, @@ -208,6 +248,26 @@ def test_autodetect_flat_vectorize_crud( results2 = ad_store.similarity_search("query", k=3, filter={"q2": "Q2"}) assert results2 == [Document(id=id4, page_content=pc4, metadata=md4)] + # delete by metadata + del_by_md = ad_store.delete_by_metadata_filter(filter={"q2": "Q2"}) + assert del_by_md is not None + assert del_by_md == 1 + results2n = ad_store.similarity_search("[-1,-1]", k=3, filter={"q2": "Q2"}) + assert results2n == [] + + def doc_sorter(doc: Document) -> str: + return doc.id or "" + + # update metadata + ad_store.update_metadata( + {"1": {"m1": "A", "mZ": "Z"}, "2": {"m1": "B", "mZ": "Z"}} + ) + matches_z = ad_store.similarity_search("[-1,-1]", k=3, filter={"mZ": "Z"}) + assert len(matches_z) == 2 + s_matches_z = sorted(matches_z, key=doc_sorter) + assert s_matches_z[0].metadata == {"m1": "A", "m2": "x", "mZ": "Z"} + assert s_matches_z[1].metadata == {"m1": "B", "m2": "y", "mZ": "Z"} + def test_autodetect_default_vectorize_crud( self, *, @@ -266,6 +326,26 @@ def test_autodetect_default_vectorize_crud( results2 = ad_store.similarity_search("query", k=3, filter={"q2": "Q2"}) assert results2 == [Document(id=id4, page_content=pc4, metadata=md4)] + # delete by metadata + del_by_md = ad_store.delete_by_metadata_filter(filter={"q2": "Q2"}) + assert del_by_md is not None + assert del_by_md == 1 + results2n = ad_store.similarity_search("[-1,-1]", k=3, filter={"q2": "Q2"}) + assert results2n == [] + + def doc_sorter(doc: Document) -> str: + return doc.id or "" + + # update metadata + ad_store.update_metadata( + {"1": {"m1": "A", "mZ": "Z"}, "2": {"m1": "B", "mZ": "Z"}} + ) + matches_z = ad_store.similarity_search("[-1,-1]", k=3, filter={"mZ": "Z"}) + assert len(matches_z) == 2 + s_matches_z = sorted(matches_z, key=doc_sorter) + assert s_matches_z[0].metadata == {"m1": "A", "m2": "x", "mZ": "Z"} + assert s_matches_z[1].metadata == {"m1": "B", "m2": "y", "mZ": "Z"} + def test_failed_docs_autodetect_flat_novectorize_crud( self, astra_db_credentials: AstraDBCredentials,