Skip to content

Commit

Permalink
fix(rnd): Disable unused prisma connection on pyro API Server process (
Browse files Browse the repository at this point in the history
…Significant-Gravitas#7641)

### Background

Pyro for API Server is not using Prisma, but still holding a Prisma connection.
The fast-API thread is also holding a Prisma connection, making Prisma connected in two different loop within a single process.

### Changes 🏗️

Disable a Prisma connection on Pyro thread for Server API process.
Fix test flakiness issue due to concurrency issue.
  • Loading branch information
majdyz authored Jul 30, 2024
1 parent 29ba4c2 commit 122f544
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 39 deletions.
73 changes: 38 additions & 35 deletions rnd/autogpt_server/autogpt_server/executor/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,12 @@ def register_next_executions(node_link: Link) -> list[NodeExecution]:
logger.error(f"{prefix} Error, next node {next_node_id} not found.")
return enqueued_executions

# Upserting execution input includes reading the existing input pins in the node
# which then either updating the existing execution input or creating a new one.
# While reading, we should avoid any other process to add input to the same node.
# Multiple node can register the same next node, we need this to be atomic
# To avoid same execution to be enqueued multiple times,
# Or the same input to be consumed multiple times.
with synchronized(api_client, ("upsert_input", next_node_id, graph_exec_id)):

# Add output data to the earliest incomplete execution, or create a new one.
next_node_exec_id, next_node_input = wait(
upsert_execution_input(
node_id=next_node_id,
Expand All @@ -173,40 +175,41 @@ def register_next_executions(node_link: Link) -> list[NodeExecution]:
)
)

# Complete missing static input pins data using the last execution input.
static_link_names = {
link.sink_name
for link in next_node.input_links
if link.is_static and link.sink_name not in next_node_input
}
if static_link_names and (
latest_execution := wait(get_latest_execution(next_node_id, graph_exec_id))
):
for name in static_link_names:
next_node_input[name] = latest_execution.input_data.get(name)

next_node_input, validation_msg = validate_exec(next_node, next_node_input)
suffix = (
f"{next_output_name}>{next_input_name}~{next_node_exec_id}:{validation_msg}"
)

if not next_node_input:
logger.warning(f"{prefix} Skipped queueing {suffix}")
return enqueued_executions

# Input is complete, enqueue the execution.
logger.warning(f"{prefix} Enqueued {suffix}")
enqueued_executions.append(
add_enqueued_execution(next_node_exec_id, next_node_id, next_node_input)
)
# Complete missing static input pins data using the last execution input.
static_link_names = {
link.sink_name
for link in next_node.input_links
if link.is_static and link.sink_name not in next_node_input
}
if static_link_names and (
latest_execution := wait(
get_latest_execution(next_node_id, graph_exec_id)
)
):
for name in static_link_names:
next_node_input[name] = latest_execution.input_data.get(name)

# Validate the input data for the next node.
next_node_input, validation_msg = validate_exec(next_node, next_node_input)
suffix = f"{next_output_name}>{next_input_name}~{next_node_exec_id}:{validation_msg}"

# Incomplete input data, skip queueing the execution.
if not next_node_input:
logger.warning(f"{prefix} Skipped queueing {suffix}")
return enqueued_executions

# Input is complete, enqueue the execution.
logger.warning(f"{prefix} Enqueued {suffix}")
enqueued_executions.append(
add_enqueued_execution(next_node_exec_id, next_node_id, next_node_input)
)

if not node_link.is_static:
return enqueued_executions
# Next execution stops here if the link is not static.
if not node_link.is_static:
return enqueued_executions

# If link is static, there could be some incomplete executions waiting for it.
# Load and complete the input missing input data, and try to re-enqueue them.
# While reading, we should avoid any other process to re-enqueue the same node.
with synchronized(api_client, ("upsert_input", next_node_id, graph_exec_id)):
# If link is static, there could be some incomplete executions waiting for it.
# Load and complete the input missing input data, and try to re-enqueue them.
for iexec in wait(get_incomplete_executions(next_node_id, graph_exec_id)):
idata = iexec.input_data
ineid = iexec.node_exec_id
Expand Down
5 changes: 3 additions & 2 deletions rnd/autogpt_server/autogpt_server/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class AgentServer(AppService):
event_queue: asyncio.Queue[ExecutionResult] = asyncio.Queue()
manager = ConnectionManager()
mutex = KeyedMutex()
use_db = False

async def event_broadcaster(self):
while True:
Expand All @@ -53,8 +54,8 @@ async def event_broadcaster(self):
@asynccontextmanager
async def lifespan(self, _: FastAPI):
await db.connect()
self.run_and_wait(block.initialize_blocks())
self.run_and_wait(graph_db.import_packaged_templates())
await block.initialize_blocks()
await graph_db.import_packaged_templates()
asyncio.create_task(self.event_broadcaster())
yield
await db.disconnect()
Expand Down
4 changes: 3 additions & 1 deletion rnd/autogpt_server/autogpt_server/util/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def run(self):

class AppService(AppProcess):
shared_event_loop: asyncio.AbstractEventLoop
use_db: bool = True

@classmethod
@property
Expand All @@ -60,7 +61,8 @@ def run_and_wait(self, coro: Coroutine[T, Any, T]) -> T:

def run(self):
self.shared_event_loop = asyncio.get_event_loop()
self.shared_event_loop.run_until_complete(db.connect())
if self.use_db:
self.shared_event_loop.run_until_complete(db.connect())

# Initialize the async loop.
async_thread = threading.Thread(target=self.__start_async_loop)
Expand Down
1 change: 0 additions & 1 deletion rnd/autogpt_server/test/executor/test_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest

from autogpt_server.blocks.basic import ObjectLookupBlock, ValueBlock
from autogpt_server.blocks.if_block import ComparisonOperator, ConditionBlock
from autogpt_server.blocks.maths import MathsBlock, Operation
from autogpt_server.data import execution, graph
from autogpt_server.executor import ExecutionManager
Expand Down

0 comments on commit 122f544

Please sign in to comment.