Skip to content

Commit

Permalink
feature(backend): Add ability to execute store agents without agent o…
Browse files Browse the repository at this point in the history
…wnership (#9179)

This PR enables the execution of store agents even if they are not owned
by the user. Key changes include handling store-listed agents in the
`get_graph` logic, improving execution flow, and ensuring
version-specific handling. These updates support more flexible agent
execution.

- **Graph Retrieval:** Updated `get_graph` to check store listings for
agents not owned by the user.
- **Version Handling:** Added `graph_version` to execution methods for
consistent version-specific execution.
- **Execution Flow:** Refactored `scheduler.py`, `rest_api.py`, and
other modules for clearer logic and better maintainability.
- **Testing:** Updated `test_manager.py` and other test cases to
validate execution of store-listed agents added test for accessing graph

---------

Co-authored-by: Reinier van der Leer <[email protected]>
Co-authored-by: Zamil Majdy <[email protected]>
  • Loading branch information
3 people committed Jan 14, 2025
1 parent fe84cbe commit 1ecc073
Show file tree
Hide file tree
Showing 17 changed files with 416 additions and 53 deletions.
40 changes: 29 additions & 11 deletions autogpt_platform/backend/backend/data/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
from typing import Any, Literal, Optional, Type

import prisma
from prisma.models import AgentGraph, AgentGraphExecution, AgentNode, AgentNodeLink
from prisma.models import (
AgentGraph,
AgentGraphExecution,
AgentNode,
AgentNodeLink,
StoreListingVersion,
)
from prisma.types import AgentGraphWhereInput
from pydantic.fields import computed_field

Expand Down Expand Up @@ -529,7 +535,6 @@ async def get_execution(user_id: str, execution_id: str) -> GraphExecution | Non
async def get_graph(
graph_id: str,
version: int | None = None,
template: bool = False,
user_id: str | None = None,
for_export: bool = False,
) -> GraphModel | None:
Expand All @@ -543,21 +548,36 @@ async def get_graph(
where_clause: AgentGraphWhereInput = {
"id": graph_id,
}

if version is not None:
where_clause["version"] = version
elif not template:
else:
where_clause["isActive"] = True

# TODO: Fix hack workaround to get adding store agents to work
if user_id is not None and not template:
where_clause["userId"] = user_id

graph = await AgentGraph.prisma().find_first(
where=where_clause,
include=AGENT_GRAPH_INCLUDE,
order={"version": "desc"},
)
return GraphModel.from_db(graph, for_export) if graph else None

# The Graph has to be owned by the user or a store listing.
if (
graph is None
or graph.userId != user_id
and not (
await StoreListingVersion.prisma().find_first(
where=prisma.types.StoreListingVersionWhereInput(
agentId=graph_id,
agentVersion=version or graph.version,
isDeleted=False,
StoreListing={"is": {"isApproved": True}},
)
)
)
):
return None

return GraphModel.from_db(graph, for_export)


async def set_graph_active_version(graph_id: str, version: int, user_id: str) -> None:
Expand Down Expand Up @@ -611,9 +631,7 @@ async def create_graph(graph: Graph, user_id: str) -> GraphModel:
async with transaction() as tx:
await __create_graph(tx, graph, user_id)

if created_graph := await get_graph(
graph.id, graph.version, graph.is_template, user_id=user_id
):
if created_graph := await get_graph(graph.id, graph.version, user_id=user_id):
return created_graph

raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")
Expand Down
2 changes: 1 addition & 1 deletion autogpt_platform/backend/backend/executor/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ def add_execution(
graph_id: str,
data: BlockInput,
user_id: str,
graph_version: int | None = None,
graph_version: int,
) -> GraphExecutionEntry:
graph: GraphModel | None = self.db_client.get_graph(
graph_id=graph_id, user_id=user_id, version=graph_version
Expand Down
5 changes: 4 additions & 1 deletion autogpt_platform/backend/backend/executor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ def execute_graph(**kwargs):
try:
log(f"Executing recurring job for graph #{args.graph_id}")
get_execution_client().add_execution(
args.graph_id, args.input_data, args.user_id
graph_id=args.graph_id,
data=args.input_data,
user_id=args.user_id,
graph_version=args.graph_version,
)
except Exception as e:
logger.exception(f"Error executing graph {args.graph_id}: {e}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,8 @@ async def webhook_ingress_generic(
continue
logger.debug(f"Executing graph #{node.graph_id} node #{node.id}")
executor.add_execution(
node.graph_id,
graph_id=node.graph_id,
graph_version=node.graph_version,
data={f"webhook_{webhook_id}_payload": payload},
user_id=webhook.user_id,
)
Expand Down
34 changes: 32 additions & 2 deletions autogpt_platform/backend/backend/server/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import typing

import autogpt_libs.auth.models
import fastapi
import fastapi.responses
import starlette.middleware.cors
Expand All @@ -17,6 +18,7 @@
import backend.data.user
import backend.server.routers.v1
import backend.server.v2.library.routes
import backend.server.v2.store.model
import backend.server.v2.store.routes
import backend.util.service
import backend.util.settings
Expand Down Expand Up @@ -117,9 +119,24 @@ def run(self):

@staticmethod
async def test_execute_graph(
graph_id: str, node_input: dict[typing.Any, typing.Any], user_id: str
graph_id: str,
graph_version: int,
node_input: dict[typing.Any, typing.Any],
user_id: str,
):
return backend.server.routers.v1.execute_graph(graph_id, node_input, user_id)
return backend.server.routers.v1.execute_graph(
graph_id, graph_version, node_input, user_id
)

@staticmethod
async def test_get_graph(
graph_id: str,
graph_version: int,
user_id: str,
):
return await backend.server.routers.v1.get_graph(
graph_id, user_id, graph_version
)

@staticmethod
async def test_create_graph(
Expand Down Expand Up @@ -149,5 +166,18 @@ async def test_get_graph_run_node_execution_results(
async def test_delete_graph(graph_id: str, user_id: str):
return await backend.server.routers.v1.delete_graph(graph_id, user_id)

@staticmethod
async def test_create_store_listing(
request: backend.server.v2.store.model.StoreSubmissionRequest, user_id: str
):
return await backend.server.v2.store.routes.create_submission(request, user_id)

@staticmethod
async def test_review_store_listing(
request: backend.server.v2.store.model.ReviewSubmissionRequest,
user: autogpt_libs.auth.models.User,
):
return await backend.server.v2.store.routes.review_submission(request, user)

def set_test_dependency_overrides(self, overrides: dict):
app.dependency_overrides.update(overrides)
13 changes: 5 additions & 8 deletions autogpt_platform/backend/backend/server/routers/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,11 @@ async def get_graph_all_versions(
async def create_new_graph(
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
) -> graph_db.GraphModel:
return await do_create_graph(create_graph, is_template=False, user_id=user_id)
return await do_create_graph(create_graph, user_id=user_id)


async def do_create_graph(
create_graph: CreateGraph,
is_template: bool,
# user_id doesn't have to be annotated like on other endpoints,
# because create_graph isn't used directly as an endpoint
user_id: str,
Expand All @@ -217,7 +216,6 @@ async def do_create_graph(
graph = await graph_db.get_graph(
create_graph.template_id,
create_graph.template_version,
template=True,
user_id=user_id,
)
if not graph:
Expand All @@ -230,8 +228,6 @@ async def do_create_graph(
status_code=400, detail="Either graph or template_id must be provided."
)

graph.is_template = is_template
graph.is_active = not is_template
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)

graph = await graph_db.create_graph(graph, user_id=user_id)
Expand Down Expand Up @@ -368,12 +364,13 @@ def get_credentials(credentials_id: str) -> "Credentials | None":
)
def execute_graph(
graph_id: str,
graph_version: int,
node_input: dict[Any, Any],
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[str, Any]: # FIXME: add proper return type
try:
graph_exec = execution_manager_client().add_execution(
graph_id, node_input, user_id=user_id
graph_id, node_input, user_id=user_id, graph_version=graph_version
)
return {"id": graph_exec.graph_exec_id}
except Exception as e:
Expand Down Expand Up @@ -452,7 +449,7 @@ async def get_templates(
async def get_template(
graph_id: str, version: int | None = None
) -> graph_db.GraphModel:
graph = await graph_db.get_graph(graph_id, version, template=True)
graph = await graph_db.get_graph(graph_id, version)
if not graph:
raise HTTPException(status_code=404, detail=f"Template #{graph_id} not found.")
return graph
Expand All @@ -466,7 +463,7 @@ async def get_template(
async def create_new_template(
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
) -> graph_db.GraphModel:
return await do_create_graph(create_graph, is_template=True, user_id=user_id)
return await do_create_graph(create_graph, user_id=user_id)


########################################################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ async def add_agent_to_library(

# Create a new graph from the template
graph = await backend.data.graph.get_graph(
agent.id, agent.version, template=True, user_id=user_id
agent.id, agent.version, user_id=user_id
)

if not graph:
Expand Down
95 changes: 87 additions & 8 deletions autogpt_platform/backend/backend/server/v2/store/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,10 @@ async def get_store_submissions(
where = prisma.types.StoreSubmissionWhereInput(user_id=user_id)
# Query submissions from database
submissions = await prisma.models.StoreSubmission.prisma().find_many(
where=where, skip=skip, take=page_size, order=[{"date_submitted": "desc"}]
where=where,
skip=skip,
take=page_size,
order=[{"date_submitted": "desc"}],
)

# Get total count for pagination
Expand Down Expand Up @@ -405,9 +408,7 @@ async def delete_store_submission(
)

# Delete the submission
await prisma.models.StoreListing.prisma().delete(
where=prisma.types.StoreListingWhereUniqueInput(id=submission.id)
)
await prisma.models.StoreListing.prisma().delete(where={"id": submission.id})

logger.debug(
f"Successfully deleted submission {submission_id} for user {user_id}"
Expand Down Expand Up @@ -504,7 +505,15 @@ async def create_store_submission(
"subHeading": sub_heading,
}
},
}
},
include={"StoreListingVersions": True},
)

slv_id = (
listing.StoreListingVersions[0].id
if listing.StoreListingVersions is not None
and len(listing.StoreListingVersions) > 0
else None
)

logger.debug(f"Created store listing for agent {agent_id}")
Expand All @@ -521,6 +530,7 @@ async def create_store_submission(
status=prisma.enums.SubmissionStatus.PENDING,
runs=0,
rating=0.0,
store_listing_version_id=slv_id,
)

except (
Expand Down Expand Up @@ -811,9 +821,7 @@ async def get_agent(

agent = store_listing_version.Agent

graph = await backend.data.graph.get_graph(
agent.id, agent.version, template=True
)
graph = await backend.data.graph.get_graph(agent.id, agent.version)

if not graph:
raise fastapi.HTTPException(
Expand All @@ -832,3 +840,74 @@ async def get_agent(
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to fetch agent"
) from e


async def review_store_submission(
store_listing_version_id: str, is_approved: bool, comments: str, reviewer_id: str
) -> prisma.models.StoreListingSubmission:
"""Review a store listing submission."""
try:
store_listing_version = (
await prisma.models.StoreListingVersion.prisma().find_unique(
where={"id": store_listing_version_id},
include={"StoreListing": True},
)
)

if not store_listing_version or not store_listing_version.StoreListing:
raise fastapi.HTTPException(
status_code=404,
detail=f"Store listing version {store_listing_version_id} not found",
)

status = (
prisma.enums.SubmissionStatus.APPROVED
if is_approved
else prisma.enums.SubmissionStatus.REJECTED
)

create_data = prisma.types.StoreListingSubmissionCreateInput(
StoreListingVersion={"connect": {"id": store_listing_version_id}},
Status=status,
reviewComments=comments,
Reviewer={"connect": {"id": reviewer_id}},
StoreListing={"connect": {"id": store_listing_version.StoreListing.id}},
createdAt=datetime.now(),
updatedAt=datetime.now(),
)

update_data = prisma.types.StoreListingSubmissionUpdateInput(
Status=status,
reviewComments=comments,
Reviewer={"connect": {"id": reviewer_id}},
StoreListing={"connect": {"id": store_listing_version.StoreListing.id}},
updatedAt=datetime.now(),
)

if is_approved:
await prisma.models.StoreListing.prisma().update(
where={"id": store_listing_version.StoreListing.id},
data={"isApproved": True},
)

submission = await prisma.models.StoreListingSubmission.prisma().upsert(
where={"storeListingVersionId": store_listing_version_id},
data=prisma.types.StoreListingSubmissionUpsertInput(
create=create_data,
update=update_data,
),
)

if not submission:
raise fastapi.HTTPException(
status_code=404,
detail=f"Store listing submission {store_listing_version_id} not found",
)

return submission

except Exception as e:
logger.error(f"Error reviewing store submission: {str(e)}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to review store submission"
) from e
7 changes: 7 additions & 0 deletions autogpt_platform/backend/backend/server/v2/store/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class StoreSubmission(pydantic.BaseModel):
status: prisma.enums.SubmissionStatus
runs: int
rating: float
store_listing_version_id: str | None = None


class StoreSubmissionsResponse(pydantic.BaseModel):
Expand Down Expand Up @@ -151,3 +152,9 @@ class StoreReviewCreate(pydantic.BaseModel):
store_listing_version_id: str
score: int
comments: str | None = None


class ReviewSubmissionRequest(pydantic.BaseModel):
store_listing_version_id: str
isApproved: bool
comments: str
Loading

0 comments on commit 1ecc073

Please sign in to comment.