From ea299ac28bb4d7a18675ca28a23e6ce93f26b32f Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Fri, 7 Mar 2025 18:17:05 +0100 Subject: [PATCH] add few control flow and simple prompting nodes --- tapeagents/nodes.py | 51 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 41 insertions(+), 10 deletions(-) diff --git a/tapeagents/nodes.py b/tapeagents/nodes.py index 54be1928..1ea47cd0 100644 --- a/tapeagents/nodes.py +++ b/tapeagents/nodes.py @@ -5,7 +5,7 @@ import json import logging import re -from typing import Annotated, Any, Generator, Type, Union +from typing import Annotated, Any, Callable, Generator, Type, Union from litellm import ChatCompletionMessageToolCall from pydantic import Field, TypeAdapter, ValidationError @@ -28,7 +28,7 @@ from tapeagents.steps import BranchStep, ReasoningThought from tapeagents.tool_calling import as_openai_tool from tapeagents.tools.code_executor import PythonCodeAction -from tapeagents.utils import FatalError, class_for_name, sanitize_json_completion +from tapeagents.utils import FatalError, class_for_name, sanitize_json_completion, step_schema from tapeagents.view import Call, Respond, TapeViewStack logger = logging.getLogger(__name__) @@ -76,11 +76,10 @@ class StandardNode(Node): _step_classes: list[type[Step]] | None = None def model_post_init(self, __context: Any) -> None: - self.prepare_step_types() super().model_post_init(__context) - def prepare_step_types(self, actions: list[type[Step]] = None): - actions = actions or [] + def prepare_step_types(self, agent: Agent): + actions = agent.known_actions if self.use_known_actions else [] step_classes_or_str = actions + (self.steps if isinstance(self.steps, list) else [self.steps]) if not step_classes_or_str: return @@ -93,11 +92,7 @@ def prepare_step_types(self, actions: list[type[Step]] = None): self._name_to_cls = {c.__name__: c for c in self._step_classes} self._steps_type = Annotated[Union[tuple(self._step_classes)], Field(discriminator="kind")] - def add_known_actions(self, actions: list[type[Step]]): - if self.use_known_actions: - self.prepare_step_types(actions) - - def make_prompt(self, agent: Any, tape: Tape) -> Prompt: + def make_prompt(self, agent: Agent, tape: Tape) -> Prompt: """Create a prompt from tape interactions. This method constructs a prompt by processing the tape content and agent steps description @@ -120,6 +115,7 @@ def make_prompt(self, agent: Any, tape: Tape) -> Prompt: 4. Checks token count and trims if needed 5. Reconstructs messages if trimming occurred """ + self.prepare_step_types(agent) steps = self.get_steps(tape, agent) steps_description = self.get_steps_description(agent) messages = self.steps_to_messages(steps, steps_description) @@ -415,6 +411,29 @@ def trim_tape(self, tape: Tape) -> Tape: return tape +class ViewNode(StandardNode): + system_prompt: str + view: Any = None + prompt: str + + def get_steps(self, tape: Tape, agent: Agent) -> list[Step]: + view_cls = class_for_name(self.view) + kwargs = view_cls(tape).as_dict() if self.view else {} + content = self.prompt.format(**kwargs) + return [UserStep(content=content)] + + +class AsStep(StandardNode): + def make_prompt(self, agent: Agent, tape: Tape) -> Prompt: + text = tape[-1].reasoning + schema = step_schema(self._step_classes[0]) + response_format = self._step_classes[0] if self.structured_output else None + msg = f"Convert the following paragraph into a structured JSON object:\n\n{text}" + if not self.structured_output: + msg += f"\n\nThe JSON object should match the following schema:\n\n{schema}" + return Prompt(messages=[{"role": "user", "content": msg}], response_format=response_format) + + class ControlFlowNode(Node): """ A node that controls the flow of execution by selecting the next node based on tape content. @@ -478,6 +497,14 @@ def select_node(self, tape: Tape) -> str: return self.next_node if isinstance(tape[-1], self.step_class) else None +class If(ControlFlowNode): + predicate: Callable[[Tape], bool] + next_node: str + + def select_node(self, tape: Tape) -> str: + return self.next_node if self.predicate(tape) else None + + class ObservationControlNode(ControlFlowNode): """ A control flow node that selects the next node based on the last observation in the tape. @@ -548,6 +575,10 @@ def generate_steps( yield step +class Return(FixedStepsNode): + steps: list[Step] = [Respond(copy_output=True)] + + class GoTo(Node): next_node: str