Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AstraDBGraphVectorStore testing #75

Merged
merged 19 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions libs/astradb/langchain_astradb/graph_vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class _Edge:


# NOTE: Conversion to string is necessary
# becasue AstraDB doesn't support matching on arrays of tuples
# because AstraDB doesn't support matching on arrays of tuples
def _tag_to_str(kind: str, tag: str) -> str:
return f"{kind}:{tag}"

Expand Down Expand Up @@ -134,12 +134,11 @@ def from_documents(
cls: type[AstraDBGraphVectorStore],
documents: Iterable[Document],
embedding: Embeddings,
ids: Iterable[str] | None = None,
**kwargs: Any,
) -> AstraDBGraphVectorStore:
"""Return GraphVectorStore initialized from documents and embeddings."""
store = cls(embedding, **kwargs)
store.add_documents(documents, ids=ids)
store.add_documents(documents)
return store

@override
Expand Down Expand Up @@ -248,11 +247,21 @@ def visit_targets(d: int, targets: Sequence[Document]) -> None:

return visited_docs

def _filter_to_metadata(self, filter_dict: dict[str, Any] | None) -> dict[str, Any]:
if filter_dict is None:
return {}
def filter_to_query(self, filter_dict: dict[str, Any] | None) -> dict[str, Any]:
"""Prepare a query for use on DB based on metadata filter.

Encode an "abstract" filter clause on metadata into a query filter
condition aware of the collection schema choice.

return self.vectorstore.document_codec.encode_filter(filter_dict)
Args:
filter_dict: a metadata condition in the form {"field": "value"}
or related.

Returns:
the corresponding mapping ready for use in queries,
aware of the details of the schema used to encode the document on DB.
"""
return self.vectorstore.filter_to_query(filter_dict)

