Skip to content

Commit

Permalink
separate script to validate submission
Browse files Browse the repository at this point in the history
  • Loading branch information
ollmer committed Feb 26, 2025
1 parent 26fd9a4 commit 63eb469
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 20 deletions.
22 changes: 2 additions & 20 deletions examples/gaia_agent/scripts/prepare_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import os
import sys

from examples.gaia_agent.scorer import question_scorer
from tapeagents.io import load_tapes

from ..eval import get_exp_config_dict, load_dataset, tape_correct
from ..steps import GaiaTape
from .validate_submission import validate_submission

logging.basicConfig(level=logging.INFO)

Expand Down Expand Up @@ -57,25 +57,7 @@ def main(exp_path: str):
return submission_file


def validate_submission(exp_path: str, submission_file: str):
assert os.path.isdir(exp_path), f"Directory {exp_path} does not exist or is not a directory"
cfg = get_exp_config_dict(exp_path)
tasks = load_dataset(cfg["split"])
answers = {task["task_id"]: task["Final answer"] for level_tasks in tasks.values() for task in level_tasks}
model_answers = {}
with open(submission_file) as f:
for line in f:
task = json.loads(line)
model_answers[task["task_id"]] = task["model_answer"]
accs = []
for task_id, answer in answers.items():
model_answer = model_answers[task_id]
acc = int(question_scorer(model_answer, answer))
accs.append(acc)
print(f"Submission accuracy: {sum(accs) / len(accs):.3f} ({sum(accs)} of {len(accs)})")


if __name__ == "__main__":
assert len(sys.argv) == 2, "Usage: examples.gaia_agent.scripts.prepare_submission <exp_dir>"
submission_file = main(sys.argv[1])
validate_submission(sys.argv[1], submission_file)
validate_submission(submission_file)
32 changes: 32 additions & 0 deletions examples/gaia_agent/scripts/validate_submission.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import json
import logging
import sys

from examples.gaia_agent.scorer import question_scorer

from ..eval import load_dataset

logging.basicConfig(level=logging.INFO)

logger = logging.getLogger(__name__)


def validate_submission(submission_file: str):
tasks = load_dataset("validation")
answers = {task["task_id"]: task["Final answer"] for level_tasks in tasks.values() for task in level_tasks}
model_answers = {}
with open(submission_file) as f:
for line in f:
task = json.loads(line)
model_answers[task["task_id"]] = task["model_answer"]
accs = []
for task_id, answer in answers.items():
model_answer = model_answers[task_id]
acc = int(question_scorer(model_answer, answer))
accs.append(acc)
print(f"\nSubmission accuracy: {sum(accs) / len(accs):.3f} ({sum(accs)} of {len(accs)})\n")


if __name__ == "__main__":
assert len(sys.argv) == 2, "Usage: examples.gaia_agent.scripts.prepare_submission <exp_dir>"
validate_submission(sys.argv[1])

0 comments on commit 63eb469

Please sign in to comment.