Skip to content

Commit

Permalink
Inference trace should be present with NaNs rather than length-zero (#…
Browse files Browse the repository at this point in the history
…3432)

Summary:
Pull Request resolved: #3432

When inference values are not available, as is the case for MOO, MF, and MT problems, the inference trace should still be present as NaNs, for compatibility with past behavior and downstream utilities. D67775930 made its length zero when there are no recommended "best parameters."

This diff:
* Pulls out a function for computing the inference trace for readability
* Sets the inference trace to be all NaNs when

Reviewed By: Balandat

Differential Revision: D70326489

fbshipit-source-id: 68b2eb8fc37ff780e76dc3608eaf1e280af796a0
  • Loading branch information
esantorella authored and facebook-github-bot committed Feb 27, 2025
1 parent 2339fb0 commit b732178
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 6 deletions.
39 changes: 33 additions & 6 deletions ax/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"""

import warnings
from collections.abc import Iterable, Mapping
from collections.abc import Iterable, Mapping, Sequence
from itertools import product
from logging import Logger, WARNING
from time import monotonic, time
Expand Down Expand Up @@ -282,6 +282,34 @@ def _get_oracle_trace_from_arms(
return np.array(oracle_trace)


def _get_inference_trace_from_params(
best_params_list: Sequence[Mapping[str, TParamValue]],
problem: BenchmarkProblem,
n_elements: int,
) -> npt.NDArray:
"""
Get the inference value of each parameterization in ``best_params_list``.
``best_params_list`` can be empty, indicating that inference value is not
supported for this benchmark, in which case the returned array will be all
NaNs with length ``n_elements``. If it is not empty, it must have length
``n_elements``.
"""
if len(best_params_list) == 0:
return np.full(n_elements, np.nan)
if len(best_params_list) != n_elements:
raise RuntimeError(
f"Expected {n_elements} elements in `best_params_list`, got "
f"{len(best_params_list)}."
)
return np.array(
[
_get_oracle_value_of_params(params=params, problem=problem)
for params in best_params_list
]
)


def benchmark_replication(
problem: BenchmarkProblem,
method: BenchmarkMethod,
Expand Down Expand Up @@ -426,11 +454,10 @@ def benchmark_replication(

scheduler.summarize_final_result()

inference_trace = np.array(
[
_get_oracle_value_of_params(params=params, problem=problem)
for params in best_params_list
]
inference_trace = _get_inference_trace_from_params(
best_params_list=best_params_list,
problem=problem,
n_elements=len(cost_trace),
)
oracle_trace = _get_oracle_trace_from_arms(
evaluated_arms_list=evaluated_arms_list, problem=problem
Expand Down
40 changes: 40 additions & 0 deletions ax/benchmark/tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numpy as np
import torch
from ax.benchmark.benchmark import (
_get_inference_trace_from_params,
benchmark_multiple_problems_methods,
benchmark_one_method_problem,
benchmark_replication,
Expand Down Expand Up @@ -737,6 +738,10 @@ def test_replication_moo_sobol(self) -> None:
)

self.assertTrue(np.all(res.score_trace <= 100))
self.assertEqual(len(res.cost_trace), problem.num_trials)
self.assertEqual(len(res.inference_trace), problem.num_trials)
# since inference trace is not supported for MOO, it should be all NaN
self.assertTrue(np.isnan(res.inference_trace).all())

def test_benchmark_one_method_problem(self) -> None:
problem = get_single_objective_benchmark_problem()
Expand Down Expand Up @@ -1059,3 +1064,38 @@ def test_compute_baseline_value_from_sobol(self) -> None:
)
# (5-0) * (5-0)
self.assertEqual(result, 25)

def test_get_inference_trace_from_params(self) -> None:
problem = get_single_objective_benchmark_problem()
with self.subTest("No params"):
n_elements = 4
result = _get_inference_trace_from_params(
best_params_list=[], problem=problem, n_elements=n_elements
)
self.assertEqual(len(result), n_elements)
self.assertTrue(np.isnan(result).all())

with self.subTest("Wrong number of params"):
n_elements = 4
with self.assertRaisesRegex(RuntimeError, "Expected 4 elements"):
_get_inference_trace_from_params(
best_params_list=[{"x0": 0.0, "x1": 0.0}],
problem=problem,
n_elements=n_elements,
)

with self.subTest("Correct number of params"):
n_elements = 2
best_params_list = [{"x0": 0.0, "x1": 0.0}, {"x0": 1.0, "x1": 1.0}]
result = _get_inference_trace_from_params(
best_params_list=best_params_list,
problem=problem,
n_elements=n_elements,
)
self.assertEqual(len(result), n_elements)
self.assertFalse(np.isnan(result).any())
expected_trace = [
problem.test_function.evaluate_true(params=params).item()
for params in best_params_list
]
self.assertEqual(result.tolist(), expected_trace)

0 comments on commit b732178

Please sign in to comment.