Skip to content

Commit

Permalink
update tape browser
Browse files Browse the repository at this point in the history
  • Loading branch information
ollmer committed Mar 4, 2025
1 parent a7f2605 commit 0a0b435
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 7 deletions.
26 changes: 21 additions & 5 deletions examples/gaia_agent/scripts/tape_browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tapeagents.renderers.camera_ready_renderer import CameraReadyRenderer
from tapeagents.tape_browser import TapeBrowser

from ..eval import calculate_accuracy, get_exp_config_dict, tape_correct
from ..eval import calculate_accuracy, get_exp_config_dict, load_dataset, tape_correct
from ..steps import GaiaStep, GaiaTape

logging.basicConfig(level=logging.INFO)
Expand All @@ -25,6 +25,18 @@ def __init__(self, tapes_folder: str, renderer):

def load_tapes(self, name: str) -> list:
_, exp_dir, postfix = name.split("/", maxsplit=2)
try:
cfg_dir = os.path.join(self.tapes_folder, exp_dir)
cfg = get_exp_config_dict(cfg_dir)
try:
tasks = load_dataset(cfg["split"])
tasks_list = [task for level in tasks.values() for task in level]
self.task_id_to_num = {task["task_id"]: i + 1 for i, task in enumerate(tasks_list)}
except Exception:
self.task_id_to_num = {}
except Exception as e:
logger.exception(f"Failed to load config from {cfg_dir}: {e}")
self.task_id_to_num = {}
tapes_path = os.path.join(self.tapes_folder, exp_dir, "tapes")
image_dir = os.path.join(self.tapes_folder, exp_dir, "attachments", "images")
if not os.path.exists(image_dir):
Expand Down Expand Up @@ -98,11 +110,14 @@ def get_exp_label(self, filename: str, tapes: list[GaiaTape]) -> str:
if tape.metadata.terminated:
errors["terminated"] += 1
last_action = None
counted = set([])
for step in tape:
llm_call = self.llm_calls.get(step.metadata.prompt_id)
visible_prompt_tokens_num += llm_call.prompt_length_tokens if llm_call else 0
visible_output_tokens_num += llm_call.output_length_tokens if llm_call else 0
visible_cost += llm_call.cost if llm_call else 0
if llm_call and step.metadata.prompt_id not in counted:
counted.add(step.metadata.prompt_id)
visible_prompt_tokens_num += llm_call.prompt_length_tokens
visible_output_tokens_num += llm_call.output_length_tokens
visible_cost += llm_call.cost
if isinstance(step, Action):
actions[step.kind] += 1
last_action = step
Expand Down Expand Up @@ -168,8 +183,9 @@ def get_tape_name(self, i: int, tape: GaiaTape) -> str:
mark += f"[{error}]"
if mark:
mark += " "
n = self.task_id_to_num.get(tape.metadata.task.get("task_id"), "")
name = tape[0].content[:32] if hasattr(tape[0], "content") else tape[0].short_view()[:32]
return f"{i+1} {mark}({tokens: }t) {name}" # type: ignore
return f"{n} {mark}({tokens: }t) {name}" # type: ignore

def get_tape_label(self, tape: GaiaTape) -> str:
llm_calls_num = 0
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ dependencies = [
"langchain-community~=0.3",
"langchain-core~=0.1",
"levenshtein~=0.25",
"litellm~=1.51",
"litellm~=1.61",
"openai~=1.55",
"pathvalidate~=3.2",
"playwright~=1.42",
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 0a0b435

Please sign in to comment.