Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
chenmoneygithub committed Jan 29, 2025
1 parent 77af440 commit bc41cb7
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 55 deletions.
4 changes: 1 addition & 3 deletions dspy/predict/predict_with_tools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from dspy.predict.predict import Predict
from dspy.primitives.prediction import Prediction


class PredictWithTools(Predict):
Expand Down Expand Up @@ -47,5 +46,4 @@ def __call__(self, tools=None, tool_choice="auto", **kwargs):
kwargs["tools"] = tools or self.tools
kwargs["tool_choice"] = tool_choice or self.tool_choice

pred = super().__call__(**kwargs).toDict()
return Prediction(**pred)
return super().__call__(**kwargs)
13 changes: 7 additions & 6 deletions dspy/predict/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,17 @@ def __init__(self, signature, tools: list[Callable], max_iters=5, use_litellm_to
instruction = [f"{signature.instructions}\n"] if signature.instructions else []

instruction_custom_tool_calling = list(instruction)
instruction_native_tool_calling = list(instruction)
instruction_litellm_tool_calling = list(instruction)

instruction_custom_tool_calling.extend(
[
f"You will be given {inputs} and your goal is to finish with {outputs}.\n",
"To do this, you will interleave Thought, Tool Name, and Tool Args, and receive a resulting Observation.\n",
"To do this, you will interleave Thought, Tool Name, and Tool Args, and receive a resulting "
"Observation.\n",
"Thought can reason about the current situation, and Tool Name can be the following types:\n",
]
)
instruction_native_tool_calling.extend(
instruction_litellm_tool_calling.extend(
[
f"You will be given {inputs} and your goal is to finish with {outputs}.\n",
"To help you reach this goal, you will be given a list of tools, and you will need to think about "
Expand All @@ -67,11 +68,11 @@ def __init__(self, signature, tools: list[Callable], max_iters=5, use_litellm_to
.append("next_tool_args", dspy.OutputField(), type_=dict[str, Any])
)
native_react_signature = dspy.Signature(
{**signature.input_fields, **signature.output_fields}, "\n".join(instruction_native_tool_calling)
{**signature.input_fields, **signature.output_fields}, "\n".join(instruction_litellm_tool_calling)
).append("trajectory", dspy.InputField(), type_=str)

self.react_with_custom_tool_calling = dspy.Predict(custom_react_signature)
self.react_with_native_tool_calling = dspy.PredictWithTools(
self.react_with_litellm_tool_calling = dspy.PredictWithTools(
native_react_signature, tools=self.tools_in_litellm_format
)

Expand Down Expand Up @@ -124,7 +125,7 @@ def _forward_with_custom_tool_calling(self, **input_args):
def _forward_with_litellm_tool_calling(self, **input_args):
trajectory = {}
for idx in range(self.max_iters):
pred = self.react_with_native_tool_calling(
pred = self.react_with_litellm_tool_calling(
**input_args, trajectory=self._format_trajectory(trajectory, last_iteration=(idx == self.max_iters - 1))
)

Expand Down
6 changes: 4 additions & 2 deletions dspy/primitives/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def __init__(
desc (str): The description of the tool.
parameters (dict[str, Any]): The parameters of the tool, represented as a dictionary from parameter name to
parameter's json schema.
arg_types (dict[str, Any]): The argument types of the tool, represented as a dictionary from parameter name
to the type of the argument.
func (Callable): The actual function that is being wrapped by the tool.
"""
self.name = name
Expand All @@ -42,8 +44,8 @@ def _resolve_pydantic_schema(model: type[BaseModel]) -> dict:
"""Recursively resolve Pydantic model schema, expanding all references."""
schema = model.model_json_schema()

# If there are no definitions, return the main schema
if "$defs" not in schema:
# If there are no definitions to resolve, return the main schema
if "$defs" not in schema and "definitions" not in schema:
return schema

def resolve_refs(obj: Any) -> Any:
Expand Down
6 changes: 3 additions & 3 deletions tests/callback/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,11 @@ def tool_2(query: str) -> str:

class MyModule(dspy.Module):
def __init__(self):
self.tools = [dspy.Tool(tool_1), dspy.Tool(tool_2)]
self.tools = [dspy.Tool.from_function(tool_1), dspy.Tool.from_function(tool_2)]

def forward(self, query: str) -> str:
query = self.tools[0](query)
return self.tools[1](query)
query = self.tools[0](query=query)
return self.tools[1](query=query)

module = MyModule()
result = module("query")
Expand Down
14 changes: 0 additions & 14 deletions tests/clients/test_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,20 +159,6 @@ def test_lm_text_calls_are_retried_for_expected_failures(
assert len(request_logs) == expected_num_retries + 1 # 1 initial request + 1 retries


def test_tools_rejected_for_non_function_models(litellm_test_server):
api_base, server_log_file_path = litellm_test_server

with mock.patch("dspy.clients.lm.litellm.supports_function_calling", return_value=False):
lm = dspy.LM(
model="openai/dspy-test-model",
api_base=api_base,
api_key="fakekey",
model_type="chat",
)
with pytest.raises(ValueError):
lm("query", tools=[{"type": "function", "function": {}}])


@pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OpenAI API key is not set")
def test_lm_tool_calls_are_returned():
openai_lm = dspy.LM(model="openai/gpt-4o-mini")
Expand Down
96 changes: 69 additions & 27 deletions tests/predict/test_react.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from dspy.utils.dummies import DummyLM, dummy_rm
from dspy.predict import react
from pydantic import BaseModel

import os
import pytest

# def test_example_no_tools():
# # Create a simple dataset which the model will use with the Retrieve tool.
Expand Down Expand Up @@ -126,32 +127,6 @@
# assert react.react[0].signature.instructions.startswith("You are going to generate output based on input.")


def test_tool_from_function():
def foo(a: int, b: int) -> int:
"""Add two numbers."""
return a + b

tool = react.Tool(foo)
assert tool.name == "foo"
assert tool.desc == "Add two numbers."
assert tool.args == {"a": {"type": "integer"}, "b": {"type": "integer"}}


def test_tool_from_class():
class Foo:
def __init__(self, user_id: str):
self.user_id = user_id

def foo(self, a: int, b: int) -> int:
"""Add two numbers."""
return a + b

tool = react.Tool(Foo("123").foo)
assert tool.name == "foo"
assert tool.desc == "Add two numbers."
assert tool.args == {"a": {"type": "integer"}, "b": {"type": "integer"}}


def test_tool_calling_with_pydantic_args():
class CalendarEvent(BaseModel):
name: str
Expand Down Expand Up @@ -228,3 +203,70 @@ class InvitationSignature(dspy.Signature):
"observation_1": "Completed.",
}
assert outputs.trajectory == expected_trajectory


@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="Openai API key is not set")
@pytest.mark.parametrize(
"use_litellm_tool_calling,test_name",
[
(True, "with_litellm_tool_calling"),
(False, "with_custom_tool_calling"),
],
)
def test_react_tool_calling(use_litellm_tool_calling: bool, test_name: str):
lm = dspy.LM("openai/gpt-4o-mini")
dspy.settings.configure(lm=lm)

class CalendarEvent(BaseModel):
name: str
date: str
participants: list[str]

def write_invitation_letter(participant_name: str, event_info: CalendarEvent):
if participant_name not in event_info.participants:
return None
return f"It's my honor to invite {participant_name} to event {event_info.name} on {event_info.date}"

class InvitationSignature(dspy.Signature):
participant_name: str = dspy.InputField(desc="The name of the participant to invite")
event_info: CalendarEvent = dspy.InputField(desc="The information about the event")
invitation_letter: str = dspy.OutputField(desc="The invitation letter to be sent to the participant")

react = dspy.ReAct(
InvitationSignature,
tools=[write_invitation_letter],
use_litellm_tool_calling=use_litellm_tool_calling,
)

outputs = react(
participant_name="Alice",
event_info=CalendarEvent(
name="Science Fair",
date="Friday",
participants=["Alice", "Bob"],
),
)
if use_litellm_tool_calling:
# Litellm tool calling returns a list of tool names and tool args
assert outputs.trajectory["tool_name_0"] == ["write_invitation_letter"]
assert outputs.trajectory["tool_args_0"] == [
{
"participant_name": "Alice",
"event_info": {
"name": "Science Fair",
"date": "Friday",
"participants": ["Alice", "Bob"],
},
}
]
else:
# Custom tool calling returns a single tool name and tool args
assert outputs.trajectory["tool_name_0"] == "write_invitation_letter"
assert outputs.trajectory["tool_args_0"] == {
"participant_name": "Alice",
"event_info": {
"name": "Science Fair",
"date": "Friday",
"participants": ["Alice", "Bob"],
},
}
15 changes: 15 additions & 0 deletions tests/primitives/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,21 @@ def test_from_function():
assert tool.parameters["y"]["type"] == "string"


def test_tool_from_class():
class Foo:
def __init__(self, user_id: str):
self.user_id = user_id

def __call__(self, a: int, b: int) -> int:
"""Add two numbers."""
return a + b

tool = Tool.from_function(Foo("123"))
assert tool.name == "Foo"
assert tool.desc == "Add two numbers."
assert tool.parameters == {"a": {"type": "integer"}, "b": {"type": "integer"}}


def test_from_function_with_pydantic():
tool = Tool.from_function(dummy_with_pydantic)

Expand Down

0 comments on commit bc41cb7

Please sign in to comment.