Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(backend): Make Redis connection Sync + Use Redis as Distributed Lock #8197

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
818ee13
feat(platform): Make Redis connection Sync + Use Redis as Distributed…
majdyz Sep 26, 2024
9bc3b83
Remove ignore
majdyz Sep 26, 2024
a17df76
add output on pytest
majdyz Sep 27, 2024
3687313
add redis
majdyz Sep 27, 2024
18a52ad
Rename password
majdyz Sep 27, 2024
9c4152e
Try redis without command
majdyz Sep 27, 2024
556f968
Merge branch 'master' of github.com:Significant-Gravitas/AutoGPT into…
majdyz Sep 27, 2024
0dae213
Merge branch 'master' into zamilmajdy/secrt-866-we-currently-lock-the…
majdyz Sep 30, 2024
3ba80d7
Address comments
majdyz Oct 1, 2024
8a1d249
Address comments
majdyz Oct 1, 2024
3cfccb1
Merge branch 'master' of github.com:Significant-Gravitas/AutoGPT into…
majdyz Oct 1, 2024
aa3d292
Fix failing test on block id change
majdyz Oct 1, 2024
53dfc78
Merge branch 'master' into zamilmajdy/secrt-866-we-currently-lock-the…
aarushik93 Oct 2, 2024
bd8b84b
Master sync
majdyz Oct 2, 2024
28ccade
Merge remote-tracking branch 'origin/zamilmajdy/secrt-866-we-currentl…
majdyz Oct 2, 2024
d3d749e
Fix pid & service name extraction in connection acquisition
majdyz Oct 2, 2024
8f21c62
Merge branch 'master' into zamilmajdy/secrt-866-we-currently-lock-the…
majdyz Oct 4, 2024
fff8cb3
Merge branch 'master' into zamilmajdy/secrt-866-we-currently-lock-the…
majdyz Oct 4, 2024
53a6520
Merge branch 'master' into zamilmajdy/secrt-866-we-currently-lock-the…
majdyz Oct 7, 2024
2bfe62b
Merge branch 'master' into zamilmajdy/secrt-866-we-currently-lock-the…
aarushik93 Oct 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions .github/workflows/platform-backend-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ jobs:
python-version: ["3.10"]
runs-on: ubuntu-latest

services:
majdyz marked this conversation as resolved.
Show resolved Hide resolved
redis:
image: bitnami/redis:6.2
env:
REDIS_PASSWORD: testpassword
ports:
- 6379:6379

steps:
- name: Checkout repository
uses: actions/checkout@v4
Expand Down Expand Up @@ -96,9 +104,9 @@ jobs:
- name: Run pytest with coverage
run: |
if [[ "${{ runner.debug }}" == "1" ]]; then
poetry run pytest -vv -o log_cli=true -o log_cli_level=DEBUG test
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG test
else
poetry run pytest -vv test
poetry run pytest -s -vv test
fi
if: success() || (failure() && steps.lint.outcome == 'failure')
env:
Expand All @@ -107,6 +115,10 @@ jobs:
SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
SUPABASE_JWT_SECRET: ${{ steps.supabase.outputs.JWT_SECRET }}
REDIS_HOST: 'localhost'
REDIS_PORT: '6379'
REDIS_PASSWORD: 'testpassword'

env:
CI: true
PLAIN_OUTPUT: True
Expand Down
35 changes: 12 additions & 23 deletions autogpt_platform/backend/backend/data/db.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import logging
import os
from contextlib import asynccontextmanager
Expand All @@ -8,40 +7,30 @@
from prisma import Prisma
from pydantic import BaseModel, Field, field_validator

from backend.util.retry import conn_retry

load_dotenv()

PRISMA_SCHEMA = os.getenv("PRISMA_SCHEMA", "schema.prisma")
os.environ["PRISMA_SCHEMA_PATH"] = PRISMA_SCHEMA

prisma, conn_id = Prisma(auto_register=True), ""
prisma = Prisma(auto_register=True)

logger = logging.getLogger(__name__)


async def connect(call_count=0):
global conn_id
if not conn_id:
conn_id = str(uuid4())

try:
logger.info(f"[Prisma-{conn_id}] Acquiring connection..")
if not prisma.is_connected():
await prisma.connect()
logger.info(f"[Prisma-{conn_id}] Connection acquired!")
except Exception as e:
if call_count <= 5:
logger.info(f"[Prisma-{conn_id}] Connection failed: {e}. Retrying now..")
await asyncio.sleep(2**call_count)
await connect(call_count + 1)
else:
raise e
@conn_retry("Prisma", "Acquiring connection")
async def connect():
if prisma.is_connected():
return
await prisma.connect()


@conn_retry("Prisma", "Releasing connection")
async def disconnect():
if prisma.is_connected():
logger.info(f"[Prisma-{conn_id}] Releasing connection.")
await prisma.disconnect()
logger.info(f"[Prisma-{conn_id}] Connection released.")
if not prisma.is_connected():
return
await prisma.disconnect()


