Skip to content

Commit

Permalink
hierarchical chat multi agent
Browse files Browse the repository at this point in the history
  • Loading branch information
Mustafa Kerem Kurban committed Oct 3, 2024
1 parent ce6a0c7 commit 4ff3e1e
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 238 deletions.
3 changes: 1 addition & 2 deletions src/neuroagent/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
from neuroagent.agents.base_agent import AgentOutput, AgentStep, BaseAgent
from neuroagent.agents.simple_agent import SimpleAgent
from neuroagent.agents.simple_chat_agent import SimpleChatAgent
from neuroagent.agents.bluenaas_sim_agent import BluenaasSimAgent
# from neuroagent.agents.bluenaas_sim_agent import BluenaasSimAgent

__all__ = [
"AgentOutput",
"AgentStep",
"BaseAgent",
"SimpleChatAgent",
"SimpleAgent",
"BluenaasSimAgent"
]
229 changes: 115 additions & 114 deletions src/neuroagent/agents/bluenaas_sim_agent.py
Original file line number Diff line number Diff line change
@@ -1,125 +1,126 @@
from typing import Any, AsyncIterator
from pydantic import BaseModel, Field, ValidationError
from langgraph.graph import StateGraph, START, END
from langgraph.errors import NodeInterrupt
from neuroagent.tools.bluenaas_tool import BlueNaaSTool, InputBlueNaaS, BlueNaaSOutput
from neuroagent.tools.get_me_model_tool import GetMEModelTool
from neuroagent.tools.electrophys_tool import ElectrophysFeatureTool
from neuroagent.app.dependencies import get_settings, get_kg_token, get_httpx_client
# from typing import Any, AsyncIterator
# from pydantic import BaseModel, Field, ValidationError
# from langgraph.graph import StateGraph, START, END
# from langgraph.errors import NodeInterrupt
# from neuroagent.tools.bluenaas_tool import BlueNaaSTool, InputBlueNaaS, BlueNaaSOutput
# # from neuroagent.tools.get_me_model_tool import GetMEModelTool
# # from neuroagent.tools.electrophys_tool import ElectrophysFeatureTool
# # from neuroagent.app.dependencies import get_settings, get_kg_token, get_httpx_client
# from neuroagent.agents import BaseAgent

class BluenaasSimAgent(BaseAgent):
"""Agent for running BlueNaaS simulations with iterative configuration improvement."""
# class BluenaasSimAgent(BaseAgent):
# """Agent for running BlueNaaS simulations with iterative configuration improvement."""

async def arun(self, query: str) -> Any:
"""Run the agent against a query."""
state_graph = StateGraph()
state_graph.add_node("parse_input", self.parse_input)
state_graph.add_node("validate_config", self.validate_config)
state_graph.add_node("prompt_user_for_missing_fields", self.prompt_user_for_missing_fields)
state_graph.add_node("finalize_config", self.finalize_config)
state_graph.add_node("run_simulation", self.run_simulation)
state_graph.add_node("process_results", self.process_results)
state_graph.add_node("handle_interruption", self.handle_interruption)
# async def arun(self, query: str) -> Any:
# """Run the agent against a query."""
# state_graph = StateGraph()
# state_graph.add_node("parse_input", self.parse_input)
# state_graph.add_node("validate_config", self.validate_config)
# state_graph.add_node("prompt_user_for_missing_fields", self.prompt_user_for_missing_fields)
# state_graph.add_node("finalize_config", self.finalize_config)
# state_graph.add_node("run_simulation", self.run_simulation)
# state_graph.add_node("process_results", self.process_results)
# state_graph.add_node("handle_interruption", self.handle_interruption)