def _get_outgoing_tags(
self,
Expand Down Expand Up @@ -318,7 +327,7 @@ def get_adjacent(tags: set[str]) -> Iterable[_Edge]:
for tag in tags:
m_filter = (metadata_filter or {}).copy()
m_filter[self.link_from_metadata_key] = tag
metadata_parameter = self._filter_to_metadata(m_filter)
metadata_parameter = self.filter_to_query(m_filter)

hits = list(
self.astra_env.collection.find(
Expand Down Expand Up @@ -382,7 +391,7 @@ def fetch_neighborhood(neighborhood: Sequence[str]) -> None:
helper.add_candidates(new_candidates)

def fetch_initial_candidates() -> None:
metadata_parameter = self._filter_to_metadata(metadata_filter).copy()
metadata_parameter = self.filter_to_query(metadata_filter).copy()
hits = list(
self.astra_env.collection.find(
filter=metadata_parameter,
Expand Down
2 changes: 1 addition & 1 deletion libs/astradb/langchain_astradb/utils/mmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def cosine_similarity(x: Matrix, y: Matrix) -> np.ndarray:
else:
x = np.array(x, dtype=np.float32)
y = np.array(y, dtype=np.float32)
z = 1 - simd.cdist(x, y, metric="cosine")
z = 1 - np.array(simd.cdist(x, y, metric="cosine"))
if isinstance(z, float):
return np.array([z])
return z
Expand Down
61 changes: 54 additions & 7 deletions libs/astradb/langchain_astradb/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,20 @@ class AstraDBVectorStore(VectorStore):

""" # noqa: E501

def _filter_to_metadata(self, filter_dict: dict[str, Any] | None) -> dict[str, Any]:
def filter_to_query(self, filter_dict: dict[str, Any] | None) -> dict[str, Any]:
"""Prepare a query for use on DB based on metadata filter.

Encode an "abstract" filter clause on metadata into a query filter
condition aware of the collection schema choice.

Args:
filter_dict: a metadata condition in the form {"field": "value"}
or related.

Returns:
the corresponding mapping ready for use in queries,
aware of the details of the schema used to encode the document on DB.
"""
if filter_dict is None:
return {}

Expand Down Expand Up @@ -1319,7 +1332,7 @@ def _similarity_search_with_score_id_by_sort(
) -> list[tuple[Document, float, str]]:
"""Run ANN search with a provided sort clause."""
self.astra_env.ensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter)
metadata_parameter = self.filter_to_query(filter)
hits_ite = self.astra_env.collection.find(
filter=metadata_parameter,
projection=self.document_codec.base_projection,
Expand Down Expand Up @@ -1515,7 +1528,7 @@ async def _asimilarity_search_with_score_id_by_sort(
) -> list[tuple[Document, float, str]]:
"""Run ANN search with a provided sort clause."""
await self.astra_env.aensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter)
metadata_parameter = self.filter_to_query(filter)
return [
(doc, sim, did)
async for (doc, sim, did) in (
Expand Down Expand Up @@ -1638,7 +1651,7 @@ def max_marginal_relevance_search_by_vector(
The list of Documents selected by maximal marginal relevance.
"""
self.astra_env.ensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter)
metadata_parameter = self.filter_to_query(filter)

return self._run_mmr_query_by_sort(
sort={"$vector": embedding},
Expand Down Expand Up @@ -1677,7 +1690,7 @@ async def amax_marginal_relevance_search_by_vector(
The list of Documents selected by maximal marginal relevance.
"""
await self.astra_env.aensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter)
metadata_parameter = self.filter_to_query(filter)

return await self._arun_mmr_query_by_sort(
sort={"$vector": embedding},
Expand Down Expand Up @@ -1719,7 +1732,7 @@ def max_marginal_relevance_search(
# this case goes directly to the "_by_sort" method
# (and does its own filter normalization, as it cannot
# use the path for the with-embedding mmr querying)
metadata_parameter = self._filter_to_metadata(filter)
metadata_parameter = self.filter_to_query(filter)
return self._run_mmr_query_by_sort(
sort={"$vectorize": query},
k=k,
Expand Down Expand Up @@ -1770,7 +1783,7 @@ async def amax_marginal_relevance_search(
# this case goes directly to the "_by_sort" method
# (and does its own filter normalization, as it cannot
# use the path for the with-embedding mmr querying)
metadata_parameter = self._filter_to_metadata(filter)
metadata_parameter = self.filter_to_query(filter)
return await self._arun_mmr_query_by_sort(
sort={"$vectorize": query},
k=k,
Expand Down Expand Up @@ -1930,10 +1943,27 @@ def from_documents(
"""
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]

if "ids" in kwargs:
warnings.warn(
(
"Parameter `ids` to AstraDBVectorStore's `from_documents` "
"method is deprecated. Please set the supplied documents' "
"`.id` attribute instead. The id attribute of Document "
"is ignored as long as the `ids` parameter is passed."
),
DeprecationWarning,
stacklevel=2,
)
ids = kwargs.pop("ids")
else:
_ids = [doc.id for doc in documents]
ids = _ids if any(the_id is not None for the_id in _ids) else None
return cls.from_texts(
texts,
embedding=embedding,
metadatas=metadatas,
ids=ids,
**kwargs,
)

Expand All @@ -1956,9 +1986,26 @@ async def afrom_documents(
"""
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]

if "ids" in kwargs:
warnings.warn(
(
"Parameter `ids` to AstraDBVectorStore's `from_documents` "
"method is deprecated. Please set the supplied documents' "
"`.id` attribute instead. The id attribute of Document "
"is ignored as long as the `ids` parameter is passed."
),
DeprecationWarning,
stacklevel=2,
)
ids = kwargs.pop("ids")
else:
_ids = [doc.id for doc in documents]
ids = _ids if any(the_id is not None for the_id in _ids) else None
return await cls.afrom_texts(
texts,
embedding=embedding,
metadatas=metadatas,
ids=ids,
**kwargs,
)
Loading
Loading