Skip to content

Commit

Permalink
Ensure thread-safely
Browse files Browse the repository at this point in the history
Update README.md
  • Loading branch information
bonk1t committed Dec 10, 2023
1 parent 046f281 commit 1e33730
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 19 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

## Overview

... Project description ...
Project Nalgonda is a tool for managing and executing AI agents.
It is built on top of the [OpenAI Assitants API](https://platform.openai.com/docs/assistants/overview)
and provides a simple interface for configuring agents and executing them.

## Features

Expand Down
13 changes: 13 additions & 0 deletions src/nalgonda/agency_config_lock_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import threading


class AgencyConfigLockManager:
"""Lock manager for agency config files"""

_locks: dict[str, threading.Lock] = {}

@classmethod
def get_lock(cls, agency_id):
if agency_id not in cls._locks:
cls._locks[agency_id] = threading.Lock()
return cls._locks[agency_id]
25 changes: 15 additions & 10 deletions src/nalgonda/agency_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,35 @@
import time

from agency_swarm import Agency, Agent
from agency_swarm.util.oai import get_openai_client
from nalgonda.config import AgencyConfig
from nalgonda.custom_tools import TOOL_MAPPING

client = get_openai_client()

logger = logging.getLogger(__name__)


class AgencyManager:
def __init__(self):
self.cache = {} # agency_id: agency
self.lock = asyncio.Lock()

async def get_or_create_agency(self, agency_id: str) -> Agency:
"""Get or create the agency for the given session ID"""
if agency_id in self.cache:
return self.cache[agency_id]
async with self.lock:
if agency_id in self.cache:
return self.cache[agency_id]

# Async-to-Sync Bridge
agency = await asyncio.to_thread(self.load_agency_from_config, agency_id)
self.cache[agency_id] = agency
return agency
# Note: Async-to-Sync Bridge
agency = await asyncio.to_thread(self.load_agency_from_config, agency_id)
self.cache[agency_id] = agency
return agency

@staticmethod
def load_agency_from_config(agency_id: str) -> Agency:
"""Load the agency from the config file"""
"""Load the agency from the config file. The agency is created using the agency-swarm library.
This code is synchronous and should be run in a single thread.
The code is currently not thread safe (due to agency-swarm limitations).
"""

start = time.time()
config = AgencyConfig.load(agency_id)
Expand All @@ -51,6 +54,8 @@ def load_agency_from_config(agency_id: str) -> Agency:
for chart in config.agency_chart
]

# Create the agency using external library agency-swarm. It is a wrapper around OpenAI API.
# It saves all the settings in the settings.json file (in the root folder, not thread safe)
agency = Agency(agency_chart, shared_instructions=config.agency_manifesto)

config.update_agent_ids_in_config(agency_id, agents=agency.agents)
Expand Down
7 changes: 5 additions & 2 deletions src/nalgonda/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import Path

from agency_config_lock_manager import AgencyConfigLockManager
from agency_swarm import Agent
from nalgonda.constants import CONFIG_FILE, DEFAULT_CONFIG_FILE
from pydantic import BaseModel, Field
Expand Down Expand Up @@ -50,12 +51,14 @@ def load(cls, agency_id: str) -> "AgencyConfig":
config_file_name = cls.get_config_name(agency_id)
config_file_name = config_file_name if config_file_name.exists() else DEFAULT_CONFIG_FILE

with open(config_file_name) as f:
lock = AgencyConfigLockManager.get_lock(agency_id)
with lock as _, open(config_file_name) as f:
return cls.model_validate_json(f.read())

def save(self, agency_id: str) -> None:
"""Save the config to a file"""
with open(self.get_config_name(agency_id), "w") as f:
lock = AgencyConfigLockManager.get_lock(agency_id)
with lock as _, open(self.get_config_name(agency_id), "w") as f:
f.write(self.model_dump_json(indent=2))

@staticmethod
Expand Down
15 changes: 9 additions & 6 deletions src/nalgonda/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,17 @@ async def create_agency():
class ConnectionManager:
def __init__(self):
self.active_connections: list[WebSocket] = []
self.connections_lock = asyncio.Lock()

async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
async with self.connections_lock:
self.active_connections.append(websocket)

def disconnect(self, websocket: WebSocket):
if websocket in self.active_connections:
self.active_connections.remove(websocket)
async def disconnect(self, websocket: WebSocket):
async with self.connections_lock:
if websocket in self.active_connections:
self.active_connections.remove(websocket)

async def send_message(self, message: str, websocket: WebSocket):
await websocket.send_text(message)
Expand All @@ -70,7 +73,7 @@ async def websocket_endpoint(websocket: WebSocket, agency_id: str):

if not user_message.strip():
await ws_manager.send_message("message not provided", websocket)
ws_manager.disconnect(websocket)
await ws_manager.disconnect(websocket)
await websocket.close(code=1003)
return

Expand All @@ -80,7 +83,7 @@ async def websocket_endpoint(websocket: WebSocket, agency_id: str):
await ws_manager.send_message(response_text, websocket)

except WebSocketDisconnect:
ws_manager.disconnect(websocket)
await ws_manager.disconnect(websocket)
logger.info(f"WebSocket disconnected for agency_id: {agency_id}")


Expand Down

0 comments on commit 1e33730

Please sign in to comment.