Skip to content

Commit

Permalink
Extend the mock_firestore_client.py; add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bonk1t committed Jan 28, 2024
1 parent 40decef commit 856f118
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 31 deletions.
20 changes: 12 additions & 8 deletions nalgonda/routers/v1/api/agency.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ async def update_or_create_agency(
"""Create or update an agency and return its id"""
# support template configs:
if not agency_config.owner_id:
logger.info(f"Creating agency for user: {current_user.id}")
logger.info(f"Creating agency for user: {current_user.id}, agency: {agency_config.name}")
agency_config.agency_id = None
else:
# check if the current_user has permissions
Expand All @@ -65,13 +65,17 @@ async def update_or_create_agency(
raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="Agency not found")

Check warning on line 65 in nalgonda/routers/v1/api/agency.py

View check run for this annotation

Codecov / codecov/patch

nalgonda/routers/v1/api/agency.py#L65

Added line #L65 was not covered by tests
if agency_config_db.owner_id != current_user.id:
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Forbidden")
# check that all used agents belong to the current user
for agent_id in agency_config.agents:
get_result = await agent_manager.get_agent(agent_id)
if get_result:
_, agent_config = get_result
if agent_config.owner_id != current_user.id:
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Forbidden")

# check that all used agents belong to the current user
for agent_id in agency_config.agents:
get_result = await agent_manager.get_agent(agent_id)
if get_result:
_, agent_config = get_result
if agent_config.owner_id != current_user.id:
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Forbidden")
# FIXME: current limitation: all agents must belong to the current user.
# to fix: If the agent is a template (agent_config.owner_id is None), it should be copied for the current user
# (reuse the code from api/agent.py)

# Ensure the agency is associated with the current user
agency_config.owner_id = current_user.id
Expand Down
10 changes: 5 additions & 5 deletions nalgonda/routers/v1/api/agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Annotated

from fastapi import APIRouter, Body, Depends, HTTPException
Expand All @@ -11,6 +12,7 @@
from nalgonda.persistence.agent_config_firestore_storage import AgentConfigFirestoreStorage
from nalgonda.services.agent_manager import AgentManager

logger = logging.getLogger(__name__)
agent_router = APIRouter(tags=["agent"])


Expand Down Expand Up @@ -52,6 +54,7 @@ async def create_or_update_agent(
) -> dict[str, str]:
# support template configs:
if not agent_config.owner_id:
logger.info(f"Creating agent for user: {current_user.id}, agent: {agent_config.name}")
agent_config.agent_id = None

Check warning on line 58 in nalgonda/routers/v1/api/agent.py

View check run for this annotation

Codecov / codecov/patch

nalgonda/routers/v1/api/agent.py#L57-L58

Added lines #L57 - L58 were not covered by tests
else:
# check if the current_user has permissions
Expand All @@ -63,14 +66,11 @@ async def create_or_update_agent(
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Forbidden")
# Ensure the agent name has not been changed
if agent_config.name != agent_config_db.name:
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail="Agent name cannot be changed")
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail="Renaming agents is not supported yet")

Check warning on line 69 in nalgonda/routers/v1/api/agent.py

View check run for this annotation

Codecov / codecov/patch

nalgonda/routers/v1/api/agent.py#L69

Added line #L69 was not covered by tests

# Ensure the agent is associated with the current user
agent_config.owner_id = current_user.id

# FIXME: a workaround explained at the top of the file
if not agent_config.name.endswith(f" ({agent_config.owner_id})"):
agent_config.name = f"{agent_config.name} ({agent_config.owner_id})"

agent_id = await agent_manager.create_or_update_agent(agent_config)

