generated from langchain-ai/integration-repo-template
-
Notifications
You must be signed in to change notification settings - Fork 118
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Replaces #306. Adding new LCEL format chains `create_neptune_opencypher_qa_chain` and `create_neptune_sparql_qa_chain` for Amazon Neptune. These will replace the legacy [`NeptuneOpenCypherQAChain`](https://python.langchain.com/api_reference/community/chains/langchain_community.chains.graph_qa.neptune_cypher.NeptuneOpenCypherQAChain.html#langchain_community.chains.graph_qa.neptune_cypher.NeptuneOpenCypherQAChain) and [`NeptuneSparqlQAChain`](https://python.langchain.com/api_reference/community/chains/langchain_community.chains.graph_qa.neptune_sparql.NeptuneSparqlQAChain.html) chains in `langchain-community`. To import, run: ``` from langchain_aws.chains import ( create_neptune_opencypher_qa_chain, create_neptune_sparql_qa_chain, ) ```
- Loading branch information
1 parent
94cb00b
commit 38c28fa
Showing
7 changed files
with
477 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from langchain_aws.chains.graph_qa import ( | ||
create_neptune_opencypher_qa_chain, | ||
create_neptune_sparql_qa_chain, | ||
) | ||
|
||
__all__ = [ | ||
"create_neptune_opencypher_qa_chain", | ||
"create_neptune_sparql_qa_chain" | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .neptune_cypher import create_neptune_opencypher_qa_chain | ||
from .neptune_sparql import create_neptune_sparql_qa_chain | ||
|
||
__all__ = [ | ||
"create_neptune_opencypher_qa_chain", | ||
"create_neptune_sparql_qa_chain" | ||
] |
180 changes: 180 additions & 0 deletions
180
libs/aws/langchain_aws/chains/graph_qa/neptune_cypher.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
from __future__ import annotations | ||
|
||
import re | ||
from typing import Any, Optional | ||
|
||
from langchain_core.language_models import BaseLanguageModel | ||
from langchain_core.prompts.base import BasePromptTemplate | ||
from langchain_core.runnables import Runnable, RunnablePassthrough | ||
|
||
from langchain_aws.graphs import BaseNeptuneGraph | ||
|
||
from .prompts import ( | ||
CYPHER_QA_PROMPT, | ||
NEPTUNE_OPENCYPHER_GENERATION_PROMPT, | ||
NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT, | ||
) | ||
|
||
INTERMEDIATE_STEPS_KEY = "intermediate_steps" | ||
|
||
|
||
def trim_query(query: str) -> str: | ||
"""Trim the query to only include Cypher keywords.""" | ||
keywords = ( | ||
"CALL", | ||
"CREATE", | ||
"DELETE", | ||
"DETACH", | ||
"LIMIT", | ||
"MATCH", | ||
"MERGE", | ||
"OPTIONAL", | ||
"ORDER", | ||
"REMOVE", | ||
"RETURN", | ||
"SET", | ||
"SKIP", | ||
"UNWIND", | ||
"WITH", | ||
"WHERE", | ||
"//", | ||
) | ||
|
||
lines = query.split("\n") | ||
new_query = "" | ||
|
||
for line in lines: | ||
if line.strip().upper().startswith(keywords): | ||
new_query += line + "\n" | ||
|
||
return new_query | ||
|
||
|
||
def extract_cypher(text: str) -> str: | ||
"""Extract Cypher code from text using Regex.""" | ||
# The pattern to find Cypher code enclosed in triple backticks | ||
pattern = r"```(.*?)```" | ||
|
||
# Find all matches in the input text | ||
matches = re.findall(pattern, text, re.DOTALL) | ||
|
||
return matches[0] if matches else text | ||
|
||
|
||
def use_simple_prompt(llm: BaseLanguageModel) -> bool: | ||
"""Decides whether to use the simple prompt""" | ||
if llm._llm_type and "anthropic" in llm._llm_type: # type: ignore | ||
return True | ||
|
||
# Bedrock anthropic | ||
if hasattr(llm, "model_id") and "anthropic" in llm.model_id: # type: ignore | ||
return True | ||
|
||
return False | ||
|
||
|
||
def get_prompt(llm: BaseLanguageModel) -> BasePromptTemplate: | ||
"""Selects the final prompt""" | ||
if use_simple_prompt(llm): | ||
return NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT | ||
else: | ||
return NEPTUNE_OPENCYPHER_GENERATION_PROMPT | ||
|
||
|
||
def create_neptune_opencypher_qa_chain( | ||
llm: BaseLanguageModel, | ||
graph: BaseNeptuneGraph, | ||
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT, | ||
cypher_prompt: Optional[BasePromptTemplate] = None, | ||
return_intermediate_steps: bool = False, | ||
return_direct: bool = False, | ||
extra_instructions: Optional[str] = None, | ||
allow_dangerous_requests: bool = False, | ||
) -> Runnable[dict[str, Any], dict]: | ||
"""Chain for question-answering against a Neptune graph | ||
by generating openCypher statements. | ||
*Security note*: Make sure that the database connection uses credentials | ||
that are narrowly-scoped to only include necessary permissions. | ||
Failure to do so may result in data corruption or loss, since the calling | ||
code may attempt commands that would result in deletion, mutation | ||
of data if appropriately prompted or reading sensitive data if such | ||
data is present in the database. | ||
The best way to guard against such negative outcomes is to (as appropriate) | ||
limit the permissions granted to the credentials used with this tool. | ||
See https://python.langchain.com/docs/security for more information. | ||
Example: | ||
.. code-block:: python | ||
chain = create_neptune_opencypher_qa_chain( | ||
llm=llm, | ||
graph=graph | ||
) | ||
response = chain.invoke({"query": "your_query_here"}) | ||
""" | ||
|
||
if allow_dangerous_requests is not True: | ||
raise ValueError( | ||
"In order to use this chain, you must acknowledge that it can make " | ||
"dangerous requests by setting `allow_dangerous_requests` to `True`. " | ||
"You must narrowly scope the permissions of the database connection " | ||
"to only include necessary permissions. Failure to do so may result " | ||
"in data corruption or loss or reading sensitive data if such data is " | ||
"present in the database. " | ||
"Only use this chain if you understand the risks and have taken the " | ||
"necessary precautions. " | ||
"See https://python.langchain.com/docs/security for more information." | ||
) | ||
|
||
qa_chain = qa_prompt | llm | ||
|
||
_cypher_prompt = cypher_prompt or get_prompt(llm) | ||
cypher_generation_chain = _cypher_prompt | llm | ||
|
||
def execute_graph_query(cypher_query: str) -> dict: | ||
return graph.query(cypher_query) | ||
|
||
def get_cypher_inputs(inputs: dict) -> dict: | ||
return { | ||
"question": inputs["query"], | ||
"schema": graph.get_schema, | ||
"extra_instructions": extra_instructions or "", | ||
} | ||
|
||
def get_qa_inputs(inputs: dict) -> dict: | ||
return { | ||
"question": inputs["query"], | ||
"context": inputs["context"], | ||
} | ||
|
||
def format_response(inputs: dict) -> dict: | ||
intermediate_steps = [{"query": inputs["cypher"]}] | ||
|
||
if return_direct: | ||
final_response = {"result": inputs["context"]} | ||
else: | ||
final_response = {"result": inputs["qa_result"]} | ||
intermediate_steps.append({"context": inputs["context"]}) | ||
|
||
if return_intermediate_steps: | ||
final_response[INTERMEDIATE_STEPS_KEY] = intermediate_steps | ||
|
||
return final_response | ||
|
||
chain_result = ( | ||
RunnablePassthrough.assign(cypher_generation_inputs=get_cypher_inputs) | ||
| { | ||
"query": lambda x: x["query"], | ||
"cypher": (lambda x: x["cypher_generation_inputs"]) | ||
| cypher_generation_chain | ||
| (lambda x: extract_cypher(x.content)) | ||
| trim_query, | ||
} | ||
| RunnablePassthrough.assign(context=lambda x: execute_graph_query(x["cypher"])) | ||
| RunnablePassthrough.assign(qa_result=(lambda x: get_qa_inputs(x)) | qa_chain) | ||
| format_response | ||
) | ||
|
||
return chain_result |
152 changes: 152 additions & 0 deletions
152
libs/aws/langchain_aws/chains/graph_qa/neptune_sparql.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
""" | ||
Question answering over an RDF or OWL graph using SPARQL. | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Any, Optional | ||
|
||
from langchain_core.language_models import BaseLanguageModel | ||
from langchain_core.prompts.base import BasePromptTemplate | ||
from langchain_core.prompts.prompt import PromptTemplate | ||
from langchain_core.runnables import Runnable, RunnablePassthrough | ||
|
||
from langchain_aws.graphs import NeptuneRdfGraph | ||
|
||
from .prompts import ( | ||
NEPTUNE_SPARQL_GENERATION_PROMPT, | ||
NEPTUNE_SPARQL_GENERATION_TEMPLATE, | ||
SPARQL_QA_PROMPT, | ||
) | ||
|
||
INTERMEDIATE_STEPS_KEY = "intermediate_steps" | ||
|
||
|
||
def extract_sparql(query: str) -> str: | ||
"""Extract SPARQL code from a text. | ||
Args: | ||
query: Text to extract SPARQL code from. | ||
Returns: | ||
SPARQL code extracted from the text. | ||
""" | ||
query = query.strip() | ||
querytoks = query.split("```") | ||
if len(querytoks) == 3: | ||
query = querytoks[1] | ||
|
||
if query.startswith("sparql"): | ||
query = query[6:] | ||
elif query.startswith("<sparql>") and query.endswith("</sparql>"): | ||
query = query[8:-9] | ||
return query | ||
|
||
|
||
def get_prompt(examples: str) -> BasePromptTemplate: | ||
"""Selects the final prompt.""" | ||
template_to_use = NEPTUNE_SPARQL_GENERATION_TEMPLATE | ||
if examples: | ||
template_to_use = template_to_use.replace("Examples:", "Examples: " + examples) | ||
return PromptTemplate( | ||
input_variables=["schema", "prompt"], template=template_to_use | ||
) | ||
return NEPTUNE_SPARQL_GENERATION_PROMPT | ||
|
||
|
||
def create_neptune_sparql_qa_chain( | ||
llm: BaseLanguageModel, | ||
graph: NeptuneRdfGraph, | ||
qa_prompt: BasePromptTemplate = SPARQL_QA_PROMPT, | ||
sparql_prompt: Optional[BasePromptTemplate] = None, | ||
return_intermediate_steps: bool = False, | ||
return_direct: bool = False, | ||
extra_instructions: Optional[str] = None, | ||
allow_dangerous_requests: bool = False, | ||
examples: Optional[str] = None, | ||
) -> Runnable[dict[str, Any], dict]: | ||
"""Chain for question-answering against a Neptune graph | ||
by generating SPARQL statements. | ||
*Security note*: Make sure that the database connection uses credentials | ||
that are narrowly-scoped to only include necessary permissions. | ||
Failure to do so may result in data corruption or loss, since the calling | ||
code may attempt commands that would result in deletion, mutation | ||
of data if appropriately prompted or reading sensitive data if such | ||
data is present in the database. | ||
The best way to guard against such negative outcomes is to (as appropriate) | ||
limit the permissions granted to the credentials used with this tool. | ||
See https://python.langchain.com/docs/security for more information. | ||
Example: | ||
.. code-block:: python | ||
chain = create_neptune_sparql_qa_chain( | ||
llm=llm, | ||
graph=graph | ||
) | ||
response = chain.invoke({"query": "your_query_here"}) | ||
""" | ||
if allow_dangerous_requests is not True: | ||
raise ValueError( | ||
"In order to use this chain, you must acknowledge that it can make " | ||
"dangerous requests by setting `allow_dangerous_requests` to `True`. " | ||
"You must narrowly scope the permissions of the database connection " | ||
"to only include necessary permissions. Failure to do so may result " | ||
"in data corruption or loss or reading sensitive data if such data is " | ||
"present in the database. " | ||
"Only use this chain if you understand the risks and have taken the " | ||
"necessary precautions. " | ||
"See https://python.langchain.com/docs/security for more information." | ||
) | ||
|
||
qa_chain = qa_prompt | llm | ||
|
||
_sparql_prompt = sparql_prompt or get_prompt(examples) | ||
sparql_generation_chain = _sparql_prompt | llm | ||
|
||
def execute_graph_query(sparql_query: str) -> dict: | ||
return graph.query(sparql_query) | ||
|
||
def get_sparql_inputs(inputs: dict) -> dict: | ||
return { | ||
"prompt": inputs["query"], | ||
"schema": graph.get_schema, | ||
"extra_instructions": extra_instructions or "", | ||
} | ||
|
||
def get_qa_inputs(inputs: dict) -> dict: | ||
return { | ||
"prompt": inputs["query"], | ||
"context": inputs["context"], | ||
} | ||
|
||
def format_response(inputs: dict) -> dict: | ||
intermediate_steps = [{"query": inputs["sparql"]}] | ||
|
||
if return_direct: | ||
final_response = {"result": inputs["context"]} | ||
else: | ||
final_response = {"result": inputs["qa_result"]} | ||
intermediate_steps.append({"context": inputs["context"]}) | ||
|
||
if return_intermediate_steps: | ||
final_response[INTERMEDIATE_STEPS_KEY] = intermediate_steps | ||
|
||
return final_response | ||
|
||
chain_result = ( | ||
RunnablePassthrough.assign(sparql_generation_inputs=get_sparql_inputs) | ||
| { | ||
"query": lambda x: x["query"], | ||
"sparql": (lambda x: x["sparql_generation_inputs"]) | ||
| sparql_generation_chain | ||
| (lambda x: extract_sparql(x.content)), | ||
} | ||
| RunnablePassthrough.assign(context=lambda x: execute_graph_query(x["sparql"])) | ||
| RunnablePassthrough.assign(qa_result=(lambda x: get_qa_inputs(x)) | qa_chain) | ||
| format_response | ||
) | ||
|
||
return chain_result |
Oops, something went wrong.