-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Mustafa Kerem Kurban
committed
Oct 3, 2024
1 parent
36c6e9f
commit ed2ceea
Showing
4 changed files
with
339 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
from typing import Any, AsyncIterator | ||
from pydantic import BaseModel, Field, ValidationError | ||
from langgraph.graph import StateGraph, START, END | ||
from langgraph.errors import NodeInterrupt | ||
from neuroagent.tools.bluenaas_tool import BlueNaaSTool, InputBlueNaaS, BlueNaaSOutput | ||
# from neuroagent.tools.get_me_model_tool import GetMEModelTool | ||
# from neuroagent.tools.electrophys_tool import ElectrophysFeatureTool | ||
# from neuroagent.app.dependencies import get_settings, get_kg_token, get_httpx_client | ||
from neuroagent.agents import BaseAgent | ||
|
||
class BluenaasSimAgent(BaseAgent): | ||
"""Agent for running BlueNaaS simulations with iterative configuration improvement.""" | ||
|
||
async def arun(self, query: str) -> Any: | ||
"""Run the agent against a query.""" | ||
state_graph = StateGraph() | ||
state_graph.add_node("parse_input", self.parse_input) | ||
state_graph.add_node("validate_config", self.validate_config) | ||
state_graph.add_node("prompt_user_for_missing_fields", self.prompt_user_for_missing_fields) | ||
state_graph.add_node("finalize_config", self.finalize_config) | ||
state_graph.add_node("run_simulation", self.run_simulation) | ||
state_graph.add_node("process_results", self.process_results) | ||
state_graph.add_node("handle_interruption", self.handle_interruption) | ||
|
||
state_graph.add_edge("parse_input", "validate_config") | ||
state_graph.add_edge("validate_config", "prompt_user_for_missing_fields", condition=lambda x: not x["valid"]) | ||
state_graph.add_edge("validate_config", "finalize_config", condition=lambda x: x["valid"]) | ||
state_graph.add_edge("prompt_user_for_missing_fields", "validate_config") | ||
state_graph.add_edge("finalize_config", "run_simulation") | ||
state_graph.add_edge("run_simulation", "process_results") | ||
state_graph.add_edge("finalize_config", "handle_interruption", condition=lambda x: x.get("interrupted", False)) | ||
|
||
initial_state = {"query": query} | ||
result = await state_graph.run(initial_state) | ||
return result | ||
|
||
async def parse_input(self, state: dict) -> dict: | ||
"""Parse user input to create initial simulation configuration.""" | ||
# Implement parsing logic here | ||
parsed_config = { | ||
"me_model_id": None, # Placeholder, should be parsed from user input | ||
"currentInjection": { | ||
"injectTo": "soma", | ||
"stimulus": { | ||
"stimulusType": "current_clamp", | ||
"stimulusProtocol": "fire_pattern", | ||
"amplitudes": [0.05] | ||
} | ||
}, | ||
"recordFrom": [ | ||
{"section": "soma", "offset": 0.5} | ||
], | ||
"conditions": { | ||
"celsius": 34.0, | ||
"vinit": -70.0, | ||
"hypamp": 0.1, | ||
"max_time": 1000.0, | ||
"time_step": 0.025, | ||
"seed": 42 | ||
}, | ||
"simulationType": "single-neuron-simulation", | ||
"simulationDuration": 1000 | ||
} | ||
state["config"] = parsed_config | ||
return state | ||
|
||
async def validate_config(self, state: dict) -> dict: | ||
"""Validate the simulation configuration using Pydantic.""" | ||
try: | ||
config = InputBlueNaaS(**state["config"]) | ||
state["valid"] = True | ||
except ValidationError as e: | ||
state["valid"] = False | ||
state["errors"] = e.errors() | ||
return state | ||
|
||
async def prompt_user_for_missing_fields(self, state: dict) -> dict: | ||
"""Prompt the user for missing fields in the configuration.""" | ||
# Implement logic to prompt user for missing fields | ||
missing_fields = [error["loc"][0] for error in state["errors"]] | ||
user_response = await self.metadata["llm"].ainvoke({ | ||
"messages": [ | ||
{"role": "system", "content": f"The following fields are missing or invalid: {missing_fields}"}, | ||
{"role": "user", "content": "Please provide the missing values."} | ||
] | ||
}) | ||
# Update state with user-provided values | ||
state["config"].update(user_response) | ||
return state | ||
|
||
async def finalize_config(self, state: dict) -> dict: | ||
"""Finalize the simulation configuration and prompt user for approval.""" | ||
user_response = await self.metadata["llm"].ainvoke({ | ||
"messages": [ | ||
{"role": "system", "content": "Here is the final simulation configuration:"}, | ||
{"role": "system", "content": str(state["config"])}, | ||
{"role": "user", "content": "Do you approve this configuration? (yes/no)"} | ||
] | ||
}) | ||
if user_response.lower() != "yes": | ||
state["interrupted"] = True | ||
raise NodeInterrupt("User did not approve the configuration.") | ||
return state | ||
|
||
async def run_simulation(self, state: dict) -> dict: | ||
"""Run the simulation using the BlueNaaSTool.""" | ||
tool = BlueNaaSTool(metadata=self.metadata) | ||
result = await tool._arun(**state["config"]) | ||
state["simulation_result"] = result | ||
return state | ||
|
||
async def process_results(self, state: dict) -> dict: | ||
"""Process the simulation results and run electrophysiological analysis.""" | ||
# Implement logic to process simulation results and run electrophysiological analysis | ||
return state | ||
|
||
async def handle_interruption(self, state: dict) -> dict: | ||
"""Handle interruptions in the state graph.""" | ||
# Implement logic to handle interruptions, such as user disapproval | ||
await self.metadata["llm"].ainvoke({ | ||
"messages": [ | ||
{"role": "system", "content": "The simulation configuration was not approved by the user."}, | ||
{"role": "user", "content": "Please provide the necessary changes to proceed."} | ||
] | ||
}) | ||
return state |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
## %% [markdown] | ||
# # Agent Supervisor | ||
# | ||
# The [previous example](../multi-agent-collaboration) routed messages automatically based on the output of the initial researcher agent. | ||
# | ||
# We can also choose to use an LLM to orchestrate the different agents. | ||
# | ||
# Below, we will create an agent group, with an agent supervisor to help delegate tasks. | ||
# | ||
# ![diagram](attachment:8ee0a8ce-f0a8-4019-b5bf-b20933e40956.png) | ||
# | ||
# To simplify the code in each agent node, we will use the AgentExecutor class from LangChain. This and other "advanced agent" notebooks are designed to show how you can implement certain design patterns in LangGraph. If the pattern suits your needs, we recommend combining it with some of the other fundamental patterns described elsewhere in the docs for best performance. | ||
# | ||
# ## Setup | ||
# | ||
# First, let's install required packages and set our API keys | ||
|
||
## %% | ||
|
||
## %% | ||
import getpass | ||
import os | ||
|
||
|
||
def _set_if_undefined(var: str): | ||
if not os.environ.get(var): | ||
os.environ[var] = getpass.getpass(f"Please provide your {var}") | ||
|
||
|
||
_set_if_undefined("OPENAI_API_KEY") | ||
_set_if_undefined("TAVILY_API_KEY") | ||
|
||
## %% [markdown] | ||
# <div class="admonition tip"> | ||
# <p class="admonition-title">Set up <a href="https://smith.langchain.com">LangSmith</a> for LangGraph development</p> | ||
# <p style="padding-top: 5px;"> | ||
# Sign up for LangSmith to quickly spot issues and improve the performance of your LangGraph projects. LangSmith lets you use trace data to debug, test, and monitor your LLM apps built with LangGraph — read more about how to get started <a href="https://docs.smith.langchain.com">here</a>. | ||
# </p> | ||
# </div> | ||
|
||
## %% [markdown] | ||
# ## Create tools | ||
# | ||
# For this example, you will make an agent to do web research with a search engine, and one agent to create plots. Define the tools they'll use below: | ||
|
||
## %% | ||
from typing import Annotated | ||
|
||
from langchain_community.tools.tavily_search import TavilySearchResults | ||
from langchain_experimental.tools import PythonREPLTool | ||
|
||
tavily_tool = TavilySearchResults(max_results=5) | ||
|
||
# This executes code locally, which can be unsafe | ||
python_repl_tool = PythonREPLTool() | ||
|
||
## %% [markdown] | ||
# ## Helper Utilities | ||
|
||
## %% [markdown] | ||
# Define a helper function that we will use to create the nodes in the graph - it takes care of converting the agent response to a human message. This is important because that is how we will add it the global state of the graph | ||
|
||
## %% | ||
from langchain_core.messages import HumanMessage | ||
|
||
|
||
def agent_node(state, agent, name): | ||
result = agent.invoke(state) | ||
return { | ||
"messages": [HumanMessage(content=result["messages"][-1].content, name=name)] | ||
} | ||
|
||
## %% [markdown] | ||
# ### Create Agent Supervisor | ||
# | ||
# It will use function calling to choose the next worker node OR finish processing. | ||
|
||
## %% | ||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | ||
from langchain_openai import ChatOpenAI | ||
from pydantic import BaseModel | ||
from typing import Literal | ||
|
||
members = ["Researcher", "Coder"] | ||
system_prompt = ( | ||
"You are a supervisor tasked with managing a conversation between the" | ||
" following workers: {members}. Given the following user request," | ||
" respond with the worker to act next. Each worker will perform a" | ||
" task and respond with their results and status. When finished," | ||
" respond with FINISH." | ||
) | ||
# Our team supervisor is an LLM node. It just picks the next agent to process | ||
# and decides when the work is completed | ||
options = ["FINISH"] + members | ||
|
||
|
||
class routeResponse(BaseModel): | ||
next: Literal[*options] | ||
|
||
|
||
prompt = ChatPromptTemplate.from_messages( | ||
[ | ||
("system", system_prompt), | ||
MessagesPlaceholder(variable_name="messages"), | ||
( | ||
"system", | ||
"Given the conversation above, who should act next?" | ||
" Or should we FINISH? Select one of: {options}", | ||
), | ||
] | ||
).partial(options=str(options), members=", ".join(members)) | ||
|
||
|
||
llm = ChatOpenAI(model="gpt-4o") | ||
|
||
|
||
def supervisor_agent(state): | ||
supervisor_chain = prompt | llm.with_structured_output(routeResponse) | ||
return supervisor_chain.invoke(state) | ||
|
||
## %% [markdown] | ||
# ## Construct Graph | ||
# | ||
# We're ready to start building the graph. Below, define the state and worker nodes using the function we just defined. | ||
|
||
## %% | ||
import functools | ||
import operator | ||
from typing import Sequence | ||
from typing_extensions import TypedDict | ||
|
||
from langchain_core.messages import BaseMessage | ||
|
||
from langgraph.graph import END, StateGraph, START | ||
from langgraph.prebuilt import create_react_agent | ||
|
||
|
||
# The agent state is the input to each node in the graph | ||
class AgentState(TypedDict): | ||
# The annotation tells the graph that new messages will always | ||
# be added to the current states | ||
messages: Annotated[Sequence[BaseMessage], operator.add] | ||
# The 'next' field indicates where to route to next | ||
next: str | ||
|
||
|
||
research_agent = create_react_agent(llm, tools=[tavily_tool]) | ||
research_node = functools.partial(agent_node, agent=research_agent, name="Researcher") | ||
|
||
# NOTE: THIS PERFORMS ARBITRARY CODE EXECUTION. PROCEED WITH CAUTION | ||
code_agent = create_react_agent(llm, tools=[python_repl_tool]) | ||
code_node = functools.partial(agent_node, agent=code_agent, name="Coder") | ||
|
||
workflow = StateGraph(AgentState) | ||
workflow.add_node("Researcher", research_node) | ||
workflow.add_node("Coder", code_node) | ||
workflow.add_node("supervisor", supervisor_agent) | ||
|
||
## %% [markdown] | ||
# Now connect all the edges in the graph. | ||
|
||
## %% | ||
for member in members: | ||
# We want our workers to ALWAYS "report back" to the supervisor when done | ||
workflow.add_edge(member, "supervisor") | ||
# The supervisor populates the "next" field in the graph state | ||
# which routes to a node or finishes | ||
conditional_map = {k: k for k in members} | ||
conditional_map["FINISH"] = END | ||
workflow.add_conditional_edges("supervisor", lambda x: x["next"], conditional_map) | ||
# Finally, add entrypoint | ||
workflow.add_edge(START, "supervisor") | ||
|
||
graph = workflow.compile() | ||
|
||
## %% [markdown] | ||
# ## Invoke the team | ||
# | ||
# With the graph created, we can now invoke it and see how it performs! | ||
|
||
## %% | ||
for s in graph.stream( | ||
{ | ||
"messages": [ | ||
HumanMessage(content="Code hello world and print it to the terminal") | ||
] | ||
} | ||
): | ||
if "__end__" not in s: | ||
print(s) | ||
print("----") | ||
|
||
## %% | ||
for s in graph.stream( | ||
{"messages": [HumanMessage(content="Write a brief research report on pikas.")]}, | ||
{"recursion_limit": 100}, | ||
): | ||
if "__end__" not in s: | ||
print(s) | ||
print("----") | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters