Skip to content

Commit

Permalink
Support dicts in lists in metadata (#148)
Browse files Browse the repository at this point in the history
* Support dicts in lists in metadata

This allows cases like:

```
metadata = {
    "entities": [
        {"entity": "Bell", "type": "PEOPLE"},
        ...
    ]
}
```

Note: For stores that perform shredding this works by JSON encoding the
entire item `{"entity": "Bell", "type": "PEOPLE"}` into the key. This
means that equality on the items of `entities` are supported, by digging
into fields won't be.

* lint/fmt

* lint
  • Loading branch information
bjchambers authored Feb 19, 2025
1 parent b99bddc commit 468bde6
Show file tree
Hide file tree
Showing 16 changed files with 172 additions and 63 deletions.
4 changes: 2 additions & 2 deletions data/animals.jsonl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{"id": "aardvark", "text": "the aardvark is a nocturnal mammal known for its burrowing habits and long snout used to sniff out ants.", "metadata": {"type": "mammal", "number_of_legs": 4, "keywords": ["burrowing", "nocturnal", "ants", "savanna"], "habitat": "savanna"}}
{"id": "albatross", "text": "the albatross is a large seabird with the longest wingspan of any bird, allowing it to glide effortlessly over oceans.", "metadata": {"type": "bird", "number_of_legs": 2, "keywords": ["seabird", "wingspan", "ocean"], "habitat": "marine"}}
{"id": "aardvark", "text": "the aardvark is a nocturnal mammal known for its burrowing habits and long snout used to sniff out ants.", "metadata": {"type": "mammal", "number_of_legs": 4, "keywords": ["burrowing", "nocturnal", "ants", "savanna"], "habitat": "savanna", "tags": [{"a": 5, "b": 7}, {"a": 8, "b": 10}]}}
{"id": "albatross", "text": "the albatross is a large seabird with the longest wingspan of any bird, allowing it to glide effortlessly over oceans.", "metadata": {"type": "bird", "number_of_legs": 2, "keywords": ["seabird", "wingspan", "ocean"], "habitat": "marine", "tags": [{"a": 5, "b": 8}, {"a": 8, "b": 10}]}}
{"id": "alligator", "text": "alligators are large reptiles with powerful jaws and are commonly found in freshwater wetlands.", "metadata": {"type": "reptile", "number_of_legs": 4, "keywords": ["reptile", "jaws", "wetlands"], "diet": "carnivorous", "nested": { "a": 5 }}}
{"id": "alpaca", "text": "alpacas are domesticated mammals valued for their soft wool and friendly demeanor.", "metadata": {"type": "mammal", "number_of_legs": 4, "keywords": ["wool", "domesticated", "friendly"], "origin": "south america", "nested": { "a": 5 }}}
{"id": "ant", "text": "ants are social insects that live in colonies and are known for their teamwork and strength.", "metadata": {"type": "insect", "number_of_legs": 6, "keywords": ["social", "colonies", "strength", "pollinator"], "diet": "omnivorous", "nested": { "a": 6 }}}
Expand Down
1 change: 1 addition & 0 deletions packages/graph-retriever/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dependencies = [
"numpy>=1.26.4",
"typing-extensions>=4.12.2",
"pytest>=8.3.4",
"immutabledict>=4.2.1",
]

[project.urls]
Expand Down
14 changes: 7 additions & 7 deletions packages/graph-retriever/src/graph_retriever/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from collections.abc import Iterable, Sequence
from typing import Any

from immutabledict import immutabledict

from graph_retriever.content import Content
from graph_retriever.edges import Edge, IdEdge, MetadataEdge
from graph_retriever.utils.run_in_executor import run_in_executor
Expand Down Expand Up @@ -355,8 +357,8 @@ async def aadjacent(

def _metadata_filter(
self,
edge: Edge,
base_filter: dict[str, Any] | None = None,
edge: Edge | None = None,
) -> dict[str, Any]:
"""
Return a filter for the `base_filter` and incoming edges from `edge`.
Expand All @@ -376,10 +378,8 @@ def _metadata_filter(
:
The metadata dictionary to use for the given filter.
"""
metadata_filter = {**(base_filter or {})}
assert isinstance(edge, MetadataEdge)
if edge is None:
metadata_filter
else:
metadata_filter[edge.incoming_field] = edge.value
return metadata_filter
value = edge.value
if isinstance(value, immutabledict):
value = dict(value)
return {edge.incoming_field: value, **(base_filter or {})}
13 changes: 13 additions & 0 deletions packages/graph-retriever/src/graph_retriever/edges/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from dataclasses import dataclass
from typing import Any, TypeAlias

from immutabledict import immutabledict

from graph_retriever import Content


Expand Down Expand Up @@ -34,6 +36,17 @@ class MetadataEdge(Edge):
The value associated with the key for this edge
"""

def __init__(self, incoming_field: str, value: Any) -> None:
# `self.field = value` and `setattr(self, "field", value)` -- don't work
# because of frozen. we need to call `__setattr__` directly (as the
# default `__init__` would do) to initialize the fields of the frozen
# dataclass.
object.__setattr__(self, "incoming_field", incoming_field)

if isinstance(value, dict):
value = immutabledict(value)
object.__setattr__(self, "value", value)

incoming_field: str
value: Any

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ class AdapterComplianceCase(abc.ABC):

id: str
expected: list[str]

requires_nested: bool = False
requires_dict_in_list: bool = False


@dataclass
Expand Down Expand Up @@ -191,6 +193,50 @@ class AdjacentCase(AdapterComplianceCase):
"horse",
],
),
AdjacentCase(
id="numeric",
query="domesticated hunters",
edges={
MetadataEdge("number_of_legs", 0),
},
k=20, # more than match the filter so we get all
expected=[
"barracuda",
"cobra",
"dolphin",
"eel",
"fish",
"jellyfish",
"manatee",
"narwhal",
],
),
AdjacentCase(
id="two_edges_diff_field",
query="domesticated hunters",
edges={
MetadataEdge("type", "reptile"),
MetadataEdge("number_of_legs", 0),
},
k=20, # more than match the filter so we get all
expected=[
"alligator",
"barracuda",
"chameleon",
"cobra",
"crocodile",
"dolphin",
"eel",
"fish",
"gecko",
"iguana",
"jellyfish",
"komodo dragon",
"lizard",
"manatee",
"narwhal",
],
),
AdjacentCase(
id="one_ids",
query="domesticated hunters",
Expand Down Expand Up @@ -262,6 +308,39 @@ class AdjacentCase(AdapterComplianceCase):
"komodo dragon", # reptile
],
),
AdjacentCase(
id="dict_in_list",
query="domesticated hunters",
edges={
MetadataEdge("tags", {"a": 5, "b": 7}),
},
expected=[
"aardvark",
],
requires_dict_in_list=True,
),
AdjacentCase(
id="dict_in_list_multiple",
query="domesticated hunters",
edges={
MetadataEdge("tags", {"a": 5, "b": 7}),
MetadataEdge("tags", {"a": 5, "b": 8}),
},
expected=[
"aardvark",
"albatross",
],
requires_dict_in_list=True,
),
AdjacentCase(
id="absent_dict",
query="domesticated hunters",
edges={
MetadataEdge("tags", {"a": 5, "b": 10}),
},
expected=[],
requires_dict_in_list=True,
),
AdjacentCase(
id="nested",
query="domesticated hunters",
Expand Down Expand Up @@ -318,6 +397,10 @@ def supports_nested_metadata(self) -> bool:
"""Return whether nested metadata is expected to work."""
return True

def supports_dict_in_list(self) -> bool:
"""Return whether dicts can appear in list fields in metadata."""
return True

def expected(self, method: str, case: AdapterComplianceCase) -> list[str]:
"""
Override to change the expected behavior of a case.
Expand Down Expand Up @@ -346,6 +429,8 @@ def expected(self, method: str, case: AdapterComplianceCase) -> list[str]:
"""
if not self.supports_nested_metadata() and case.requires_nested:
pytest.xfail("nested metadata not supported")
if not self.supports_dict_in_list() and case.requires_dict_in_list:
pytest.xfail("dict-in-list fields is not supported")
return case.expected

@pytest.fixture(params=GET_CASES, ids=lambda c: c.id)
Expand Down
1 change: 1 addition & 0 deletions packages/langchain-graph-retriever/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ classifiers = [
dependencies = [
"backoff>=2.2.1",
"graph-retriever",
"immutabledict>=4.2.1",
"langchain-core>=0.3.29",
"networkx>=3.4.2",
"pydantic>=2.10.4",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from graph_retriever.utils import merge
from graph_retriever.utils.batched import batched
from graph_retriever.utils.top_k import top_k
from immutabledict import immutabledict
from typing_extensions import override

try:
Expand Down Expand Up @@ -107,13 +108,27 @@ def with_user_filters(
) -> dict[str, Any]:
return filter if encoded else codec.encode_filter(filter)

def process_value(v: Any) -> Any:
if isinstance(v, immutabledict):
return dict(v)
else:
return v

for k, v in metadata.items():
for v_batch in batched(v, 100):
batch = list(v_batch)
if len(batch) == 1:
yield (with_user_filters({k: batch[0]}, encoded=False))
batch = [process_value(v) for v in v_batch]
if isinstance(batch[0], dict):
if len(batch) == 1:
yield with_user_filters({k: {"$all": [batch[0]]}}, encoded=False)
else:
yield with_user_filters(
{"$or": [{k: {"$all": [v]}} for v in batch]}, encoded=False
)
else:
yield (with_user_filters({k: {"$in": batch}}, encoded=False))
if len(batch) == 1:
yield (with_user_filters({k: batch[0]}, encoded=False))
else:
yield (with_user_filters({k: {"$in": batch}}, encoded=False))

for id_batch in batched(ids, 100):
ids = list(id_batch)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from graph_retriever import Content
from graph_retriever.adapters import Adapter
from graph_retriever.edges import Edge, MetadataEdge
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.runnables import run_in_executor
Expand Down Expand Up @@ -358,37 +357,6 @@ async def _aget(
**kwargs,
)

def _metadata_filter(
self,
base_filter: dict[str, Any] | None = None,
edge: Edge | None = None,
) -> dict[str, Any]:
"""
Return a filter for the `base_filter` and incoming edges from `edge`.
Parameters
----------
base_filter :
Any base metadata filter that should be used for search.
Generally corresponds to the user specified filters for the entire
traversal. Should be combined with the filters necessary to support
nodes with an *incoming* edge matching `edge`.
edge :
An optional edge which should be added to the filter.
Returns
-------
:
The metadata dictionary to use for the given filter.
"""
metadata_filter = {**(base_filter or {})}
assert isinstance(edge, MetadataEdge)
if edge is None:
metadata_filter
else:
metadata_filter[edge.incoming_field] = edge.value
return metadata_filter


class ShreddedLangchainAdapter(LangchainAdapter[StoreT]):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, vector_store: OpenSearchVectorSearch):
self._id_field = "_id"

def _build_filter(
self, filter: dict[str, str] | None = None
self, filter: dict[str, Any] | None = None
) -> list[dict[str, Any]] | None:
"""
Build a filter query for OpenSearch based on metadata.
Expand All @@ -68,17 +68,24 @@ def _build_filter(
-------
:
Filter query for OpenSearch.
Raises
------
ValueError
If the query is not supported by OpenSearch adapter.
"""
if filter is None:
return None
return [
{
"terms" if isinstance(value, list) else "term": {
f"metadata.{key}.keyword": value
}
}
for key, value in filter.items()
]

filters = []
for key, value in filter.items():
if isinstance(value, list):
filters.append({"terms": {f"metadata.{key}": value}})
elif isinstance(value, dict):
raise ValueError("Open Search doesn't suport dictionary searches.")
else:
filters.append({"term": {f"metadata.{key}": value}})
return filters

@override
def _search(
Expand All @@ -92,9 +99,8 @@ def _search(
# use an efficient_filter to collect results that
# are near the embedding vector until up to 'k'
# documents that match the filter are found.
kwargs["efficient_filter"] = {
"bool": {"must": self._build_filter(filter=filter)}
}
query = {"bool": {"must": self._build_filter(filter=filter)}}
kwargs["efficient_filter"] = query

if k == 0:
return []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,10 @@ def restore_documents(
and value == self.static_value
):
original_key, original_value = split_key
value = json.loads(original_value)
if original_key not in new_doc.metadata:
new_doc.metadata[original_key] = []
new_doc.metadata[original_key].append(original_value)
new_doc.metadata[original_key].append(value)
else:
# Retain non-shredded metadata as is
new_doc.metadata[key] = value
Expand All @@ -137,7 +138,7 @@ def shredded_key(self, key: str, value: Any) -> str:
str
the shredded key
"""
return f"{key}{self.path_delimiter}{value}"
return f"{key}{self.path_delimiter}{json.dumps(value)}"

def shredded_value(self) -> str:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ def astra_config(enabled_stores: set[str]) -> Iterator[_AstraConfig | None]:
assert found, f"Keyspace '{keyspace}' not created"
yield _AstraConfig(token=token, keyspace=keyspace, api_endpoint=api_endpoint)

admin.drop_keyspace(keyspace)
if keyspace != "default_keyspace":
admin.drop_keyspace(keyspace)


class TestAstraAdapter(AdapterComplianceSuite):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def adapter(
)
docs = list(shredder.transform_documents(animal_docs))
store.add_documents(docs)
yield CassandraAdapter(store, shredder, {"keywords"})
yield CassandraAdapter(store, shredder, {"keywords", "tags"})

if session:
session.shutdown()
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def remove_nested_metadata(doc: Document) -> Document:
collection_metadata={"hnsw:space": "cosine"},
)

yield ChromaAdapter(store, shredder, nested_metadata_fields={"keywords"})
yield ChromaAdapter(
store, shredder, nested_metadata_fields={"keywords", "tags"}
)

store.delete_collection()
Loading

0 comments on commit 468bde6

Please sign in to comment.