Skip to content

Commit

Permalink
improve dspy.Tool
Browse files Browse the repository at this point in the history
  • Loading branch information
chenmoneygithub committed Jan 29, 2025
1 parent 72c342b commit 1975352
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 39 deletions.
2 changes: 1 addition & 1 deletion dspy/predict/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, signature, tools: list[Callable], max_iters=5, use_litellm_to
self.signature = signature = ensure_signature(signature)
self.max_iters = max_iters
self.use_litellm_tool_calling = use_litellm_tool_calling
tools = [t if isinstance(t, Tool) else Tool.from_function(t) for t in tools]
tools = [t if isinstance(t, Tool) else Tool(t) for t in tools]

self.tools_in_litellm_format = [tool.convert_to_litellm_tool_format() for tool in tools]
self.tools = {tool.name: tool for tool in tools}
Expand Down
60 changes: 32 additions & 28 deletions dspy/primitives/tool.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import inspect
from typing import Any, Callable, get_origin, get_type_hints
from typing import Any, Callable, Optional, get_origin, get_type_hints

from jsonschema import ValidationError, validate
from pydantic import BaseModel, TypeAdapter
Expand All @@ -16,31 +16,36 @@ class Tool:

def __init__(
self,
name: str = None,
desc: str = None,
parameters: dict[str, Any] = None,
arg_types: dict[str, Any] = None,
func: Callable = None,
func: Callable,
name: Optional[str] = None,
desc: Optional[str] = None,
parameters: Optional[dict[str, Any]] = None,
arg_types: Optional[dict[str, Any]] = None,
parameter_desc: Optional[dict[str, str]] = None,
):
"""Initialize the Tool class.
Args:
name (str): The name of the tool.
desc (str): The description of the tool.
parameters (dict[str, Any]): The parameters of the tool, represented as a dictionary from parameter name to
parameter's json schema.
arg_types (dict[str, Any]): The argument types of the tool, represented as a dictionary from parameter name
to the type of the argument.
func (Callable): The actual function that is being wrapped by the tool.
name (Optional[str], optional): The name of the tool. Defaults to None.
desc (Optional[str], optional): The description of the tool. Defaults to None.
parameters (Optional[dict[str, Any]], optional): The parameters of the tool, represented as a dictionary
from parameter name to parameter's json schema. Defaults to None.
arg_types (Optional[dict[str, Any]], optional): The argument types of the tool, represented as a dictionary
from parameter name to the type of the argument. Defaults to None.
parameter_desc (Optional[dict[str, str]], optional): Descriptions for each parameter, represented as a
dictionary from parameter name to description string. Defaults to None.
"""
self.func = func
self.name = name
self.desc = desc
self.parameters = parameters or {}
self.arg_types = arg_types or {}
self.func = func
self.parameters = parameters
self.arg_types = arg_types
self.parameter_desc = parameter_desc

@staticmethod
def _resolve_pydantic_schema(model: type[BaseModel]) -> dict:
self._parse_function(func, parameter_desc)

def _resolve_pydantic_schema(self, model: type[BaseModel]) -> dict:
"""Recursively resolve Pydantic model schema, expanding all references."""
schema = model.model_json_schema()

Expand All @@ -67,18 +72,11 @@ def resolve_refs(obj: Any) -> Any:
resolved_schema.pop("$defs", None)
return resolved_schema