@asynccontextmanager
Expand Down
48 changes: 16 additions & 32 deletions autogpt_platform/backend/backend/data/queue.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import json
import logging
import os
from abc import ABC, abstractmethod
from datetime import datetime

from redis.asyncio import Redis

from backend.data import redis
from backend.data.execution import ExecutionResult

logger = logging.getLogger(__name__)
Expand All @@ -18,60 +16,46 @@ def default(self, o):
return super().default(o)


class AsyncEventQueue(ABC):
class AbstractEventQueue(ABC):
majdyz marked this conversation as resolved.
Show resolved Hide resolved
@abstractmethod
async def connect(self):
def connect(self):
pass

@abstractmethod
async def close(self):
def close(self):
pass

@abstractmethod
async def put(self, execution_result: ExecutionResult):
def put(self, execution_result: ExecutionResult):
pass

@abstractmethod
async def get(self) -> ExecutionResult | None:
def get(self) -> ExecutionResult | None:
pass


class AsyncRedisEventQueue(AsyncEventQueue):
class RedisEventQueue(AbstractEventQueue):
def __init__(self):
self.host = os.getenv("REDIS_HOST", "localhost")
self.port = int(os.getenv("REDIS_PORT", "6379"))
self.password = os.getenv("REDIS_PASSWORD", "password")
self.queue_name = os.getenv("REDIS_QUEUE", "execution_events")
self.connection = None
self.queue_name = redis.QUEUE_NAME

async def connect(self):
if not self.connection:
self.connection = Redis(
host=self.host,
port=self.port,
password=self.password,
decode_responses=True,
)
await self.connection.ping()
logger.info(f"Connected to Redis on {self.host}:{self.port}")
def connect(self):
self.connection = redis.connect()

async def put(self, execution_result: ExecutionResult):
def put(self, execution_result: ExecutionResult):
if self.connection:
message = json.dumps(execution_result.model_dump(), cls=DateTimeEncoder)
logger.info(f"Putting execution result to Redis {message}")
await self.connection.lpush(self.queue_name, message) # type: ignore
self.connection.lpush(self.queue_name, message)

async def get(self) -> ExecutionResult | None:
def get(self) -> ExecutionResult | None:
if self.connection:
message = await self.connection.rpop(self.queue_name) # type: ignore
message = self.connection.rpop(self.queue_name)
if message is not None and isinstance(message, (str, bytes, bytearray)):
data = json.loads(message)
logger.info(f"Getting execution result from Redis {data}")
return ExecutionResult(**data)
return None

async def close(self):
if self.connection:
await self.connection.close()
self.connection = None
logger.info("Closed connection to Redis")
def close(self):
redis.disconnect()
48 changes: 48 additions & 0 deletions autogpt_platform/backend/backend/data/redis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import logging
majdyz marked this conversation as resolved.
Show resolved Hide resolved
import os

from dotenv import load_dotenv
from redis import Redis

from backend.util.retry import conn_retry

load_dotenv()

HOST = os.getenv("REDIS_HOST", "localhost")
PORT = int(os.getenv("REDIS_PORT", "6379"))
PASSWORD = os.getenv("REDIS_PASSWORD", "password")
QUEUE_NAME = os.getenv("REDIS_QUEUE", "execution_events")

logger = logging.getLogger(__name__)
connection: Redis | None = None


@conn_retry("Redis", "Acquiring connection")
def connect() -> Redis:
global connection
if connection:
return connection

c = Redis(
host=HOST,
port=PORT,
password=PASSWORD,
decode_responses=True,
)
c.ping()
connection = c
return connection


@conn_retry("Redis", "Releasing connection")
def disconnect():
global connection
if connection:
connection.close()
connection = None


def get_redis() -> Redis:
if not connection:
raise RuntimeError("Redis connection is not established")
return connection
14 changes: 9 additions & 5 deletions autogpt_platform/backend/backend/executor/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
if TYPE_CHECKING:
from backend.server.rest_api import AgentServer

