-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Julien
committed
Nov 13, 2024
1 parent
d8b536a
commit 1205065
Showing
6 changed files
with
207 additions
and
168 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# main.py | ||
import sys | ||
from pathlib import Path | ||
import pytest | ||
|
||
|
||
# Ajoute le répertoire racine du projet au chemin Python | ||
project_root = Path(__file__).parent | ||
sys.path.append(str(project_root)) | ||
|
||
from swarm import Swarm | ||
from agents.base_agents import kubectl_agent, stripe_agent, grafana_agent, db_agent | ||
|
||
client = Swarm() | ||
|
||
def run_and_get_tool_calls(agent, query, get="tool_calls"): | ||
"""Helper function to run a query with an agent and return the specified attribute from response messages.""" | ||
message = {"role": "user", "content": query} | ||
response = client.run(agent=agent, messages=[message], execute_tools=False) | ||
return response.messages[-1].get(get) | ||
|
||
# Configuration de tests | ||
test_cases = [ | ||
{"agent": kubectl_agent, "query": "get pods count", "expected_function": "kubectl"}, | ||
{"agent": kubectl_agent, "query": "get app chatbot", "expected_function": "get_app", "expected_arguments": '{"app_name":"chatbot"}'}, | ||
{"agent": kubectl_agent, "query": "unsync app chatbot", "expected_function": "update_sync_policy", "expected_arguments": '{"app_name":"chatbot","is_sync":"false"}'}, | ||
{"agent": stripe_agent, "query": "get user [email protected]", "expected_function": "stripe_query", "expected_arguments": '{"email":"[email protected]"}'}, | ||
{"agent": stripe_agent, "query": "get payments of customer id 1234RTY78", "expected_function": "stripe_payments_list", "expected_arguments": '{"customer_id":"1234RTY78"}'}, | ||
{"agent": grafana_agent, "query": "show user [email protected] in grafana", "expected_function": "grafana_query", "expected_arguments": '{"email":"[email protected]"}'}, | ||
{"agent": db_agent, "query": "find user [email protected]", "expected_function": "transfer_to_query"}, | ||
] | ||
|
||
@pytest.mark.parametrize("case", test_cases) | ||
def test_tool_calls(case): | ||
"""Test tool calls based on provided cases in `test_cases`.""" | ||
tool_calls = run_and_get_tool_calls(case["agent"], case["query"]) | ||
|
||
assert tool_calls and len(tool_calls) == 1 | ||
assert tool_calls[0]["function"]["name"] == case["expected_function"] | ||
|
||
if "expected_arguments" in case: | ||
assert tool_calls[0]["function"]["arguments"] == case["expected_arguments"] | ||
|
||
@pytest.mark.parametrize( | ||
"query", | ||
[ | ||
"cancel the subscription xxxxxx", | ||
"refund payment with charge id xxxxxx", | ||
], | ||
) | ||
def test_confirm_content(query): | ||
"""Test to confirm specific content in response messages.""" | ||
content = run_and_get_tool_calls(stripe_agent, query, "content") | ||
assert "confirm" in content.lower() or "xxxxxx" in content | ||
|
||
@pytest.mark.parametrize( | ||
"query", | ||
[ | ||
"select all users since yesterday", | ||
"count all customers since today", | ||
#"count all tesla users since today", | ||
"find services from [email protected]", | ||
], | ||
) | ||
def test_db_transfer_queries(query): | ||
"""Test db_agent calls to ensure the 'transfer_to_query' function is used.""" | ||
tool_calls = run_and_get_tool_calls(db_agent, query) | ||
assert tool_calls and len(tool_calls) == 1 | ||
assert tool_calls[0]["function"]["name"] == "transfer_to_query" |
Oops, something went wrong.