Skip to content

Commit

Permalink
Fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
VRSEN committed Feb 28, 2024
1 parent 527ee8c commit 6fff6a4
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions tests/test_agency.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
import unittest

from agency_swarm.tools import CodeInterpreter
from agency_swarm.tools import CodeInterpreter, Retrieval

sys.path.insert(0, '../agency-swarm')
from agency_swarm.util import create_agent_template
Expand Down Expand Up @@ -78,7 +78,7 @@ def save_thread_callback(agents_and_thread_ids):
instructions="Your task is to say test to another test agent using SendMessage tool. "
"If the agent, does not "
"respond or something goes wrong please say 'error' and nothing else. "
"Otherwise say 'success' and nothing else.")
"Otherwise say 'success' and nothing else.", code_interpreter=True)
create_agent_template("TestAgent2", "Test Agent 2", path="./test_agents",
instructions="Please respond to the user that test was a success.")

Expand All @@ -97,7 +97,7 @@ def save_thread_callback(agents_and_thread_ids):
from test_agents import CEO, TestAgent1, TestAgent2
cls.ceo = CEO()
cls.agent1 = TestAgent1()
cls.agent1.add_tool(CodeInterpreter)
cls.agent1.add_tool(Retrieval)
cls.agent2 = TestAgent2()

def test_1_init_agency(self):
Expand Down Expand Up @@ -171,6 +171,7 @@ def test_5_load_from_db(self):

from test_agents import CEO, TestAgent1, TestAgent2
agent1 = TestAgent1()
agent1.add_tool(Retrieval)
agent2 = TestAgent2()
ceo = CEO()

Expand Down Expand Up @@ -261,19 +262,21 @@ def check_agent_settings(self, agent, async_mode=False):
self.assertTrue(assistant)
self.assertTrue(agent._check_parameters(assistant.model_dump()))
if agent.name == "TestAgent1":
num_tools = 2 if not async_mode else 3
num_tools = 3 if not async_mode else 4
self.assertTrue(len(assistant.file_ids) == self.__class__.num_files)
for file_id in assistant.file_ids:
self.assertTrue(file_id in agent.file_ids)
# check retrieval tools is there
print("assistant tools", assistant.tools)
self.assertTrue(len(assistant.tools) == num_tools)
self.assertTrue(len(agent.tools) == num_tools)
self.assertTrue(assistant.tools[0].type == "retrieval")
self.assertTrue(assistant.tools[1].type == "function")
self.assertTrue(assistant.tools[1].function.name == "SendMessage")
self.assertTrue(assistant.tools[0].type == "code_interpreter")
self.assertTrue(assistant.tools[1].type == "retrieval")
self.assertTrue(assistant.tools[2].type == "function")
self.assertTrue(assistant.tools[2].function.name == "SendMessage")
if async_mode:
self.assertTrue(assistant.tools[2].type == "function")
self.assertTrue(assistant.tools[2].function.name == "GetResponse")
self.assertTrue(assistant.tools[3].type == "function")
self.assertTrue(assistant.tools[3].function.name == "GetResponse")
elif agent.name == "TestAgent2":
self.assertTrue(len(assistant.tools) == self.__class__.num_schemas)
for tool in assistant.tools:
Expand Down

0 comments on commit 6fff6a4

Please sign in to comment.