Skip to content

Commit

Permalink
track some files
Browse files Browse the repository at this point in the history
  • Loading branch information
Mustafa Kerem Kurban committed Oct 3, 2024
1 parent 36c6e9f commit ed2ceea
Show file tree
Hide file tree
Showing 4 changed files with 339 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Add get morphoelectric (me) model tool
- Filter me model by username
- Add bluenaas tool
- Add hierarcichal multi agent for bluenass

## [0.1.1] - 26.09.2024

Expand Down
126 changes: 126 additions & 0 deletions src/neuroagent/agents/bluenaas_sim_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from typing import Any, AsyncIterator
from pydantic import BaseModel, Field, ValidationError
from langgraph.graph import StateGraph, START, END
from langgraph.errors import NodeInterrupt
from neuroagent.tools.bluenaas_tool import BlueNaaSTool, InputBlueNaaS, BlueNaaSOutput
# from neuroagent.tools.get_me_model_tool import GetMEModelTool
# from neuroagent.tools.electrophys_tool import ElectrophysFeatureTool
# from neuroagent.app.dependencies import get_settings, get_kg_token, get_httpx_client
from neuroagent.agents import BaseAgent

class BluenaasSimAgent(BaseAgent):
"""Agent for running BlueNaaS simulations with iterative configuration improvement."""

async def arun(self, query: str) -> Any:
"""Run the agent against a query."""
state_graph = StateGraph()
state_graph.add_node("parse_input", self.parse_input)
state_graph.add_node("validate_config", self.validate_config)
state_graph.add_node("prompt_user_for_missing_fields", self.prompt_user_for_missing_fields)
state_graph.add_node("finalize_config", self.finalize_config)
state_graph.add_node("run_simulation", self.run_simulation)
state_graph.add_node("process_results", self.process_results)
state_graph.add_node("handle_interruption", self.handle_interruption)

state_graph.add_edge("parse_input", "validate_config")
state_graph.add_edge("validate_config", "prompt_user_for_missing_fields", condition=lambda x: not x["valid"])
state_graph.add_edge("validate_config", "finalize_config", condition=lambda x: x["valid"])
state_graph.add_edge("prompt_user_for_missing_fields", "validate_config")
state_graph.add_edge("finalize_config", "run_simulation")
state_graph.add_edge("run_simulation", "process_results")
state_graph.add_edge("finalize_config", "handle_interruption", condition=lambda x: x.get("interrupted", False))

initial_state = {"query": query}
result = await state_graph.run(initial_state)
return result

async def parse_input(self, state: dict) -> dict:
"""Parse user input to create initial simulation configuration."""
# Implement parsing logic here
parsed_config = {
"me_model_id": None, # Placeholder, should be parsed from user input
"currentInjection": {
"injectTo": "soma",
"stimulus": {
"stimulusType": "current_clamp",
"stimulusProtocol": "fire_pattern",
"amplitudes": [0.05]
}
},
"recordFrom": [
{"section": "soma", "offset": 0.5}
],
"conditions": {
"celsius": 34.0,
"vinit": -70.0,
"hypamp": 0.1,
"max_time": 1000.0,
"time_step": 0.025,
"seed": 42
},
"simulationType": "single-neuron-simulation",
"simulationDuration": 1000
}
state["config"] = parsed_config
return state

async def validate_config(self, state: dict) -> dict:
"""Validate the simulation configuration using Pydantic."""
try:
config = InputBlueNaaS(**state["config"])
state["valid"] = True
except ValidationError as e:
state["valid"] = False
state["errors"] = e.errors()
return state

async def prompt_user_for_missing_fields(self, state: dict) -> dict:
"""Prompt the user for missing fields in the configuration."""
# Implement logic to prompt user for missing fields
missing_fields = [error["loc"][0] for error in state["errors"]]
user_response = await self.metadata["llm"].ainvoke({
"messages": [
{"role": "system", "content": f"The following fields are missing or invalid: {missing_fields}"},
{"role": "user", "content": "Please provide the missing values."}
]
})
# Update state with user-provided values
state["config"].update(user_response)
return state

