diff --git a/data/animals.jsonl b/data/animals.jsonl index e22a006f..dc27aa37 100644 --- a/data/animals.jsonl +++ b/data/animals.jsonl @@ -1,9 +1,9 @@ {"id": "aardvark", "text": "the aardvark is a nocturnal mammal known for its burrowing habits and long snout used to sniff out ants.", "metadata": {"type": "mammal", "number_of_legs": 4, "keywords": ["burrowing", "nocturnal", "ants", "savanna"], "habitat": "savanna"}} {"id": "albatross", "text": "the albatross is a large seabird with the longest wingspan of any bird, allowing it to glide effortlessly over oceans.", "metadata": {"type": "bird", "number_of_legs": 2, "keywords": ["seabird", "wingspan", "ocean"], "habitat": "marine"}} -{"id": "alligator", "text": "alligators are large reptiles with powerful jaws and are commonly found in freshwater wetlands.", "metadata": {"type": "reptile", "number_of_legs": 4, "keywords": ["reptile", "jaws", "wetlands"], "diet": "carnivorous"}} -{"id": "alpaca", "text": "alpacas are domesticated mammals valued for their soft wool and friendly demeanor.", "metadata": {"type": "mammal", "number_of_legs": 4, "keywords": ["wool", "domesticated", "friendly"], "origin": "south america"}} -{"id": "ant", "text": "ants are social insects that live in colonies and are known for their teamwork and strength.", "metadata": {"type": "insect", "number_of_legs": 6, "keywords": ["social", "colonies", "strength", "pollinator"], "diet": "omnivorous"}} -{"id": "anteater", "text": "anteaters use their long tongues to eat thousands of ants and termites each day.", "metadata": {"type": "mammal", "number_of_legs": 4, "keywords": ["ants", "tongue", "termites"], "diet": "insectivore"}} +{"id": "alligator", "text": "alligators are large reptiles with powerful jaws and are commonly found in freshwater wetlands.", "metadata": {"type": "reptile", "number_of_legs": 4, "keywords": ["reptile", "jaws", "wetlands"], "diet": "carnivorous", "nested": { "a": 5 }}} +{"id": "alpaca", "text": "alpacas are domesticated mammals valued for their soft wool and friendly demeanor.", "metadata": {"type": "mammal", "number_of_legs": 4, "keywords": ["wool", "domesticated", "friendly"], "origin": "south america", "nested": { "a": 5 }}} +{"id": "ant", "text": "ants are social insects that live in colonies and are known for their teamwork and strength.", "metadata": {"type": "insect", "number_of_legs": 6, "keywords": ["social", "colonies", "strength", "pollinator"], "diet": "omnivorous", "nested": { "a": 6 }}} +{"id": "anteater", "text": "anteaters use their long tongues to eat thousands of ants and termites each day.", "metadata": {"type": "mammal", "number_of_legs": 4, "keywords": ["ants", "tongue", "termites"], "diet": "insectivore", "nested": { "b": 5 }}} {"id": "antelope", "text": "antelopes are graceful herbivorous mammals that are often prey for large predators in the wild.", "metadata": {"type": "mammal", "number_of_legs": 4, "keywords": ["graceful", "herbivore", "prey"], "habitat": "grasslands"}} {"id": "armadillo", "text": "armadillos have hard, protective shells and are known for their ability to roll into a ball.", "metadata": {"type": "mammal", "number_of_legs": 4, "keywords": ["protective", "shell", "rolling"], "diet": "insectivore"}} {"id": "baboon", "text": "baboons are highly social primates with complex group dynamics and strong bonds.", "metadata": {"type": "mammal", "number_of_legs": 4, "keywords": ["social", "primates", "group"], "diet": "omnivorous"}} diff --git a/packages/graph-retriever/src/graph_retriever/adapters/in_memory.py b/packages/graph-retriever/src/graph_retriever/adapters/in_memory.py index 76240058..de8cb858 100644 --- a/packages/graph-retriever/src/graph_retriever/adapters/in_memory.py +++ b/packages/graph-retriever/src/graph_retriever/adapters/in_memory.py @@ -113,7 +113,11 @@ def _matches(self, filter: dict[str, Any] | None, content: Content) -> bool: return True for key, filter_value in filter.items(): - content_value = content.metadata.get(key, SENTINEL) + content_value = content.metadata + for key_part in key.split("."): + content_value = content_value.get(key_part, SENTINEL) + if content_value is SENTINEL: + break if not self._value_matches(filter_value, content_value): return False return True diff --git a/packages/graph-retriever/src/graph_retriever/edges/metadata.py b/packages/graph-retriever/src/graph_retriever/edges/metadata.py index 7d31d08c..9ec3998e 100644 --- a/packages/graph-retriever/src/graph_retriever/edges/metadata.py +++ b/packages/graph-retriever/src/graph_retriever/edges/metadata.py @@ -41,6 +41,15 @@ class Id: """ +def _nested_get(metadata: dict[str, Any], key: str) -> Any: + value = metadata + for key_part in key.split("."): + value = value.get(key_part, SENTINEL) + if value is SENTINEL: + break + return value + + class MetadataEdgeFunction: """ Helper for extracting and encoding edges in metadata. @@ -116,7 +125,7 @@ def mk_edge(v) -> Edge: if isinstance(source_key, Id): edges.add(mk_edge(id)) else: - value = metadata.get(source_key, SENTINEL) + value = _nested_get(metadata, source_key) if isinstance(value, BASIC_TYPES): edges.add(mk_edge(value)) elif isinstance(value, Iterable): diff --git a/packages/graph-retriever/src/graph_retriever/testing/adapter_tests.py b/packages/graph-retriever/src/graph_retriever/testing/adapter_tests.py index 7440992a..bd3080a4 100644 --- a/packages/graph-retriever/src/graph_retriever/testing/adapter_tests.py +++ b/packages/graph-retriever/src/graph_retriever/testing/adapter_tests.py @@ -40,7 +40,7 @@ def assert_ids_any_order( assert set(result_ids) == set(expected), "should contain exactly expected IDs" -@dataclass +@dataclass(kw_only=True) class AdapterComplianceCase(abc.ABC): """ Base dataclass for test cases. @@ -56,6 +56,7 @@ class AdapterComplianceCase(abc.ABC): id: str expected: list[str] + requires_nested: bool = False @dataclass @@ -261,6 +262,46 @@ class AdjacentCase(AdapterComplianceCase): "komodo dragon", # reptile ], ), + AdjacentCase( + id="nested", + query="domesticated hunters", + edges={ + MetadataEdge("nested.a", 5), + }, + expected=[ + "alligator", + "alpaca", + ], + requires_nested=True, + ), + AdjacentCase( + id="nested_same_field", + query="domesticated hunters", + edges={ + MetadataEdge("nested.a", 5), + MetadataEdge("nested.a", 6), + }, + expected=[ + "alligator", + "alpaca", + "ant", + ], + requires_nested=True, + ), + AdjacentCase( + id="nested_diff_field", + query="domesticated hunters", + edges={ + MetadataEdge("nested.a", 5), + MetadataEdge("nested.b", 5), + }, + expected=[ + "alligator", + "alpaca", + "anteater", + ], + requires_nested=True, + ), ] @@ -273,6 +314,10 @@ class AdapterComplianceSuite(abc.ABC): loaded. """ + def supports_nested_metadata(self) -> bool: + """Return whether nested metadata is expected to work.""" + return True + def expected(self, method: str, case: AdapterComplianceCase) -> list[str]: """ Override to change the expected behavior of a case. @@ -299,6 +344,8 @@ def expected(self, method: str, case: AdapterComplianceCase) -> list[str]: : The expected animals. """ + if not self.supports_nested_metadata() and case.requires_nested: + pytest.xfail("nested metadata not supported") return case.expected @pytest.fixture(params=GET_CASES, ids=lambda c: c.id) diff --git a/packages/graph-retriever/tests/edges/test_metadata.py b/packages/graph-retriever/tests/edges/test_metadata.py index fdb51835..b6ffe849 100644 --- a/packages/graph-retriever/tests/edges/test_metadata.py +++ b/packages/graph-retriever/tests/edges/test_metadata.py @@ -38,6 +38,14 @@ def test_edge_function(): ) +def test_nested_edge(): + edge_function = MetadataEdgeFunction([("a.b", "b.c")]) + assert edge_function(mk_node({"a": {"b": 5}, "b": {"c": 7}})) == Edges( + {MetadataEdge("b.c", 7)}, + {MetadataEdge("b.c", 5)}, + ) + + def test_link_to_id(): edge_function = MetadataEdgeFunction([("mentions", Id())]) result = edge_function(mk_node({"mentions": ["a", "c"]})) diff --git a/packages/langchain-graph-retriever/src/langchain_graph_retriever/adapters/in_memory.py b/packages/langchain-graph-retriever/src/langchain_graph_retriever/adapters/in_memory.py index 2a4d9a03..f7af1d4b 100644 --- a/packages/langchain-graph-retriever/src/langchain_graph_retriever/adapters/in_memory.py +++ b/packages/langchain-graph-retriever/src/langchain_graph_retriever/adapters/in_memory.py @@ -99,7 +99,11 @@ def _equals_or_contains( True if and only if `metadata[key] == value` or `metadata[key]` is a list containing `value`. """ - actual = metadata.get(key, SENTINEL) + actual = metadata + for key_part in key.split("."): + actual = actual.get(key_part, SENTINEL) + if actual is SENTINEL: + break if actual == value: return True diff --git a/packages/langchain-graph-retriever/tests/adapters/test_cassandra.py b/packages/langchain-graph-retriever/tests/adapters/test_cassandra.py index f211b9bb..47b886f0 100644 --- a/packages/langchain-graph-retriever/tests/adapters/test_cassandra.py +++ b/packages/langchain-graph-retriever/tests/adapters/test_cassandra.py @@ -48,6 +48,9 @@ def cluster( class TestCassandraAdapter(AdapterComplianceSuite): + def supports_nested_metadata(self) -> bool: + return False + @pytest.fixture(scope="class") def adapter( self, diff --git a/packages/langchain-graph-retriever/tests/adapters/test_chroma.py b/packages/langchain-graph-retriever/tests/adapters/test_chroma.py index 9363203b..7a7affe3 100644 --- a/packages/langchain-graph-retriever/tests/adapters/test_chroma.py +++ b/packages/langchain-graph-retriever/tests/adapters/test_chroma.py @@ -9,6 +9,9 @@ class TestChroma(AdapterComplianceSuite): + def supports_nested_metadata(self) -> bool: + return False + @pytest.fixture(scope="class") def adapter( self, @@ -25,6 +28,16 @@ def adapter( ) shredder = ShreddingTransformer() + + # Chroma doesn't even support *writing* nested data currently, so we + # filter it out. + def remove_nested_metadata(doc: Document) -> Document: + metadata = doc.metadata.copy() + metadata.pop("nested", None) + return Document(id=doc.id, page_content=doc.page_content, metadata=metadata) + + animal_docs = [remove_nested_metadata(doc) for doc in animal_docs] + docs = list(shredder.transform_documents(animal_docs)) store = Chroma.from_documents( docs, diff --git a/packages/langchain-graph-retriever/tests/adapters/test_open_search.py b/packages/langchain-graph-retriever/tests/adapters/test_open_search.py index 40a1fa7a..01565c10 100644 --- a/packages/langchain-graph-retriever/tests/adapters/test_open_search.py +++ b/packages/langchain-graph-retriever/tests/adapters/test_open_search.py @@ -8,6 +8,9 @@ class TestOpenSearch(AdapterComplianceSuite): + def supports_nested_metadata(self) -> bool: + return False + @pytest.fixture(scope="class") def adapter( self,