state_graph.add_edge("parse_input", "validate_config")
state_graph.add_edge("validate_config", "prompt_user_for_missing_fields", condition=lambda x: not x["valid"])
state_graph.add_edge("validate_config", "finalize_config", condition=lambda x: x["valid"])
state_graph.add_edge("prompt_user_for_missing_fields", "validate_config")
state_graph.add_edge("finalize_config", "run_simulation")
state_graph.add_edge("run_simulation", "process_results")
state_graph.add_edge("finalize_config", "handle_interruption", condition=lambda x: x.get("interrupted", False))
# state_graph.add_edge("parse_input", "validate_config")
# state_graph.add_edge("validate_config", "prompt_user_for_missing_fields", condition=lambda x: not x["valid"])
# state_graph.add_edge("validate_config", "finalize_config", condition=lambda x: x["valid"])
# state_graph.add_edge("prompt_user_for_missing_fields", "validate_config")
# state_graph.add_edge("finalize_config", "run_simulation")
# state_graph.add_edge("run_simulation", "process_results")
# state_graph.add_edge("finalize_config", "handle_interruption", condition=lambda x: x.get("interrupted", False))

initial_state = {"query": query}
result = await state_graph.run(initial_state)
return result
# initial_state = {"query": query}
# result = await state_graph.run(initial_state)
# return result

async def parse_input(self, state: dict) -> dict:
"""Parse user input to create initial simulation configuration."""
# Implement parsing logic here
parsed_config = {
"me_model_id": None, # Placeholder, should be parsed from user input
"currentInjection": {
"injectTo": "soma",
"stimulus": {
"stimulusType": "current_clamp",
"stimulusProtocol": "fire_pattern",
"amplitudes": [0.05]
}
},
"recordFrom": [
{"section": "soma", "offset": 0.5}
],
"conditions": {
"celsius": 34.0,
"vinit": -70.0,
"hypamp": 0.1,
"max_time": 1000.0,
"time_step": 0.025,
"seed": 42
},
"simulationType": "single-neuron-simulation",
"simulationDuration": 1000
}
state["config"] = parsed_config
return state
# async def parse_input(self, state: dict) -> dict:
# """Parse user input to create initial simulation configuration."""
# # Implement parsing logic here
# parsed_config = {
# "me_model_id": None, # Placeholder, should be parsed from user input
# "currentInjection": {
# "injectTo": "soma",
# "stimulus": {
# "stimulusType": "current_clamp",
# "stimulusProtocol": "fire_pattern",
# "amplitudes": [0.05]
# }
# },
# "recordFrom": [
# {"section": "soma", "offset": 0.5}
# ],
# "conditions": {
# "celsius": 34.0,
# "vinit": -70.0,
# "hypamp": 0.1,
# "max_time": 1000.0,
# "time_step": 0.025,
# "seed": 42
# },
# "simulationType": "single-neuron-simulation",
# "simulationDuration": 1000
# }
# state["config"] = parsed_config
# return state

async def validate_config(self, state: dict) -> dict:
"""Validate the simulation configuration using Pydantic."""
try:
config = InputBlueNaaS(**state["config"])
state["valid"] = True
except ValidationError as e:
state["valid"] = False
state["errors"] = e.errors()
return state
# async def validate_config(self, state: dict) -> dict:
# """Validate the simulation configuration using Pydantic."""
# try:
# config = InputBlueNaaS(**state["config"])
# state["valid"] = True
# except ValidationError as e:
# state["valid"] = False
# state["errors"] = e.errors()
# return state

async def prompt_user_for_missing_fields(self, state: dict) -> dict:
"""Prompt the user for missing fields in the configuration."""
# Implement logic to prompt user for missing fields
missing_fields = [error["loc"][0] for error in state["errors"]]
user_response = await self.metadata["llm"].ainvoke({
"messages": [
{"role": "system", "content": f"The following fields are missing or invalid: {missing_fields}"},
{"role": "user", "content": "Please provide the missing values."}
]
})
# Update state with user-provided values
state["config"].update(user_response)
return state
# async def prompt_user_for_missing_fields(self, state: dict) -> dict:
# """Prompt the user for missing fields in the configuration."""
# # Implement logic to prompt user for missing fields
# missing_fields = [error["loc"][0] for error in state["errors"]]
# user_response = await self.metadata["llm"].ainvoke({
# "messages": [
# {"role": "system", "content": f"The following fields are missing or invalid: {missing_fields}"},
# {"role": "user", "content": "Please provide the missing values."}
# ]
# })
# # Update state with user-provided values
# state["config"].update(user_response)
# return state

