Skip to content

Commit

Permalink
adding resample benchmark objects
Browse files Browse the repository at this point in the history
  • Loading branch information
TLSDC committed Feb 26, 2025
1 parent 67f186d commit 987608e
Showing 1 changed file with 94 additions and 0 deletions.
94 changes: 94 additions & 0 deletions src/agentlab/experiments/custom_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from abc import abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional

import benchmarks
import pandas as pd
from bgym import Benchmark, EnvArgs, HighLevelActionSetArgs
from browsergym.experiments.benchmark.base import BenchmarkBackend
from dataclasses_json import DataClassJsonMixin, config
from torch import threshold

from agentlab.analyze.inspect_results import load_result_df
from agentlab.experiments.study import Study


@dataclass
class ResampleBenchmark(Benchmark):
exp_dir: Path = None
name: str = None
high_level_action_set_args: HighLevelActionSetArgs = None
is_multi_tab: bool = None
supports_parallel_seeds: bool = None
env_args_list: list[EnvArgs] = None
backends: list[BenchmarkBackend] = None
task_metadata: Optional[pd.DataFrame] = field(
default_factory=lambda: None,
metadata=config(
encoder=lambda df: df.to_dict(orient="records") if df is not None else None,
decoder=lambda items: pd.DataFrame(items) if items is not None else None,
),
)

def __post_init__(self):
assert self.exp_dir is not None
study = Study.load(self.exp_dir)
benchmark = study.benchmark

self.name = f"resample-{benchmark.name}"
self.high_level_action_set_args = benchmark.high_level_action_set_args
self.is_multi_tab = benchmark.is_multi_tab
self.supports_parallel_seeds = benchmark.supports_parallel_seeds
self.backends = benchmark.backends
# we discard the task_metadata to create new ones in post_init

values = self.evaluate(study, benchmark.env_args_list)
selected_env_args = self.select(values, benchmark.env_args_list)

if len(selected_env_args) == 0:
raise ValueError("No env_args selected, lower restrictions")

self.env_args_list = selected_env_args

super().__post_init__()

@abstractmethod
def evaluate(self, study, env_args_list):
pass

@abstractmethod
def select(self, values, env_args_list):
pass


@dataclass
class AllTasksBenchmark(ResampleBenchmark):
def evaluate(self, study, env_args_list):
return [0] * len(env_args_list)

def select(self, values, env_args_list):
return env_args_list


@dataclass
class HighVarianceBenchmark(ResampleBenchmark):
threshold: float = 0.2

def evaluate(self, study: Study, env_args_list):
result_df = load_result_df(study.dir)
return dict(result_df.groupby("env.task_name")["cum_reward"].std())

def select(self, values, env_args_list):
selected_env_args = []
for env_args in env_args_list:
if values[env_args.task_name] > self.threshold:
selected_env_args.append(env_args)
return selected_env_args


if __name__ == "__main__":
exp_dir = Path("/home/t/agentlab_results/2025-02-26_10-15-04_genericagent-gpt-4o-mini-2024-07-18-on-miniwob-tiny-test")
benchmark = HighVarianceBenchmark(exp_dir=exp_dir)
print(benchmark.env_args_list)

0 comments on commit 987608e

Please sign in to comment.