Skip to content

Commit

Permalink
'Cannot Pickle' RLock Bug fix (#18)
Browse files Browse the repository at this point in the history
* 'Cannot Pickle' RLock Bug fix; Update & Refactor tests; Minor refactoring
* Update packages
* AgencyConfig Validation bug fix; Update tests
* Add error handler for Pydantic validation errors
  • Loading branch information
bonk1t authored Jan 15, 2024
1 parent 1d3738f commit 671a33d
Show file tree
Hide file tree
Showing 10 changed files with 837 additions and 335 deletions.
47 changes: 41 additions & 6 deletions nalgonda/dependencies/agency_manager.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import asyncio
import logging
from copy import copy
from uuid import uuid4

from agency_swarm import Agency, Agent
from agency_swarm.util import get_openai_client
from fastapi import Depends
from redis import asyncio as aioredis

Expand Down Expand Up @@ -31,21 +33,23 @@ async def create_agency(self, agency_id: str | None = None) -> str:
agency_config = await asyncio.to_thread(agency_config_storage.load_or_create)

agents = await self.load_and_construct_agents(agency_config)
agency = self.construct_agency(agency_config, agents)
agency = await asyncio.to_thread(self.construct_agency, agency_config, agents)

await self.cache_agency(agency, agency_id, None)
return agency_id

async def get_agency(self, agency_id: str, thread_id: str | None = None) -> Agency | None:
cache_key = self.get_cache_key(agency_id, thread_id)
agency = await self.cache_manager.get(cache_key)

if not agency:
# If agency is not found in the cache, re-populate the cache
agency = await self.repopulate_cache(agency_id)
if not agency:
logger.error(f"Agency configuration for {agency_id} could not be found in the Firestore database.")
return None

agency = self._restore_client_objects(agency)
return agency

async def update_agency(self, agency_config: AgencyConfig, updated_data: dict) -> None:
Expand All @@ -54,23 +58,26 @@ async def update_agency(self, agency_config: AgencyConfig, updated_data: dict) -

updated_data.pop("agency_id", None) # ensure agency_id is not modified
agency_config.update(updated_data)

AgencyConfig.model_validate(agency_config.model_dump())

agency_config_storage = AgencyConfigFirestoreStorage(agency_id)
await asyncio.to_thread(agency_config_storage.save, agency_config)

# Update the agency in the cache
await self.repopulate_cache(agency_id)

async def repopulate_cache(self, agency_id: str) -> Agency | None:
cache_key = self.get_cache_key(agency_id)
agency_config_storage = AgencyConfigFirestoreStorage(agency_id)
agency_config = await asyncio.to_thread(agency_config_storage.load)
if not agency_config:
logger.error(f"Agency with id {agency_id} not found.")
return None

agents = await self.load_and_construct_agents(agency_config)
agency = self.construct_agency(agency_config, agents)
await self.cache_manager.set(cache_key, agency)
agency = await asyncio.to_thread(self.construct_agency, agency_config, agents)

await self.cache_agency(agency, agency_id, None)
return agency

async def load_and_construct_agents(self, agency_config: AgencyConfig) -> dict[str, Agent]:
Expand Down Expand Up @@ -103,8 +110,8 @@ def construct_agency(agency_config: AgencyConfig, agents: dict[str, Agent]) -> A
async def cache_agency(self, agency: Agency, agency_id: str, thread_id: str | None) -> None:
"""Cache the agency."""
cache_key = self.get_cache_key(agency_id, thread_id)

await self.cache_manager.set(cache_key, agency)
agency_clean = self._remove_client_objects(agency)
await self.cache_manager.set(cache_key, agency_clean)

async def delete_agency_from_cache(self, agency_id: str, thread_id: str | None) -> None:
"""Delete the agency from the cache."""
Expand All @@ -116,6 +123,34 @@ async def delete_agency_from_cache(self, agency_id: str, thread_id: str | None)
def get_cache_key(agency_id: str, thread_id: str | None = None) -> str:
return f"{agency_id}/{thread_id}" if thread_id else agency_id

@staticmethod
def _remove_client_objects(agency: Agency) -> Agency:
"""Remove all client objects from the agency object"""
agency_copy = copy(agency)
agency_copy.agents = [copy(agent) for agent in agency_copy.agents]

for agent in agency_copy.agents:
agent.client = None

agency_copy.main_thread = copy(agency_copy.main_thread)
agency_copy.main_thread.client = None

agency_copy.main_thread.recipient_agent = copy(agency_copy.main_thread.recipient_agent)
agency_copy.main_thread.recipient_agent.client = None

agency_copy.ceo = copy(agency_copy.ceo)
agency_copy.ceo.client = None

return agency_copy

@staticmethod
def _restore_client_objects(agency: Agency) -> Agency:
"""Restore all client objects from the agency object"""
for agent in agency.agents:
agent.client = get_openai_client()
agency.main_thread.client = get_openai_client()
return agency


def get_agency_manager(
redis: aioredis.Redis = Depends(get_redis), agent_manager: AgentManager = Depends(get_agent_manager)
Expand Down
21 changes: 1 addition & 20 deletions nalgonda/dependencies/caching/redis_cache_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pickle

from agency_swarm import Agency
from agency_swarm.util import get_openai_client
from redis import asyncio as aioredis

from nalgonda.constants import DEFAULT_CACHE_EXPIRATION
Expand All @@ -24,13 +23,11 @@ async def get(self, key: str) -> Agency | None:
return None

loaded = pickle.loads(serialized_data)
loaded = self.restore_client_objects(loaded)
return loaded

async def set(self, key: str, value: Agency, expire: int = DEFAULT_CACHE_EXPIRATION) -> None:
"""Sets the value for the given key in the cache"""
value_copy = self.remove_client_objects(value)
serialized_data = pickle.dumps(value_copy)
serialized_data = pickle.dumps(value)
await self.redis.set(key, serialized_data, ex=expire)

async def delete(self, key: str) -> None:
Expand All @@ -40,19 +37,3 @@ async def delete(self, key: str) -> None:
async def close(self) -> None:
"""Closes the Redis connection"""
await self.redis.close()

@staticmethod
def remove_client_objects(agency: Agency) -> Agency:
"""Remove all client objects from the agency object"""
for agent in agency.agents:
agent.client = None
agency.main_thread.client = None
return agency

@staticmethod
def restore_client_objects(agency: Agency) -> Agency:
"""Restore all client objects from the agency object"""
for agent in agency.agents:
agent.client = get_openai_client()
agency.main_thread.client = get_openai_client()
return agency
14 changes: 14 additions & 0 deletions nalgonda/exception_handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import logging
from http import HTTPStatus

from fastapi import Request
from fastapi.responses import JSONResponse

logger = logging.getLogger(__name__)


async def bad_request_exception_handler(request: Request, exc: ValueError): # noqa: ARG001
return JSONResponse(
status_code=HTTPStatus.BAD_REQUEST,
content={"detail": str(exc)},
)
3 changes: 3 additions & 0 deletions nalgonda/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from firebase_admin import credentials
from pydantic import ValidationError
from starlette.staticfiles import StaticFiles

from nalgonda.constants import BASE_DIR
from nalgonda.exception_handlers import bad_request_exception_handler
from nalgonda.routers.v1 import v1_router
from nalgonda.settings import settings
from nalgonda.utils import init_webserver_folders
Expand Down Expand Up @@ -48,6 +50,7 @@

v1_api_app = FastAPI(root_path="/v1")
v1_api_app.include_router(v1_router)
v1_api_app.add_exception_handler(ValidationError, bad_request_exception_handler)

# mount an api route such that the main route serves the ui and the /api
app.mount("/v1", v1_api_app)
Expand Down
Loading

0 comments on commit 671a33d

Please sign in to comment.