Skip to content

Commit

Permalink
add few control flow and simple prompting nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
ollmer committed Mar 7, 2025
1 parent 7378e4c commit ea299ac
Showing 1 changed file with 41 additions and 10 deletions.
51 changes: 41 additions & 10 deletions tapeagents/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
import logging
import re
from typing import Annotated, Any, Generator, Type, Union
from typing import Annotated, Any, Callable, Generator, Type, Union

from litellm import ChatCompletionMessageToolCall
from pydantic import Field, TypeAdapter, ValidationError
Expand All @@ -28,7 +28,7 @@
from tapeagents.steps import BranchStep, ReasoningThought
from tapeagents.tool_calling import as_openai_tool
from tapeagents.tools.code_executor import PythonCodeAction
from tapeagents.utils import FatalError, class_for_name, sanitize_json_completion
from tapeagents.utils import FatalError, class_for_name, sanitize_json_completion, step_schema
from tapeagents.view import Call, Respond, TapeViewStack

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -76,11 +76,10 @@ class StandardNode(Node):
_step_classes: list[type[Step]] | None = None

def model_post_init(self, __context: Any) -> None:
self.prepare_step_types()
super().model_post_init(__context)

def prepare_step_types(self, actions: list[type[Step]] = None):
actions = actions or []
def prepare_step_types(self, agent: Agent):
actions = agent.known_actions if self.use_known_actions else []
step_classes_or_str = actions + (self.steps if isinstance(self.steps, list) else [self.steps])
if not step_classes_or_str:
return
Expand All @@ -93,11 +92,7 @@ def prepare_step_types(self, actions: list[type[Step]] = None):
self._name_to_cls = {c.__name__: c for c in self._step_classes}
self._steps_type = Annotated[Union[tuple(self._step_classes)], Field(discriminator="kind")]

def add_known_actions(self, actions: list[type[Step]]):
if self.use_known_actions:
self.prepare_step_types(actions)

def make_prompt(self, agent: Any, tape: Tape) -> Prompt:
def make_prompt(self, agent: Agent, tape: Tape) -> Prompt:
"""Create a prompt from tape interactions.
This method constructs a prompt by processing the tape content and agent steps description
Expand All @@ -120,6 +115,7 @@ def make_prompt(self, agent: Any, tape: Tape) -> Prompt:
4. Checks token count and trims if needed
5. Reconstructs messages if trimming occurred
"""
self.prepare_step_types(agent)
steps = self.get_steps(tape, agent)
steps_description = self.get_steps_description(agent)
messages = self.steps_to_messages(steps, steps_description)
Expand Down Expand Up @@ -415,6 +411,29 @@ def trim_tape(self, tape: Tape) -> Tape:
return tape


class ViewNode(StandardNode):
system_prompt: str
view: Any = None
prompt: str

def get_steps(self, tape: Tape, agent: Agent) -> list[Step]:
view_cls = class_for_name(self.view)
kwargs = view_cls(tape).as_dict() if self.view else {}
content = self.prompt.format(**kwargs)
return [UserStep(content=content)]


class AsStep(StandardNode):
def make_prompt(self, agent: Agent, tape: Tape) -> Prompt:
text = tape[-1].reasoning
schema = step_schema(self._step_classes[0])
response_format = self._step_classes[0] if self.structured_output else None
msg = f"Convert the following paragraph into a structured JSON object:\n\n{text}"
if not self.structured_output:
msg += f"\n\nThe JSON object should match the following schema:\n\n{schema}"
return Prompt(messages=[{"role": "user", "content": msg}], response_format=response_format)


class ControlFlowNode(Node):
"""
A node that controls the flow of execution by selecting the next node based on tape content.
Expand Down Expand Up @@ -478,6 +497,14 @@ def select_node(self, tape: Tape) -> str:
return self.next_node if isinstance(tape[-1], self.step_class) else None


class If(ControlFlowNode):
predicate: Callable[[Tape], bool]
next_node: str

def select_node(self, tape: Tape) -> str:
return self.next_node if self.predicate(tape) else None


class ObservationControlNode(ControlFlowNode):
"""
A control flow node that selects the next node based on the last observation in the tape.
Expand Down Expand Up @@ -548,6 +575,10 @@ def generate_steps(
yield step


class Return(FixedStepsNode):
steps: list[Step] = [Respond(copy_output=True)]


class GoTo(Node):
next_node: str

Expand Down

0 comments on commit ea299ac

Please sign in to comment.