From ed2ceea8a6629b508d165e06ff7702be6d84f9cd Mon Sep 17 00:00:00 2001 From: Mustafa Kerem Kurban Date: Thu, 3 Oct 2024 19:52:52 +0200 Subject: [PATCH] track some files --- CHANGELOG.md | 3 + src/neuroagent/agents/bluenaas_sim_agent.py | 126 +++++++++++ .../multi_agents/agent_supervisor_ex.py | 202 ++++++++++++++++++ tests/app/test_dependencies.py | 8 + 4 files changed, 339 insertions(+) create mode 100644 src/neuroagent/agents/bluenaas_sim_agent.py create mode 100644 src/neuroagent/multi_agents/agent_supervisor_ex.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 2cffd59..120674d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Add get morphoelectric (me) model tool +- Filter me model by username +- Add bluenaas tool +- Add hierarcichal multi agent for bluenass ## [0.1.1] - 26.09.2024 diff --git a/src/neuroagent/agents/bluenaas_sim_agent.py b/src/neuroagent/agents/bluenaas_sim_agent.py new file mode 100644 index 0000000..1d27749 --- /dev/null +++ b/src/neuroagent/agents/bluenaas_sim_agent.py @@ -0,0 +1,126 @@ +from typing import Any, AsyncIterator +from pydantic import BaseModel, Field, ValidationError +from langgraph.graph import StateGraph, START, END +from langgraph.errors import NodeInterrupt +from neuroagent.tools.bluenaas_tool import BlueNaaSTool, InputBlueNaaS, BlueNaaSOutput +# from neuroagent.tools.get_me_model_tool import GetMEModelTool +# from neuroagent.tools.electrophys_tool import ElectrophysFeatureTool +# from neuroagent.app.dependencies import get_settings, get_kg_token, get_httpx_client +from neuroagent.agents import BaseAgent + +class BluenaasSimAgent(BaseAgent): + """Agent for running BlueNaaS simulations with iterative configuration improvement.""" + + async def arun(self, query: str) -> Any: + """Run the agent against a query.""" + state_graph = StateGraph() + state_graph.add_node("parse_input", self.parse_input) + state_graph.add_node("validate_config", self.validate_config) + state_graph.add_node("prompt_user_for_missing_fields", self.prompt_user_for_missing_fields) + state_graph.add_node("finalize_config", self.finalize_config) + state_graph.add_node("run_simulation", self.run_simulation) + state_graph.add_node("process_results", self.process_results) + state_graph.add_node("handle_interruption", self.handle_interruption) + + state_graph.add_edge("parse_input", "validate_config") + state_graph.add_edge("validate_config", "prompt_user_for_missing_fields", condition=lambda x: not x["valid"]) + state_graph.add_edge("validate_config", "finalize_config", condition=lambda x: x["valid"]) + state_graph.add_edge("prompt_user_for_missing_fields", "validate_config") + state_graph.add_edge("finalize_config", "run_simulation") + state_graph.add_edge("run_simulation", "process_results") + state_graph.add_edge("finalize_config", "handle_interruption", condition=lambda x: x.get("interrupted", False)) + + initial_state = {"query": query} + result = await state_graph.run(initial_state) + return result + + async def parse_input(self, state: dict) -> dict: + """Parse user input to create initial simulation configuration.""" + # Implement parsing logic here + parsed_config = { + "me_model_id": None, # Placeholder, should be parsed from user input + "currentInjection": { + "injectTo": "soma", + "stimulus": { + "stimulusType": "current_clamp", + "stimulusProtocol": "fire_pattern", + "amplitudes": [0.05] + } + }, + "recordFrom": [ + {"section": "soma", "offset": 0.5} + ], + "conditions": { + "celsius": 34.0, + "vinit": -70.0, + "hypamp": 0.1, + "max_time": 1000.0, + "time_step": 0.025, + "seed": 42 + }, + "simulationType": "single-neuron-simulation", + "simulationDuration": 1000 + } + state["config"] = parsed_config + return state + + async def validate_config(self, state: dict) -> dict: + """Validate the simulation configuration using Pydantic.""" + try: + config = InputBlueNaaS(**state["config"]) + state["valid"] = True + except ValidationError as e: + state["valid"] = False + state["errors"] = e.errors() + return state + + async def prompt_user_for_missing_fields(self, state: dict) -> dict: + """Prompt the user for missing fields in the configuration.""" + # Implement logic to prompt user for missing fields + missing_fields = [error["loc"][0] for error in state["errors"]] + user_response = await self.metadata["llm"].ainvoke({ + "messages": [ + {"role": "system", "content": f"The following fields are missing or invalid: {missing_fields}"}, + {"role": "user", "content": "Please provide the missing values."} + ] + }) + # Update state with user-provided values + state["config"].update(user_response) + return state + + async def finalize_config(self, state: dict) -> dict: + """Finalize the simulation configuration and prompt user for approval.""" + user_response = await self.metadata["llm"].ainvoke({ + "messages": [ + {"role": "system", "content": "Here is the final simulation configuration:"}, + {"role": "system", "content": str(state["config"])}, + {"role": "user", "content": "Do you approve this configuration? (yes/no)"} + ] + }) + if user_response.lower() != "yes": + state["interrupted"] = True + raise NodeInterrupt("User did not approve the configuration.") + return state + + async def run_simulation(self, state: dict) -> dict: + """Run the simulation using the BlueNaaSTool.""" + tool = BlueNaaSTool(metadata=self.metadata) + result = await tool._arun(**state["config"]) + state["simulation_result"] = result + return state + + async def process_results(self, state: dict) -> dict: + """Process the simulation results and run electrophysiological analysis.""" + # Implement logic to process simulation results and run electrophysiological analysis + return state + + async def handle_interruption(self, state: dict) -> dict: + """Handle interruptions in the state graph.""" + # Implement logic to handle interruptions, such as user disapproval + await self.metadata["llm"].ainvoke({ + "messages": [ + {"role": "system", "content": "The simulation configuration was not approved by the user."}, + {"role": "user", "content": "Please provide the necessary changes to proceed."} + ] + }) + return state \ No newline at end of file diff --git a/src/neuroagent/multi_agents/agent_supervisor_ex.py b/src/neuroagent/multi_agents/agent_supervisor_ex.py new file mode 100644 index 0000000..aedf7b4 --- /dev/null +++ b/src/neuroagent/multi_agents/agent_supervisor_ex.py @@ -0,0 +1,202 @@ +## %% [markdown] +# # Agent Supervisor +# +# The [previous example](../multi-agent-collaboration) routed messages automatically based on the output of the initial researcher agent. +# +# We can also choose to use an LLM to orchestrate the different agents. +# +# Below, we will create an agent group, with an agent supervisor to help delegate tasks. +# +# ![diagram](attachment:8ee0a8ce-f0a8-4019-b5bf-b20933e40956.png) +# +# To simplify the code in each agent node, we will use the AgentExecutor class from LangChain. This and other "advanced agent" notebooks are designed to show how you can implement certain design patterns in LangGraph. If the pattern suits your needs, we recommend combining it with some of the other fundamental patterns described elsewhere in the docs for best performance. +# +# ## Setup +# +# First, let's install required packages and set our API keys + +## %% + +## %% +import getpass +import os + + +def _set_if_undefined(var: str): + if not os.environ.get(var): + os.environ[var] = getpass.getpass(f"Please provide your {var}") + + +_set_if_undefined("OPENAI_API_KEY") +_set_if_undefined("TAVILY_API_KEY") + +## %% [markdown] +#
+#

Set up LangSmith for LangGraph development

+#

+# Sign up for LangSmith to quickly spot issues and improve the performance of your LangGraph projects. LangSmith lets you use trace data to debug, test, and monitor your LLM apps built with LangGraph — read more about how to get started here. +#

+#
+ +## %% [markdown] +# ## Create tools +# +# For this example, you will make an agent to do web research with a search engine, and one agent to create plots. Define the tools they'll use below: + +## %% +from typing import Annotated + +from langchain_community.tools.tavily_search import TavilySearchResults +from langchain_experimental.tools import PythonREPLTool + +tavily_tool = TavilySearchResults(max_results=5) + +# This executes code locally, which can be unsafe +python_repl_tool = PythonREPLTool() + +## %% [markdown] +# ## Helper Utilities + +## %% [markdown] +# Define a helper function that we will use to create the nodes in the graph - it takes care of converting the agent response to a human message. This is important because that is how we will add it the global state of the graph + +## %% +from langchain_core.messages import HumanMessage + + +def agent_node(state, agent, name): + result = agent.invoke(state) + return { + "messages": [HumanMessage(content=result["messages"][-1].content, name=name)] + } + +## %% [markdown] +# ### Create Agent Supervisor +# +# It will use function calling to choose the next worker node OR finish processing. + +## %% +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_openai import ChatOpenAI +from pydantic import BaseModel +from typing import Literal + +members = ["Researcher", "Coder"] +system_prompt = ( + "You are a supervisor tasked with managing a conversation between the" + " following workers: {members}. Given the following user request," + " respond with the worker to act next. Each worker will perform a" + " task and respond with their results and status. When finished," + " respond with FINISH." +) +# Our team supervisor is an LLM node. It just picks the next agent to process +# and decides when the work is completed +options = ["FINISH"] + members + + +class routeResponse(BaseModel): + next: Literal[*options] + + +prompt = ChatPromptTemplate.from_messages( + [ + ("system", system_prompt), + MessagesPlaceholder(variable_name="messages"), + ( + "system", + "Given the conversation above, who should act next?" + " Or should we FINISH? Select one of: {options}", + ), + ] +).partial(options=str(options), members=", ".join(members)) + + +llm = ChatOpenAI(model="gpt-4o") + + +def supervisor_agent(state): + supervisor_chain = prompt | llm.with_structured_output(routeResponse) + return supervisor_chain.invoke(state) + +## %% [markdown] +# ## Construct Graph +# +# We're ready to start building the graph. Below, define the state and worker nodes using the function we just defined. + +## %% +import functools +import operator +from typing import Sequence +from typing_extensions import TypedDict + +from langchain_core.messages import BaseMessage + +from langgraph.graph import END, StateGraph, START +from langgraph.prebuilt import create_react_agent + + +# The agent state is the input to each node in the graph +class AgentState(TypedDict): + # The annotation tells the graph that new messages will always + # be added to the current states + messages: Annotated[Sequence[BaseMessage], operator.add] + # The 'next' field indicates where to route to next + next: str + + +research_agent = create_react_agent(llm, tools=[tavily_tool]) +research_node = functools.partial(agent_node, agent=research_agent, name="Researcher") + +# NOTE: THIS PERFORMS ARBITRARY CODE EXECUTION. PROCEED WITH CAUTION +code_agent = create_react_agent(llm, tools=[python_repl_tool]) +code_node = functools.partial(agent_node, agent=code_agent, name="Coder") + +workflow = StateGraph(AgentState) +workflow.add_node("Researcher", research_node) +workflow.add_node("Coder", code_node) +workflow.add_node("supervisor", supervisor_agent) + +## %% [markdown] +# Now connect all the edges in the graph. + +## %% +for member in members: + # We want our workers to ALWAYS "report back" to the supervisor when done + workflow.add_edge(member, "supervisor") +# The supervisor populates the "next" field in the graph state +# which routes to a node or finishes +conditional_map = {k: k for k in members} +conditional_map["FINISH"] = END +workflow.add_conditional_edges("supervisor", lambda x: x["next"], conditional_map) +# Finally, add entrypoint +workflow.add_edge(START, "supervisor") + +graph = workflow.compile() + +## %% [markdown] +# ## Invoke the team +# +# With the graph created, we can now invoke it and see how it performs! + +## %% +for s in graph.stream( + { + "messages": [ + HumanMessage(content="Code hello world and print it to the terminal") + ] + } +): + if "__end__" not in s: + print(s) + print("----") + +## %% +for s in graph.stream( + {"messages": [HumanMessage(content="Write a brief research report on pikas.")]}, + {"recursion_limit": 100}, +): + if "__end__" not in s: + print(s) + print("----") + + diff --git a/tests/app/test_dependencies.py b/tests/app/test_dependencies.py index cde38e3..a66c452 100644 --- a/tests/app/test_dependencies.py +++ b/tests/app/test_dependencies.py @@ -31,6 +31,7 @@ get_traces_tool, get_update_kg_hierarchy, get_user_id, + run_single_cell_sim_tool ) from neuroagent.tools import ( ElectrophysFeatureTool, @@ -214,6 +215,9 @@ def test_get_agent(monkeypatch, patch_required_env): me_model_tool = get_me_model_tool( settings=settings, token=token, httpx_client=httpx_client ) + bluenaas_tool = run_single_cell_sim_tool( + settings=settings, token=token, httpx_client=httpx_client + ) agent = get_agent( llm=language_model, @@ -226,6 +230,7 @@ def test_get_agent(monkeypatch, patch_required_env): traces_tool=traces_tool, settings=settings, me_model_tool=me_model_tool, + bluenaas_tool=bluenaas_tool ) assert isinstance(agent, SimpleAgent) @@ -274,8 +279,11 @@ async def test_get_chat_agent(monkeypatch, db_connection, patch_required_env): morphology_feature_tool=morphology_feature_tool, kg_morpho_feature_tool=kg_morpho_feature_tool, electrophys_feature_tool=electrophys_feature_tool, + me_model_tool=get_me_model_tool, + bluenaas_tool=run_single_cell_sim_tool, traces_tool=traces_tool, memory=memory, + settings=settings ) assert isinstance(agent, SimpleChatAgent)