Skip to content

Commit

Permalink
trying stochastic stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
TLSDC committed Mar 4, 2025
1 parent 987608e commit 3d53b89
Showing 1 changed file with 46 additions and 5 deletions.
51 changes: 46 additions & 5 deletions src/agentlab/experiments/custom_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from typing import Optional

import benchmarks
import numpy as np
import pandas as pd
from bgym import Benchmark, EnvArgs, HighLevelActionSetArgs
from browsergym.experiments.benchmark.base import BenchmarkBackend
from browsergym.experiments.benchmark.utils import make_env_args_list_from_repeat_tasks
from dataclasses_json import DataClassJsonMixin, config
from torch import threshold

Expand Down Expand Up @@ -66,10 +68,10 @@ def select(self, values, env_args_list):
class AllTasksBenchmark(ResampleBenchmark):
def evaluate(self, study, env_args_list):
return [0] * len(env_args_list)

def select(self, values, env_args_list):
return env_args_list


@dataclass
class HighVarianceBenchmark(ResampleBenchmark):
Expand All @@ -87,8 +89,47 @@ def select(self, values, env_args_list):
return selected_env_args


@dataclass
class StochasticHighVarianceBenchmark(ResampleBenchmark):
regulation_threshold: float = 0.1
total_seeds = 600
min_seeds = 2
random_seed = 42

def evaluate(self, study: Study, env_args_list):
result_df = load_result_df(study.dir)
var = result_df.groupby("env.task_name")["cum_reward"].var()
probs = dict((var + self.regulation_threshold) / (var + self.regulation_threshold).sum())
return probs

def select(self, values, env_args_list: list[EnvArgs]):
selected_env_args = []
max_steps = env_args_list[0].max_steps
for task_name, p in values.items():
# ceil to avoid missing seeds
n_seeds = np.random.RandomState(self.random_seed).poisson(p * self.total_seeds)
n_seeds = max(n_seeds, self.min_seeds)
for seed in np.random.RandomState(self.random_seed).randint(0, 2**32, n_seeds):
selected_env_args.append(
EnvArgs(
task_name=task_name,
task_seed=int(seed),
max_steps=max_steps,
headless=True,
record_video=False,
wait_for_user_message=False,
viewport=None,
slow_mo=None,
storage_state=None,
task_kwargs=None,
)
)
return selected_env_args


if __name__ == "__main__":
exp_dir = Path("/home/t/agentlab_results/2025-02-26_10-15-04_genericagent-gpt-4o-mini-2024-07-18-on-miniwob-tiny-test")
benchmark = HighVarianceBenchmark(exp_dir=exp_dir)
exp_dir = Path(
"/Users/t.lesellierdechezell/agentlab_results/2025-03-04_14-43-48_genericagent-gpt-4o-mini-2024-07-18-on-miniwob"
)
benchmark = StochasticHighVarianceBenchmark(exp_dir=exp_dir)
print(benchmark.env_args_list)

0 comments on commit 3d53b89

Please sign in to comment.