Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bonk1t committed Jan 28, 2024
1 parent 29d3291 commit 40decef
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 52 deletions.
9 changes: 5 additions & 4 deletions nalgonda/custom_tools/search_web.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@


class SearchWeb(BaseTool):
"""Search the web with a search phrase and return the results."""
"""Search the web with a search query and return the results."""

phrase: str = Field(
query: str = Field(
...,
description="The search phrase you want to use. Optimize the search phrase for an internet search engine.",
description="The search query you want to use. Optimize the search query for an internet search engine.",
)
max_results: int = Field(default=10, description="The maximum number of search results to return, default is 10.")

def run(self) -> str:
with DDGS() as ddgs:
return "\n".join(str(result) for result in ddgs.text(self.phrase, max_results=self.max_results))
results = [str(r) for r in ddgs.text(self.query, max_results=self.max_results)]
return "\n".join(results) if results else "No results found."

Check warning on line 18 in nalgonda/custom_tools/search_web.py

View check run for this annotation

Codecov / codecov/patch

nalgonda/custom_tools/search_web.py#L17-L18

Added lines #L17 - L18 were not covered by tests
11 changes: 2 additions & 9 deletions nalgonda/routers/v1/api/session.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
from typing import Annotated

from agency_swarm import Agency
from fastapi import APIRouter, Depends, HTTPException
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND

Expand Down Expand Up @@ -80,14 +79,8 @@ async def post_agency_message(
raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="Agency not found")

Check warning on line 79 in nalgonda/routers/v1/api/session.py

View check run for this annotation

Codecov / codecov/patch

nalgonda/routers/v1/api/session.py#L79

Added line #L79 was not covered by tests

try:
response = await process_message(user_message, agency)
response = await agency.get_completion(message=user_message, yield_messages=False)
return {"response": response}
except Exception as e:
logger.exception(e)
return {"error": str(e)}


async def process_message(user_message: str, agency: Agency) -> str:
"""Process a message from the user and return the response from the User Proxy."""
response = agency.get_completion(message=user_message, yield_messages=False)
return response
raise HTTPException(status_code=500, detail="Something went wrong") from e
20 changes: 14 additions & 6 deletions tests/functional/v1/api/test_agency_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from unittest import mock
from unittest.mock import AsyncMock, patch

import pytest

from nalgonda.models.agency_config import AgencyConfig
from tests.test_utils import TEST_USER_ID


def test_get_agency_list_success(client, mock_get_current_active_user, mock_firestore_client): # noqa: ARG001
@pytest.mark.usefixtures("mock_get_current_active_user")
def test_get_agency_list_success(client, mock_firestore_client):
# Setup expected response
expected_agency = AgencyConfig(agency_id="agency1", owner_id="test_user_id", name="Test agency")
mock_firestore_client.setup_mock_data("agency_configs", "test_agency_id", expected_agency.model_dump())
Expand All @@ -16,7 +19,8 @@ def test_get_agency_list_success(client, mock_get_current_active_user, mock_fire
assert response.json() == [expected_agency.model_dump()]


def test_get_agency_config(client, mock_firestore_client, mock_get_current_active_user): # noqa: ARG001
@pytest.mark.usefixtures("mock_get_current_active_user")
def test_get_agency_config(client, mock_firestore_client):
mock_data = {
"agency_id": "test_agency_id",
"owner_id": TEST_USER_ID,
Expand All @@ -33,14 +37,16 @@ def test_get_agency_config(client, mock_firestore_client, mock_get_current_activ
assert response.json() == mock_data


def test_get_agency_config_not_found(client, mock_get_current_active_user): # noqa: ARG001
@pytest.mark.usefixtures("mock_get_current_active_user")
def test_get_agency_config_not_found(client):
# Simulate non-existent agency by not setting up any data for it
response = client.get("/v1/api/agency?agency_id=non_existent_agency")
assert response.status_code == 404
assert response.json() == {"detail": "Agency not found"}


def test_create_agency_success(client, mock_firestore_client, mock_get_current_active_user): # noqa: ARG001
@pytest.mark.usefixtures("mock_get_current_active_user", "mock_firestore_client")
def test_create_agency_success(client):
template_config = {
"agency_id": "template_agency_id",
"name": "Test agency",
Expand All @@ -65,7 +71,8 @@ def test_create_agency_success(client, mock_firestore_client, mock_get_current_a
assert mock_update_or_create_agency.mock_calls[0].args[0].agency_id != "template_agency_id"


def test_update_agency_success(client, mock_firestore_client, mock_get_current_active_user): # noqa: ARG001
@pytest.mark.usefixtures("mock_get_current_active_user")
def test_update_agency_success(client, mock_firestore_client):
# Setup initial data in mock Firestore client
initial_data = {
"agency_id": "test_agency_id",
Expand Down Expand Up @@ -93,7 +100,8 @@ def test_update_agency_success(client, mock_firestore_client, mock_get_current_a
assert mock_firestore_client.to_dict() == new_data


def test_update_agency_owner_id_mismatch(client, mock_firestore_client, mock_get_current_active_user): # noqa: ARG001
@pytest.mark.usefixtures("mock_get_current_active_user")
def test_update_agency_owner_id_mismatch(client, mock_firestore_client):
# Setup initial data in mock Firestore client
initial_data = {
"agency_id": "test_agency_id",
Expand Down
9 changes: 6 additions & 3 deletions tests/functional/v1/api/test_agent_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,17 @@ def agent_data():
}


def test_get_agent_config(client, agent_data, mock_firestore_client, mock_get_current_active_user): # noqa: ARG001
@pytest.mark.usefixtures("mock_get_current_active_user")
def test_get_agent_config(client, agent_data, mock_firestore_client):
mock_firestore_client.setup_mock_data("agent_configs", AGENT_ID, agent_data)

response = client.get(f"/v1/api/agent?agent_id={AGENT_ID}")
assert response.status_code == 200
assert response.json() == agent_data


def test_update_agent_config_success(client, agent_data, mock_firestore_client, mock_get_current_active_user): # noqa: ARG001
@pytest.mark.usefixtures("mock_get_current_active_user")
def test_update_agent_config_success(client, agent_data, mock_firestore_client):
mock_firestore_client.setup_mock_data("agent_configs", AGENT_ID, agent_data)

with patch("nalgonda.services.agent_manager.AgentManager") as mock_agent_manager:
Expand All @@ -41,7 +43,8 @@ def test_update_agent_config_success(client, agent_data, mock_firestore_client,
assert response.json() == {"agent_id": AGENT_ID}


def test_update_agent_config_owner_id_mismatch(client, agent_data, mock_firestore_client, mock_get_current_active_user): # noqa: ARG001
@pytest.mark.usefixtures("mock_get_current_active_user")
def test_update_agent_config_owner_id_mismatch(client, agent_data, mock_firestore_client):
agent_data_db = agent_data.copy()
agent_data_db["owner_id"] = "other_user"
mock_firestore_client.setup_mock_data("agent_configs", AGENT_ID, agent_data_db)
Expand Down
78 changes: 75 additions & 3 deletions tests/functional/v1/api/test_session_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from fastapi import status

from nalgonda.models.request_models import ThreadPostRequest
Expand All @@ -8,7 +9,8 @@
from tests.test_utils import TEST_USER_ID


def test_create_session_success(client, mock_firestore_client, mock_get_current_active_user): # noqa: ARG001
@pytest.mark.usefixtures("mock_get_current_active_user")
def test_create_session_success(client, mock_firestore_client):
with patch.object(
AgencyManager, "get_agency", AsyncMock(return_value=MagicMock())
) as mock_get_agency, patch.object(
Expand All @@ -24,14 +26,15 @@ def test_create_session_success(client, mock_firestore_client, mock_get_current_
# Create a test client
response = client.post("/v1/api/session", json=request_data.model_dump())
# Assertions
assert response.status_code == status.HTTP_200_OK
assert response.status_code == 200
assert response.json() == {"session_id": "new_session_id"}
mock_get_agency.assert_awaited_once_with("test_agency_id", None)
mock_create_threads.assert_called_once_with(mock_get_agency.return_value)
mock_cache_agency.assert_awaited_once_with(mock_get_agency.return_value, "test_agency_id", "new_session_id")


def test_create_session_agency_not_found(client, mock_get_current_active_user): # noqa: ARG001
@pytest.mark.usefixtures("mock_get_current_active_user")
def test_create_session_agency_not_found(client):
with patch.object(AgencyManager, "get_agency", AsyncMock(return_value=None)):
# Create request data
request_data = ThreadPostRequest(agency_id="test_agency_id")
Expand All @@ -40,3 +43,72 @@ def test_create_session_agency_not_found(client, mock_get_current_active_user):
# Assertions
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.json() == {"detail": "Agency not found"}


@pytest.fixture
def mock_get_agency():
get_agency_mock = AsyncMock(return_value=MagicMock(get_completion=AsyncMock(return_value="Hello, world!")))
with patch.object(AgencyManager, "get_agency", get_agency_mock) as mock_get_agency:
yield mock_get_agency


# Successful message sending
@pytest.mark.usefixtures("mock_get_current_active_user")
def test_post_agency_message_success(client, mock_get_agency, mock_firestore_client):
agency_data = {"owner_id": "test_user_id", "agency_id": "test_agency_id", "name": "Test Agency"}
mock_firestore_client.setup_mock_data("agency_configs", "test_agency_id", agency_data)

# Sending a message
message_data = {"agency_id": "test_agency_id", "thread_id": "test_thread_id", "message": "Hello, world!"}

response = client.post("/v1/api/session/message", json=message_data)

assert response.status_code == 200
# We will check for the actual message we set up to be sent
assert response.json().get("response") == "Hello, world!"
mock_get_agency.assert_called_once_with("test_agency_id", "test_thread_id")


# Agency configuration not found
@pytest.mark.usefixtures("mock_get_current_active_user", "mock_firestore_client")
def test_post_agency_message_agency_config_not_found(client, mock_get_agency):
# Sending a message
message_data = {"agency_id": "test_agency", "thread_id": "test_thread", "message": "Hello, world!"}
response = client.post("/v1/api/session/message", json=message_data)

assert response.status_code == 404
assert response.json()["detail"] == "Agency not found"
mock_get_agency.assert_not_called()


# Current user not the owner of the agency
@pytest.mark.usefixtures("mock_get_current_active_user")
def test_post_agency_message_unauthorized(client, mock_get_agency, mock_firestore_client):
agency_data = {"owner_id": "other_user_id", "agency_id": "test_agency", "name": "Test Agency"}
mock_firestore_client.setup_mock_data("agency_configs", "test_agency", agency_data)

# Sending a message
message_data = {"agency_id": "test_agency", "thread_id": "test_thread", "message": "Hello, world!"}
response = client.post("/v1/api/session/message", json=message_data)

assert response.status_code == 403
assert response.json()["detail"] == "Forbidden"
mock_get_agency.assert_not_called()


# Failure in message processing
@pytest.mark.usefixtures("mock_get_current_active_user")
def test_post_agency_message_processing_failure(client, mock_get_agency, mock_firestore_client):
agency_data = {"owner_id": "test_user_id", "agency_id": "test_agency", "name": "Test Agency"}
mock_firestore_client.setup_mock_data("agency_configs", "test_agency", agency_data)

mock_get_agency.return_value.get_completion.side_effect = Exception("Test exception")

# Sending a message
message_data = {"agency_id": "test_agency", "thread_id": "test_thread", "message": "Hello, world!"}
response = client.post("/v1/api/session/message", json=message_data)

assert response.status_code == 500
assert response.json()["detail"] == "Something went wrong"

mock_get_agency.assert_called_once_with("test_agency", "test_thread")
17 changes: 8 additions & 9 deletions tests/functional/v1/api/test_tool_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,17 @@ def tool_config_data():
}


def test_get_tool_list(tool_config_data, client, mock_firestore_client, mock_get_current_active_user): # noqa: ARG001
@pytest.mark.usefixtures("mock_get_current_active_user")
def test_get_tool_list(tool_config_data, client, mock_firestore_client):
mock_firestore_client.setup_mock_data("tool_configs", "tool1", tool_config_data)

response = client.get("/v1/api/tool/list")
assert response.status_code == 200
assert response.json() == [tool_config_data]


def test_approve_tool(tool_config_data, client, mock_firestore_client, mock_get_current_superuser): # noqa: ARG001
@pytest.mark.usefixtures("mock_get_current_superuser")
def test_approve_tool(tool_config_data, client, mock_firestore_client):
mock_firestore_client.setup_mock_data("tool_configs", "tool1", tool_config_data)

response = client.post("/v1/api/tool/approve?tool_id=tool1")
Expand All @@ -40,7 +42,8 @@ def test_approve_tool(tool_config_data, client, mock_firestore_client, mock_get_


@patch("nalgonda.routers.v1.api.tool.generate_tool_description", MagicMock(return_value="Test description"))
def test_update_tool_config_success(tool_config_data, client, mock_firestore_client, mock_get_current_active_user): # noqa: ARG001
@pytest.mark.usefixtures("mock_get_current_active_user")
def test_update_tool_config_success(tool_config_data, client, mock_firestore_client):
mock_firestore_client.setup_mock_data("tool_configs", "tool1", tool_config_data)

tool_config_data = tool_config_data.copy()
Expand All @@ -57,12 +60,8 @@ def test_update_tool_config_success(tool_config_data, client, mock_firestore_cli
assert updated_config.version == 2


def test_update_tool_config_owner_id_mismatch(
tool_config_data,
client,
mock_firestore_client,
mock_get_current_active_user, # noqa: ARG001
):
@pytest.mark.usefixtures("mock_get_current_active_user")
def test_update_tool_config_owner_id_mismatch(tool_config_data, client, mock_firestore_client):
tool_config_data["owner_id"] = "another_user"

mock_firestore_client.setup_mock_data("tool_configs", "tool1", tool_config_data)
Expand Down
20 changes: 2 additions & 18 deletions tests/unit/services/test_agency_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,9 @@
from tests.test_utils import TEST_USER_ID


class MockRedisCacheManager:
def __init__(self, *args, **kwargs):
pass

def __call__(self, *args, **kwargs): # noqa: ARG001, ARG002
return self

async def get(self, key): # noqa: ARG002
return None

async def set(self, key, value):
pass

async def delete(self, key):
pass


@pytest.fixture
def agency_manager(mock_firestore_client): # noqa: ARG001
@pytest.mark.usefixtures("mock_firestore_client")
def agency_manager():
yield AgencyManager(
cache_manager=MagicMock(), agent_manager=MagicMock(), agency_config_storage=AgencyConfigFirestoreStorage()
)
Expand Down

0 comments on commit 40decef

Please sign in to comment.