async def finalize_config(self, state: dict) -> dict:
"""Finalize the simulation configuration and prompt user for approval."""
user_response = await self.metadata["llm"].ainvoke({
"messages": [
{"role": "system", "content": "Here is the final simulation configuration:"},
{"role": "system", "content": str(state["config"])},
{"role": "user", "content": "Do you approve this configuration? (yes/no)"}
]
})
if user_response.lower() != "yes":
state["interrupted"] = True
raise NodeInterrupt("User did not approve the configuration.")
return state
# async def finalize_config(self, state: dict) -> dict:
# """Finalize the simulation configuration and prompt user for approval."""
# user_response = await self.metadata["llm"].ainvoke({
# "messages": [
# {"role": "system", "content": "Here is the final simulation configuration:"},
# {"role": "system", "content": str(state["config"])},
# {"role": "user", "content": "Do you approve this configuration? (yes/no)"}
# ]
# })
# if user_response.lower() != "yes":
# state["interrupted"] = True
# raise NodeInterrupt("User did not approve the configuration.")
# return state

async def run_simulation(self, state: dict) -> dict:
"""Run the simulation using the BlueNaaSTool."""
tool = BlueNaaSTool(metadata=self.metadata)
result = await tool._arun(**state["config"])
state["simulation_result"] = result
return state
# async def run_simulation(self, state: dict) -> dict:
# """Run the simulation using the BlueNaaSTool."""
# tool = BlueNaaSTool(metadata=self.metadata)
# result = await tool._arun(**state["config"])
# state["simulation_result"] = result
# return state

async def process_results(self, state: dict) -> dict:
"""Process the simulation results and run electrophysiological analysis."""
# Implement logic to process simulation results and run electrophysiological analysis
return state
# async def process_results(self, state: dict) -> dict:
# """Process the simulation results and run electrophysiological analysis."""
# # Implement logic to process simulation results and run electrophysiological analysis
# return state

async def handle_interruption(self, state: dict) -> dict:
"""Handle interruptions in the state graph."""
# Implement logic to handle interruptions, such as user disapproval
await self.metadata["llm"].ainvoke({
"messages": [
{"role": "system", "content": "The simulation configuration was not approved by the user."},
{"role": "user", "content": "Please provide the necessary changes to proceed."}
]
})
return state
# async def handle_interruption(self, state: dict) -> dict:
# """Handle interruptions in the state graph."""
# # Implement logic to handle interruptions, such as user disapproval
# await self.metadata["llm"].ainvoke({
# "messages": [
# {"role": "system", "content": "The simulation configuration was not approved by the user."},
# {"role": "user", "content": "Please provide the necessary changes to proceed."}
# ]
# })
# return state
3 changes: 1 addition & 2 deletions src/neuroagent/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ class SettingsTools(BaseModel):
trace: SettingsTrace = SettingsTrace()
kg_morpho_features: SettingsKGMorpho = SettingsKGMorpho()
me_model: SettingsGetMEModel = SettingsGetMEModel()
blue_naas: SettingsBlueNaaS = SettingsBlueNaaS()
bluenaas: SettingsBlueNaaS = SettingsBlueNaaS()

model_config = ConfigDict(frozen=True)

Expand Down Expand Up @@ -220,7 +220,6 @@ class Settings(BaseSettings):
logging: SettingsLogging = SettingsLogging() # has no required
keycloak: SettingsKeycloak = SettingsKeycloak() # has no required
misc: SettingsMisc = SettingsMisc() # has no required
# langsmith: SettingsLangsmith = SettingsLangsmith()


model_config = SettingsConfigDict(
Expand Down
Loading

0 comments on commit 4ff3e1e

Please sign in to comment.