From b732178679417aee9c2b06538aab9b92cd76c542 Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Thu, 27 Feb 2025 12:37:02 -0800 Subject: [PATCH] Inference trace should be present with NaNs rather than length-zero (#3432) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3432 When inference values are not available, as is the case for MOO, MF, and MT problems, the inference trace should still be present as NaNs, for compatibility with past behavior and downstream utilities. D67775930 made its length zero when there are no recommended "best parameters." This diff: * Pulls out a function for computing the inference trace for readability * Sets the inference trace to be all NaNs when Reviewed By: Balandat Differential Revision: D70326489 fbshipit-source-id: 68b2eb8fc37ff780e76dc3608eaf1e280af796a0 --- ax/benchmark/benchmark.py | 39 ++++++++++++++++++++++----- ax/benchmark/tests/test_benchmark.py | 40 ++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 6 deletions(-) diff --git a/ax/benchmark/benchmark.py b/ax/benchmark/benchmark.py index 3caa2b73c30..31be1b78cae 100644 --- a/ax/benchmark/benchmark.py +++ b/ax/benchmark/benchmark.py @@ -20,7 +20,7 @@ """ import warnings -from collections.abc import Iterable, Mapping +from collections.abc import Iterable, Mapping, Sequence from itertools import product from logging import Logger, WARNING from time import monotonic, time @@ -282,6 +282,34 @@ def _get_oracle_trace_from_arms( return np.array(oracle_trace) +def _get_inference_trace_from_params( + best_params_list: Sequence[Mapping[str, TParamValue]], + problem: BenchmarkProblem, + n_elements: int, +) -> npt.NDArray: + """ + Get the inference value of each parameterization in ``best_params_list``. + + ``best_params_list`` can be empty, indicating that inference value is not + supported for this benchmark, in which case the returned array will be all + NaNs with length ``n_elements``. If it is not empty, it must have length + ``n_elements``. + """ + if len(best_params_list) == 0: + return np.full(n_elements, np.nan) + if len(best_params_list) != n_elements: + raise RuntimeError( + f"Expected {n_elements} elements in `best_params_list`, got " + f"{len(best_params_list)}." + ) + return np.array( + [ + _get_oracle_value_of_params(params=params, problem=problem) + for params in best_params_list + ] + ) + + def benchmark_replication( problem: BenchmarkProblem, method: BenchmarkMethod, @@ -426,11 +454,10 @@ def benchmark_replication( scheduler.summarize_final_result() - inference_trace = np.array( - [ - _get_oracle_value_of_params(params=params, problem=problem) - for params in best_params_list - ] + inference_trace = _get_inference_trace_from_params( + best_params_list=best_params_list, + problem=problem, + n_elements=len(cost_trace), ) oracle_trace = _get_oracle_trace_from_arms( evaluated_arms_list=evaluated_arms_list, problem=problem diff --git a/ax/benchmark/tests/test_benchmark.py b/ax/benchmark/tests/test_benchmark.py index ec4c207d7df..4a14f713d35 100644 --- a/ax/benchmark/tests/test_benchmark.py +++ b/ax/benchmark/tests/test_benchmark.py @@ -16,6 +16,7 @@ import numpy as np import torch from ax.benchmark.benchmark import ( + _get_inference_trace_from_params, benchmark_multiple_problems_methods, benchmark_one_method_problem, benchmark_replication, @@ -737,6 +738,10 @@ def test_replication_moo_sobol(self) -> None: ) self.assertTrue(np.all(res.score_trace <= 100)) + self.assertEqual(len(res.cost_trace), problem.num_trials) + self.assertEqual(len(res.inference_trace), problem.num_trials) + # since inference trace is not supported for MOO, it should be all NaN + self.assertTrue(np.isnan(res.inference_trace).all()) def test_benchmark_one_method_problem(self) -> None: problem = get_single_objective_benchmark_problem() @@ -1059,3 +1064,38 @@ def test_compute_baseline_value_from_sobol(self) -> None: ) # (5-0) * (5-0) self.assertEqual(result, 25) + + def test_get_inference_trace_from_params(self) -> None: + problem = get_single_objective_benchmark_problem() + with self.subTest("No params"): + n_elements = 4 + result = _get_inference_trace_from_params( + best_params_list=[], problem=problem, n_elements=n_elements + ) + self.assertEqual(len(result), n_elements) + self.assertTrue(np.isnan(result).all()) + + with self.subTest("Wrong number of params"): + n_elements = 4 + with self.assertRaisesRegex(RuntimeError, "Expected 4 elements"): + _get_inference_trace_from_params( + best_params_list=[{"x0": 0.0, "x1": 0.0}], + problem=problem, + n_elements=n_elements, + ) + + with self.subTest("Correct number of params"): + n_elements = 2 + best_params_list = [{"x0": 0.0, "x1": 0.0}, {"x0": 1.0, "x1": 1.0}] + result = _get_inference_trace_from_params( + best_params_list=best_params_list, + problem=problem, + n_elements=n_elements, + ) + self.assertEqual(len(result), n_elements) + self.assertFalse(np.isnan(result).any()) + expected_trace = [ + problem.test_function.evaluate_true(params=params).item() + for params in best_params_list + ] + self.assertEqual(result.tolist(), expected_trace)