diff --git a/libs/astradb/langchain_astradb/graph_vectorstores.py b/libs/astradb/langchain_astradb/graph_vectorstores.py index f9f00e5..d796915 100644 --- a/libs/astradb/langchain_astradb/graph_vectorstores.py +++ b/libs/astradb/langchain_astradb/graph_vectorstores.py @@ -39,7 +39,7 @@ class _Edge: # NOTE: Conversion to string is necessary -# becasue AstraDB doesn't support matching on arrays of tuples +# because AstraDB doesn't support matching on arrays of tuples def _tag_to_str(kind: str, tag: str) -> str: return f"{kind}:{tag}" @@ -134,12 +134,11 @@ def from_documents( cls: type[AstraDBGraphVectorStore], documents: Iterable[Document], embedding: Embeddings, - ids: Iterable[str] | None = None, **kwargs: Any, ) -> AstraDBGraphVectorStore: """Return GraphVectorStore initialized from documents and embeddings.""" store = cls(embedding, **kwargs) - store.add_documents(documents, ids=ids) + store.add_documents(documents) return store @override @@ -248,11 +247,21 @@ def visit_targets(d: int, targets: Sequence[Document]) -> None: return visited_docs - def _filter_to_metadata(self, filter_dict: dict[str, Any] | None) -> dict[str, Any]: - if filter_dict is None: - return {} + def filter_to_query(self, filter_dict: dict[str, Any] | None) -> dict[str, Any]: + """Prepare a query for use on DB based on metadata filter. + + Encode an "abstract" filter clause on metadata into a query filter + condition aware of the collection schema choice. - return self.vectorstore.document_codec.encode_filter(filter_dict) + Args: + filter_dict: a metadata condition in the form {"field": "value"} + or related. + + Returns: + the corresponding mapping ready for use in queries, + aware of the details of the schema used to encode the document on DB. + """ + return self.vectorstore.filter_to_query(filter_dict) def _get_outgoing_tags( self, @@ -318,7 +327,7 @@ def get_adjacent(tags: set[str]) -> Iterable[_Edge]: for tag in tags: m_filter = (metadata_filter or {}).copy() m_filter[self.link_from_metadata_key] = tag - metadata_parameter = self._filter_to_metadata(m_filter) + metadata_parameter = self.filter_to_query(m_filter) hits = list( self.astra_env.collection.find( @@ -382,7 +391,7 @@ def fetch_neighborhood(neighborhood: Sequence[str]) -> None: helper.add_candidates(new_candidates) def fetch_initial_candidates() -> None: - metadata_parameter = self._filter_to_metadata(metadata_filter).copy() + metadata_parameter = self.filter_to_query(metadata_filter).copy() hits = list( self.astra_env.collection.find( filter=metadata_parameter, diff --git a/libs/astradb/langchain_astradb/utils/mmr.py b/libs/astradb/langchain_astradb/utils/mmr.py index 93babd1..3ff58a0 100644 --- a/libs/astradb/langchain_astradb/utils/mmr.py +++ b/libs/astradb/langchain_astradb/utils/mmr.py @@ -50,7 +50,7 @@ def cosine_similarity(x: Matrix, y: Matrix) -> np.ndarray: else: x = np.array(x, dtype=np.float32) y = np.array(y, dtype=np.float32) - z = 1 - simd.cdist(x, y, metric="cosine") + z = 1 - np.array(simd.cdist(x, y, metric="cosine")) if isinstance(z, float): return np.array([z]) return z diff --git a/libs/astradb/langchain_astradb/vectorstores.py b/libs/astradb/langchain_astradb/vectorstores.py index cfb2cf6..faea7c7 100644 --- a/libs/astradb/langchain_astradb/vectorstores.py +++ b/libs/astradb/langchain_astradb/vectorstores.py @@ -300,7 +300,20 @@ class AstraDBVectorStore(VectorStore): """ # noqa: E501 - def _filter_to_metadata(self, filter_dict: dict[str, Any] | None) -> dict[str, Any]: + def filter_to_query(self, filter_dict: dict[str, Any] | None) -> dict[str, Any]: + """Prepare a query for use on DB based on metadata filter. + + Encode an "abstract" filter clause on metadata into a query filter + condition aware of the collection schema choice. + + Args: + filter_dict: a metadata condition in the form {"field": "value"} + or related. + + Returns: + the corresponding mapping ready for use in queries, + aware of the details of the schema used to encode the document on DB. + """ if filter_dict is None: return {} @@ -1319,7 +1332,7 @@ def _similarity_search_with_score_id_by_sort( ) -> list[tuple[Document, float, str]]: """Run ANN search with a provided sort clause.""" self.astra_env.ensure_db_setup() - metadata_parameter = self._filter_to_metadata(filter) + metadata_parameter = self.filter_to_query(filter) hits_ite = self.astra_env.collection.find( filter=metadata_parameter, projection=self.document_codec.base_projection, @@ -1515,7 +1528,7 @@ async def _asimilarity_search_with_score_id_by_sort( ) -> list[tuple[Document, float, str]]: """Run ANN search with a provided sort clause.""" await self.astra_env.aensure_db_setup() - metadata_parameter = self._filter_to_metadata(filter) + metadata_parameter = self.filter_to_query(filter) return [ (doc, sim, did) async for (doc, sim, did) in ( @@ -1638,7 +1651,7 @@ def max_marginal_relevance_search_by_vector( The list of Documents selected by maximal marginal relevance. """ self.astra_env.ensure_db_setup() - metadata_parameter = self._filter_to_metadata(filter) + metadata_parameter = self.filter_to_query(filter) return self._run_mmr_query_by_sort( sort={"$vector": embedding}, @@ -1677,7 +1690,7 @@ async def amax_marginal_relevance_search_by_vector( The list of Documents selected by maximal marginal relevance. """ await self.astra_env.aensure_db_setup() - metadata_parameter = self._filter_to_metadata(filter) + metadata_parameter = self.filter_to_query(filter) return await self._arun_mmr_query_by_sort( sort={"$vector": embedding}, @@ -1719,7 +1732,7 @@ def max_marginal_relevance_search( # this case goes directly to the "_by_sort" method # (and does its own filter normalization, as it cannot # use the path for the with-embedding mmr querying) - metadata_parameter = self._filter_to_metadata(filter) + metadata_parameter = self.filter_to_query(filter) return self._run_mmr_query_by_sort( sort={"$vectorize": query}, k=k, @@ -1770,7 +1783,7 @@ async def amax_marginal_relevance_search( # this case goes directly to the "_by_sort" method # (and does its own filter normalization, as it cannot # use the path for the with-embedding mmr querying) - metadata_parameter = self._filter_to_metadata(filter) + metadata_parameter = self.filter_to_query(filter) return await self._arun_mmr_query_by_sort( sort={"$vectorize": query}, k=k, @@ -1930,10 +1943,27 @@ def from_documents( """ texts = [d.page_content for d in documents] metadatas = [d.metadata for d in documents] + + if "ids" in kwargs: + warnings.warn( + ( + "Parameter `ids` to AstraDBVectorStore's `from_documents` " + "method is deprecated. Please set the supplied documents' " + "`.id` attribute instead. The id attribute of Document " + "is ignored as long as the `ids` parameter is passed." + ), + DeprecationWarning, + stacklevel=2, + ) + ids = kwargs.pop("ids") + else: + _ids = [doc.id for doc in documents] + ids = _ids if any(the_id is not None for the_id in _ids) else None return cls.from_texts( texts, embedding=embedding, metadatas=metadatas, + ids=ids, **kwargs, ) @@ -1956,9 +1986,26 @@ async def afrom_documents( """ texts = [d.page_content for d in documents] metadatas = [d.metadata for d in documents] + + if "ids" in kwargs: + warnings.warn( + ( + "Parameter `ids` to AstraDBVectorStore's `from_documents` " + "method is deprecated. Please set the supplied documents' " + "`.id` attribute instead. The id attribute of Document " + "is ignored as long as the `ids` parameter is passed." + ), + DeprecationWarning, + stacklevel=2, + ) + ids = kwargs.pop("ids") + else: + _ids = [doc.id for doc in documents] + ids = _ids if any(the_id is not None for the_id in _ids) else None return await cls.afrom_texts( texts, embedding=embedding, metadatas=metadatas, + ids=ids, **kwargs, ) diff --git a/libs/astradb/pyproject.toml b/libs/astradb/pyproject.toml index 9e7e07e..94cdd2e 100644 --- a/libs/astradb/pyproject.toml +++ b/libs/astradb/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-astradb" -version = "0.4.1" +version = "0.4.2" description = "An integration package connecting Astra DB and LangChain" authors = [] readme = "README.md" diff --git a/libs/astradb/tests/integration_tests/test_graphvectorstore.py b/libs/astradb/tests/integration_tests/test_graphvectorstore.py new file mode 100644 index 0000000..570ce9c --- /dev/null +++ b/libs/astradb/tests/integration_tests/test_graphvectorstore.py @@ -0,0 +1,399 @@ +"""Test of Astra DB graph vector store class `AstraDBGraphVectorStore` + +Refer to `test_vectorstores.py` for the requirements to run. +""" + +# ruff: noqa: FIX002 TD002 TD003 + +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Iterable + +import pytest +from astrapy import DataAPIClient +from astrapy.authentication import StaticTokenProvider +from langchain_core.documents import Document +from langchain_core.graph_vectorstores.base import Node +from langchain_core.graph_vectorstores.links import Link, add_links + +from langchain_astradb.graph_vectorstores import AstraDBGraphVectorStore +from langchain_astradb.utils.astradb import SetupMode +from tests.conftest import ParserEmbeddings + +from .conftest import AstraDBCredentials, _has_env_vars + +if TYPE_CHECKING: + from astrapy import Collection + from langchain_core.embeddings import Embeddings + +# Faster testing (no actual collection deletions). Off by default (=full tests) +SKIP_COLLECTION_DELETE = ( + int(os.environ.get("ASTRA_DB_SKIP_COLLECTION_DELETIONS", "0")) != 0 +) + +GVS_NOVECTORIZE_COLLECTION = "lc_gvs_novectorize" +# for testing with autodetect +CUSTOM_CONTENT_KEY = "xcontent" +LONG_TEXT = "This is the textual content field in the doc." + + +@pytest.fixture(scope="session") +def provisioned_novectorize_collection( + astra_db_credentials: AstraDBCredentials, +) -> Iterable[Collection]: + """Provision a general-purpose collection for the no-vectorize tests.""" + client = DataAPIClient(environment=astra_db_credentials["environment"]) + database = client.get_database( + astra_db_credentials["api_endpoint"], + token=StaticTokenProvider(astra_db_credentials["token"]), + namespace=astra_db_credentials["namespace"], + ) + collection = database.create_collection( + GVS_NOVECTORIZE_COLLECTION, + dimension=2, + check_exists=False, + metric="euclidean", + ) + yield collection + + if not SKIP_COLLECTION_DELETE: + collection.drop() + + +@pytest.fixture +def novectorize_empty_collection( + provisioned_novectorize_collection: Collection, +) -> Iterable[Collection]: + provisioned_novectorize_collection.delete_many({}) + yield provisioned_novectorize_collection + + provisioned_novectorize_collection.delete_many({}) + + +@pytest.fixture +def embedding() -> Embeddings: + return ParserEmbeddings(dimension=2) + + +@pytest.fixture +def novectorize_empty_graph_store( + novectorize_empty_collection: Collection, # noqa: ARG001 + astra_db_credentials: AstraDBCredentials, + embedding: Embeddings, +) -> AstraDBGraphVectorStore: + return AstraDBGraphVectorStore( + embedding=embedding, + collection_name=GVS_NOVECTORIZE_COLLECTION, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.OFF, + ) + + +@pytest.fixture +def novectorize_autodetect_full_graph_store( + astra_db_credentials: AstraDBCredentials, + novectorize_empty_collection: Collection, + embedding: Embeddings, + graph_docs: list[Document], +) -> AstraDBGraphVectorStore: + """ + Pre-populate the collection and have (VectorStore)autodetect work on it, + then create and return a GraphVectorStore, additionally filled with + the same (graph-)entries as for `novectorize_full_graph_store`. + """ + novectorize_empty_collection.insert_many( + [ + { + CUSTOM_CONTENT_KEY: LONG_TEXT, + "$vector": [100, 0], + "mds": "S", + "mdi": 100, + }, + { + CUSTOM_CONTENT_KEY: LONG_TEXT, + "$vector": [100, 1], + "mds": "T", + "mdi": 101, + }, + { + CUSTOM_CONTENT_KEY: LONG_TEXT, + "$vector": [100, 2], + "mds": "U", + "mdi": 102, + }, + ] + ) + gstore = AstraDBGraphVectorStore( + embedding=embedding, + collection_name=GVS_NOVECTORIZE_COLLECTION, + link_to_metadata_key="x_link_to_x", + link_from_metadata_key="x_link_from_x", + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + content_field="*", + autodetect_collection=True, + ) + gstore.add_documents(graph_docs) + return gstore + + +def assert_all_flat_docs(collection: Collection) -> None: + """ + Check that after graph-insertions all docs in the store + still obey the underlying autodetected doc schema on DB. + """ + for doc in collection.find({}, projection={"*": True}): + assert all(not isinstance(v, dict) for v in doc.values()) + assert CUSTOM_CONTENT_KEY in doc + assert isinstance(doc["$vector"], list) + + +@pytest.fixture +def graph_docs() -> list[Document]: + """ + This is a pre-populated graph vector store, + with entries placed in a certain way. + + Space of the entries (under Euclidean similarity): + + A0 (*) + .... AL AR <.... + : | : + : | ^ : + v | . v + | : + TR | : BL + T0 --------------x-------------- B0 + TL | : BR + | : + | . + | . + | + FL FR + F0 + + the query point is at (*). + the A are bidirectionally with B + the A are outgoing to T + the A are incoming from F + The links are like: L with L, 0 with 0 and R with R. + """ + + docs_a = [ + Document(page_content="[-1, 9]", metadata={"label": "AL"}), + Document(page_content="[0, 10]", metadata={"label": "A0"}), + Document(page_content="[1, 9]", metadata={"label": "AR"}), + ] + docs_b = [ + Document(page_content="[9, 1]", metadata={"label": "BL"}), + Document(page_content="[10, 0]", metadata={"label": "B0"}), + Document(page_content="[9, -1]", metadata={"label": "BR"}), + ] + docs_f = [ + Document(page_content="[1, -9]", metadata={"label": "BL"}), + Document(page_content="[0, -10]", metadata={"label": "B0"}), + Document(page_content="[-1, -9]", metadata={"label": "BR"}), + ] + docs_t = [ + Document(page_content="[-9, -1]", metadata={"label": "TL"}), + Document(page_content="[-10, 0]", metadata={"label": "T0"}), + Document(page_content="[-9, 1]", metadata={"label": "TR"}), + ] + for doc_a, suffix in zip(docs_a, ["l", "0", "r"]): + add_links(doc_a, Link.bidir(kind="ab_example", tag=f"tag_{suffix}")) + add_links(doc_a, Link.outgoing(kind="at_example", tag=f"tag_{suffix}")) + add_links(doc_a, Link.incoming(kind="af_example", tag=f"tag_{suffix}")) + for doc_b, suffix in zip(docs_b, ["l", "0", "r"]): + add_links(doc_b, Link.bidir(kind="ab_example", tag=f"tag_{suffix}")) + for doc_t, suffix in zip(docs_t, ["l", "0", "r"]): + add_links(doc_t, Link.incoming(kind="at_example", tag=f"tag_{suffix}")) + for doc_f, suffix in zip(docs_f, ["l", "0", "r"]): + add_links(doc_f, Link.outgoing(kind="af_example", tag=f"tag_{suffix}")) + return docs_a + docs_b + docs_f + docs_t + + +@pytest.fixture +def novectorize_full_graph_store( + novectorize_empty_graph_store: AstraDBGraphVectorStore, + graph_docs: list[Document], +) -> AstraDBGraphVectorStore: + novectorize_empty_graph_store.add_documents(graph_docs) + return novectorize_empty_graph_store + + +@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") +class TestAstraDBGraphVectorStore: + @pytest.mark.parametrize( + ("store_name", "is_autodetected"), + [ + ("novectorize_full_graph_store", False), + ("novectorize_autodetect_full_graph_store", True), + ], + ids=["native_store", "autodetected_store"], + ) + def test_gvs_similarity_search( + self, + *, + store_name: str, + is_autodetected: bool, + request: pytest.FixtureRequest, + ) -> None: + """Simple (non-graph) similarity search on a graph vector store.""" + store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) + ss_response = store.similarity_search(query="[2, 10]", k=2) + ss_labels = [doc.metadata["label"] for doc in ss_response] + assert ss_labels == ["AR", "A0"] + ss_by_v_response = store.similarity_search_by_vector(embedding=[2, 10], k=2) + ss_by_v_labels = [doc.metadata["label"] for doc in ss_by_v_response] + assert ss_by_v_labels == ["AR", "A0"] + if is_autodetected: + assert_all_flat_docs(store.vectorstore.astra_env.collection) + + @pytest.mark.parametrize( + ("store_name", "is_autodetected"), + [ + ("novectorize_full_graph_store", False), + ("novectorize_autodetect_full_graph_store", True), + ], + ids=["native_store", "autodetected_store"], + ) + def test_gvs_traversal_search( + self, + *, + store_name: str, + is_autodetected: bool, + request: pytest.FixtureRequest, + ) -> None: + """Graph traversal search on a graph vector store.""" + store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) + ts_response = store.traversal_search(query="[2, 10]", k=2, depth=2) + # this is a set, as some of the internals of trav.search are set-driven + # so ordering is not deterministic: + ts_labels = {doc.metadata["label"] for doc in ts_response} + assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"} + if is_autodetected: + assert_all_flat_docs(store.vectorstore.astra_env.collection) + + @pytest.mark.parametrize( + ("store_name", "is_autodetected"), + [ + ("novectorize_full_graph_store", False), + ("novectorize_autodetect_full_graph_store", True), + ], + ids=["native_store", "autodetected_store"], + ) + def test_gvs_mmr_traversal_search( + self, + *, + store_name: str, + is_autodetected: bool, + request: pytest.FixtureRequest, + ) -> None: + """MMR Graph traversal search on a graph vector store.""" + store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) + mt_response = store.mmr_traversal_search( + query="[2, 10]", + k=2, + depth=2, + fetch_k=1, + adjacent_k=2, + lambda_mult=0.1, + ) + # TODO: can this rightfully be a list (or must it be a set)? + mt_labels = {doc.metadata["label"] for doc in mt_response} + assert mt_labels == {"AR", "BR"} + if is_autodetected: + assert_all_flat_docs(store.vectorstore.astra_env.collection) + + def test_gvs_from_texts( + self, + *, + astra_db_credentials: AstraDBCredentials, + novectorize_empty_collection: Collection, # noqa: ARG002 + embedding: Embeddings, + ) -> None: + g_store = AstraDBGraphVectorStore.from_texts( + texts=["[1, 2]"], + embedding=embedding, + metadatas=[{"md": 1}], + ids=["x_id"], + collection_name=GVS_NOVECTORIZE_COLLECTION, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + content_field=CUSTOM_CONTENT_KEY, + setup_mode=SetupMode.OFF, + ) + hits = g_store.similarity_search("[2, 1]", k=2) + assert len(hits) == 1 + assert hits[0].page_content == "[1, 2]" + assert hits[0].id == "x_id" + # there may be more re:graph structure. + assert hits[0].metadata["md"] == 1 + + def test_gvs_from_documents_containing_ids( + self, + *, + astra_db_credentials: AstraDBCredentials, + novectorize_empty_collection: Collection, # noqa: ARG002 + embedding: Embeddings, + ) -> None: + the_document = Document( + page_content="[1, 2]", + metadata={"md": 1}, + id="x_id", + ) + g_store = AstraDBGraphVectorStore.from_documents( + documents=[the_document], + embedding=embedding, + collection_name=GVS_NOVECTORIZE_COLLECTION, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + content_field=CUSTOM_CONTENT_KEY, + setup_mode=SetupMode.OFF, + ) + hits = g_store.similarity_search("[2, 1]", k=2) + assert len(hits) == 1 + assert hits[0].page_content == "[1, 2]" + assert hits[0].id == "x_id" + # there may be more re:graph structure. + assert hits[0].metadata["md"] == 1 + + def test_gvs_add_nodes( + self, + *, + novectorize_empty_graph_store: AstraDBGraphVectorStore, + ) -> None: + links0 = [ + Link(kind="kA", direction="out", tag="tA"), + Link(kind="kB", direction="bidir", tag="tB"), + ] + links1 = [ + Link(kind="kC", direction="in", tag="tC"), + ] + nodes = [ + Node(id="id0", text="[0, 2]", metadata={"m": 0}, links=links0), + Node(text="[0, 1]", metadata={"m": 1}, links=links1), + ] + novectorize_empty_graph_store.add_nodes(nodes) + hits = novectorize_empty_graph_store.similarity_search_by_vector([0, 3]) + assert len(hits) == 2 + assert hits[0].id == "id0" + assert hits[0].page_content == "[0, 2]" + md0 = hits[0].metadata + assert md0["m"] == 0 + assert any(isinstance(v, list) for k, v in md0.items() if k != "m") + assert hits[1].id != "id0" + assert hits[1].page_content == "[0, 1]" + md1 = hits[1].metadata + assert md1["m"] == 1 + assert any(isinstance(v, list) for k, v in md1.items() if k != "m") diff --git a/libs/astradb/tests/integration_tests/test_vectorstores.py b/libs/astradb/tests/integration_tests/test_vectorstores.py index 6ddbc1c..3690105 100644 --- a/libs/astradb/tests/integration_tests/test_vectorstores.py +++ b/libs/astradb/tests/integration_tests/test_vectorstores.py @@ -394,10 +394,10 @@ async def test_astradb_vectorstore_pre_delete_collection_async( finally: await v_store.adelete_collection() - def test_astradb_vectorstore_from_x_sync( + def test_astradb_vectorstore_from_texts_sync( self, astra_db_credentials: AstraDBCredentials ) -> None: - """from_texts and from_documents methods.""" + """from_texts methods.""" emb = SomeEmbeddings(dimension=2) # prepare empty collection AstraDBVectorStore( @@ -472,8 +472,12 @@ def test_astradb_vectorstore_from_x_sync( else: v_store.clear() - # from_documents - v_store_2 = AstraDBVectorStore.from_documents( + def test_astradb_vectorstore_from_documents_without_ids_sync( + self, astra_db_credentials: AstraDBCredentials + ) -> None: + """from_documents methods.""" + emb = SomeEmbeddings(dimension=2) + v_store = AstraDBVectorStore.from_documents( [ Document(page_content="Hee"), Document(page_content="Hoi"), @@ -485,18 +489,119 @@ def test_astradb_vectorstore_from_x_sync( namespace=astra_db_credentials["namespace"], environment=astra_db_credentials["environment"], ) + + try: + hits = v_store.similarity_search("Hoi", k=1) + assert len(hits) == 1 + assert hits[0].page_content == "Hoi" + finally: + if not SKIP_COLLECTION_DELETE: + v_store.delete_collection() + else: + v_store.clear() + + def test_astradb_vectorstore_from_documents_separate_ids_sync( + self, astra_db_credentials: AstraDBCredentials + ) -> None: + """from_documents methods.""" + emb = SomeEmbeddings(dimension=2) + with pytest.warns(DeprecationWarning) as rec_warnings: + v_store = AstraDBVectorStore.from_documents( + [ + Document(page_content="Hee"), + Document(page_content="Hoi"), + ], + embedding=emb, + ids=["idx0", "idx1"], + collection_name=COLLECTION_NAME_DIM2, + token=astra_db_credentials["token"], + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + ) + f_rec_warnings = [ + wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) + ] + assert len(f_rec_warnings) == 1 + + try: + hits = v_store.similarity_search("Hoi", k=1) + assert len(hits) == 1 + assert hits[0].page_content == "Hoi" + assert hits[0].id == "idx1" + finally: + if not SKIP_COLLECTION_DELETE: + v_store.delete_collection() + else: + v_store.clear() + + def test_astradb_vectorstore_from_documents_containing_ids_sync( + self, astra_db_credentials: AstraDBCredentials + ) -> None: + """from_documents methods.""" + emb = SomeEmbeddings(dimension=2) + v_store = AstraDBVectorStore.from_documents( + [ + Document(page_content="Hee", id="idx0"), + Document(page_content="Hoi", id="idx1"), + ], + embedding=emb, + collection_name=COLLECTION_NAME_DIM2, + token=astra_db_credentials["token"], + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + ) + try: + hits = v_store.similarity_search("Hoi", k=1) + assert len(hits) == 1 + assert hits[0].page_content == "Hoi" + assert hits[0].id == "idx1" + finally: + if not SKIP_COLLECTION_DELETE: + v_store.delete_collection() + else: + v_store.clear() + + def test_astradb_vectorstore_from_documents_pass_ids_twice_sync( + self, astra_db_credentials: AstraDBCredentials + ) -> None: + """from_documents methods.""" + emb = SomeEmbeddings(dimension=2) + with pytest.warns(DeprecationWarning) as rec_warnings: + v_store = AstraDBVectorStore.from_documents( + [ + Document(page_content="Hee"), + Document(page_content="Hoi", id="idy1"), + ], + ids=["idx0", "idx1"], + embedding=emb, + collection_name=COLLECTION_NAME_DIM2, + token=astra_db_credentials["token"], + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + ) + f_rec_warnings = [ + wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) + ] + assert len(f_rec_warnings) == 1 + try: - assert v_store_2.similarity_search("Hoi", k=1)[0].page_content == "Hoi" + hits = v_store.similarity_search("Hoi", k=1) + assert len(hits) == 1 + assert hits[0].page_content == "Hoi" + assert hits[0].id == "idx1" finally: if not SKIP_COLLECTION_DELETE: - v_store_2.delete_collection() + v_store.delete_collection() else: - v_store_2.clear() + v_store.clear() - def test_astradb_vectorstore_from_x_vectorize_sync( + def test_astradb_vectorstore_from_texts_vectorize_sync( self, astra_db_credentials: AstraDBCredentials ) -> None: - """from_texts and from_documents methods with vectorize.""" + """from_texts methods with vectorize.""" AstraDBVectorStore( collection_vector_service_options=OPENAI_VECTORIZE_OPTIONS_HEADER, collection_embedding_api_key=os.environ["OPENAI_API_KEY"], @@ -523,29 +628,42 @@ def test_astradb_vectorstore_from_x_vectorize_sync( finally: v_store.delete_collection() - # from_documents - v_store_2 = AstraDBVectorStore.from_documents( - [ - Document(page_content="Hee"), - Document(page_content="Hoi"), - ], - collection_vector_service_options=OPENAI_VECTORIZE_OPTIONS_HEADER, - collection_embedding_api_key=os.environ["OPENAI_API_KEY"], - collection_name=COLLECTION_NAME_VECTORIZE_OPENAI_HEADER, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) + def test_astradb_vectorstore_from_documents_separate_ids_vectorize_sync( + self, astra_db_credentials: AstraDBCredentials + ) -> None: + """from_documents methods with vectorize.""" + with pytest.warns(DeprecationWarning) as rec_warnings: + v_store = AstraDBVectorStore.from_documents( + [ + Document(page_content="Hee"), + Document(page_content="Hoi"), + ], + ids=["idx0", "idx1"], + collection_vector_service_options=OPENAI_VECTORIZE_OPTIONS_HEADER, + collection_embedding_api_key=os.environ["OPENAI_API_KEY"], + collection_name=COLLECTION_NAME_VECTORIZE_OPENAI_HEADER, + token=astra_db_credentials["token"], + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + ) + f_rec_warnings = [ + wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) + ] + assert len(f_rec_warnings) == 1 + try: - assert v_store_2.similarity_search("Hoi", k=1)[0].page_content == "Hoi" + hits = v_store.similarity_search("Hoi", k=1) + assert len(hits) == 1 + assert hits[0].page_content == "Hoi" + assert hits[0].id == "idx1" finally: - v_store_2.delete_collection() + v_store.delete_collection() - async def test_astradb_vectorstore_from_x_async( + async def test_astradb_vectorstore_from_texts_async( self, astra_db_credentials: AstraDBCredentials ) -> None: - """from_texts and from_documents methods.""" + """from_texts methods.""" emb = SomeEmbeddings(dimension=2) # prepare empty collection await AstraDBVectorStore( @@ -620,8 +738,12 @@ async def test_astradb_vectorstore_from_x_async( else: await v_store.aclear() - # from_documents - v_store_2 = await AstraDBVectorStore.afrom_documents( + async def test_astradb_vectorstore_from_documents_without_ids_async( + self, astra_db_credentials: AstraDBCredentials + ) -> None: + """afrom_documents methods.""" + emb = SomeEmbeddings(dimension=2) + v_store = await AstraDBVectorStore.afrom_documents( [ Document(page_content="Hee"), Document(page_content="Hoi"), @@ -633,20 +755,119 @@ async def test_astradb_vectorstore_from_x_async( namespace=astra_db_credentials["namespace"], environment=astra_db_credentials["environment"], ) + try: - assert (await v_store_2.asimilarity_search("Hoi", k=1))[ - 0 - ].page_content == "Hoi" + hits = await v_store.asimilarity_search("Hoi", k=1) + assert len(hits) == 1 + assert hits[0].page_content == "Hoi" + finally: + if not SKIP_COLLECTION_DELETE: + await v_store.adelete_collection() + else: + await v_store.aclear() + + async def test_astradb_vectorstore_from_documents_separate_ids_async( + self, astra_db_credentials: AstraDBCredentials + ) -> None: + """afrom_documents methods.""" + emb = SomeEmbeddings(dimension=2) + with pytest.warns(DeprecationWarning) as rec_warnings: + v_store = await AstraDBVectorStore.afrom_documents( + [ + Document(page_content="Hee"), + Document(page_content="Hoi"), + ], + embedding=emb, + ids=["idx0", "idx1"], + collection_name=COLLECTION_NAME_DIM2, + token=astra_db_credentials["token"], + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + ) + f_rec_warnings = [ + wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) + ] + assert len(f_rec_warnings) == 1 + + try: + hits = await v_store.asimilarity_search("Hoi", k=1) + assert len(hits) == 1 + assert hits[0].page_content == "Hoi" + assert hits[0].id == "idx1" + finally: + if not SKIP_COLLECTION_DELETE: + await v_store.adelete_collection() + else: + await v_store.aclear() + + async def test_astradb_vectorstore_from_documents_containing_ids_async( + self, astra_db_credentials: AstraDBCredentials + ) -> None: + """from_documents methods.""" + emb = SomeEmbeddings(dimension=2) + v_store = await AstraDBVectorStore.afrom_documents( + [ + Document(page_content="Hee", id="idx0"), + Document(page_content="Hoi", id="idx1"), + ], + embedding=emb, + collection_name=COLLECTION_NAME_DIM2, + token=astra_db_credentials["token"], + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + ) + try: + hits = v_store.similarity_search("Hoi", k=1) + assert len(hits) == 1 + assert hits[0].page_content == "Hoi" + assert hits[0].id == "idx1" finally: if not SKIP_COLLECTION_DELETE: - await v_store_2.adelete_collection() + v_store.delete_collection() + else: + v_store.clear() + + async def test_astradb_vectorstore_from_documents_pass_ids_twice_async( + self, astra_db_credentials: AstraDBCredentials + ) -> None: + """from_documents methods.""" + emb = SomeEmbeddings(dimension=2) + with pytest.warns(DeprecationWarning) as rec_warnings: + v_store = await AstraDBVectorStore.afrom_documents( + [ + Document(page_content="Hee"), + Document(page_content="Hoi", id="idy0"), + ], + ids=["idx0", "idx1"], + embedding=emb, + collection_name=COLLECTION_NAME_DIM2, + token=astra_db_credentials["token"], + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + ) + f_rec_warnings = [ + wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) + ] + assert len(f_rec_warnings) == 1 + + try: + hits = await v_store.asimilarity_search("Hoi", k=1) + assert len(hits) == 1 + assert hits[0].page_content == "Hoi" + assert hits[0].id == "idx1" + finally: + if not SKIP_COLLECTION_DELETE: + v_store.delete_collection() else: - await v_store_2.aclear() + v_store.clear() - async def test_astradb_vectorstore_from_x_vectorize_async( + async def test_astradb_vectorstore_from_texts_vectorize_async( self, astra_db_credentials: AstraDBCredentials ) -> None: - """from_texts and from_documents methods with vectorize.""" + """from_texts methods with vectorize.""" # from_text with vectorize v_store = await AstraDBVectorStore.afrom_texts( texts=["Haa", "Huu"], @@ -665,26 +886,37 @@ async def test_astradb_vectorstore_from_x_vectorize_async( finally: await v_store.adelete_collection() - # from_documents with vectorize - v_store_2 = await AstraDBVectorStore.afrom_documents( - [ - Document(page_content="HeeH"), - Document(page_content="HooH"), - ], - collection_vector_service_options=OPENAI_VECTORIZE_OPTIONS_HEADER, - collection_embedding_api_key=os.environ["OPENAI_API_KEY"], - collection_name=COLLECTION_NAME_VECTORIZE_OPENAI_HEADER, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) + async def test_astradb_vectorstore_from_documents_separate_ids_vectorize_async( + self, astra_db_credentials: AstraDBCredentials + ) -> None: + """afrom_documents methods with vectorize.""" + with pytest.warns(DeprecationWarning) as rec_warnings: + v_store = await AstraDBVectorStore.afrom_documents( + [ + Document(page_content="HeeH"), + Document(page_content="HooH"), + ], + ids=["idx0", "idx1"], + collection_vector_service_options=OPENAI_VECTORIZE_OPTIONS_HEADER, + collection_embedding_api_key=os.environ["OPENAI_API_KEY"], + collection_name=COLLECTION_NAME_VECTORIZE_OPENAI_HEADER, + token=astra_db_credentials["token"], + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + ) + f_rec_warnings = [ + wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) + ] + assert len(f_rec_warnings) == 1 + try: - assert (await v_store_2.asimilarity_search("HeeH", k=1))[ - 0 - ].page_content == "HeeH" + hits = await v_store.asimilarity_search("HeeH", k=1) + assert len(hits) == 1 + assert hits[0].page_content == "HeeH" + assert hits[0].id == "idx0" finally: - await v_store_2.adelete_collection() + await v_store.adelete_collection() @pytest.mark.parametrize( "vector_store", diff --git a/libs/astradb/tests/unit_tests/test_mmr_helper.py b/libs/astradb/tests/unit_tests/test_mmr_helper.py new file mode 100644 index 0000000..02167c5 --- /dev/null +++ b/libs/astradb/tests/unit_tests/test_mmr_helper.py @@ -0,0 +1,67 @@ +from langchain_core.documents import Document + +from langchain_astradb.utils.mmr_traversal import MmrHelper + +IDS = { + "-1", + "-2", + "-3", + "-4", + "-5", + "+1", + "+2", + "+3", + "+4", + "+5", +} + + +class TestMmrHelper: + def test_mmr_helper_functional(self) -> None: + helper = MmrHelper(k=3, query_embedding=[6, 5], lambda_mult=0.5) + + assert len(list(helper.candidate_ids())) == 0 + + helper.add_candidates({"-1": (Document(page_content="-1"), [3, 5])}) + helper.add_candidates({"-2": (Document(page_content="-2"), [3, 5])}) + helper.add_candidates({"-3": (Document(page_content="-3"), [2, 6])}) + helper.add_candidates({"-4": (Document(page_content="-4"), [1, 6])}) + helper.add_candidates({"-5": (Document(page_content="-5"), [0, 6])}) + + assert len(list(helper.candidate_ids())) == 5 + + helper.add_candidates({"+1": (Document(page_content="+1"), [5, 3])}) + helper.add_candidates({"+2": (Document(page_content="+2"), [5, 3])}) + helper.add_candidates({"+3": (Document(page_content="+3"), [6, 2])}) + helper.add_candidates({"+4": (Document(page_content="+4"), [6, 1])}) + helper.add_candidates({"+5": (Document(page_content="+5"), [6, 0])}) + + assert len(list(helper.candidate_ids())) == 10 + + for idx in range(3): + best_id = helper.pop_best() + assert best_id in IDS + assert len(list(helper.candidate_ids())) == 9 - idx + assert best_id not in helper.candidate_ids() + + def test_mmr_helper_max_diversity(self) -> None: + helper = MmrHelper(k=2, query_embedding=[6, 5], lambda_mult=0) + helper.add_candidates({"-1": (Document(page_content="-1"), [3, 5])}) + helper.add_candidates({"-2": (Document(page_content="-2"), [3, 5])}) + helper.add_candidates({"-3": (Document(page_content="-3"), [2, 6])}) + helper.add_candidates({"-4": (Document(page_content="-4"), [1, 6])}) + helper.add_candidates({"-5": (Document(page_content="-5"), [0, 6])}) + + best = {helper.pop_best(), helper.pop_best()} + assert best == {"-1", "-5"} + + def test_mmr_helper_max_similarity(self) -> None: + helper = MmrHelper(k=2, query_embedding=[6, 5], lambda_mult=1) + helper.add_candidates({"-1": (Document(page_content="-1"), [3, 5])}) + helper.add_candidates({"-2": (Document(page_content="-2"), [3, 5])}) + helper.add_candidates({"-3": (Document(page_content="-3"), [2, 6])}) + helper.add_candidates({"-4": (Document(page_content="-4"), [1, 6])}) + helper.add_candidates({"-5": (Document(page_content="-5"), [0, 6])}) + + best = {helper.pop_best(), helper.pop_best()} + assert best == {"-1", "-2"} diff --git a/libs/astradb/tests/unit_tests/test_vs_doc_codecs.py b/libs/astradb/tests/unit_tests/test_vs_doc_codecs.py index 4acb71e..4e5d8c5 100644 --- a/libs/astradb/tests/unit_tests/test_vs_doc_codecs.py +++ b/libs/astradb/tests/unit_tests/test_vs_doc_codecs.py @@ -21,6 +21,8 @@ DOCUMENT_ID = "the_id" LC_DOCUMENT = Document(id=DOCUMENT_ID, page_content=CONTENT, metadata=METADATA) LC_FILTER = {"a0": 0, "$or": [{"b1": 1}, {"b2": 2}]} +ID_FILTER = {"_id": DOCUMENT_ID} +VECTOR_SORT = {"$vector": VECTOR} ASTRA_DEFAULT_DOCUMENT_NOVECTORIZE = { "_id": DOCUMENT_ID, @@ -28,11 +30,15 @@ "metadata": METADATA, "$vector": VECTOR, } -ASTRA_DEFAULT_DOCUMENT_VECTORIZE = { +ASTRA_DEFAULT_DOCUMENT_VECTORIZE: dict[str, Any] = { "_id": DOCUMENT_ID, "$vectorize": CONTENT, "metadata": METADATA, } +ASTRA_DEFAULT_DOCUMENT_VECTORIZE_READ = { + "$vector": VECTOR, + **ASTRA_DEFAULT_DOCUMENT_VECTORIZE, +} ASTRA_DEFAULT_FILTER = { "metadata.a0": 0, "$or": [{"metadata.b1": 1}, {"metadata.b2": 2}], @@ -122,6 +128,28 @@ def test_default_novectorize_filtering(self) -> None: encoded_flt = codec.encode_filter(LC_FILTER) assert encoded_flt == ASTRA_DEFAULT_FILTER + def test_default_novectorize_vector_decoding(self) -> None: + """Test vector-decoding for default, no-vectorize.""" + codec = _DefaultVSDocumentCodec( + content_field="content_x", ignore_invalid_documents=False + ) + assert codec.decode_vector(ASTRA_DEFAULT_DOCUMENT_NOVECTORIZE) == VECTOR + assert codec.decode_vector({}) is None + + def test_default_novectorize_id_encoding(self) -> None: + """Test id-encoding for default, no-vectorize.""" + codec = _DefaultVSDocumentCodec( + content_field="content_x", ignore_invalid_documents=False + ) + assert codec.encode_id(DOCUMENT_ID) == ID_FILTER + + def test_default_novectorize_vectorsort_encoding(self) -> None: + """Test vector-sort-encoding for default, no-vectorize.""" + codec = _DefaultVSDocumentCodec( + content_field="content_x", ignore_invalid_documents=False + ) + assert codec.encode_vector_sort(VECTOR) == VECTOR_SORT + def test_default_vectorize_encoding(self) -> None: """Test encoding for default, vectorize.""" codec = _DefaultVectorizeVSDocumentCodec(ignore_invalid_documents=False) @@ -172,6 +200,22 @@ def test_default_vectorize_filtering(self) -> None: encoded_flt = codec.encode_filter(LC_FILTER) assert encoded_flt == ASTRA_DEFAULT_FILTER + def test_default_vectorize_vector_decoding(self) -> None: + """Test vector-decoding for default, vectorize.""" + codec = _DefaultVectorizeVSDocumentCodec(ignore_invalid_documents=False) + assert codec.decode_vector(ASTRA_DEFAULT_DOCUMENT_VECTORIZE_READ) == VECTOR + assert codec.decode_vector({}) is None + + def test_default_vectorize_id_encoding(self) -> None: + """Test id-encoding for default, vectorize.""" + codec = _DefaultVectorizeVSDocumentCodec(ignore_invalid_documents=False) + assert codec.encode_id(DOCUMENT_ID) == ID_FILTER + + def test_default_vectorize_vectorsort_encoding(self) -> None: + """Test vector-sort-encoding for default, vectorize.""" + codec = _DefaultVectorizeVSDocumentCodec(ignore_invalid_documents=False) + assert codec.encode_vector_sort(VECTOR) == VECTOR_SORT + def test_flat_novectorize_encoding(self) -> None: """Test encoding for flat, no-vectorize.""" codec = _FlatVSDocumentCodec( @@ -234,6 +278,28 @@ def test_flat_novectorize_filtering(self) -> None: encoded_flt = codec.encode_filter(LC_FILTER) assert encoded_flt == ASTRA_FLAT_FILTER + def test_flat_novectorize_vector_decoding(self) -> None: + """Test vector-decoding for flat, no-vectorize.""" + codec = _FlatVSDocumentCodec( + content_field="content_x", ignore_invalid_documents=False + ) + assert codec.decode_vector(ASTRA_FLAT_DOCUMENT_NOVECTORIZE) == VECTOR + assert codec.decode_vector({}) is None + + def test_flat_novectorize_id_encoding(self) -> None: + """Test id-encoding for flat, no-vectorize.""" + codec = _FlatVSDocumentCodec( + content_field="content_x", ignore_invalid_documents=False + ) + assert codec.encode_id(DOCUMENT_ID) == ID_FILTER + + def test_flat_novectorize_vectorsort_encoding(self) -> None: + """Test vector-sort-encoding for flat, no-vectorize.""" + codec = _FlatVSDocumentCodec( + content_field="content_x", ignore_invalid_documents=False + ) + assert codec.encode_vector_sort(VECTOR) == VECTOR_SORT + def test_flat_vectorize_encoding(self) -> None: """Test encoding for flat, vectorize.""" codec = _FlatVectorizeVSDocumentCodec(ignore_invalid_documents=False) @@ -283,3 +349,19 @@ def test_flat_vectorize_filtering(self) -> None: codec = _FlatVectorizeVSDocumentCodec(ignore_invalid_documents=False) encoded_flt = codec.encode_filter(LC_FILTER) assert encoded_flt == ASTRA_FLAT_FILTER + + def test_flat_vectorize_vector_decoding(self) -> None: + """Test vector-decoding for flat, vectorize.""" + codec = _FlatVectorizeVSDocumentCodec(ignore_invalid_documents=False) + assert codec.decode_vector(ASTRA_FLAT_DOCUMENT_VECTORIZE_READ) == VECTOR + assert codec.decode_vector({}) is None + + def test_flat_vectorize_id_encoding(self) -> None: + """Test id-encoding for flat, vectorize.""" + codec = _FlatVectorizeVSDocumentCodec(ignore_invalid_documents=False) + assert codec.encode_id(DOCUMENT_ID) == ID_FILTER + + def test_flat_vectorize_vectorsort_encoding(self) -> None: + """Test vector-sort-encoding for flat, vectorize.""" + codec = _FlatVectorizeVSDocumentCodec(ignore_invalid_documents=False) + assert codec.encode_vector_sort(VECTOR) == VECTOR_SORT