Skip to content

Commit

Permalink
extracted FunctionSchema to public module tools
Browse files Browse the repository at this point in the history
  • Loading branch information
jonchun committed Feb 12, 2025
1 parent 17df4fc commit c30b05f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 23 deletions.
32 changes: 10 additions & 22 deletions pydantic_ai_slim/pydantic_ai/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations as _annotations

from inspect import Parameter, signature
from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast, get_origin
from typing import TYPE_CHECKING, Any, Callable, cast, get_origin

from pydantic import ConfigDict
from pydantic._internal import _decorators, _generate_schema, _typing_extra
Expand All @@ -20,24 +20,12 @@
from ._utils import check_object_json_schema, is_model_like

if TYPE_CHECKING:
from .tools import DocstringFormat, ObjectJsonSchema
from .tools import DocstringFormat, FunctionSchema


__all__ = ('function_schema',)


class FunctionSchema(TypedDict):
"""Internal information about a function schema."""

description: str
validator: SchemaValidator
json_schema: ObjectJsonSchema
# if not None, the function takes a single by that name (besides potentially `info`)
single_arg_name: str | None
positional_fields: list[str]
var_positional_field: str | None


def function_schema( # noqa: C901
function: Callable[..., Any],
takes_ctx: bool,
Expand Down Expand Up @@ -161,14 +149,14 @@ def function_schema( # noqa: C901
# and set it on the tool
description = json_schema.pop('description', None)

return FunctionSchema(
description=description,
validator=schema_validator,
json_schema=check_object_json_schema(json_schema),
single_arg_name=single_arg_name,
positional_fields=positional_fields,
var_positional_field=var_positional_field,
)
return {
'description': description,
'validator': schema_validator,
'json_schema': check_object_json_schema(json_schema),
'single_arg_name': single_arg_name,
'positional_fields': positional_fields,
'var_positional_field': var_positional_field,
}


def takes_ctx(function: Callable[..., Any]) -> bool:
Expand Down
15 changes: 14 additions & 1 deletion pydantic_ai_slim/pydantic_ai/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import inspect
from collections.abc import Awaitable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypedDict, Union, cast

from pydantic import ValidationError
from pydantic_core import SchemaValidator
Expand All @@ -19,6 +19,7 @@
__all__ = (
'AgentDepsT',
'DocstringFormat',
'FunctionSchema',
'RunContext',
'SystemPromptFunc',
'ToolFuncContext',
Expand All @@ -35,6 +36,18 @@
"""Type variable for agent dependencies."""


class FunctionSchema(TypedDict):
"""Internal information about a function schema."""

description: str
validator: SchemaValidator
json_schema: ObjectJsonSchema
# if not None, the function takes a single by that name (besides potentially `info`)
single_arg_name: str | None
positional_fields: list[str]
var_positional_field: str | None


@dataclasses.dataclass
class RunContext(Generic[AgentDepsT]):
"""Information about the current call."""
Expand Down

0 comments on commit c30b05f

Please sign in to comment.