From c930f04e69047c41f29e3a801bd4490ccc3837e0 Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Thu, 20 Feb 2025 11:01:22 +0000 Subject: [PATCH 01/11] Replaced value_sanitize with _value_sanitize from neo4j-graphrag --- .../langchain_neo4j/graphs/neo4j_graph.py | 51 +--------------- .../unit_tests/graphs/test_neo4j_graph.py | 58 ------------------- 2 files changed, 3 insertions(+), 106 deletions(-) diff --git a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py index 2eda35b..c0bc6b2 100644 --- a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py +++ b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional, Type from langchain_core.utils import get_from_dict_or_env +from neo4j_graphrag.schema import _value_sanitize from langchain_neo4j.graphs.graph_document import GraphDocument from langchain_neo4j.graphs.graph_store import GraphStore @@ -65,52 +66,6 @@ def clean_string_values(text: str) -> str: return text.replace("\n", " ").replace("\r", " ") -def value_sanitize(d: Any) -> Any: - """Sanitize the input dictionary or list. - - Sanitizes the input by removing embedding-like values, - lists with more than 128 elements, that are mostly irrelevant for - generating answers in a LLM context. These properties, if left in - results, can occupy significant context space and detract from - the LLM's performance by introducing unnecessary noise and cost. - - Args: - d (Any): The input dictionary or list to sanitize. - - Returns: - Any: The sanitized dictionary or list. - """ - if isinstance(d, dict): - new_dict = {} - for key, value in d.items(): - if isinstance(value, dict): - sanitized_value = value_sanitize(value) - if ( - sanitized_value is not None - ): # Check if the sanitized value is not None - new_dict[key] = sanitized_value - elif isinstance(value, list): - if len(value) < LIST_LIMIT: - sanitized_value = value_sanitize(value) - if ( - sanitized_value is not None - ): # Check if the sanitized value is not None - new_dict[key] = sanitized_value - # Do not include the key if the list is oversized - else: - new_dict[key] = value - return new_dict - elif isinstance(d, list): - if len(d) < LIST_LIMIT: - return [ - value_sanitize(item) for item in d if value_sanitize(item) is not None - ] - else: - return None - else: - return d - - def _get_node_import_query(baseEntityLabel: bool, include_source: bool) -> str: if baseEntityLabel: return ( @@ -460,7 +415,7 @@ def query( ) json_data = [r.data() for r in data] if self.sanitize: - json_data = [value_sanitize(el) for el in json_data] + json_data = [_value_sanitize(el) for el in json_data] return json_data except Neo4jError as e: if not ( @@ -490,7 +445,7 @@ def query( result = session.run(Query(text=query, timeout=self.timeout), params) json_data = [r.data() for r in result] if self.sanitize: - json_data = [value_sanitize(el) for el in json_data] + json_data = [_value_sanitize(el) for el in json_data] return json_data def refresh_schema(self) -> None: diff --git a/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py b/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py index 9ee2396..85583e2 100644 --- a/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py +++ b/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py @@ -10,7 +10,6 @@ LIST_LIMIT, Neo4jGraph, _format_schema, - value_sanitize, ) @@ -25,63 +24,6 @@ def mock_neo4j_driver() -> Generator[MagicMock, None, None]: yield mock_driver_instance -@pytest.mark.parametrize( - "description, input_value, expected_output", - [ - ( - "Small list", - {"key1": "value1", "small_list": list(range(15))}, - {"key1": "value1", "small_list": list(range(15))}, - ), - ( - "Oversized list", - {"key1": "value1", "oversized_list": list(range(LIST_LIMIT + 1))}, - {"key1": "value1"}, - ), - ( - "Nested oversized list", - {"key1": "value1", "oversized_list": {"key": list(range(150))}}, - {"key1": "value1", "oversized_list": {}}, - ), - ( - "Dict in list", - { - "key1": "value1", - "oversized_list": [1, 2, {"key": list(range(LIST_LIMIT + 1))}], - }, - {"key1": "value1", "oversized_list": [1, 2, {}]}, - ), - ( - "Dict in nested list", - { - "key1": "value1", - "deeply_nested_lists": [ - [[[{"final_nested_key": list(range(LIST_LIMIT + 1))}]]] - ], - }, - {"key1": "value1", "deeply_nested_lists": [[[[{}]]]]}, - ), - ( - "Bare oversized list", - list(range(LIST_LIMIT + 1)), - None, - ), - ( - "None value", - None, - None, - ), - ], -) -def test_value_sanitize( - description: str, input_value: Dict[str, Any], expected_output: Any -) -> None: - """Test the value_sanitize function.""" - assert ( - value_sanitize(input_value) == expected_output - ), f"Failed test case: {description}" - - def test_driver_state_management(mock_neo4j_driver: MagicMock) -> None: """Comprehensive test for driver state management.""" # Create graph instance From fb6a65da884e5fd7c2f03d4ba736895ed943be09 Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Thu, 20 Feb 2025 11:14:58 +0000 Subject: [PATCH 02/11] Replaced node_properties_query, rel_properties_query, and rel_query with NODE_PROPERTIES_QUERY, REL_PROPERTIES_QUERY, and REL_QUERY from neo4j-graphrag --- .../langchain_neo4j/graphs/neo4j_graph.py | 43 +++++-------------- .../integration_tests/graphs/test_neo4j.py | 16 +++---- .../unit_tests/graphs/test_neo4j_graph.py | 2 +- 3 files changed, 17 insertions(+), 44 deletions(-) diff --git a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py index c0bc6b2..2c98f36 100644 --- a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py +++ b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py @@ -2,7 +2,12 @@ from typing import Any, Dict, List, Optional, Type from langchain_core.utils import get_from_dict_or_env -from neo4j_graphrag.schema import _value_sanitize +from neo4j_graphrag.schema import ( + NODE_PROPERTIES_QUERY, + REL_PROPERTIES_QUERY, + REL_QUERY, + _value_sanitize, +) from langchain_neo4j.graphs.graph_document import GraphDocument from langchain_neo4j.graphs.graph_store import GraphStore @@ -15,34 +20,6 @@ # Threshold for returning all available prop values in graph schema DISTINCT_VALUE_LIMIT = 10 -node_properties_query = """ -CALL apoc.meta.data() -YIELD label, other, elementType, type, property -WHERE NOT type = "RELATIONSHIP" AND elementType = "node" - AND NOT label IN $EXCLUDED_LABELS -WITH label AS nodeLabels, collect({property:property, type:type}) AS properties -RETURN {labels: nodeLabels, properties: properties} AS output - -""" - -rel_properties_query = """ -CALL apoc.meta.data() -YIELD label, other, elementType, type, property -WHERE NOT type = "RELATIONSHIP" AND elementType = "relationship" - AND NOT label in $EXCLUDED_LABELS -WITH label AS nodeLabels, collect({property:property, type:type}) AS properties -RETURN {type: nodeLabels, properties: properties} AS output -""" - -rel_query = """ -CALL apoc.meta.data() -YIELD label, other, elementType, type, property -WHERE type = "RELATIONSHIP" AND elementType = "node" -UNWIND other AS other_node -WITH * WHERE NOT label IN $EXCLUDED_LABELS - AND NOT other_node IN $EXCLUDED_LABELS -RETURN {start: label, type: property, end: toString(other_node)} AS output -""" include_docs_query = ( "MERGE (d:Document {id:$document.metadata.id}) " @@ -461,20 +438,20 @@ def refresh_schema(self) -> None: node_properties = [ el["output"] for el in self.query( - node_properties_query, + NODE_PROPERTIES_QUERY, params={"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]}, ) ] rel_properties = [ el["output"] for el in self.query( - rel_properties_query, params={"EXCLUDED_LABELS": EXCLUDED_RELS} + REL_PROPERTIES_QUERY, params={"EXCLUDED_LABELS": EXCLUDED_RELS} ) ] relationships = [ el["output"] for el in self.query( - rel_query, + REL_QUERY, params={"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]}, ) ] @@ -494,7 +471,7 @@ def refresh_schema(self) -> None: index = [] self.structured_schema = { - "node_props": {el["labels"]: el["properties"] for el in node_properties}, + "node_props": {el["label"]: el["properties"] for el in node_properties}, "rel_props": {el["type"]: el["properties"] for el in rel_properties}, "relationships": relationships, "metadata": {"constraint": constraint, "index": index}, diff --git a/libs/neo4j/tests/integration_tests/graphs/test_neo4j.py b/libs/neo4j/tests/integration_tests/graphs/test_neo4j.py index bd47454..0521f81 100644 --- a/libs/neo4j/tests/integration_tests/graphs/test_neo4j.py +++ b/libs/neo4j/tests/integration_tests/graphs/test_neo4j.py @@ -3,15 +3,11 @@ import pytest from langchain_core.documents import Document +from neo4j_graphrag.schema import NODE_PROPERTIES_QUERY, REL_PROPERTIES_QUERY, REL_QUERY from langchain_neo4j import Neo4jGraph from langchain_neo4j.graphs.graph_document import GraphDocument, Node, Relationship -from langchain_neo4j.graphs.neo4j_graph import ( - BASE_ENTITY_LABEL, - node_properties_query, - rel_properties_query, - rel_query, -) +from langchain_neo4j.graphs.neo4j_graph import BASE_ENTITY_LABEL test_data = [ GraphDocument( @@ -73,20 +69,20 @@ def test_cypher_return_correct_schema() -> None: graph.refresh_schema() node_properties = graph.query( - node_properties_query, params={"EXCLUDED_LABELS": [BASE_ENTITY_LABEL]} + NODE_PROPERTIES_QUERY, params={"EXCLUDED_LABELS": [BASE_ENTITY_LABEL]} ) relationships_properties = graph.query( - rel_properties_query, params={"EXCLUDED_LABELS": [BASE_ENTITY_LABEL]} + REL_PROPERTIES_QUERY, params={"EXCLUDED_LABELS": [BASE_ENTITY_LABEL]} ) relationships = graph.query( - rel_query, params={"EXCLUDED_LABELS": [BASE_ENTITY_LABEL]} + REL_QUERY, params={"EXCLUDED_LABELS": [BASE_ENTITY_LABEL]} ) expected_node_properties = [ { "output": { "properties": [{"property": "property_a", "type": "STRING"}], - "labels": "LabelA", + "label": "LabelA", } } ] diff --git a/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py b/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py index 85583e2..4e4e0f2 100644 --- a/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py +++ b/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py @@ -239,7 +239,7 @@ def test_refresh_schema_handles_client_error(mock_neo4j_driver: MagicMock) -> No { "output": { "properties": [{"property": "property_a", "type": "STRING"}], - "labels": "LabelA", + "label": "LabelA", } } ] From 6e64972de132fb1c511b71c502420bb9c82901dc Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Thu, 20 Feb 2025 11:15:35 +0000 Subject: [PATCH 03/11] Linting fix --- libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py b/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py index 4e4e0f2..3f42ff5 100644 --- a/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py +++ b/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py @@ -1,5 +1,5 @@ from types import ModuleType -from typing import Any, Dict, Generator, Mapping, Sequence, Union +from typing import Dict, Generator, Mapping, Sequence, Union from unittest.mock import MagicMock, patch import pytest From 1df800e71536de5ffe2d75cceb93d5ef36c1aebb Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Thu, 20 Feb 2025 11:18:48 +0000 Subject: [PATCH 04/11] Replaced neo4j_graph constants with neo4j-graphrag equivalents --- libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py index 2c98f36..7496acf 100644 --- a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py +++ b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py @@ -3,6 +3,12 @@ from langchain_core.utils import get_from_dict_or_env from neo4j_graphrag.schema import ( + BASE_ENTITY_LABEL, + DISTINCT_VALUE_LIMIT, + EXCLUDED_LABELS, + EXCLUDED_RELS, + EXHAUSTIVE_SEARCH_LIMIT, + LIST_LIMIT, NODE_PROPERTIES_QUERY, REL_PROPERTIES_QUERY, REL_QUERY, @@ -12,15 +18,6 @@ from langchain_neo4j.graphs.graph_document import GraphDocument from langchain_neo4j.graphs.graph_store import GraphStore -BASE_ENTITY_LABEL = "__Entity__" -EXCLUDED_LABELS = ["_Bloom_Perspective_", "_Bloom_Scene_"] -EXCLUDED_RELS = ["_Bloom_HAS_SCENE_"] -EXHAUSTIVE_SEARCH_LIMIT = 10000 -LIST_LIMIT = 128 -# Threshold for returning all available prop values in graph schema -DISTINCT_VALUE_LIMIT = 10 - - include_docs_query = ( "MERGE (d:Document {id:$document.metadata.id}) " "SET d.text = $document.page_content " From 8cb675de7cb4117342f1e1fcc25c88cca852e00d Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Thu, 20 Feb 2025 11:20:47 +0000 Subject: [PATCH 05/11] Replaced clean_string_values with _clean_string_values from neo4j-graphrag --- .../langchain_neo4j/graphs/neo4j_graph.py | 23 ++++--------------- 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py index 7496acf..c6a974e 100644 --- a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py +++ b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py @@ -12,6 +12,7 @@ NODE_PROPERTIES_QUERY, REL_PROPERTIES_QUERY, REL_QUERY, + _clean_string_values, _value_sanitize, ) @@ -26,20 +27,6 @@ ) -def clean_string_values(text: str) -> str: - """Clean string values for schema. - - Cleans the input text by replacing newline and carriage return characters. - - Args: - text (str): The input text to clean. - - Returns: - str: The cleaned text. - """ - return text.replace("\n", " ").replace("\r", " ") - - def _get_node_import_query(baseEntityLabel: bool, include_source: bool) -> str: if baseEntityLabel: return ( @@ -99,7 +86,7 @@ def _format_schema(schema: Dict, is_enhanced: bool) -> str: if prop["type"] == "STRING" and prop.get("values"): if prop.get("distinct_count", 11) > DISTINCT_VALUE_LIMIT: example = ( - f'Example: "{clean_string_values(prop["values"][0])}"' + f'Example: "{_clean_string_values(prop["values"][0])}"' if prop["values"] else "" ) @@ -107,7 +94,7 @@ def _format_schema(schema: Dict, is_enhanced: bool) -> str: example = ( ( "Available options: " - f'{[clean_string_values(el) for el in prop["values"]]}' + f'{[_clean_string_values(el) for el in prop["values"]]}' ) if prop["values"] else "" @@ -147,7 +134,7 @@ def _format_schema(schema: Dict, is_enhanced: bool) -> str: if prop["type"] == "STRING" and prop.get("values"): if prop.get("distinct_count", 11) > DISTINCT_VALUE_LIMIT: example = ( - f'Example: "{clean_string_values(prop["values"][0])}"' + f'Example: "{_clean_string_values(prop["values"][0])}"' if prop["values"] else "" ) @@ -155,7 +142,7 @@ def _format_schema(schema: Dict, is_enhanced: bool) -> str: example = ( ( "Available options: " - f'{[clean_string_values(el) for el in prop["values"]]}' + f'{[_clean_string_values(el) for el in prop["values"]]}' ) if prop["values"] else "" From 286600cd9f315245249727775636ff9bbbdce2a8 Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Thu, 20 Feb 2025 11:27:25 +0000 Subject: [PATCH 06/11] Replaced _format_schema with format_schema from neo4j-graphrag --- .../langchain_neo4j/graphs/neo4j_graph.py | 134 +------ .../unit_tests/graphs/test_neo4j_graph.py | 375 +----------------- 2 files changed, 5 insertions(+), 504 deletions(-) diff --git a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py index c6a974e..251fedf 100644 --- a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py +++ b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py @@ -8,12 +8,11 @@ EXCLUDED_LABELS, EXCLUDED_RELS, EXHAUSTIVE_SEARCH_LIMIT, - LIST_LIMIT, NODE_PROPERTIES_QUERY, REL_PROPERTIES_QUERY, REL_QUERY, - _clean_string_values, _value_sanitize, + format_schema, ) from langchain_neo4j.graphs.graph_document import GraphDocument @@ -74,135 +73,6 @@ def _get_rel_import_query(baseEntityLabel: bool) -> str: ) -def _format_schema(schema: Dict, is_enhanced: bool) -> str: - formatted_node_props = [] - formatted_rel_props = [] - if is_enhanced: - # Enhanced formatting for nodes - for node_type, properties in schema["node_props"].items(): - formatted_node_props.append(f"- **{node_type}**") - for prop in properties: - example = "" - if prop["type"] == "STRING" and prop.get("values"): - if prop.get("distinct_count", 11) > DISTINCT_VALUE_LIMIT: - example = ( - f'Example: "{_clean_string_values(prop["values"][0])}"' - if prop["values"] - else "" - ) - else: # If less than 10 possible values return all - example = ( - ( - "Available options: " - f'{[_clean_string_values(el) for el in prop["values"]]}' - ) - if prop["values"] - else "" - ) - - elif prop["type"] in [ - "INTEGER", - "FLOAT", - "DATE", - "DATE_TIME", - "LOCAL_DATE_TIME", - ]: - if prop.get("min") and prop.get("max"): - example = f'Min: {prop["min"]}, Max: {prop["max"]}' - else: - example = ( - f'Example: "{prop["values"][0]}"' - if prop.get("values") - else "" - ) - elif prop["type"] == "LIST": - # Skip embeddings - if not prop.get("min_size") or prop["min_size"] > LIST_LIMIT: - continue - example = ( - f'Min Size: {prop["min_size"]}, Max Size: {prop["max_size"]}' - ) - formatted_node_props.append( - f" - `{prop['property']}`: {prop['type']} {example}" - ) - - # Enhanced formatting for relationships - for rel_type, properties in schema["rel_props"].items(): - formatted_rel_props.append(f"- **{rel_type}**") - for prop in properties: - example = "" - if prop["type"] == "STRING" and prop.get("values"): - if prop.get("distinct_count", 11) > DISTINCT_VALUE_LIMIT: - example = ( - f'Example: "{_clean_string_values(prop["values"][0])}"' - if prop["values"] - else "" - ) - else: # If less than 10 possible values return all - example = ( - ( - "Available options: " - f'{[_clean_string_values(el) for el in prop["values"]]}' - ) - if prop["values"] - else "" - ) - elif prop["type"] in [ - "INTEGER", - "FLOAT", - "DATE", - "DATE_TIME", - "LOCAL_DATE_TIME", - ]: - if prop.get("min") and prop.get("max"): # If we have min/max - example = f'Min: {prop["min"]}, Max: {prop["max"]}' - else: # return a single value - example = ( - f'Example: "{prop["values"][0]}"' if prop["values"] else "" - ) - elif prop["type"] == "LIST": - # Skip embeddings - if not prop.get("min_size") or prop["min_size"] > LIST_LIMIT: - continue - example = ( - f'Min Size: {prop["min_size"]}, Max Size: {prop["max_size"]}' - ) - formatted_rel_props.append( - f" - `{prop['property']}`: {prop['type']} {example}" - ) - else: - # Format node properties - for label, props in schema["node_props"].items(): - props_str = ", ".join( - [f"{prop['property']}: {prop['type']}" for prop in props] - ) - formatted_node_props.append(f"{label} {{{props_str}}}") - - # Format relationship properties using structured_schema - for type, props in schema["rel_props"].items(): - props_str = ", ".join( - [f"{prop['property']}: {prop['type']}" for prop in props] - ) - formatted_rel_props.append(f"{type} {{{props_str}}}") - - # Format relationships - formatted_rels = [ - f"(:{el['start']})-[:{el['type']}]->(:{el['end']})" - for el in schema["relationships"] - ] - - return "\n".join( - [ - "Node properties:", - "\n".join(formatted_node_props), - "Relationship properties:", - "\n".join(formatted_rel_props), - "The relationships:", - "\n".join(formatted_rels), - ] - ) - - def _remove_backticks(text: str) -> str: return text.replace("`", "") @@ -519,7 +389,7 @@ def refresh_schema(self) -> None: except CypherTypeError: continue - schema = _format_schema(self.structured_schema, self._enhanced_schema) + schema = format_schema(self.structured_schema, self._enhanced_schema) self.schema = schema diff --git a/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py b/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py index 3f42ff5..11a46e2 100644 --- a/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py +++ b/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py @@ -1,16 +1,13 @@ from types import ModuleType -from typing import Dict, Generator, Mapping, Sequence, Union +from typing import Generator, Mapping, Sequence, Union from unittest.mock import MagicMock, patch import pytest from neo4j.exceptions import ClientError, ConfigurationError, Neo4jError +from neo4j_graphrag.schema import LIST_LIMIT from langchain_neo4j.graphs.graph_document import GraphDocument, Node, Relationship -from langchain_neo4j.graphs.neo4j_graph import ( - LIST_LIMIT, - Neo4jGraph, - _format_schema, -) +from langchain_neo4j.graphs.neo4j_graph import Neo4jGraph @pytest.fixture @@ -315,372 +312,6 @@ def test_add_graph_docs_inc_src_err(mock_neo4j_driver: MagicMock) -> None: ) -@pytest.mark.parametrize( - "description, schema, is_enhanced, expected_output", - [ - ( - "Enhanced, string property with high distinct count", - { - "node_props": { - "Person": [ - { - "property": "name", - "type": "STRING", - "values": ["Alice", "Bob", "Charlie"], - "distinct_count": 11, - } - ] - }, - "rel_props": {}, - "relationships": [], - }, - True, - ( - "Node properties:\n" - "- **Person**\n" - ' - `name`: STRING Example: "Alice"\n' - "Relationship properties:\n" - "\n" - "The relationships:\n" - ), - ), - ( - "Enhanced, string property with low distinct count", - { - "node_props": { - "Animal": [ - { - "property": "species", - "type": "STRING", - "values": ["Cat", "Dog"], - "distinct_count": 2, - } - ] - }, - "rel_props": {}, - "relationships": [], - }, - True, - ( - "Node properties:\n" - "- **Animal**\n" - " - `species`: STRING Available options: ['Cat', 'Dog']\n" - "Relationship properties:\n" - "\n" - "The relationships:\n" - ), - ), - ( - "Enhanced, numeric property with min and max", - { - "node_props": { - "Person": [ - {"property": "age", "type": "INTEGER", "min": 20, "max": 70} - ] - }, - "rel_props": {}, - "relationships": [], - }, - True, - ( - "Node properties:\n" - "- **Person**\n" - " - `age`: INTEGER Min: 20, Max: 70\n" - "Relationship properties:\n" - "\n" - "The relationships:\n" - ), - ), - ( - "Enhanced, numeric property with values", - { - "node_props": { - "Event": [ - { - "property": "date", - "type": "DATE", - "values": ["2021-01-01", "2021-01-02"], - } - ] - }, - "rel_props": {}, - "relationships": [], - }, - True, - ( - "Node properties:\n" - "- **Event**\n" - ' - `date`: DATE Example: "2021-01-01"\n' - "Relationship properties:\n" - "\n" - "The relationships:\n" - ), - ), - ( - "Enhanced, list property that should be skipped", - { - "node_props": { - "Document": [ - { - "property": "embedding", - "type": "LIST", - "min_size": 150, - "max_size": 200, - } - ] - }, - "rel_props": {}, - "relationships": [], - }, - True, - ( - "Node properties:\n" - "- **Document**\n" - "Relationship properties:\n" - "\n" - "The relationships:\n" - ), - ), - ( - "Enhanced, list property that should be included", - { - "node_props": { - "Document": [ - { - "property": "keywords", - "type": "LIST", - "min_size": 2, - "max_size": 5, - } - ] - }, - "rel_props": {}, - "relationships": [], - }, - True, - ( - "Node properties:\n" - "- **Document**\n" - " - `keywords`: LIST Min Size: 2, Max Size: 5\n" - "Relationship properties:\n" - "\n" - "The relationships:\n" - ), - ), - ( - "Enhanced, relationship string property with high distinct count", - { - "node_props": {}, - "rel_props": { - "KNOWS": [ - { - "property": "since", - "type": "STRING", - "values": ["2000", "2001", "2002"], - "distinct_count": 15, - } - ] - }, - "relationships": [], - }, - True, - ( - "Node properties:\n" - "\n" - "Relationship properties:\n" - "- **KNOWS**\n" - ' - `since`: STRING Example: "2000"\n' - "The relationships:\n" - ), - ), - ( - "Enhanced, relationship string property with low distinct count", - { - "node_props": {}, - "rel_props": { - "LIKES": [ - { - "property": "intensity", - "type": "STRING", - "values": ["High", "Medium", "Low"], - "distinct_count": 3, - } - ] - }, - "relationships": [], - }, - True, - ( - "Node properties:\n" - "\n" - "Relationship properties:\n" - "- **LIKES**\n" - " - `intensity`: STRING Available options: ['High', 'Medium', 'Low']\n" - "The relationships:\n" - ), - ), - ( - "Enhanced, relationship numeric property with min and max", - { - "node_props": {}, - "rel_props": { - "WORKS_WITH": [ - { - "property": "since", - "type": "INTEGER", - "min": 1995, - "max": 2020, - } - ] - }, - "relationships": [], - }, - True, - ( - "Node properties:\n" - "\n" - "Relationship properties:\n" - "- **WORKS_WITH**\n" - " - `since`: INTEGER Min: 1995, Max: 2020\n" - "The relationships:\n" - ), - ), - ( - "Enhanced, relationship list property that should be skipped", - { - "node_props": {}, - "rel_props": { - "KNOWS": [ - { - "property": "embedding", - "type": "LIST", - "min_size": 150, - "max_size": 200, - } - ] - }, - "relationships": [], - }, - True, - ( - "Node properties:\n" - "\n" - "Relationship properties:\n" - "- **KNOWS**\n" - "The relationships:\n" - ), - ), - ( - "Enhanced, relationship list property that should be included", - { - "node_props": {}, - "rel_props": { - "KNOWS": [ - { - "property": "messages", - "type": "LIST", - "min_size": 2, - "max_size": 5, - } - ] - }, - "relationships": [], - }, - True, - ( - "Node properties:\n" - "\n" - "Relationship properties:\n" - "- **KNOWS**\n" - " - `messages`: LIST Min Size: 2, Max Size: 5\n" - "The relationships:\n" - ), - ), - ( - "Enhanced, relationship numeric property without min and max", - { - "node_props": {}, - "rel_props": { - "OWES": [ - { - "property": "amount", - "type": "FLOAT", - "values": [3.14, 2.71], - } - ] - }, - "relationships": [], - }, - True, - ( - "Node properties:\n" - "\n" - "Relationship properties:\n" - "- **OWES**\n" - ' - `amount`: FLOAT Example: "3.14"\n' - "The relationships:\n" - ), - ), - ( - "Enhanced, property with empty values list", - { - "node_props": { - "Person": [ - { - "property": "name", - "type": "STRING", - "values": [], - "distinct_count": 15, - } - ] - }, - "rel_props": {}, - "relationships": [], - }, - True, - ( - "Node properties:\n" - "- **Person**\n" - " - `name`: STRING \n" - "Relationship properties:\n" - "\n" - "The relationships:\n" - ), - ), - ( - "Enhanced, property with missing values", - { - "node_props": { - "Person": [ - { - "property": "name", - "type": "STRING", - "distinct_count": 15, - } - ] - }, - "rel_props": {}, - "relationships": [], - }, - True, - ( - "Node properties:\n" - "- **Person**\n" - " - `name`: STRING \n" - "Relationship properties:\n" - "\n" - "The relationships:\n" - ), - ), - ], -) -def test_format_schema( - description: str, schema: Dict, is_enhanced: bool, expected_output: str -) -> None: - result = _format_schema(schema, is_enhanced) - assert result == expected_output, f"Failed test case: {description}" - - # _enhanced_schema_cypher tests From eca69ec9afe3e2c420ff5018f97270523f0771c5 Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Thu, 20 Feb 2025 11:41:38 +0000 Subject: [PATCH 07/11] Replaced _enhanced_schema_cypher with get_enhanced_schema_cypher from neo4j-graphrag --- .../langchain_neo4j/graphs/neo4j_graph.py | 184 ++-------------- .../unit_tests/graphs/test_neo4j_graph.py | 199 ------------------ 2 files changed, 13 insertions(+), 370 deletions(-) diff --git a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py index 251fedf..c495772 100644 --- a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py +++ b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py @@ -4,7 +4,6 @@ from langchain_core.utils import get_from_dict_or_env from neo4j_graphrag.schema import ( BASE_ENTITY_LABEL, - DISTINCT_VALUE_LIMIT, EXCLUDED_LABELS, EXCLUDED_RELS, EXHAUSTIVE_SEARCH_LIMIT, @@ -13,6 +12,7 @@ REL_QUERY, _value_sanitize, format_schema, + get_enhanced_schema_cypher, ) from langchain_neo4j.graphs.graph_document import GraphDocument @@ -346,8 +346,12 @@ def refresh_schema(self) -> None: node_props = self.structured_schema["node_props"].get(node["name"]) if not node_props: # The node has no properties continue - enhanced_cypher = self._enhanced_schema_cypher( - node["name"], node_props, node["count"] < EXHAUSTIVE_SEARCH_LIMIT + enhanced_cypher = get_enhanced_schema_cypher( + driver=self._driver, + structured_schema=self.structured_schema, + label_or_type=node["name"], + properties=node_props, + exhaustive=node["count"] < EXHAUSTIVE_SEARCH_LIMIT, ) # Due to schema-flexible nature of neo4j errors can happen try: @@ -374,10 +378,12 @@ def refresh_schema(self) -> None: rel_props = self.structured_schema["rel_props"].get(rel["name"]) if not rel_props: # The rel has no properties continue - enhanced_cypher = self._enhanced_schema_cypher( - rel["name"], - rel_props, - rel["count"] < EXHAUSTIVE_SEARCH_LIMIT, + enhanced_cypher = get_enhanced_schema_cypher( + driver=self._driver, + structured_schema=self.structured_schema, + label_or_type=rel["name"], + properties=rel_props, + exhaustive=rel["count"] < EXHAUSTIVE_SEARCH_LIMIT, is_relationship=True, ) try: @@ -488,170 +494,6 @@ def add_graph_documents( }, ) - def _enhanced_schema_cypher( - self, - label_or_type: str, - properties: List[Dict[str, Any]], - exhaustive: bool, - is_relationship: bool = False, - ) -> str: - if is_relationship: - match_clause = f"MATCH ()-[n:`{label_or_type}`]->()" - else: - match_clause = f"MATCH (n:`{label_or_type}`)" - - with_clauses = [] - return_clauses = [] - output_dict = {} - if exhaustive: - for prop in properties: - prop_name = prop["property"] - prop_type = prop["type"] - if prop_type == "STRING": - with_clauses.append( - ( - f"collect(distinct substring(toString(n.`{prop_name}`)" - f", 0, 50)) AS `{prop_name}_values`" - ) - ) - return_clauses.append( - ( - f"values:`{prop_name}_values`[..{DISTINCT_VALUE_LIMIT}]," - f" distinct_count: size(`{prop_name}_values`)" - ) - ) - elif prop_type in [ - "INTEGER", - "FLOAT", - "DATE", - "DATE_TIME", - "LOCAL_DATE_TIME", - ]: - with_clauses.append(f"min(n.`{prop_name}`) AS `{prop_name}_min`") - with_clauses.append(f"max(n.`{prop_name}`) AS `{prop_name}_max`") - with_clauses.append( - f"count(distinct n.`{prop_name}`) AS `{prop_name}_distinct`" - ) - return_clauses.append( - ( - f"min: toString(`{prop_name}_min`), " - f"max: toString(`{prop_name}_max`), " - f"distinct_count: `{prop_name}_distinct`" - ) - ) - elif prop_type == "LIST": - with_clauses.append( - ( - f"min(size(n.`{prop_name}`)) AS `{prop_name}_size_min`, " - f"max(size(n.`{prop_name}`)) AS `{prop_name}_size_max`" - ) - ) - return_clauses.append( - f"min_size: `{prop_name}_size_min`, " - f"max_size: `{prop_name}_size_max`" - ) - elif prop_type in ["BOOLEAN", "POINT", "DURATION"]: - continue - output_dict[prop_name] = "{" + return_clauses.pop() + "}" - else: - # Just sample 5 random nodes - match_clause += " WITH n LIMIT 5" - for prop in properties: - prop_name = prop["property"] - prop_type = prop["type"] - - # Check if indexed property, we can still do exhaustive - prop_index = [ - el - for el in self.structured_schema["metadata"]["index"] - if el["label"] == label_or_type - and el["properties"] == [prop_name] - and el["type"] == "RANGE" - ] - if prop_type == "STRING": - if ( - prop_index - and prop_index[0].get("size") > 0 - and prop_index[0].get("distinctValues") <= DISTINCT_VALUE_LIMIT - ): - distinct_values = self.query( - f"CALL apoc.schema.properties.distinct(" - f"'{label_or_type}', '{prop_name}') YIELD value" - )[0]["value"] - return_clauses.append( - ( - f"values: {distinct_values}," - f" distinct_count: {len(distinct_values)}" - ) - ) - else: - with_clauses.append( - ( - f"collect(distinct substring(toString(n.`{prop_name}`)" - f", 0, 50)) AS `{prop_name}_values`" - ) - ) - return_clauses.append(f"values: `{prop_name}_values`") - elif prop_type in [ - "INTEGER", - "FLOAT", - "DATE", - "DATE_TIME", - "LOCAL_DATE_TIME", - ]: - if not prop_index: - with_clauses.append( - f"collect(distinct toString(n.`{prop_name}`)) " - f"AS `{prop_name}_values`" - ) - return_clauses.append(f"values: `{prop_name}_values`") - else: - with_clauses.append( - f"min(n.`{prop_name}`) AS `{prop_name}_min`" - ) - with_clauses.append( - f"max(n.`{prop_name}`) AS `{prop_name}_max`" - ) - with_clauses.append( - f"count(distinct n.`{prop_name}`) AS `{prop_name}_distinct`" - ) - return_clauses.append( - ( - f"min: toString(`{prop_name}_min`), " - f"max: toString(`{prop_name}_max`), " - f"distinct_count: `{prop_name}_distinct`" - ) - ) - - elif prop_type == "LIST": - with_clauses.append( - ( - f"min(size(n.`{prop_name}`)) AS `{prop_name}_size_min`, " - f"max(size(n.`{prop_name}`)) AS `{prop_name}_size_max`" - ) - ) - return_clauses.append( - ( - f"min_size: `{prop_name}_size_min`, " - f"max_size: `{prop_name}_size_max`" - ) - ) - elif prop_type in ["BOOLEAN", "POINT", "DURATION"]: - continue - - output_dict[prop_name] = "{" + return_clauses.pop() + "}" - - with_clause = "WITH " + ",\n ".join(with_clauses) - return_clause = ( - "RETURN {" - + ", ".join(f"`{k}`: {v}" for k, v in output_dict.items()) - + "} AS output" - ) - - # Combine all parts of the Cypher query - cypher_query = "\n".join([match_clause, with_clause, return_clause]) - return cypher_query - def close(self) -> None: """ Explicitly close the Neo4j driver connection. diff --git a/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py b/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py index 11a46e2..918c7b3 100644 --- a/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py +++ b/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py @@ -310,202 +310,3 @@ def test_add_graph_docs_inc_src_err(mock_neo4j_driver: MagicMock) -> None: "include_source is set to True, but at least one document has no `source`." in str(exc_info.value) ) - - -# _enhanced_schema_cypher tests - - -def test_enhanced_schema_cypher_integer_exhaustive_true( - mock_neo4j_driver: MagicMock, -) -> None: - graph = Neo4jGraph( - url="bolt://localhost:7687", username="neo4j", password="password" - ) - - graph.structured_schema = {"metadata": {"index": []}} - properties = [{"property": "age", "type": "INTEGER"}] - query = graph._enhanced_schema_cypher("Person", properties, exhaustive=True) - assert "min(n.`age`) AS `age_min`" in query - assert "max(n.`age`) AS `age_max`" in query - assert "count(distinct n.`age`) AS `age_distinct`" in query - assert ( - "min: toString(`age_min`), max: toString(`age_max`), " - "distinct_count: `age_distinct`" in query - ) - - -def test_enhanced_schema_cypher_list_exhaustive_true( - mock_neo4j_driver: MagicMock, -) -> None: - graph = Neo4jGraph( - url="bolt://localhost:7687", username="neo4j", password="password" - ) - graph.structured_schema = {"metadata": {"index": []}} - properties = [{"property": "tags", "type": "LIST"}] - query = graph._enhanced_schema_cypher("Article", properties, exhaustive=True) - assert "min(size(n.`tags`)) AS `tags_size_min`" in query - assert "max(size(n.`tags`)) AS `tags_size_max`" in query - assert "min_size: `tags_size_min`, max_size: `tags_size_max`" in query - - -def test_enhanced_schema_cypher_boolean_exhaustive_true( - mock_neo4j_driver: MagicMock, -) -> None: - graph = Neo4jGraph( - url="bolt://localhost:7687", username="neo4j", password="password" - ) - properties = [{"property": "active", "type": "BOOLEAN"}] - query = graph._enhanced_schema_cypher("User", properties, exhaustive=True) - # BOOLEAN types should be skipped, so their properties should not be in the query - assert "n.`active`" not in query - - -def test_enhanced_schema_cypher_integer_exhaustive_false_no_index( - mock_neo4j_driver: MagicMock, -) -> None: - graph = Neo4jGraph( - url="bolt://localhost:7687", username="neo4j", password="password" - ) - graph.structured_schema = {"metadata": {"index": []}} - properties = [{"property": "age", "type": "INTEGER"}] - query = graph._enhanced_schema_cypher("Person", properties, exhaustive=False) - assert "collect(distinct toString(n.`age`)) AS `age_values`" in query - assert "values: `age_values`" in query - - -def test_enhanced_schema_cypher_integer_exhaustive_false_with_index( - mock_neo4j_driver: MagicMock, -) -> None: - graph = Neo4jGraph( - url="bolt://localhost:7687", username="neo4j", password="password" - ) - graph.structured_schema = { - "metadata": { - "index": [ - { - "label": "Person", - "properties": ["age"], - "type": "RANGE", - } - ] - } - } - properties = [{"property": "age", "type": "INTEGER"}] - query = graph._enhanced_schema_cypher("Person", properties, exhaustive=False) - assert "min(n.`age`) AS `age_min`" in query - assert "max(n.`age`) AS `age_max`" in query - assert "count(distinct n.`age`) AS `age_distinct`" in query - assert ( - "min: toString(`age_min`), max: toString(`age_max`), " - "distinct_count: `age_distinct`" in query - ) - - -def test_enhanced_schema_cypher_list_exhaustive_false( - mock_neo4j_driver: MagicMock, -) -> None: - graph = Neo4jGraph( - url="bolt://localhost:7687", username="neo4j", password="password" - ) - properties = [{"property": "tags", "type": "LIST"}] - query = graph._enhanced_schema_cypher("Article", properties, exhaustive=False) - assert "min(size(n.`tags`)) AS `tags_size_min`" in query - assert "max(size(n.`tags`)) AS `tags_size_max`" in query - assert "min_size: `tags_size_min`, max_size: `tags_size_max`" in query - - -def test_enhanced_schema_cypher_boolean_exhaustive_false( - mock_neo4j_driver: MagicMock, -) -> None: - graph = Neo4jGraph( - url="bolt://localhost:7687", username="neo4j", password="password" - ) - properties = [{"property": "active", "type": "BOOLEAN"}] - query = graph._enhanced_schema_cypher("User", properties, exhaustive=False) - # BOOLEAN types should be skipped, so their properties should not be in the query - assert "n.`active`" not in query - - -def test_enhanced_schema_cypher_string_exhaustive_false_with_index( - mock_neo4j_driver: MagicMock, -) -> None: - graph = Neo4jGraph( - url="bolt://localhost:7687", username="neo4j", password="password" - ) - graph.structured_schema = { - "metadata": { - "index": [ - { - "label": "Person", - "properties": ["status"], - "type": "RANGE", - "size": 5, - "distinctValues": 5, - } - ] - } - } - graph.query = MagicMock(return_value=[{"value": ["Single", "Married", "Divorced"]}]) # type: ignore[method-assign] - properties = [{"property": "status", "type": "STRING"}] - query = graph._enhanced_schema_cypher("Person", properties, exhaustive=False) - assert "values: ['Single', 'Married', 'Divorced'], distinct_count: 3" in query - - -def test_enhanced_schema_cypher_string_exhaustive_false_no_index( - mock_neo4j_driver: MagicMock, -) -> None: - graph = Neo4jGraph( - url="bolt://localhost:7687", username="neo4j", password="password" - ) - graph.structured_schema = {"metadata": {"index": []}} - properties = [{"property": "status", "type": "STRING"}] - query = graph._enhanced_schema_cypher("Person", properties, exhaustive=False) - assert ( - "collect(distinct substring(toString(n.`status`), 0, 50)) AS `status_values`" - in query - ) - assert "values: `status_values`" in query - - -def test_enhanced_schema_cypher_point_type(mock_neo4j_driver: MagicMock) -> None: - graph = Neo4jGraph( - url="bolt://localhost:7687", username="neo4j", password="password" - ) - properties = [{"property": "location", "type": "POINT"}] - query = graph._enhanced_schema_cypher("Place", properties, exhaustive=True) - # POINT types should be skipped - assert "n.`location`" not in query - - -def test_enhanced_schema_cypher_duration_type(mock_neo4j_driver: MagicMock) -> None: - graph = Neo4jGraph( - url="bolt://localhost:7687", username="neo4j", password="password" - ) - properties = [{"property": "duration", "type": "DURATION"}] - query = graph._enhanced_schema_cypher("Event", properties, exhaustive=False) - # DURATION types should be skipped - assert "n.`duration`" not in query - - -def test_enhanced_schema_cypher_relationship(mock_neo4j_driver: MagicMock) -> None: - graph = Neo4jGraph( - url="bolt://localhost:7687", username="neo4j", password="password" - ) - properties = [{"property": "since", "type": "INTEGER"}] - - query = graph._enhanced_schema_cypher( - label_or_type="FRIENDS_WITH", - properties=properties, - exhaustive=True, - is_relationship=True, - ) - - assert query.startswith("MATCH ()-[n:`FRIENDS_WITH`]->()") - assert "min(n.`since`) AS `since_min`" in query - assert "max(n.`since`) AS `since_max`" in query - assert "count(distinct n.`since`) AS `since_distinct`" in query - expected_return_clause = ( - "`since`: {min: toString(`since_min`), max: toString(`since_max`), " - "distinct_count: `since_distinct`}" - ) - assert expected_return_clause in query From 3e2cb1ffad830ad21f78c5a195d620fc2f1545bf Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Thu, 20 Feb 2025 14:38:34 +0000 Subject: [PATCH 08/11] Updated refresh_schema to use neo4j-graphrag schema methods --- .../langchain_neo4j/graphs/neo4j_graph.py | 126 ++---------------- .../unit_tests/graphs/test_neo4j_graph.py | 67 ++++++---- 2 files changed, 51 insertions(+), 142 deletions(-) diff --git a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py index c495772..9d4c4bb 100644 --- a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py +++ b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py @@ -4,15 +4,9 @@ from langchain_core.utils import get_from_dict_or_env from neo4j_graphrag.schema import ( BASE_ENTITY_LABEL, - EXCLUDED_LABELS, - EXCLUDED_RELS, - EXHAUSTIVE_SEARCH_LIMIT, - NODE_PROPERTIES_QUERY, - REL_PROPERTIES_QUERY, - REL_QUERY, _value_sanitize, format_schema, - get_enhanced_schema_cypher, + get_structured_schema, ) from langchain_neo4j.graphs.graph_document import GraphDocument @@ -287,117 +281,13 @@ def refresh_schema(self) -> None: RuntimeError: If the connection has been closed. """ self._check_driver_state() - from neo4j.exceptions import ClientError, CypherTypeError - - node_properties = [ - el["output"] - for el in self.query( - NODE_PROPERTIES_QUERY, - params={"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]}, - ) - ] - rel_properties = [ - el["output"] - for el in self.query( - REL_PROPERTIES_QUERY, params={"EXCLUDED_LABELS": EXCLUDED_RELS} - ) - ] - relationships = [ - el["output"] - for el in self.query( - REL_QUERY, - params={"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]}, - ) - ] - - # Get constraints & indexes - try: - constraint = self.query("SHOW CONSTRAINTS") - index = self.query( - "CALL apoc.schema.nodes() YIELD label, properties, type, size, " - "valuesSelectivity WHERE type = 'RANGE' RETURN *, " - "size * valuesSelectivity as distinctValues" - ) - except ( - ClientError - ): # Read-only user might not have access to schema information - constraint = [] - index = [] - - self.structured_schema = { - "node_props": {el["label"]: el["properties"] for el in node_properties}, - "rel_props": {el["type"]: el["properties"] for el in rel_properties}, - "relationships": relationships, - "metadata": {"constraint": constraint, "index": index}, - } - if self._enhanced_schema: - schema_counts = self.query( - "CALL apoc.meta.graph({sample: 1000, maxRels: 100}) " - "YIELD nodes, relationships " - "RETURN nodes, [rel in relationships | {name:apoc.any.property" - "(rel, 'type'), count: apoc.any.property(rel, 'count')}]" - " AS relationships" - ) - # Update node info - for node in schema_counts[0]["nodes"]: - # Skip bloom labels - if node["name"] in EXCLUDED_LABELS: - continue - node_props = self.structured_schema["node_props"].get(node["name"]) - if not node_props: # The node has no properties - continue - enhanced_cypher = get_enhanced_schema_cypher( - driver=self._driver, - structured_schema=self.structured_schema, - label_or_type=node["name"], - properties=node_props, - exhaustive=node["count"] < EXHAUSTIVE_SEARCH_LIMIT, - ) - # Due to schema-flexible nature of neo4j errors can happen - try: - enhanced_info = self.query( - enhanced_cypher, - # Disable the - # Neo.ClientNotification.Statement.AggregationSkippedNull - # notifications raised by the use of collect in the enhanced - # schema query - session_params={ - "notifications_disabled_categories": ["UNRECOGNIZED"] - }, - )[0]["output"] - for prop in node_props: - if prop["property"] in enhanced_info: - prop.update(enhanced_info[prop["property"]]) - except CypherTypeError: - continue - # Update rel info - for rel in schema_counts[0]["relationships"]: - # Skip bloom labels - if rel["name"] in EXCLUDED_RELS: - continue - rel_props = self.structured_schema["rel_props"].get(rel["name"]) - if not rel_props: # The rel has no properties - continue - enhanced_cypher = get_enhanced_schema_cypher( - driver=self._driver, - structured_schema=self.structured_schema, - label_or_type=rel["name"], - properties=rel_props, - exhaustive=rel["count"] < EXHAUSTIVE_SEARCH_LIMIT, - is_relationship=True, - ) - try: - enhanced_info = self.query(enhanced_cypher)[0]["output"] - for prop in rel_props: - if prop["property"] in enhanced_info: - prop.update(enhanced_info[prop["property"]]) - # Due to schema-flexible nature of neo4j errors can happen - except CypherTypeError: - continue - - schema = format_schema(self.structured_schema, self._enhanced_schema) - - self.schema = schema + self.structured_schema = get_structured_schema( + driver=self._driver, + is_enhanced=self._enhanced_schema, + ) + self.schema = format_schema( + schema=self.structured_schema, is_enhanced=self._enhanced_schema + ) def add_graph_documents( self, diff --git a/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py b/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py index 918c7b3..07bbf4a 100644 --- a/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py +++ b/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py @@ -3,6 +3,9 @@ from unittest.mock import MagicMock, patch import pytest +from neo4j._data import Record +from neo4j._sync.driver import EagerResult +from neo4j._work.summary import ResultSummary from neo4j.exceptions import ClientError, ConfigurationError, Neo4jError from neo4j_graphrag.schema import LIST_LIMIT @@ -16,7 +19,11 @@ def mock_neo4j_driver() -> Generator[MagicMock, None, None]: mock_driver_instance = MagicMock() mock_driver.return_value = mock_driver_instance mock_driver_instance.verify_connectivity.return_value = None - mock_driver_instance.execute_query = MagicMock(return_value=([], None, None)) + mock_driver_instance.execute_query = MagicMock( + return_value=EagerResult( + records=[], summary=MagicMock(spec=ResultSummary), keys=[] + ) + ) mock_driver_instance._closed = False yield mock_driver_instance @@ -200,6 +207,7 @@ def test_query_fallback_execution(mock_neo4j_driver: MagicMock) -> None: password="password", database="test_db", sanitize=True, + refresh_schema=False, ) mock_session = MagicMock() mock_result = MagicMock() @@ -231,37 +239,47 @@ def test_refresh_schema_handles_client_error(mock_neo4j_driver: MagicMock) -> No username="neo4j", password="password", database="test_db", + refresh_schema=False, ) node_properties = [ - { - "output": { - "properties": [{"property": "property_a", "type": "STRING"}], - "label": "LabelA", + Record( + { + "output": { + "properties": [{"property": "property_a", "type": "STRING"}], + "label": "LabelA", + } } - } + ) ] relationships_properties = [ - { - "output": { - "type": "REL_TYPE", - "properties": [{"property": "rel_prop", "type": "STRING"}], + Record( + { + "output": { + "type": "REL_TYPE", + "properties": [{"property": "rel_prop", "type": "STRING"}], + } } - } + ) ] relationships = [ - {"output": {"start": "LabelA", "type": "REL_TYPE", "end": "LabelB"}}, - {"output": {"start": "LabelA", "type": "REL_TYPE", "end": "LabelC"}}, + Record({"output": {"start": "LabelA", "type": "REL_TYPE", "end": "LabelB"}}), + Record({"output": {"start": "LabelA", "type": "REL_TYPE", "end": "LabelC"}}), ] - # Mock the query method to raise ClientError for constraint and index queries - graph.query = MagicMock( # type: ignore[method-assign] - side_effect=[ - node_properties, - relationships_properties, - relationships, - ClientError("Mock ClientError"), - ] - ) + mock_neo4j_driver.execute_query.side_effect = [ + EagerResult( + records=node_properties, summary=MagicMock(spec=ResultSummary), keys=[] + ), + EagerResult( + records=relationships_properties, + summary=MagicMock(spec=ResultSummary), + keys=[], + ), + EagerResult( + records=relationships, summary=MagicMock(spec=ResultSummary), keys=[] + ), + ClientError("Mock ClientError"), + ] graph.refresh_schema() # Assertions @@ -270,8 +288,9 @@ def test_refresh_schema_handles_client_error(mock_neo4j_driver: MagicMock) -> No assert graph.structured_schema["metadata"]["index"] == [] # Ensure the query method was called as expected - assert graph.query.call_count == 4 - graph.query.assert_any_call("SHOW CONSTRAINTS") + assert mock_neo4j_driver.execute_query.call_count == 4 + calls = mock_neo4j_driver.execute_query.call_args_list + assert any(call.args[0].text == "SHOW CONSTRAINTS" for call in calls) def test_get_schema(mock_neo4j_driver: MagicMock) -> None: From 849ee6c9aa9b853f01061c7a2253fd9bc929253b Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Thu, 20 Feb 2025 16:35:57 +0000 Subject: [PATCH 09/11] Added missing args to get_structured_schema call --- libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py index 9d4c4bb..e61d2c5 100644 --- a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py +++ b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py @@ -284,6 +284,9 @@ def refresh_schema(self) -> None: self.structured_schema = get_structured_schema( driver=self._driver, is_enhanced=self._enhanced_schema, + database=self._database, + timeout=self.timeout, + sanitize=self.sanitize, ) self.schema = format_schema( schema=self.structured_schema, is_enhanced=self._enhanced_schema From 94c0c551b884fb25759c0aed86cb093c2a97c2ee Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Wed, 26 Feb 2025 10:25:04 +0000 Subject: [PATCH 10/11] Fixed neo4j.Record import --- libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py b/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py index 07bbf4a..a9990b4 100644 --- a/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py +++ b/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest -from neo4j._data import Record +from neo4j import Record from neo4j._sync.driver import EagerResult from neo4j._work.summary import ResultSummary from neo4j.exceptions import ClientError, ConfigurationError, Neo4jError From e54e3debdec0864f5f16d3f6a72d7cbd9642d927 Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Wed, 26 Feb 2025 10:29:49 +0000 Subject: [PATCH 11/11] Updated Neo4jGraph integration tests to use MagicMock rather than EagerResult --- libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py b/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py index a9990b4..cd4d7c0 100644 --- a/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py +++ b/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py @@ -4,7 +4,6 @@ import pytest from neo4j import Record -from neo4j._sync.driver import EagerResult from neo4j._work.summary import ResultSummary from neo4j.exceptions import ClientError, ConfigurationError, Neo4jError from neo4j_graphrag.schema import LIST_LIMIT @@ -20,7 +19,7 @@ def mock_neo4j_driver() -> Generator[MagicMock, None, None]: mock_driver.return_value = mock_driver_instance mock_driver_instance.verify_connectivity.return_value = None mock_driver_instance.execute_query = MagicMock( - return_value=EagerResult( + return_value=MagicMock( records=[], summary=MagicMock(spec=ResultSummary), keys=[] ) ) @@ -267,15 +266,15 @@ def test_refresh_schema_handles_client_error(mock_neo4j_driver: MagicMock) -> No ] mock_neo4j_driver.execute_query.side_effect = [ - EagerResult( + MagicMock( records=node_properties, summary=MagicMock(spec=ResultSummary), keys=[] ), - EagerResult( + MagicMock( records=relationships_properties, summary=MagicMock(spec=ResultSummary), keys=[], ), - EagerResult( + MagicMock( records=relationships, summary=MagicMock(spec=ResultSummary), keys=[] ), ClientError("Mock ClientError"),