Skip to content

Commit

Permalink
Fix cache management; rename session_id->thread_id
Browse files Browse the repository at this point in the history
  • Loading branch information
bonk1t committed Jan 29, 2024
1 parent 73d88ba commit 5829eeb
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 13 deletions.
2 changes: 1 addition & 1 deletion nalgonda/models/request_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ class ThreadPostRequest(BaseModel):

class AgencyMessagePostRequest(BaseModel):
agency_id: str = Field(..., description="The unique identifier for the agency.")
message: str = Field(..., description="The message to be sent to the agency.")
thread_id: str = Field(..., description="The identifier for the conversational thread.")
message: str = Field(..., description="The message to be sent to the agency.")
6 changes: 3 additions & 3 deletions nalgonda/routers/v1/api/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ async def create_session(
if not agency:
raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="Agency not found")

session_id = thread_manager.create_threads(agency)
thread_id = thread_manager.create_threads(agency)

await agency_manager.cache_agency(agency, agency_id, session_id)
return {"session_id": session_id}
await agency_manager.cache_agency(agency, agency_id, thread_id)
return {"thread_id": thread_id}


@session_router.post("/session/message")
Expand Down
8 changes: 5 additions & 3 deletions nalgonda/services/agency_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async def get_agency(self, agency_id: str, thread_id: str | None = None) -> Agen

if not agency:
# If agency is not found in the cache, re-populate the cache
agency = await self.repopulate_cache_and_update_assistants(agency_id)
agency = await self.repopulate_cache_and_update_assistants(agency_id, thread_id)
if not agency:
logger.error(f"Agency configuration for {agency_id} could not be found in the Firestore database.")
return None
Expand All @@ -46,7 +46,9 @@ async def update_or_create_agency(self, agency_config: AgencyConfig) -> str:
await self.repopulate_cache_and_update_assistants(agency_id)
return agency_id

async def repopulate_cache_and_update_assistants(self, agency_id: str) -> Agency | None:
async def repopulate_cache_and_update_assistants(
self, agency_id: str, thread_id: str | None = None
) -> Agency | None:
"""Gets the agency config from the Firestore, constructs agents and agency
(agency-swarm also updates assistants), and saves the Agency instance to Redis
(with expiration period, see constants.DEFAULT_CACHE_EXPIRATION).
Expand All @@ -59,7 +61,7 @@ async def repopulate_cache_and_update_assistants(self, agency_id: str) -> Agency
agents = await self.load_and_construct_agents(agency_config)
agency = await asyncio.to_thread(self.construct_agency, agency_config, agents)

await self.cache_agency(agency, agency_id, None)
await self.cache_agency(agency, agency_id, thread_id)
return agency

async def load_and_construct_agents(self, agency_config: AgencyConfig) -> dict[str, Agent]:
Expand Down
6 changes: 3 additions & 3 deletions tests/functional/v1/api/test_session_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_create_session_success(client, mock_firestore_client):
with patch.object(
AgencyManager, "get_agency", AsyncMock(return_value=MagicMock())
) as mock_get_agency, patch.object(
ThreadManager, "create_threads", MagicMock(return_value="new_session_id")
ThreadManager, "create_threads", MagicMock(return_value="new_thread_id")
) as mock_create_threads, patch.object(AgencyManager, "cache_agency", AsyncMock()) as mock_cache_agency:
# mock Firestore to pass the security owner_id check
mock_firestore_client.setup_mock_data(
Expand All @@ -27,10 +27,10 @@ def test_create_session_success(client, mock_firestore_client):
response = client.post("/v1/api/session", json=request_data.model_dump())
# Assertions
assert response.status_code == 200
assert response.json() == {"session_id": "new_session_id"}
assert response.json() == {"thread_id": "new_thread_id"}
mock_get_agency.assert_awaited_once_with("test_agency_id", None)
mock_create_threads.assert_called_once_with(mock_get_agency.return_value)
mock_cache_agency.assert_awaited_once_with(mock_get_agency.return_value, "test_agency_id", "new_session_id")
mock_cache_agency.assert_awaited_once_with(mock_get_agency.return_value, "test_agency_id", "new_thread_id")


@pytest.mark.usefixtures("mock_get_current_active_user")
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/services/test_agency_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ async def test_get_agency_repopulate_cache(agency_manager):
agency = await agency_manager.get_agency("test_agency_id")
assert agency is not None
mock_get.assert_called_once_with("test_agency_id")
mock_repopulate.assert_called_once_with("test_agency_id")
mock_repopulate.assert_called_once_with("test_agency_id", None)


@pytest.mark.asyncio
Expand Down Expand Up @@ -70,7 +70,7 @@ async def test_repopulate_cache_no_config(agency_manager):
) as mock_async_to_thread:
mock_async_to_thread.return_value = None

result = await agency_manager.repopulate_cache_and_update_assistants("nonexistent_agency_id")
result = await agency_manager.repopulate_cache_and_update_assistants("nonexistent_agency_id", None)
assert result is None
mock_async_to_thread.assert_called_once()
mock_logger.error.assert_called_once_with("Agency with id nonexistent_agency_id not found.")
Expand Down Expand Up @@ -98,7 +98,7 @@ async def test_repopulate_cache_success(agency_manager, mock_firestore_client):
mock_load_agents.return_value = {"agent1": agent}
mock_construct_agency.return_value = Agency([], "manifesto")

result = await agency_manager.repopulate_cache_and_update_assistants("test_agency_id")
result = await agency_manager.repopulate_cache_and_update_assistants("test_agency_id", None)
assert result is not None
mock_load_agents.assert_called_once_with(agency_config)
mock_construct_agency.assert_called_once_with(agency_config, {"agent1": agent})
Expand Down

0 comments on commit 5829eeb

Please sign in to comment.