Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use native tool calling for dspy.ReAct #3921

Open
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

chenmoneygithub
Copy link
Collaborator

@chenmoneygithub chenmoneygithub commented Jan 25, 2025

Support native tool calling in DSPy. This is a relatively big change that involves multiple modules:

  • dspy.LM: Add an explicit tools arg, and parse the LLM response's tool_calls part as a special field in the LM output.
  • Adapter: adapter is modified accordingly to handle the tool calls.
  • dspy.ReAct: In addition to the custom tool calling (as of dspy==2.5.42), we support native tool calling by allowing users to define the tools, and inside the implementation we do automatic input formatting and output parsing in order to execute the tools.
  • dspy.Tool: we improve the dspy.Tool abstraction to facilitate the tool calling process.

Custom testing/benchmarking script:

import dspy

lm_4o_mini = dspy.LM("openai/gpt-4o-mini")
lm_4o = dspy.LM("openai/gpt-4o")
dspy.configure(lm=lm_4o_mini)


import litellm

from dspy.datasets import DataLoader

litellm.cache = None

kwargs = dict(fields=("claim", "supporting_facts", "hpqa_id", "num_hops"), input_keys=("claim",))
hover = DataLoader().from_huggingface(dataset_name="hover-nlp/hover", split="train", trust_remote_code=True, **kwargs)

hpqa_ids = set()
filtered_hover = []
for x in hover:
    if x["num_hops"] == 3 and x["hpqa_id"] not in hpqa_ids:
        hpqa_ids.add(x["hpqa_id"])
        filtered_hover.append(
            dspy.Example(claim=x.claim, titles=list(set([y["key"] for y in x.supporting_facts]))).with_inputs("claim")
        )
hover = filtered_hover

trainset, devset, testset = hover[:100], hover[100:200], hover[650:]

example = trainset[0]

print("Claim:", example.claim)
print("Pages that must be retrieved:", example.titles)

DOCS = {}


def search(query: str, k: int) -> list[str]:
    results = dspy.ColBERTv2(url="http://20.102.90.50:2017/wiki17_abstracts")(query, k=k)
    results = [x["text"] for x in results]

    for result in results:
        title, text = result.split(" | ", 1)
        DOCS[title] = text

    return results


def search_wikipedia(query: str) -> list[str]:
    """Returns top-5 results and then the titles of the top-5 to top-30 results."""

    topK = search(query, 30)
    titles, topK = [f"`{x.split(' | ')[0]}`" for x in topK[5:30]], topK[:5]
    return topK + [f"Other retrieved pages have titles: {', '.join(titles)}."]


def lookup_wikipedia(title: str) -> str:
    """Returns the text of the Wikipedia page, if it exists."""

    if title in DOCS:
        return DOCS[title]

    results = [x for x in search(title, 10) if x.startswith(title + " | ")]
    if not results:
        return f"No Wikipedia page found for title: {title}"
    return results[0]


instructions = "Find all Wikipedia titles relevant to verifying (or refuting) the claim."
signature = dspy.Signature("claim -> titles: list[str]", instructions)
tools = [dspy.Tool.from_function(search_wikipedia), dspy.Tool.from_function(lookup_wikipedia)]

react = dspy.ReAct(signature, tools=tools, max_iters=20, use_litellm_tool_calling=True)

output = react(claim="David Gregory was born in 1625.")
print(output)

dspy.inspect_history(n=2)


def top5_recall(example, pred, trace=None):
    gold_titles = example.titles
    recall = sum(x in pred.titles[:5] for x in gold_titles) / len(gold_titles)

    # If we're "bootstrapping" for optimization, return True if and only if the recall is perfect.
    if trace is not None:
        return recall >= 1.0

    # If we're just doing inference, just measure the recall.
    return recall


evaluate = dspy.Evaluate(devset=devset[:10], metric=top5_recall, num_threads=10, display_progress=True, display_table=5)


def safe_react(claim: str):
    try:
        return react(claim=claim)
    except Exception:
        return dspy.Prediction(titles=[])


evaluate(safe_react)

@chenmoneygithub chenmoneygithub marked this pull request as draft January 25, 2025 03:03
@chenmoneygithub chenmoneygithub changed the title Use native tool calling for dspy.ReAct [WIP] Use native tool calling for dspy.ReAct Jan 25, 2025
@chenmoneygithub chenmoneygithub changed the title [WIP] Use native tool calling for dspy.ReAct Use native tool calling for dspy.ReAct Jan 27, 2025
@chenmoneygithub chenmoneygithub marked this pull request as ready for review January 27, 2025 06:52
@chenmoneygithub chenmoneygithub changed the title Use native tool calling for dspy.ReAct [WIP] Use native tool calling for dspy.ReAct Jan 27, 2025
@chenmoneygithub chenmoneygithub changed the title [WIP] Use native tool calling for dspy.ReAct Use native tool calling for dspy.ReAct Jan 29, 2025
@chenmoneygithub chenmoneygithub requested a review from okhat January 29, 2025 05:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant