diff --git a/ax/core/base_trial.py b/ax/core/base_trial.py index a5945b68043..e37988c7819 100644 --- a/ax/core/base_trial.py +++ b/ax/core/base_trial.py @@ -17,7 +17,7 @@ from ax.core.arm import Arm from ax.core.data import Data from ax.core.formatting_utils import data_and_evaluations_from_raw_data -from ax.core.generator_run import GeneratorRun +from ax.core.generator_run import GeneratorRun, GeneratorRunType from ax.core.map_data import MapData from ax.core.map_metric import MapMetric from ax.core.metric import Metric, MetricFetchResult @@ -26,6 +26,7 @@ from ax.core.types import TCandidateMetadata, TEvaluationOutcome from ax.exceptions.core import UnsupportedError from ax.utils.common.base import SortableBase +from ax.utils.common.constants import Keys from pyre_extensions import none_throws @@ -33,6 +34,10 @@ # import as module to make sphinx-autodoc-typehints happy from ax import core # noqa F401 +MANUAL_GENERATION_METHOD_STR = "Manual" +UNKNOWN_GENERATION_METHOD_STR = "Unknown" +STATUS_QUO_GENERATION_METHOD_STR = "Status Quo" + def immutable_once_run(func: Callable) -> Callable: """Decorator for methods that should throw Error when @@ -659,6 +664,43 @@ def mark_arm_abandoned(self, arm_name: str, reason: str | None = None) -> BaseTr "Use `trial.mark_abandoned` if applicable." ) + @property + def generation_method_str(self) -> str: + """Returns the generation method(s) used to generate this trial's arms, + as a human-readable string (e.g. 'Sobol', 'BoTorch', 'Manual', etc.). + Returns a comma-delimited string if multiple generation methods were used. + """ + # Use model key provided during warm-starting if present, since the + # generator run may not be present on warm-started trials. + if ( + warm_start_model_key := self._properties.get(Keys.WARMSTART_TRIAL_MODEL_KEY) + ) is not None: + return warm_start_model_key + + generation_methods = { + none_throws(generator_run._model_key) + for generator_run in self.generator_runs + if generator_run._model_key is not None + } + + # Add generator-run-type strings for non-ModelBridge generator runs. + gr_type_name_to_str = { + GeneratorRunType.MANUAL.name: MANUAL_GENERATION_METHOD_STR, + GeneratorRunType.STATUS_QUO.name: STATUS_QUO_GENERATION_METHOD_STR, + } + generation_methods |= { + gr_type_name_to_str[generator_run.generator_run_type] + for generator_run in self.generator_runs + if generator_run.generator_run_type in gr_type_name_to_str + } + + return ( + # Sort for deterministic output + ", ".join(sorted(generation_methods)) + if generation_methods + else UNKNOWN_GENERATION_METHOD_STR + ) + def _mark_failed_if_past_TTL(self) -> None: """If trial has TTL set and is running, check if the TTL has elapsed and mark the trial failed if so. diff --git a/ax/core/tests/test_batch_trial.py b/ax/core/tests/test_batch_trial.py index 1f625735cee..951f49e0326 100644 --- a/ax/core/tests/test_batch_trial.py +++ b/ax/core/tests/test_batch_trial.py @@ -13,12 +13,16 @@ import numpy as np from ax.core.arm import Arm +from ax.core.base_trial import ( + MANUAL_GENERATION_METHOD_STR, + TrialStatus, + UNKNOWN_GENERATION_METHOD_STR, +) from ax.core.batch_trial import BatchTrial, GeneratorRunStruct from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun, GeneratorRunType from ax.core.parameter import FixedParameter, ParameterType from ax.core.search_space import SearchSpace -from ax.core.trial_status import TrialStatus from ax.exceptions.core import UnsupportedError from ax.runners.synthetic import SyntheticRunner from ax.utils.common.testutils import TestCase @@ -76,9 +80,12 @@ def test_BasicProperties(self) -> None: self.batch.generator_run_structs[0].generator_run.generator_run_type, GeneratorRunType.MANUAL.name, ) + self.assertEqual(self.batch.generation_method_str, MANUAL_GENERATION_METHOD_STR) - # Test empty arms - self.assertEqual(len(self.experiment.new_batch_trial().abandoned_arms), 0) + # Test empty trial + t = self.experiment.new_batch_trial() + self.assertEqual(len(t.abandoned_arms), 0) + self.assertEqual(t.generation_method_str, UNKNOWN_GENERATION_METHOD_STR) def test_UndefinedSetters(self) -> None: with self.assertRaises(NotImplementedError): @@ -480,6 +487,9 @@ def test_clone_to(self, _) -> None: # test cloning with clear_trial_type=True new_batch_trial = batch.clone_to(clear_trial_type=True) self.assertIsNone(new_batch_trial.trial_type) + self.assertEqual( + new_batch_trial.generation_method_str, MANUAL_GENERATION_METHOD_STR + ) def test_Runner(self) -> None: # Verify BatchTrial without runner will fail @@ -667,9 +677,12 @@ def test_TTL(self) -> None: self.assertIn(2, self.experiment.trial_indices_by_status[TrialStatus.FAILED]) def test_get_candidate_metadata_from_all_generator_runs(self) -> None: + self.assertEqual(self.batch.generation_method_str, MANUAL_GENERATION_METHOD_STR) gr_1 = get_generator_run() gr_2 = get_generator_run2() self.batch.add_generator_run(gr_1) + self.assertEqual(self.batch.generation_method_str, "Manual, Sobol") + # Arms are named when adding GR to trial, so reassign to have a GR that has # names arms. gr_1 = self.batch._generator_run_structs[-1].generator_run @@ -719,6 +732,7 @@ def test_get_candidate_metadata_from_all_generator_runs(self) -> None: cand_metadata_expected[arm.name], self.batch._get_candidate_metadata(arm.name), ) + self.assertEqual(self.batch.generation_method_str, "Manual, Sobol") def test_Sortable(self) -> None: new_batch_trial = self.experiment.new_batch_trial() diff --git a/ax/core/tests/test_trial.py b/ax/core/tests/test_trial.py index 56e73cf87b0..5786e28395f 100644 --- a/ax/core/tests/test_trial.py +++ b/ax/core/tests/test_trial.py @@ -12,7 +12,12 @@ from unittest.mock import Mock, patch import pandas as pd -from ax.core.base_trial import BaseTrial, TrialStatus +from ax.core.base_trial import ( + BaseTrial, + MANUAL_GENERATION_METHOD_STR, + TrialStatus, + UNKNOWN_GENERATION_METHOD_STR, +) from ax.core.data import Data from ax.core.generator_run import GeneratorRun, GeneratorRunType from ax.core.runner import Runner @@ -87,10 +92,13 @@ def test_basic_properties(self) -> None: self.assertEqual( self.trial.generator_run.generator_run_type, GeneratorRunType.MANUAL.name ) + self.assertEqual(self.trial.generation_method_str, MANUAL_GENERATION_METHOD_STR) # Test empty arms + t = self.experiment.new_trial() with self.assertRaises(AttributeError): - self.experiment.new_trial().arm_weights + t.arm_weights + self.assertEqual(t.generation_method_str, UNKNOWN_GENERATION_METHOD_STR) self.trial.mark_running(no_runner_required=True) self.assertTrue(self.trial.status.is_running) diff --git a/ax/utils/common/constants.py b/ax/utils/common/constants.py index f0ff54c367c..dd578ef743e 100644 --- a/ax/utils/common/constants.py +++ b/ax/utils/common/constants.py @@ -93,4 +93,5 @@ class Keys(str, Enum): TASK_FEATURES = "task_features" TRIAL_COMPLETION_TIMESTAMP = "trial_completion_timestamp" WARM_START_REFITTING = "warm_start_refitting" + WARMSTART_TRIAL_MODEL_KEY = "generation_model_key" X_BASELINE = "X_baseline"