Skip to content

Commit

Permalink
feat(platform): Support multiple credentials inputs on blocks (#8932)
Browse files Browse the repository at this point in the history
- Resolves #8930
- Depends on #8725

### Changes 🏗️

- feat(platform): Support multiple credentials inputs on blocks

Aside from `credentials`, fields within the name pattern `*_credentials`
are now also supported!

- Update docs with info on multi credentials support

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Ask @aarushik93 to test
Pwuts authored Jan 3, 2025
1 parent 84af37a commit 1375a0f
Showing 10 changed files with 207 additions and 127 deletions.
70 changes: 36 additions & 34 deletions autogpt_platform/backend/backend/data/block.py
Original file line number Diff line number Diff line change
@@ -22,10 +22,10 @@
from backend.util.settings import Config

from .model import (
CREDENTIALS_FIELD_NAME,
ContributorDetails,
Credentials,
CredentialsMetaInput,
is_credentials_field_name,
)

app_config = Config()
@@ -138,17 +138,38 @@ def get_required_fields(cls) -> set[str]:
@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
"""Validates the schema definition. Rules:
- Only one `CredentialsMetaInput` field may be present.
- This field MUST be called `credentials`.
- A field that is called `credentials` MUST be a `CredentialsMetaInput`.
- Fields with annotation `CredentialsMetaInput` MUST be
named `credentials` or `*_credentials`
- Fields named `credentials` or `*_credentials` MUST be
of type `CredentialsMetaInput`
"""
super().__pydantic_init_subclass__(**kwargs)

# Reset cached JSON schema to prevent inheriting it from parent class
cls.cached_jsonschema = {}

credentials_fields = [
field_name
credentials_fields = cls.get_credentials_fields()

for field_name in cls.get_fields():
if is_credentials_field_name(field_name):
if field_name not in credentials_fields:
raise TypeError(
f"Credentials field '{field_name}' on {cls.__qualname__} "
f"is not of type {CredentialsMetaInput.__name__}"
)

credentials_fields[field_name].validate_credentials_field_schema(cls)

elif field_name in credentials_fields:
raise KeyError(
f"Credentials field '{field_name}' on {cls.__qualname__} "
"has invalid name: must be 'credentials' or *_credentials"
)

@classmethod
def get_credentials_fields(cls) -> dict[str, type[CredentialsMetaInput]]:
return {
field_name: info.annotation
for field_name, info in cls.model_fields.items()
if (
inspect.isclass(info.annotation)
@@ -157,32 +178,7 @@ def __pydantic_init_subclass__(cls, **kwargs):
CredentialsMetaInput,
)
)
]
if len(credentials_fields) > 1:
raise ValueError(
f"{cls.__qualname__} can only have one CredentialsMetaInput field"
)
elif (
len(credentials_fields) == 1
and credentials_fields[0] != CREDENTIALS_FIELD_NAME
):
raise ValueError(
f"CredentialsMetaInput field on {cls.__qualname__} "
"must be named 'credentials'"
)
elif (
len(credentials_fields) == 0
and CREDENTIALS_FIELD_NAME in cls.model_fields.keys()
):
raise TypeError(
f"Field 'credentials' on {cls.__qualname__} "
f"must be of type {CredentialsMetaInput.__name__}"
)
if credentials_field := cls.model_fields.get(CREDENTIALS_FIELD_NAME):
credentials_input_type = cast(
CredentialsMetaInput, credentials_field.annotation
)
credentials_input_type.validate_credentials_field_schema(cls)
}