@classmethod
def from_function(cls, func: Callable):
"""Class method that converts a python function to a `Tool`.
def _parse_function(self, func: Callable, parameter_desc: dict[str, str] = None):
"""Helper method that parses a function to extract the name, description, and parameters.
This is a helper function that automatically infers the name, description, and parameters of the tool from the
provided function. In order to make the inference work, the function must have valid type hints.
Args:
func (Callable): The function to be wrapped by the tool.
Returns:
Tool: The tool object.
"""
annotations_func = func if inspect.isfunction(func) or inspect.ismethod(func) else func.__call__
name = getattr(func, "__name__", type(func).__name__)
Expand All @@ -89,13 +87,19 @@ def from_function(cls, func: Callable):
arg_types[k] = v
if k == "return":
continue

if isinstance((origin := get_origin(v) or v), type) and issubclass(origin, BaseModel):
schema = cls._resolve_pydantic_schema(origin)
schema = self._resolve_pydantic_schema(origin)
parameters[k] = schema
else:
parameters[k] = TypeAdapter(v).json_schema()
if parameter_desc and k in parameter_desc:
parameters[k]["description"] = parameter_desc[k]

return cls(name=name, desc=desc, parameters=parameters, arg_types=arg_types, func=func)
self.name = self.name or name
self.desc = self.desc or desc
self.parameters = self.parameters or parameters
self.arg_types = self.arg_types or arg_types

def convert_to_litellm_tool_format(self):
"""Converts the tool to the format required by litellm for tool calling."""
Expand Down
27 changes: 17 additions & 10 deletions tests/primitives/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def test_basic_initialization():
assert callable(tool.func)


def test_from_function():
tool = Tool.from_function(dummy_function)
def test_tool_from_function():
tool = Tool(dummy_function)

assert tool.name == "dummy_function"
assert "A dummy function for testing" in tool.desc
Expand All @@ -95,14 +95,14 @@ def __call__(self, a: int, b: int) -> int:
"""Add two numbers."""
return a + b

tool = Tool.from_function(Foo("123"))
tool = Tool(Foo("123"))
assert tool.name == "Foo"
assert tool.desc == "Add two numbers."
assert tool.parameters == {"a": {"type": "integer"}, "b": {"type": "integer"}}


def test_from_function_with_pydantic():
tool = Tool.from_function(dummy_with_pydantic)
def test_tool_from_function_with_pydantic():
tool = Tool(dummy_with_pydantic)

assert tool.name == "dummy_with_pydantic"
assert "model" in tool.parameters
Expand All @@ -112,7 +112,7 @@ def test_from_function_with_pydantic():


def test_convert_to_litellm_tool_format():
tool = Tool.from_function(dummy_function)
tool = Tool(dummy_function)
litellm_format = tool.convert_to_litellm_tool_format()

assert litellm_format["type"] == "function"
Expand All @@ -123,32 +123,38 @@ def test_convert_to_litellm_tool_format():


def test_tool_callable():
tool = Tool.from_function(dummy_function)
tool = Tool(dummy_function)
result = tool(x=42, y="hello")
assert result == "hello 42"


def test_tool_with_pydantic_callable():
tool = Tool.from_function(dummy_with_pydantic)
tool = Tool(dummy_with_pydantic)
model = DummyModel(field1="test", field2=123)
result = tool(model=model)
assert result == "test 123"


def test_invalid_function_call():
tool = Tool.from_function(dummy_function)
tool = Tool(dummy_function)
with pytest.raises(ValueError):
tool(x="not an integer", y="hello")


def test_parameter_desc():
tool = Tool(dummy_function, parameter_desc={"x": "The x parameter"})
assert tool.parameters["x"]["description"] == "The x parameter"


def test_complex_nested_schema():
tool = Tool.from_function(complex_dummy_function)
tool = Tool(complex_dummy_function, parameter_desc={"profile": "The user ultimate profile"})

assert tool.name == "complex_dummy_function"
assert "profile" in tool.parameters

profile_schema = tool.parameters["profile"]
assert profile_schema["type"] == "object"
assert profile_schema["description"] == "The user ultimate profile"

# Check nested structure
properties = profile_schema["properties"]
Expand Down Expand Up @@ -214,6 +220,7 @@ def test_complex_nested_schema():
},
"required": ["user_id", "name", "contact"],
"title": "UserProfile",
"description": "The user ultimate profile",
"type": "object",
},
"priority": {"type": "integer"},
Expand Down

0 comments on commit 1975352

Please sign in to comment.