return {"agent_id": agent_id}
5 changes: 5 additions & 0 deletions nalgonda/services/agent_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ async def create_or_update_agent(self, agent_config: AgentConfig) -> str:
Returns:
str: agent_id
"""

# FIXME: a workaround explained at the top of the file api/agent.py
if not agent_config.name.endswith(f" ({agent_config.owner_id})"):
agent_config.name = f"{agent_config.name} ({agent_config.owner_id})"

agent = self._construct_agent(agent_config)
agent.init_oai() # initialize the openai agent to get the id
agent_config.agent_id = agent.id
Expand Down
39 changes: 38 additions & 1 deletion tests/functional/v1/api/test_agency_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from unittest import mock
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from agency_swarm import Agent

from nalgonda.models.agency_config import AgencyConfig
from nalgonda.models.agent_config import AgentConfig
from tests.test_utils import TEST_USER_ID


Expand Down Expand Up @@ -120,3 +122,38 @@ def test_update_agency_owner_id_mismatch(client, mock_firestore_client):

assert response.status_code == 403
assert response.json() == {"detail": "Forbidden"}


@pytest.mark.usefixtures("mock_get_current_active_user")
def test_update_agency_with_foreign_agent(client, mock_firestore_client):
agency_config_data = {
"agency_id": "test_agency_id",
"owner_id": TEST_USER_ID,
"name": "Test Agency",
"agents": ["foreign_agent_id"],
}
foreign_agent_config = AgentConfig(
name="Foreign Agent", owner_id="foreign_owner_id", description="Test Agent", instructions="Test Instructions"
)
mock_firestore_client.setup_mock_data("agency_configs", "test_agency_id", agency_config_data)
mock_firestore_client.setup_mock_data("agent_configs", "foreign_agent_id", foreign_agent_config.model_dump())

agent_mock = MagicMock(spec=Agent)
expected_agent_return_value = (agent_mock, foreign_agent_config)

# Mock the AgentManager to return an agent with a different owner when get_agent is called
with patch("nalgonda.services.agent_manager.AgentManager.get_agent", new_callable=AsyncMock) as mock_get_agent:
mock_get_agent.return_value = expected_agent_return_value

# Simulate a PUT request to update the agency with agents belonging to a different user
response = client.put("/v1/api/agency", json=agency_config_data)

# Check if the server responds with a 403 Forbidden
assert response.status_code == 403
assert response.json() == {"detail": "Forbidden"}

# Check if the agent manager was called with the correct arguments
mock_get_agent.assert_called_once_with("foreign_agent_id")

# Check if the agency config was not updated
assert mock_firestore_client.collection("agency_configs").to_dict() == agency_config_data
42 changes: 25 additions & 17 deletions tests/test_utils/mock_firestore_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,37 +13,45 @@ def to_dict(self):
class MockFirestoreClient:
def __init__(self):
self._collections = {}
self.current_collection = None
self.current_document = None
self.current_document_id = None
self._current_collection = None
self._current_documents = {} # Tracks the current document ID for each collection

def collection(self, collection_name):
self.current_collection = collection_name
self._current_collection = collection_name
if collection_name not in self._current_documents:
self._current_documents[collection_name] = {"current_document": None, "current_document_id": None}
return self

def document(self, document_name):
self.current_document = document_name
if self._current_collection:
self._current_documents[self._current_collection]["current_document"] = document_name
return self

def get(self):
return self

@property
def exists(self):
collection = self._collections.get(self.current_collection, {})
return self.current_document in collection
collection = self._collections.get(self._current_collection, {})
current_doc = self._current_documents.get(self._current_collection, {}).get("current_document")
return current_doc in collection

def set(self, data: dict):
self._collections.setdefault(self.current_collection, {})[self.current_document] = data
collection = self._current_collection
current_doc = self._current_documents[collection]["current_document"]
self._collections.setdefault(collection, {})[current_doc] = data

def to_dict(self):
collection = self._collections.get(self.current_collection, {})
return collection.get(self.current_document, {})
collection = self._collections.get(self._current_collection, {})
current_doc = self._current_documents.get(self._current_collection, {}).get("current_document")
return collection.get(current_doc, {})

def setup_mock_data(self, collection_name, document_name, data, doc_id=None):
self.current_collection = collection_name
self.current_document = document_name
self.current_document_id = doc_id
self._current_collection = collection_name
if collection_name not in self._current_documents:
self._current_documents[collection_name] = {}
self._current_documents[collection_name]["current_document"] = document_name
self._current_documents[collection_name]["current_document_id"] = doc_id
self.set(data)

def where(self, filter: FieldFilter):
Expand All @@ -54,16 +62,16 @@ def where(self, filter: FieldFilter):
return self

def stream(self):
# This method should return a list of mock documents
# matching the criteria set in the 'where' method.
matching_docs = []
for doc_id, doc in self._collections.get(self.current_collection, {}).items():
for doc_id, doc in self._collections.get(self._current_collection, {}).items():
if doc.get(self._where_field) == self._where_value:
matching_docs.append(MockDocumentSnapshot(doc_id, doc))
return matching_docs

def add(self, data) -> list:
# This method should add a new document to the collection
# and return a list with the new document.
collection = self._current_collection
current_doc_id = self._current_documents[collection].get("current_document_id")
self.set(data)
return [[], MockDocumentSnapshot(self.current_document_id, data)]
return [[], MockDocumentSnapshot(current_doc_id, data)]

0 comments on commit 856f118

Please sign in to comment.