BlockSchemaInputType = TypeVar("BlockSchemaInputType", bound=BlockSchema)
@@ -255,7 +251,7 @@ def __init__(
test_input: BlockInput | list[BlockInput] | None = None,
test_output: BlockData | list[BlockData] | None = None,
test_mock: dict[str, Any] | None = None,
test_credentials: Optional[Credentials] = None,
test_credentials: Optional[Credentials | dict[str, Credentials]] = None,
disabled: bool = False,
static_output: bool = False,
block_type: BlockType = BlockType.STANDARD,
@@ -297,10 +293,16 @@ def __init__(
if self.webhook_config:
if isinstance(self.webhook_config, BlockWebhookConfig):
# Enforce presence of credentials field on auto-setup webhook blocks
if CREDENTIALS_FIELD_NAME not in self.input_schema.model_fields:
if not (cred_fields := self.input_schema.get_credentials_fields()):
raise TypeError(
"credentials field is required on auto-setup webhook blocks"
)
# Disallow multiple credentials inputs on webhook blocks
elif len(cred_fields) > 1:
raise ValueError(
"Multiple credentials inputs not supported on webhook blocks"
)

self.block_type = BlockType.WEBHOOK
else:
self.block_type = BlockType.WEBHOOK_MANUAL
44 changes: 27 additions & 17 deletions autogpt_platform/backend/backend/data/model.py
Original file line number Diff line number Diff line change
@@ -245,7 +245,8 @@ class UserIntegrations(BaseModel):
CT = TypeVar("CT", bound=CredentialsType)


CREDENTIALS_FIELD_NAME = "credentials"
def is_credentials_field_name(field_name: str) -> bool:
return field_name == "credentials" or field_name.endswith("_credentials")


class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
@@ -254,21 +255,21 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
provider: CP
type: CT

@staticmethod
def _add_json_schema_extra(schema, cls: CredentialsMetaInput):
schema["credentials_provider"] = get_args(
cls.model_fields["provider"].annotation
)
schema["credentials_types"] = get_args(cls.model_fields["type"].annotation)
@classmethod
def allowed_providers(cls) -> tuple[ProviderName, ...]:
return get_args(cls.model_fields["provider"].annotation)

model_config = ConfigDict(
json_schema_extra=_add_json_schema_extra, # type: ignore
)
@classmethod
def allowed_cred_types(cls) -> tuple[CredentialsType, ...]:
return get_args(cls.model_fields["type"].annotation)

@classmethod
def validate_credentials_field_schema(cls, model: type["BlockSchema"]):
"""Validates the schema of a `credentials` field"""
field_schema = model.jsonschema()["properties"][CREDENTIALS_FIELD_NAME]
"""Validates the schema of a credentials input field"""
field_name = next(
name for name, type in model.get_credentials_fields().items() if type is cls
)
field_schema = model.jsonschema()["properties"][field_name]
try:
schema_extra = _CredentialsFieldSchemaExtra[CP, CT].model_validate(
field_schema
@@ -282,11 +283,20 @@ def validate_credentials_field_schema(cls, model: type["BlockSchema"]):
f"{field_schema}"
) from e

if (
len(schema_extra.credentials_provider) > 1
and not schema_extra.discriminator
):
raise TypeError("Multi-provider CredentialsField requires discriminator!")
if len(cls.allowed_providers()) > 1 and not schema_extra.discriminator:
raise TypeError(
f"Multi-provider CredentialsField '{field_name}' "
"requires discriminator!"
)

@staticmethod
def _add_json_schema_extra(schema, cls: CredentialsMetaInput):
schema["credentials_provider"] = cls.allowed_providers()
schema["credentials_types"] = cls.allowed_cred_types()

model_config = ConfigDict(
json_schema_extra=_add_json_schema_extra, # type: ignore
)


class _CredentialsFieldSchemaExtra(BaseModel, Generic[CP, CT]):
78 changes: 41 additions & 37 deletions autogpt_platform/backend/backend/executor/manager.py
Original file line number Diff line number Diff line change
@@ -10,7 +10,6 @@
from multiprocessing.pool import AsyncResult, Pool
from typing import TYPE_CHECKING, Any, Generator, TypeVar, cast

from pydantic import BaseModel
from redis.lock import Lock as RedisLock

if TYPE_CHECKING:
@@ -20,7 +19,14 @@

from backend.blocks.agent import AgentExecutorBlock
from backend.data import redis
from backend.data.block import Block, BlockData, BlockInput, BlockType, get_block
from backend.data.block import (
Block,
BlockData,
BlockInput,
BlockSchema,
BlockType,
get_block,
)
from backend.data.execution import (
ExecutionQueue,
ExecutionResult,
@@ -31,7 +37,6 @@
parse_execution_output,
)
from backend.data.graph import GraphModel, Link, Node
from backend.data.model import CREDENTIALS_FIELD_NAME, CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util import json
from backend.util.decorator import error_logged, time_measured
@@ -170,10 +175,11 @@ def update_execution(status: ExecutionStatus) -> ExecutionResult:
# one (running) block at a time; simultaneous execution of blocks using same
# credentials is not supported.
creds_lock = None
if CREDENTIALS_FIELD_NAME in input_data:
credentials_meta = CredentialsMetaInput(**input_data[CREDENTIALS_FIELD_NAME])
input_model = cast(type[BlockSchema], node_block.input_schema)
for field_name, input_type in input_model.get_credentials_fields().items():
credentials_meta = input_type(**input_data[field_name])
credentials, creds_lock = creds_manager.acquire(user_id, credentials_meta.id)
extra_exec_kwargs["credentials"] = credentials
extra_exec_kwargs[field_name] = credentials

output_size = 0
end_status = ExecutionStatus.COMPLETED
@@ -893,41 +899,39 @@ def _validate_node_input_credentials(self, graph: GraphModel, user_id: str):
raise ValueError(f"Unknown block {node.block_id} for node #{node.id}")

# Find any fields of type CredentialsMetaInput
model_fields = cast(type[BaseModel], block.input_schema).model_fields
if CREDENTIALS_FIELD_NAME not in model_fields:
credentials_fields = cast(
type[BlockSchema], block.input_schema
).get_credentials_fields()
if not credentials_fields:
continue

field = model_fields[CREDENTIALS_FIELD_NAME]

# The BlockSchema class enforces that a `credentials` field is always a
# `CredentialsMetaInput`, so we can safely assume this here.
credentials_meta_type = cast(CredentialsMetaInput, field.annotation)
credentials_meta = credentials_meta_type.model_validate(
node.input_default[CREDENTIALS_FIELD_NAME]
)
# Fetch the corresponding Credentials and perform sanity checks
credentials = self.credentials_store.get_creds_by_id(
user_id, credentials_meta.id
)
if not credentials:
raise ValueError(
f"Unknown credentials #{credentials_meta.id} "
f"for node #{node.id}"
)
if (
credentials.provider != credentials_meta.provider
or credentials.type != credentials_meta.type
):
logger.warning(
f"Invalid credentials #{credentials.id} for node #{node.id}: "
"type/provider mismatch: "
f"{credentials_meta.type}<>{credentials.type};"
f"{credentials_meta.provider}<>{credentials.provider}"
for field_name, credentials_meta_type in credentials_fields.items():
credentials_meta = credentials_meta_type.model_validate(
node.input_default[field_name]
)
raise ValueError(
f"Invalid credentials #{credentials.id} for node #{node.id}: "
"type/provider mismatch"
# Fetch the corresponding Credentials and perform sanity checks
credentials = self.credentials_store.get_creds_by_id(
user_id, credentials_meta.id
)
if not credentials:
raise ValueError(
f"Unknown credentials #{credentials_meta.id} "
f"for node #{node.id} input '{field_name}'"
)
if (
credentials.provider != credentials_meta.provider
or credentials.type != credentials_meta.type
):
logger.warning(
f"Invalid credentials #{credentials.id} for node #{node.id}: "
"type/provider mismatch: "
f"{credentials_meta.type}<>{credentials.type};"
f"{credentials_meta.provider}<>{credentials.provider}"
)
raise ValueError(
f"Invalid credentials #{credentials.id} for node #{node.id}: "
"type/provider mismatch"
)


# ------- UTILITIES ------- #
Loading

0 comments on commit 1375a0f

Please sign in to comment.