From 5829eeb7181795c350243a96fe425df403740380 Mon Sep 17 00:00:00 2001 From: Nikita Bobrovskiy <39348559+bonk1t@users.noreply.github.com> Date: Mon, 29 Jan 2024 22:34:35 +0200 Subject: [PATCH] Fix cache management; rename session_id->thread_id --- nalgonda/models/request_models.py | 2 +- nalgonda/routers/v1/api/session.py | 6 +++--- nalgonda/services/agency_manager.py | 8 +++++--- tests/functional/v1/api/test_session_endpoints.py | 6 +++--- tests/unit/services/test_agency_manager.py | 6 +++--- 5 files changed, 15 insertions(+), 13 deletions(-) diff --git a/nalgonda/models/request_models.py b/nalgonda/models/request_models.py index e7348708..c7582a22 100644 --- a/nalgonda/models/request_models.py +++ b/nalgonda/models/request_models.py @@ -7,5 +7,5 @@ class ThreadPostRequest(BaseModel): class AgencyMessagePostRequest(BaseModel): agency_id: str = Field(..., description="The unique identifier for the agency.") - message: str = Field(..., description="The message to be sent to the agency.") thread_id: str = Field(..., description="The identifier for the conversational thread.") + message: str = Field(..., description="The message to be sent to the agency.") diff --git a/nalgonda/routers/v1/api/session.py b/nalgonda/routers/v1/api/session.py index 86bc9d85..0aca33cb 100644 --- a/nalgonda/routers/v1/api/session.py +++ b/nalgonda/routers/v1/api/session.py @@ -48,10 +48,10 @@ async def create_session( if not agency: raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="Agency not found") - session_id = thread_manager.create_threads(agency) + thread_id = thread_manager.create_threads(agency) - await agency_manager.cache_agency(agency, agency_id, session_id) - return {"session_id": session_id} + await agency_manager.cache_agency(agency, agency_id, thread_id) + return {"thread_id": thread_id} @session_router.post("/session/message") diff --git a/nalgonda/services/agency_manager.py b/nalgonda/services/agency_manager.py index afd63371..eb6c10e4 100644 --- a/nalgonda/services/agency_manager.py +++ b/nalgonda/services/agency_manager.py @@ -29,7 +29,7 @@ async def get_agency(self, agency_id: str, thread_id: str | None = None) -> Agen if not agency: # If agency is not found in the cache, re-populate the cache - agency = await self.repopulate_cache_and_update_assistants(agency_id) + agency = await self.repopulate_cache_and_update_assistants(agency_id, thread_id) if not agency: logger.error(f"Agency configuration for {agency_id} could not be found in the Firestore database.") return None @@ -46,7 +46,9 @@ async def update_or_create_agency(self, agency_config: AgencyConfig) -> str: await self.repopulate_cache_and_update_assistants(agency_id) return agency_id - async def repopulate_cache_and_update_assistants(self, agency_id: str) -> Agency | None: + async def repopulate_cache_and_update_assistants( + self, agency_id: str, thread_id: str | None = None + ) -> Agency | None: """Gets the agency config from the Firestore, constructs agents and agency (agency-swarm also updates assistants), and saves the Agency instance to Redis (with expiration period, see constants.DEFAULT_CACHE_EXPIRATION). @@ -59,7 +61,7 @@ async def repopulate_cache_and_update_assistants(self, agency_id: str) -> Agency agents = await self.load_and_construct_agents(agency_config) agency = await asyncio.to_thread(self.construct_agency, agency_config, agents) - await self.cache_agency(agency, agency_id, None) + await self.cache_agency(agency, agency_id, thread_id) return agency async def load_and_construct_agents(self, agency_config: AgencyConfig) -> dict[str, Agent]: diff --git a/tests/functional/v1/api/test_session_endpoints.py b/tests/functional/v1/api/test_session_endpoints.py index 7d37afdd..920c5716 100644 --- a/tests/functional/v1/api/test_session_endpoints.py +++ b/tests/functional/v1/api/test_session_endpoints.py @@ -14,7 +14,7 @@ def test_create_session_success(client, mock_firestore_client): with patch.object( AgencyManager, "get_agency", AsyncMock(return_value=MagicMock()) ) as mock_get_agency, patch.object( - ThreadManager, "create_threads", MagicMock(return_value="new_session_id") + ThreadManager, "create_threads", MagicMock(return_value="new_thread_id") ) as mock_create_threads, patch.object(AgencyManager, "cache_agency", AsyncMock()) as mock_cache_agency: # mock Firestore to pass the security owner_id check mock_firestore_client.setup_mock_data( @@ -27,10 +27,10 @@ def test_create_session_success(client, mock_firestore_client): response = client.post("/v1/api/session", json=request_data.model_dump()) # Assertions assert response.status_code == 200 - assert response.json() == {"session_id": "new_session_id"} + assert response.json() == {"thread_id": "new_thread_id"} mock_get_agency.assert_awaited_once_with("test_agency_id", None) mock_create_threads.assert_called_once_with(mock_get_agency.return_value) - mock_cache_agency.assert_awaited_once_with(mock_get_agency.return_value, "test_agency_id", "new_session_id") + mock_cache_agency.assert_awaited_once_with(mock_get_agency.return_value, "test_agency_id", "new_thread_id") @pytest.mark.usefixtures("mock_get_current_active_user") diff --git a/tests/unit/services/test_agency_manager.py b/tests/unit/services/test_agency_manager.py index a787104c..4ec45d73 100644 --- a/tests/unit/services/test_agency_manager.py +++ b/tests/unit/services/test_agency_manager.py @@ -38,7 +38,7 @@ async def test_get_agency_repopulate_cache(agency_manager): agency = await agency_manager.get_agency("test_agency_id") assert agency is not None mock_get.assert_called_once_with("test_agency_id") - mock_repopulate.assert_called_once_with("test_agency_id") + mock_repopulate.assert_called_once_with("test_agency_id", None) @pytest.mark.asyncio @@ -70,7 +70,7 @@ async def test_repopulate_cache_no_config(agency_manager): ) as mock_async_to_thread: mock_async_to_thread.return_value = None - result = await agency_manager.repopulate_cache_and_update_assistants("nonexistent_agency_id") + result = await agency_manager.repopulate_cache_and_update_assistants("nonexistent_agency_id", None) assert result is None mock_async_to_thread.assert_called_once() mock_logger.error.assert_called_once_with("Agency with id nonexistent_agency_id not found.") @@ -98,7 +98,7 @@ async def test_repopulate_cache_success(agency_manager, mock_firestore_client): mock_load_agents.return_value = {"agent1": agent} mock_construct_agency.return_value = Agency([], "manifesto") - result = await agency_manager.repopulate_cache_and_update_assistants("test_agency_id") + result = await agency_manager.repopulate_cache_and_update_assistants("test_agency_id", None) assert result is not None mock_load_agents.assert_called_once_with(agency_config) mock_construct_agency.assert_called_once_with(agency_config, {"agent1": agent})