Skip to content

Commit

Permalink
fix EBTS bug where gen returned less than n total weight (#3316)
Browse files Browse the repository at this point in the history
Summary:

This could happen if enought `weights` were zero here than the length of `weights` became less than `n`: https://www.internalfb.com/code/fbsource/[b9c286a9f21709eb4c964c3c24bb629b2b218c86]/fbcode/ax/models/discrete/thompson.py?lines=109-113.

The `AlmostEqual` is needed to deal with numerical precision (e.g. 9.0000000002)

This also makes `batch_size` optional in `BenchmarkMethod`. A `batch_size` of `None`, means the batch size from the initial trial is used. For bandits, this means the cardinality of the search space becomes the batch size after the factorial node.

Differential Revision: D69253157
  • Loading branch information
sdaulton authored and facebook-github-bot committed Feb 7, 2025
1 parent 01c6a73 commit 632333f
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 22 deletions.
2 changes: 1 addition & 1 deletion ax/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def get_benchmark_scheduler_options(
Returns:
``SchedulerOptions``
"""
if method.batch_size > 1 or include_sq:
if method.batch_size is None or method.batch_size > 1 or include_sq:
trial_type = TrialType.BATCH_TRIAL
else:
trial_type = TrialType.TRIAL
Expand Down
2 changes: 1 addition & 1 deletion ax/benchmark/benchmark_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class BenchmarkMethod(Base):
distribute_replications: bool = False
use_model_predictions_for_best_point: bool = False

batch_size: int = 1
batch_size: int | None = 1
run_trials_in_batches: bool = False
max_pending_trials: int = 1
early_stopping_strategy: BaseEarlyStoppingStrategy | None = None
Expand Down
13 changes: 6 additions & 7 deletions ax/models/discrete/full_factorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,19 @@ def gen(
pending_observations: Sequence[Sequence[Sequence[TParamValue]]] | None = None,
model_gen_options: TConfig | None = None,
) -> tuple[list[TParamValueList], list[float], TGenMetadata]:
if n != -1:
logger.warning(
"FullFactorialGenerator will ignore the specified value of n. "
"The generator automatically determines how many arms to "
"generate."
)

if fixed_features:
# Make a copy so as to not mutate it
parameter_values = list(parameter_values)
for fixed_feature_index, fixed_feature_value in fixed_features.items():
parameter_values[fixed_feature_index] = [fixed_feature_value]

num_arms = reduce(mul, [len(values) for values in parameter_values], 1)
if n != num_arms:
logger.warning(
"FullFactorialGenerator will ignore the specified value of n. "
"The generator automatically determines how many arms to "
"generate."
)
if self.check_cardinality and num_arms > self.max_cardinality:
raise ValueError(
f"FullFactorialGenerator generated {num_arms} arms, "
Expand Down
9 changes: 4 additions & 5 deletions ax/models/discrete/thompson.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def gen(
pending_observations: Sequence[Sequence[Sequence[TParamValue]]] | None = None,
model_gen_options: TConfig | None = None,
) -> tuple[list[Sequence[TParamValue]], list[float], TGenMetadata]:
if n <= 0:
raise ValueError("ThompsonSampler requires n > 0.")
if objective_weights is None:
raise ValueError("ThompsonSampler requires objective weights.")

Expand All @@ -104,7 +106,6 @@ def gen(
objective_weights=objective_weights, outcome_constraints=outcome_constraints
)
min_weight = self.min_weight if self.min_weight is not None else 2.0 / k

# Second entry is used for tie-breaking
weighted_arms = [
(weights[i], np.random.random(), arms[i])
Expand All @@ -120,17 +121,15 @@ def gen(
)

weighted_arms.sort(reverse=True)
top_weighted_arms = weighted_arms[:n] if n > 0 else weighted_arms
top_weighted_arms = weighted_arms[:n]
top_arms = [arm for _, _, arm in top_weighted_arms]
top_weights = [weight for weight, _, _ in top_weighted_arms]

# N TS arms should have total weight N
if self.uniform_weights:
top_weights = [1.0 for _ in top_weights]
else:
top_weights = [
(x * len(top_weights)) / sum(top_weights) for x in top_weights
]
top_weights = [(x * n) / sum(top_weights) for x in top_weights]
return (
top_arms,
top_weights,
Expand Down
4 changes: 2 additions & 2 deletions ax/models/tests/test_eb_thompson.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_EmpiricalBayesThompsonSamplerGen(self) -> None:
)
self.assertEqual(arms, [[4, 4], [3, 3], [2, 2], [1, 1]])
for weight, expected_weight in zip(
weights, [4 * i for i in [0.66, 0.25, 0.07, 0.02]]
weights, [5 * i for i in [0.66, 0.25, 0.07, 0.02]]
):
self.assertAlmostEqual(weight, expected_weight, delta=0.1)

Expand All @@ -95,7 +95,7 @@ def test_EmpiricalBayesThompsonSamplerWarning(self) -> None:
)
self.assertEqual(arms, [[3, 3], [2, 2], [1, 1]])
for weight, expected_weight in zip(
weights, [3 * i for i in [0.74, 0.21, 0.05]]
weights, [5 * i for i in [0.74, 0.21, 0.05]]
):
self.assertAlmostEqual(weight, expected_weight, delta=0.1)

Expand Down
4 changes: 2 additions & 2 deletions ax/models/tests/test_full_factorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ def test_FullFactorialValidation(self) -> None:
objective_weights=np.ones(1),
)

# Raise error because n != -1
# Raise error because n != num_arms
generator = FullFactorialGenerator()
parameter_values = [[1, 2], ["foo", "bar"]]
with self.assertLogs(
FullFactorialGenerator.__module__, logging.WARNING
) as logger:
generated_points, weights, _ = generator.gen(
n=5,
n=-1,
parameter_values=parameter_values,
objective_weights=np.ones(1),
)
Expand Down
20 changes: 19 additions & 1 deletion ax/models/tests/test_thompson.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def test_ThompsonSamplerValidation(self) -> None:
generator.gen(5, self.parameter_values, objective_weights=None)

def test_ThompsonSamplerMinWeight(self) -> None:
np.random.seed(0)
generator = ThompsonSampler(min_weight=0.01)
generator.fit(
Xs=self.Xs,
Expand All @@ -99,7 +100,7 @@ def test_ThompsonSamplerMinWeight(self) -> None:
outcome_names=self.outcome_names,
)
arms, weights, _ = generator.gen(
n=5,
n=3,
parameter_values=self.parameter_values,
objective_weights=np.ones(1),
)
Expand Down Expand Up @@ -223,3 +224,20 @@ def test_ThompsonSamplerMultiObjectiveWarning(self) -> None:
" not lead to a meaningful result.",
str(warning_list[0].message),
)

def test_ThompsonSamplerNonPositiveN(self) -> None:
generator = ThompsonSampler(min_weight=0.0)
generator.fit(
Xs=self.Xs,
Ys=self.Ys,
Yvars=self.Yvars,
parameter_values=self.parameter_values,
outcome_names=self.outcome_names,
)
for n in (-1, 0):
with self.assertRaisesRegex(ValueError, "ThompsonSampler requires n > 0"):
generator.gen(
n=n,
parameter_values=self.parameter_values,
objective_weights=np.ones(1),
)
25 changes: 22 additions & 3 deletions tutorials/factorial/factorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,16 @@
"We then generate an a set of arms that covers the full space of the factorial design, including the status quo. There are three parameters, with two, three, and four values, respectively, so there are 24 possible arms."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from math import prod\n",
"n = prod(len(p.values) for p in search_space.parameters.values())"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -332,8 +342,10 @@
"source": [
"factorial = Generators.FACTORIAL(search_space=exp.search_space)\n",
"factorial_run = factorial.gen(\n",
" n=-1\n",
") # Number of arms to generate is derived from the search space.\n",
" # Number of arms to generate is derived from the search space. \n",
" # So any number can be passed here.\n",
" n=n \n",
") \n",
"print(len(factorial_run.arms))"
]
},
Expand Down Expand Up @@ -433,7 +445,7 @@
" trial.mark_completed()\n",
" thompson = Generators.THOMPSON(experiment=exp, data=trial.fetch_data(), min_weight=0.01)\n",
" models.append(thompson)\n",
" thompson_run = thompson.gen(n=-1)\n",
" thompson_run = thompson.gen(n=n)\n",
" trial = exp.new_batch_trial(optimize_for_power=True).add_generator_run(thompson_run)"
]
},
Expand Down Expand Up @@ -626,6 +638,13 @@
"\n",
"render(plot_marginal_effects(models[0], \"success_metric\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down

0 comments on commit 632333f

Please sign in to comment.