From 3d53b8907abfa287c56c222123bb5cbdfc7ba75d Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Tue, 4 Mar 2025 16:02:36 -0500 Subject: [PATCH] trying stochastic stuff --- src/agentlab/experiments/custom_benchmark.py | 51 ++++++++++++++++++-- 1 file changed, 46 insertions(+), 5 deletions(-) diff --git a/src/agentlab/experiments/custom_benchmark.py b/src/agentlab/experiments/custom_benchmark.py index 6fc79dcf..35561beb 100644 --- a/src/agentlab/experiments/custom_benchmark.py +++ b/src/agentlab/experiments/custom_benchmark.py @@ -4,9 +4,11 @@ from typing import Optional import benchmarks +import numpy as np import pandas as pd from bgym import Benchmark, EnvArgs, HighLevelActionSetArgs from browsergym.experiments.benchmark.base import BenchmarkBackend +from browsergym.experiments.benchmark.utils import make_env_args_list_from_repeat_tasks from dataclasses_json import DataClassJsonMixin, config from torch import threshold @@ -66,10 +68,10 @@ def select(self, values, env_args_list): class AllTasksBenchmark(ResampleBenchmark): def evaluate(self, study, env_args_list): return [0] * len(env_args_list) - + def select(self, values, env_args_list): return env_args_list - + @dataclass class HighVarianceBenchmark(ResampleBenchmark): @@ -87,8 +89,47 @@ def select(self, values, env_args_list): return selected_env_args +@dataclass +class StochasticHighVarianceBenchmark(ResampleBenchmark): + regulation_threshold: float = 0.1 + total_seeds = 600 + min_seeds = 2 + random_seed = 42 + + def evaluate(self, study: Study, env_args_list): + result_df = load_result_df(study.dir) + var = result_df.groupby("env.task_name")["cum_reward"].var() + probs = dict((var + self.regulation_threshold) / (var + self.regulation_threshold).sum()) + return probs + + def select(self, values, env_args_list: list[EnvArgs]): + selected_env_args = [] + max_steps = env_args_list[0].max_steps + for task_name, p in values.items(): + # ceil to avoid missing seeds + n_seeds = np.random.RandomState(self.random_seed).poisson(p * self.total_seeds) + n_seeds = max(n_seeds, self.min_seeds) + for seed in np.random.RandomState(self.random_seed).randint(0, 2**32, n_seeds): + selected_env_args.append( + EnvArgs( + task_name=task_name, + task_seed=int(seed), + max_steps=max_steps, + headless=True, + record_video=False, + wait_for_user_message=False, + viewport=None, + slow_mo=None, + storage_state=None, + task_kwargs=None, + ) + ) + return selected_env_args + + if __name__ == "__main__": - exp_dir = Path("/home/t/agentlab_results/2025-02-26_10-15-04_genericagent-gpt-4o-mini-2024-07-18-on-miniwob-tiny-test") - benchmark = HighVarianceBenchmark(exp_dir=exp_dir) + exp_dir = Path( + "/Users/t.lesellierdechezell/agentlab_results/2025-03-04_14-43-48_genericagent-gpt-4o-mini-2024-07-18-on-miniwob" + ) + benchmark = StochasticHighVarianceBenchmark(exp_dir=exp_dir) print(benchmark.env_args_list) -