from backend.data import db
from backend.data import db, redis
majdyz marked this conversation as resolved.
Show resolved Hide resolved
from backend.data.block import Block, BlockData, BlockInput, BlockType, get_block
from backend.data.credit import get_user_credit_model
from backend.data.execution import (
Expand Down Expand Up @@ -216,12 +216,13 @@ def update_execution(status: ExecutionStatus) -> ExecutionResult:


@contextmanager
def synchronized(api_client: "AgentServer", key: Any):
api_client.acquire_lock(key)
def synchronized(key: str, timeout: int = 60):
lock = redis.get_redis().lock(f"lock:{key}", timeout=timeout)
try:
lock.acquire()
yield
finally:
api_client.release_lock(key)
lock.release()


def _enqueue_next_nodes(
Expand Down Expand Up @@ -268,7 +269,7 @@ def register_next_executions(node_link: Link) -> list[NodeExecution]:
# Multiple node can register the same next node, we need this to be atomic
# To avoid same execution to be enqueued multiple times,
# Or the same input to be consumed multiple times.
with synchronized(api_client, ("upsert_input", next_node_id, graph_exec_id)):
with synchronized(f"upsert_input-{next_node_id}-{graph_exec_id}"):
# Add output data to the earliest incomplete execution, or create a new one.
next_node_exec_id, next_node_input = wait(
upsert_execution_input(
Expand Down Expand Up @@ -437,6 +438,7 @@ def on_node_executor_start(cls):
cls.loop = asyncio.new_event_loop()
cls.pid = os.getpid()

redis.connect()
cls.loop.run_until_complete(db.connect())
cls.agent_server_client = get_agent_server_client()

Expand All @@ -454,6 +456,8 @@ def on_node_executor_stop(cls):

logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting DB...")
cls.loop.run_until_complete(db.disconnect())
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
redis.disconnect()
logger.info(f"[on_node_executor_stop {cls.pid}] ✅ Finished cleanup")

@classmethod
Expand Down
1 change: 0 additions & 1 deletion autogpt_platform/backend/backend/executor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def __init__(self, refresh_interval=10):
self.use_db = True
self.last_check = datetime.min
self.refresh_interval = refresh_interval
self.use_redis = False

@property
def execution_manager_client(self) -> ExecutionManager:
Expand Down
24 changes: 7 additions & 17 deletions autogpt_platform/backend/backend/server/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@
from backend.data import user as user_db
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.data.credit import get_block_costs, get_user_credit_model
from backend.data.queue import AsyncEventQueue, AsyncRedisEventQueue
from backend.data.queue import RedisEventQueue
from backend.data.user import get_or_create_user
from backend.executor import ExecutionManager, ExecutionScheduler
from backend.server.model import CreateGraph, SetGraphActiveVersion
from backend.util.lock import KeyedMutex
from backend.util.service import AppService, expose, get_service_client
from backend.util.settings import Config, Settings

Expand All @@ -32,24 +31,23 @@


class AgentServer(AppService):
mutex = KeyedMutex()
use_redis = True
use_queue = True
_test_dependency_overrides = {}
_user_credit_model = get_user_credit_model()

def __init__(self, event_queue: AsyncEventQueue | None = None):
def __init__(self):
super().__init__(port=Config().agent_server_port)
self.event_queue = event_queue or AsyncRedisEventQueue()
self.event_queue = RedisEventQueue()
majdyz marked this conversation as resolved.
Show resolved Hide resolved

@asynccontextmanager
async def lifespan(self, _: FastAPI):
await db.connect()
self.run_and_wait(self.event_queue.connect())
self.event_queue.connect()
await block.initialize_blocks()
if await user_db.create_default_user(settings.config.enable_auth):
await graph_db.import_packaged_templates()
yield
await self.event_queue.close()
self.event_queue.close()
await db.disconnect()

def run_service(self):
Expand Down Expand Up @@ -616,15 +614,7 @@ def get_execution_schedules(
@expose
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
execution_result = execution_db.ExecutionResult(**execution_result_dict)
self.run_and_wait(self.event_queue.put(execution_result))

@expose
def acquire_lock(self, key: Any):
self.mutex.lock(key)

@expose
def release_lock(self, key: Any):
self.mutex.unlock(key)
self.event_queue.put(execution_result)

@classmethod
def update_configuration(
Expand Down
15 changes: 9 additions & 6 deletions autogpt_platform/backend/backend/server/ws_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware

from backend.data.queue import AsyncRedisEventQueue
from backend.data.queue import RedisEventQueue
from backend.data.user import DEFAULT_USER_ID
from backend.server.conn_manager import ConnectionManager
from backend.server.model import ExecutionSubscription, Methods, WsMessage
Expand All @@ -20,15 +20,16 @@

@asynccontextmanager
async def lifespan(app: FastAPI):
await event_queue.connect()
event_queue.connect()
manager = get_connection_manager()
asyncio.create_task(event_broadcaster(manager))
fut = asyncio.create_task(event_broadcaster(manager))
fut.add_done_callback(lambda _: logger.info("Event broadcaster stopped"))
yield
await event_queue.close()
event_queue.close()


app = FastAPI(lifespan=lifespan)
event_queue = AsyncRedisEventQueue()
event_queue = RedisEventQueue()
_connection_manager = None

logger.info(f"CORS allow origins: {settings.config.backend_cors_allow_origins}")
Expand All @@ -50,9 +51,11 @@ def get_connection_manager():

async def event_broadcaster(manager: ConnectionManager):
while True:
event = await event_queue.get()
event = event_queue.get()
if event is not None:
await manager.send_execution_result(event)
else:
await asyncio.sleep(0.1)


async def authenticate_websocket(websocket: WebSocket) -> str:
Expand Down
Loading
Loading