Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AstraDBGraphVectorStore testing #75

Merged
merged 19 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions libs/astradb/langchain_astradb/graph_vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class _Edge:


# NOTE: Conversion to string is necessary
# becasue AstraDB doesn't support matching on arrays of tuples
# because AstraDB doesn't support matching on arrays of tuples
def _tag_to_str(kind: str, tag: str) -> str:
return f"{kind}:{tag}"

Expand Down Expand Up @@ -139,7 +139,17 @@ def from_documents(
) -> AstraDBGraphVectorStore:
"""Return GraphVectorStore initialized from documents and embeddings."""
store = cls(embedding, **kwargs)
store.add_documents(documents, ids=ids)
# `store.add_documents` ends up calling store.add_nodes, which
# discards the kwargs including ids. This is the place to normalize
# the documents' .id and the separate ids into one and the same:
_documents: Iterable[Document]
if ids is not None:
_documents = [document.copy() for document in documents]
for _doc_id, _document in zip(ids, _documents):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be some check that ids and documents have the same size.

Copy link
Collaborator Author

@hemidactylus hemidactylus Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Closing since ids is going away from this method.

Side note for your amusement. This check would be not entirely trivial if one does not want to list(...) the iterables. I had started to implement something like this before removing it all.

But the real caveat is that zip(ite1, ite2) treats the two iterables differently (i.e. it's implemented straigthforwardly). It can consume one item too much from ite1 which goes nowhere. The first snippet here prints "4", the second one gives a StopIteration:

r1 = (i for i in range(4))
r2 = (i for i in range(5))

for p in zip(r1, r2):
   print(p)


next(r2)
r1 = (i for i in range(4))
r2 = (i for i in range(5))

for p in zip(r2, r1):
   print(p)


next(r2)

(almost obvious in hindsight, but still)

_document.id = _doc_id
else:
_documents = documents
store.add_documents(_documents)
return store

@override
Expand Down Expand Up @@ -248,11 +258,21 @@ def visit_targets(d: int, targets: Sequence[Document]) -> None:

return visited_docs

def _filter_to_metadata(self, filter_dict: dict[str, Any] | None) -> dict[str, Any]:
if filter_dict is None:
return {}
def filter_to_query(self, filter_dict: dict[str, Any] | None) -> dict[str, Any]:
"""Prepare a query for use on DB based on metadata filter.

return self.vectorstore.document_codec.encode_filter(filter_dict)
Encode an "abstract" filter clause on metadata into a query filter
condition aware of the collection schema choice.

Args:
filter_dict: a metadata condition in the form {"field": "value"}
or related.

Returns:
the corresponding mapping ready for use in queries,
aware of the details of the schema used to encode the document on DB.
"""
return self.vectorstore.filter_to_query(filter_dict)

def _get_outgoing_tags(
self,
Expand Down Expand Up @@ -318,7 +338,7 @@ def get_adjacent(tags: set[str]) -> Iterable[_Edge]:
for tag in tags:
m_filter = (metadata_filter or {}).copy()
m_filter[self.link_from_metadata_key] = tag
metadata_parameter = self._filter_to_metadata(m_filter)
metadata_parameter = self.filter_to_query(m_filter)

hits = list(
self.astra_env.collection.find(
Expand Down Expand Up @@ -382,7 +402,7 @@ def fetch_neighborhood(neighborhood: Sequence[str]) -> None:
helper.add_candidates(new_candidates)

def fetch_initial_candidates() -> None:
metadata_parameter = self._filter_to_metadata(metadata_filter).copy()
metadata_parameter = self.filter_to_query(metadata_filter).copy()
hits = list(
self.astra_env.collection.find(
filter=metadata_parameter,
Expand Down
2 changes: 1 addition & 1 deletion libs/astradb/langchain_astradb/utils/mmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def cosine_similarity(x: Matrix, y: Matrix) -> np.ndarray:
else:
x = np.array(x, dtype=np.float32)
y = np.array(y, dtype=np.float32)
z = 1 - simd.cdist(x, y, metric="cosine")
z = 1 - np.array(simd.cdist(x, y, metric="cosine"))
if isinstance(z, float):
return np.array([z])
return z
Expand Down
41 changes: 34 additions & 7 deletions libs/astradb/langchain_astradb/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,20 @@ class AstraDBVectorStore(VectorStore):

""" # noqa: E501

def _filter_to_metadata(self, filter_dict: dict[str, Any] | None) -> dict[str, Any]:
def filter_to_query(self, filter_dict: dict[str, Any] | None) -> dict[str, Any]:
"""Prepare a query for use on DB based on metadata filter.

Encode an "abstract" filter clause on metadata into a query filter
condition aware of the collection schema choice.

Args:
filter_dict: a metadata condition in the form {"field": "value"}
or related.

Returns:
the corresponding mapping ready for use in queries,
aware of the details of the schema used to encode the document on DB.
"""
if filter_dict is None:
return {}

Expand Down Expand Up @@ -1319,7 +1332,7 @@ def _similarity_search_with_score_id_by_sort(
) -> list[tuple[Document, float, str]]:
"""Run ANN search with a provided sort clause."""
self.astra_env.ensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter)
metadata_parameter = self.filter_to_query(filter)
hits_ite = self.astra_env.collection.find(
filter=metadata_parameter,
projection=self.document_codec.base_projection,
Expand Down Expand Up @@ -1515,7 +1528,7 @@ async def _asimilarity_search_with_score_id_by_sort(
) -> list[tuple[Document, float, str]]:
"""Run ANN search with a provided sort clause."""
await self.astra_env.aensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter)
metadata_parameter = self.filter_to_query(filter)
return [
(doc, sim, did)
async for (doc, sim, did) in (
Expand Down Expand Up @@ -1638,7 +1651,7 @@ def max_marginal_relevance_search_by_vector(
The list of Documents selected by maximal marginal relevance.
"""
self.astra_env.ensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter)
metadata_parameter = self.filter_to_query(filter)

return self._run_mmr_query_by_sort(
sort={"$vector": embedding},
Expand Down Expand Up @@ -1677,7 +1690,7 @@ async def amax_marginal_relevance_search_by_vector(
The list of Documents selected by maximal marginal relevance.
"""
await self.astra_env.aensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter)
metadata_parameter = self.filter_to_query(filter)

return await self._arun_mmr_query_by_sort(
sort={"$vector": embedding},
Expand Down Expand Up @@ -1719,7 +1732,7 @@ def max_marginal_relevance_search(
# this case goes directly to the "_by_sort" method
# (and does its own filter normalization, as it cannot
# use the path for the with-embedding mmr querying)
metadata_parameter = self._filter_to_metadata(filter)
metadata_parameter = self.filter_to_query(filter)
return self._run_mmr_query_by_sort(
sort={"$vectorize": query},
k=k,
Expand Down Expand Up @@ -1770,7 +1783,7 @@ async def amax_marginal_relevance_search(
# this case goes directly to the "_by_sort" method
# (and does its own filter normalization, as it cannot
# use the path for the with-embedding mmr querying)
metadata_parameter = self._filter_to_metadata(filter)
metadata_parameter = self.filter_to_query(filter)
return await self._arun_mmr_query_by_sort(
sort={"$vectorize": query},
k=k,
Expand Down Expand Up @@ -1930,6 +1943,13 @@ def from_documents(
"""
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
if "ids" not in kwargs:
ids = [doc.id for doc in documents]

# If there's at least one valid ID, we'll assume that IDs
# should be used.
if any(ids):
kwargs["ids"] = ids
return cls.from_texts(
texts,
embedding=embedding,
Expand All @@ -1956,6 +1976,13 @@ async def afrom_documents(
"""
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
if "ids" not in kwargs:
ids = [doc.id for doc in documents]

# If there's at least one valid ID, we'll assume that IDs
# should be used.
if any(ids):
kwargs["ids"] = ids
return await cls.afrom_texts(
texts,
embedding=embedding,
Expand Down
Loading
Loading