Skip to content

Commit

Permalink
Added test for streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
VRSEN committed Mar 24, 2024
1 parent 9edfea6 commit 235a037
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 7 deletions.
3 changes: 3 additions & 0 deletions agency_swarm/agency/agency.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def get_completion_stream(self, message: str, event_handler: type(AgencyEventHan
if self.async_mode:
raise Exception("Streaming is not supported in async mode.")

if not inspect.isclass(event_handler):
raise Exception("Event handler must not be an instance.")

gen = self.main_thread.get_completion_stream(message=message, event_handler=event_handler,
message_files=message_files, recipient_agent=recipient_agent)

Expand Down
2 changes: 1 addition & 1 deletion tests/demos/streaming_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def run(self):
])

def test_demo(self):
self.agency.run_demo()
self.agency.demo_gradio()


if __name__ == '__main__':
Expand Down
74 changes: 68 additions & 6 deletions tests/test_agency.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,20 @@
import time
import unittest

from openai.types.beta.threads.runs import ToolCall

from agency_swarm.tools import CodeInterpreter, Retrieval

sys.path.insert(0, '../agency-swarm')
from agency_swarm.util import create_agent_template

from agency_swarm import set_openai_key, Agent, Agency
from agency_swarm import set_openai_key, Agent, Agency, AgencyEventHandler
from typing_extensions import override
from agency_swarm.tools import BaseTool


class AgencyTest(unittest.TestCase):
TestTool = None
agency = None
agent2 = None
agent1 = None
Expand Down Expand Up @@ -94,11 +99,29 @@ def save_thread_callback(agents_and_thread_ids):
shutil.copyfile("./data/schemas/" + file, "./test_agents/TestAgent2/schemas/" + file)
cls.num_schemas += 1

class TestTool(BaseTool):
"""
A simple test tool that returns "Test Successful" to demonstrate the functionality of a custom tool within the Agency Swarm framework.
"""

# This tool does not require any input fields, but you can define them similarly for other tools.

def run(self):
"""
Executes the test tool's main functionality. In this case, it simply returns a success message.
"""
self.shared_state.set("test_tool_used", True)

return "Test Successful"

cls.TestTool = TestTool

from test_agents import CEO, TestAgent1, TestAgent2
cls.ceo = CEO()
cls.agent1 = TestAgent1()
cls.agent1.add_tool(Retrieval)
cls.agent2 = TestAgent2()
cls.agent2.add_tool(cls.TestTool)

def test_1_init_agency(self):
"""it should initialize agency with agents"""
Expand Down Expand Up @@ -162,7 +185,44 @@ def test_4_agent_communication(self):
for agent in self.__class__.agency.agents:
self.assertTrue(agent.id in [settings['id'] for settings in self.__class__.loaded_agents_settings])

def test_5_load_from_db(self):
def test_5_agent_communication_stream(self):
"""it should communicate between agents using streaming"""
print("TestAgent1 tools", self.__class__.agent1.tools)

test_tool_used = False
test_agent2_used = False

class EventHandler(AgencyEventHandler):
@override
def on_text_created(self, text) -> None:
# get the name of the agent that is sending the message
if self.recipient_agent_name == "TestAgent2":
nonlocal test_agent2_used
test_agent2_used = True

def on_tool_call_done(self, tool_call: ToolCall) -> None:
if tool_call.function.name == "TestTool":
nonlocal test_tool_used
test_tool_used = True

message = self.__class__.agency.get_completion_stream("Please tell TestAgent1 to tell TestAgent 2 to use test tool.",
event_handler=EventHandler)

self.assertFalse('error' in message.lower())

self.assertTrue(test_tool_used)
self.assertTrue(test_agent2_used)

self.assertTrue(self.__class__.TestTool.shared_state.get("test_tool_used"))

for agent_name, threads in self.__class__.agency.agents_and_threads.items():
for other_agent_name, thread in threads.items():
self.assertTrue(thread.id in self.__class__.loaded_thread_ids[agent_name][other_agent_name])

for agent in self.__class__.agency.agents:
self.assertTrue(agent.id in [settings['id'] for settings in self.__class__.loaded_agents_settings])

def test_6_load_from_db(self):
"""it should load agents from db"""
# os.rename("settings.json", "settings2.json")

Expand All @@ -173,6 +233,8 @@ def test_5_load_from_db(self):
agent1 = TestAgent1()
agent1.add_tool(Retrieval)
agent2 = TestAgent2()
agent2.add_tool(self.__class__.TestTool)

ceo = CEO()

# check that agents are loaded
Expand Down Expand Up @@ -205,7 +267,7 @@ def test_5_load_from_db(self):
self.assertTrue(agent.id in [settings['id'] for settings in self.__class__.loaded_agents_settings])
self.assertTrue(agent.id in [settings['id'] for settings in previous_loaded_agents_settings])

def test_6_init_async_agency(self):
def test_7_init_async_agency(self):
"""it should initialize agency with agents"""
# reset loaded thread ids
self.__class__.loaded_thread_ids = {}
Expand All @@ -222,7 +284,7 @@ def test_6_init_async_agency(self):

self.check_all_agents_settings(True)

def test_7_async_agent_communication(self):
def test_8_async_agent_communication(self):
"""it should communicate between agents asynchronously"""
print("TestAgent1 tools", self.__class__.agent1.tools)
self.__class__.agency.get_completion("Please tell TestAgent1 to say test to TestAgent2.",
Expand Down Expand Up @@ -278,7 +340,7 @@ def check_agent_settings(self, agent, async_mode=False):
self.assertTrue(assistant.tools[3].type == "function")
self.assertTrue(assistant.tools[3].function.name == "GetResponse")
elif agent.name == "TestAgent2":
self.assertTrue(len(assistant.tools) == self.__class__.num_schemas)
self.assertTrue(len(assistant.tools) == self.__class__.num_schemas + 1)
for tool in assistant.tools:
self.assertTrue(tool.type == "function")
self.assertTrue(tool.function.name in [tool.__name__ for tool in agent.tools])
Expand All @@ -287,7 +349,7 @@ def check_agent_settings(self, agent, async_mode=False):
self.assertTrue(len(assistant.file_ids) == 0)
self.assertTrue(len(assistant.tools) == num_tools)
else:
raise Exception("Unknown agent name")
pass
except Exception as e:
print("Error checking agent settings ", agent.name)
raise e
Expand Down

0 comments on commit 235a037

Please sign in to comment.