Skip to content

Commit

Permalink
mapping tool call name back to step
Browse files Browse the repository at this point in the history
  • Loading branch information
ollmer committed Mar 7, 2025
1 parent 518818b commit 3144698
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions tapeagents/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class StandardNode(Node):
structured_output: bool = False
_steps_type: Any = None
_step_classes: list[type[Step]] | None = None
_tools: dict[str, dict] | None = None
_tool_name_to_cls: dict[str, type[Step]] | None = None

def model_post_init(self, __context: Any) -> None:
super().model_post_init(__context)
Expand All @@ -89,8 +91,10 @@ def prepare_step_types(self, agent: Agent):
self._step_classes = [c for c in self._step_classes if c != PythonCodeAction]
if self.structured_output:
assert len(self._step_classes) == 1, "Structured output requires exactly one output step class"
self._name_to_cls = {c.__name__: c for c in self._step_classes}
self._steps_type = Annotated[Union[tuple(self._step_classes)], Field(discriminator="kind")]
if self.use_function_calls:
self._tools = {step_cls: as_openai_tool(step_cls) for step_cls in self._step_classes}
self._tool_name_to_cls = {tool["function"]["name"]: step_cls for step_cls, tool in self._tools.items()}

def make_prompt(self, agent: Agent, tape: Tape) -> Prompt:
"""Create a prompt from tape interactions.
Expand Down Expand Up @@ -126,7 +130,7 @@ def make_prompt(self, agent: Agent, tape: Tape) -> Prompt:
self.trim_obs_except_last_n = old_trim

response_format = self._step_classes[0] if self.structured_output else None
tools = [as_openai_tool(s) for s in self._step_classes] if self.use_function_calls else None
tools = list(self._tools.values()) if self.use_function_calls else None
prompt = Prompt(messages=messages, tools=tools, response_format=response_format)
return prompt

Expand Down Expand Up @@ -284,10 +288,10 @@ def generate_steps(
yield SetNextNode(next_node=self.next_node)

def tool_call_to_step(self, tool_call: ChatCompletionMessageToolCall) -> Step:
step_cls = self._name_to_cls.get(tool_call.function.name)
step_cls = self._tool_name_to_cls.get(tool_call.function.name)
if step_cls is None:
return LLMOutputParsingFailureAction(
error=f"Unknown tool call: {tool_call.function.name}", llm_output=tool_call
error=f"Unknown tool call: {tool_call.function.name}", llm_output=tool_call.model_dump_json(indent=2)
)
args = tool_call.function.arguments
return step_cls.model_validate_json(args) if args else step_cls()
Expand Down

0 comments on commit 3144698

Please sign in to comment.