async def finalize_config(self, state: dict) -> dict:
"""Finalize the simulation configuration and prompt user for approval."""
user_response = await self.metadata["llm"].ainvoke({
"messages": [
{"role": "system", "content": "Here is the final simulation configuration:"},
{"role": "system", "content": str(state["config"])},
{"role": "user", "content": "Do you approve this configuration? (yes/no)"}
]
})
if user_response.lower() != "yes":
state["interrupted"] = True
raise NodeInterrupt("User did not approve the configuration.")
return state

async def run_simulation(self, state: dict) -> dict:
"""Run the simulation using the BlueNaaSTool."""
tool = BlueNaaSTool(metadata=self.metadata)
result = await tool._arun(**state["config"])
state["simulation_result"] = result
return state

async def process_results(self, state: dict) -> dict:
"""Process the simulation results and run electrophysiological analysis."""
# Implement logic to process simulation results and run electrophysiological analysis
return state

async def handle_interruption(self, state: dict) -> dict:
"""Handle interruptions in the state graph."""
# Implement logic to handle interruptions, such as user disapproval
await self.metadata["llm"].ainvoke({
"messages": [
{"role": "system", "content": "The simulation configuration was not approved by the user."},
{"role": "user", "content": "Please provide the necessary changes to proceed."}
]
})
return state
202 changes: 202 additions & 0 deletions src/neuroagent/multi_agents/agent_supervisor_ex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
## %% [markdown]
# # Agent Supervisor
#
# The [previous example](../multi-agent-collaboration) routed messages automatically based on the output of the initial researcher agent.
#
# We can also choose to use an LLM to orchestrate the different agents.
#
# Below, we will create an agent group, with an agent supervisor to help delegate tasks.
#
# ![diagram](attachment:8ee0a8ce-f0a8-4019-b5bf-b20933e40956.png)
#
# To simplify the code in each agent node, we will use the AgentExecutor class from LangChain. This and other "advanced agent" notebooks are designed to show how you can implement certain design patterns in LangGraph. If the pattern suits your needs, we recommend combining it with some of the other fundamental patterns described elsewhere in the docs for best performance.
#
# ## Setup
#
# First, let's install required packages and set our API keys

## %%

## %%
import getpass
import os


def _set_if_undefined(var: str):
if not os.environ.get(var):
os.environ[var] = getpass.getpass(f"Please provide your {var}")


_set_if_undefined("OPENAI_API_KEY")
_set_if_undefined("TAVILY_API_KEY")

## %% [markdown]
# <div class="admonition tip">
# <p class="admonition-title">Set up <a href="https://smith.langchain.com">LangSmith</a> for LangGraph development</p>
# <p style="padding-top: 5px;">
# Sign up for LangSmith to quickly spot issues and improve the performance of your LangGraph projects. LangSmith lets you use trace data to debug, test, and monitor your LLM apps built with LangGraph — read more about how to get started <a href="https://docs.smith.langchain.com">here</a>.
# </p>
# </div>

## %% [markdown]
# ## Create tools
#
# For this example, you will make an agent to do web research with a search engine, and one agent to create plots. Define the tools they'll use below:

## %%
from typing import Annotated

from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_experimental.tools import PythonREPLTool

tavily_tool = TavilySearchResults(max_results=5)

# This executes code locally, which can be unsafe
python_repl_tool = PythonREPLTool()

## %% [markdown]
# ## Helper Utilities

## %% [markdown]
# Define a helper function that we will use to create the nodes in the graph - it takes care of converting the agent response to a human message. This is important because that is how we will add it the global state of the graph

## %%
from langchain_core.messages import HumanMessage


def agent_node(state, agent, name):
result = agent.invoke(state)
return {
"messages": [HumanMessage(content=result["messages"][-1].content, name=name)]
}

## %% [markdown]
# ### Create Agent Supervisor
#
# It will use function calling to choose the next worker node OR finish processing.

## %%
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI
from pydantic import BaseModel
from typing import Literal

members = ["Researcher", "Coder"]
system_prompt = (
"You are a supervisor tasked with managing a conversation between the"
" following workers: {members}. Given the following user request,"
" respond with the worker to act next. Each worker will perform a"
" task and respond with their results and status. When finished,"
" respond with FINISH."
)
# Our team supervisor is an LLM node. It just picks the next agent to process
# and decides when the work is completed
options = ["FINISH"] + members


class routeResponse(BaseModel):
next: Literal[*options]


prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
MessagesPlaceholder(variable_name="messages"),
(
"system",
"Given the conversation above, who should act next?"
" Or should we FINISH? Select one of: {options}",
),
]
).partial(options=str(options), members=", ".join(members))


llm = ChatOpenAI(model="gpt-4o")


def supervisor_agent(state):
supervisor_chain = prompt | llm.with_structured_output(routeResponse)
return supervisor_chain.invoke(state)

## %% [markdown]
# ## Construct Graph
#
# We're ready to start building the graph. Below, define the state and worker nodes using the function we just defined.

## %%
import functools
import operator
from typing import Sequence
from typing_extensions import TypedDict

from langchain_core.messages import BaseMessage

from langgraph.graph import END, StateGraph, START
from langgraph.prebuilt import create_react_agent


# The agent state is the input to each node in the graph
class AgentState(TypedDict):
# The annotation tells the graph that new messages will always
# be added to the current states
messages: Annotated[Sequence[BaseMessage], operator.add]
# The 'next' field indicates where to route to next
next: str


research_agent = create_react_agent(llm, tools=[tavily_tool])
research_node = functools.partial(agent_node, agent=research_agent, name="Researcher")

# NOTE: THIS PERFORMS ARBITRARY CODE EXECUTION. PROCEED WITH CAUTION
code_agent = create_react_agent(llm, tools=[python_repl_tool])
code_node = functools.partial(agent_node, agent=code_agent, name="Coder")

workflow = StateGraph(AgentState)
workflow.add_node("Researcher", research_node)
workflow.add_node("Coder", code_node)
workflow.add_node("supervisor", supervisor_agent)

## %% [markdown]
# Now connect all the edges in the graph.

## %%
for member in members:
# We want our workers to ALWAYS "report back" to the supervisor when done
workflow.add_edge(member, "supervisor")
# The supervisor populates the "next" field in the graph state
# which routes to a node or finishes
conditional_map = {k: k for k in members}
conditional_map["FINISH"] = END
workflow.add_conditional_edges("supervisor", lambda x: x["next"], conditional_map)
# Finally, add entrypoint
workflow.add_edge(START, "supervisor")

graph = workflow.compile()

## %% [markdown]
# ## Invoke the team
#
# With the graph created, we can now invoke it and see how it performs!

## %%
for s in graph.stream(
{
"messages": [
HumanMessage(content="Code hello world and print it to the terminal")
]
}
):
if "__end__" not in s:
print(s)
print("----")

## %%
for s in graph.stream(
{"messages": [HumanMessage(content="Write a brief research report on pikas.")]},
{"recursion_limit": 100},
):
if "__end__" not in s:
print(s)
print("----")


8 changes: 8 additions & 0 deletions tests/app/test_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
get_traces_tool,
get_update_kg_hierarchy,
get_user_id,
run_single_cell_sim_tool
)
from neuroagent.tools import (
ElectrophysFeatureTool,
Expand Down Expand Up @@ -214,6 +215,9 @@ def test_get_agent(monkeypatch, patch_required_env):
me_model_tool = get_me_model_tool(
settings=settings, token=token, httpx_client=httpx_client
)
bluenaas_tool = run_single_cell_sim_tool(
settings=settings, token=token, httpx_client=httpx_client
)

agent = get_agent(
llm=language_model,
Expand All @@ -226,6 +230,7 @@ def test_get_agent(monkeypatch, patch_required_env):
traces_tool=traces_tool,
settings=settings,
me_model_tool=me_model_tool,
bluenaas_tool=bluenaas_tool
)

assert isinstance(agent, SimpleAgent)
Expand Down Expand Up @@ -274,8 +279,11 @@ async def test_get_chat_agent(monkeypatch, db_connection, patch_required_env):
morphology_feature_tool=morphology_feature_tool,
kg_morpho_feature_tool=kg_morpho_feature_tool,
electrophys_feature_tool=electrophys_feature_tool,
me_model_tool=get_me_model_tool,
bluenaas_tool=run_single_cell_sim_tool,
traces_tool=traces_tool,
memory=memory,
settings=settings
)

assert isinstance(agent, SimpleChatAgent)
Expand Down

0 comments on commit ed2ceea

Please sign in to comment.