From 7c139efb1dbe16f5fb82c57efed0f5283c6f2674 Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Fri, 14 Feb 2025 09:17:16 -0800 Subject: [PATCH] only consider trials with data as target trial (#3351) Summary: See title. This ensures the target trial is one with data. Reviewed By: mgarrard Differential Revision: D69477011 --- ax/core/utils.py | 66 +++-- .../generation_node_input_constructors.py | 6 +- ...test_generation_node_input_constructors.py | 228 ++++++++---------- .../tests/test_generation_strategy.py | 3 +- ax/utils/common/mock.py | 7 +- 5 files changed, 159 insertions(+), 151 deletions(-) diff --git a/ax/core/utils.py b/ax/core/utils.py index ec97811769b..9062fafec32 100644 --- a/ax/core/utils.py +++ b/ax/core/utils.py @@ -25,7 +25,7 @@ from ax.core.trial import Trial from ax.core.types import ComparisonOp from ax.utils.common.constants import Keys -from pyre_extensions import assert_is_instance, none_throws +from pyre_extensions import none_throws TArmTrial = tuple[str, int] @@ -431,8 +431,11 @@ def extend_pending_observations( def get_target_trial_index(experiment: Experiment) -> int | None: """Get the index of the target trial in the ``Experiment``. - Find the target trial giving priority in the following order: - 1. a running long-run trial + Find the target trial (among those with data) giving priority in the following + order: + 1. a running long-run trial. Note if there is a running long-run trial on the + experiment without data, or if there is no data on the experiment, then + this will return None. 2. Most recent trial expecting data with running trials be considered the most recent @@ -450,47 +453,60 @@ def get_target_trial_index(experiment: Experiment) -> int | None: # TODO: @mgarrard improve logic to include trial_obsolete_threshold that # takes into account the age of the trial, and consider more heavily weighting # long run trials. + df = experiment.lookup_data().df + if df.empty: + return None + trial_indices_with_data = set(df.trial_index.unique()) + # only consider running trials with data running_trials = [ - assert_is_instance(trial, BatchTrial) + trial for trial in experiment.trials_by_status[TrialStatus.RUNNING] + if trial.index in trial_indices_with_data ] sorted_running_trials = _sort_trials(trials=running_trials, trials_are_running=True) # Priority 1: Any running long-run trial - target_trial_idx = next( - ( - trial.index - for trial in sorted_running_trials - if trial.trial_type == Keys.LONG_RUN - ), - None, + has_running_long_run_trial = any( + trial.trial_type == Keys.LONG_RUN + for trial in experiment.trials_by_status[TrialStatus.RUNNING] ) - if target_trial_idx is not None: - return target_trial_idx + if has_running_long_run_trial: + # This returns a running long-run trial with data or None + # if there are running long-run trials on the experiment, but + # no data for that trial + return next( + ( + trial.index + for trial in sorted_running_trials + if trial.trial_type == Keys.LONG_RUN + ), + None, + ) - # Priority 2: longest running currently running trial + # Priority 2: longest running currently running trial with data if len(sorted_running_trials) > 0: return sorted_running_trials[0].index - # Priortiy 3: the longest running trial expecting data, discounting running trials + # Priortiy 3: the longest running trial with data, discounting running trials # as we handled those above - trials_expecting_data = [ - assert_is_instance(trial, BatchTrial) - for trial in experiment.trials_expecting_data - if trial.status != TrialStatus.RUNNING + non_running_trial_indices_with_data = trial_indices_with_data - { + t.index for t in running_trials + } + non_running_trials_with_data = [ + experiment.trials[i] for i in non_running_trial_indices_with_data ] - sorted_trials_expecting_data = _sort_trials( - trials=trials_expecting_data, trials_are_running=False + sorted_non_running_trials_with_data = _sort_trials( + trials=non_running_trials_with_data, trials_are_running=False ) - if len(sorted_trials_expecting_data) > 0: - return sorted_trials_expecting_data[0].index + if len(sorted_non_running_trials_with_data) > 0: + return sorted_non_running_trials_with_data[0].index return None def _sort_trials( - trials: list[BatchTrial], + trials: list[BaseTrial], trials_are_running: bool, -) -> list[BatchTrial]: +) -> list[BaseTrial]: """Sort a list of trials by (1) duration of trial, (2) number of arms in trial. Args: diff --git a/ax/generation_strategy/generation_node_input_constructors.py b/ax/generation_strategy/generation_node_input_constructors.py index 87255921d84..96f5d8fb942 100644 --- a/ax/generation_strategy/generation_node_input_constructors.py +++ b/ax/generation_strategy/generation_node_input_constructors.py @@ -103,8 +103,10 @@ def get_status_quo( raise AxGenerationException( f"Attempting to construct status quo input into {next_node} but couldn't " "identify the target trial. Often this could be due to no trials on the " - f"experiment that are in status {STATUSES_EXPECTING_DATA}. The trials on " - f"this experiment are: {experiment.trials}." + f"experiment that are in status {STATUSES_EXPECTING_DATA} " + f"and have data. The trials on this experiment are: " + f"{experiment.trials} and trials with data are: " + f"{experiment.lookup_data().df.trial_index.unique()}." ) if experiment.status_quo is None: raise AxGenerationException( diff --git a/ax/generation_strategy/tests/test_generation_node_input_constructors.py b/ax/generation_strategy/tests/test_generation_node_input_constructors.py index 1a55c284e17..d4f04f697b8 100644 --- a/ax/generation_strategy/tests/test_generation_node_input_constructors.py +++ b/ax/generation_strategy/tests/test_generation_node_input_constructors.py @@ -281,18 +281,15 @@ def test_no_n_provided_remaining_n_with_exp_prop(self) -> None: self.assertEqual(num_to_gen, 4) def test_set_target_trial_long_run_wins(self) -> None: - self._add_sobol_trial( - experiment=self.experiment, - trial_type=Keys.LONG_RUN, - complete=False, - num_arms=1, - ) - self._add_sobol_trial( - experiment=self.experiment, - trial_type=Keys.SHORT_RUN, - complete=False, - num_arms=3, - ) + for num_arms, trial_type in zip((1, 3), (Keys.LONG_RUN, Keys.SHORT_RUN)): + self._add_sobol_trial( + experiment=self.experiment, + trial_type=trial_type, + complete=False, + num_arms=num_arms, + with_status_quo=True, + ) + self.experiment.fetch_data() target_trial = NodeInputConstructors.TARGET_TRIAL_FIXED_FEATURES( previous_node=None, next_node=self.sobol_generation_node, @@ -314,6 +311,7 @@ def test_status_quo_features_no_sq(self) -> None: complete=False, num_arms=1, ) + self.experiment.fetch_data() with self.assertRaisesRegex( AxGenerationException, "experiment has no status quo", @@ -326,20 +324,15 @@ def test_status_quo_features_no_sq(self) -> None: ) def test_status_quo_features(self) -> None: - self._add_sobol_trial( - experiment=self.experiment, - trial_type=Keys.LONG_RUN, - complete=False, - num_arms=1, - with_status_quo=True, - ) - self._add_sobol_trial( - experiment=self.experiment, - trial_type=Keys.LONG_RUN, - complete=False, - num_arms=3, - with_status_quo=True, - ) + for num_arms in (1, 3): + self._add_sobol_trial( + experiment=self.experiment, + trial_type=Keys.LONG_RUN, + complete=False, + num_arms=num_arms, + with_status_quo=True, + ) + self.experiment.fetch_data() sq_ft = NodeInputConstructors.STATUS_QUO_FEATURES( previous_node=None, next_node=self.sobol_generation_node, @@ -352,18 +345,14 @@ def test_status_quo_features(self) -> None: ) def test_set_target_trial_most_arms_long_run_wins(self) -> None: - self._add_sobol_trial( - experiment=self.experiment, - trial_type=Keys.LONG_RUN, - complete=False, - num_arms=1, - ) - self._add_sobol_trial( - experiment=self.experiment, - trial_type=Keys.LONG_RUN, - complete=False, - num_arms=3, - ) + for num_arms in (1, 3): + self._add_sobol_trial( + experiment=self.experiment, + trial_type=Keys.LONG_RUN, + complete=False, + num_arms=num_arms, + ) + self.experiment.fetch_data() # Test most arms should win target_trial = NodeInputConstructors.TARGET_TRIAL_FIXED_FEATURES( previous_node=None, @@ -382,18 +371,14 @@ def test_set_target_trial_most_arms_long_run_wins(self) -> None: def test_set_target_trial_long_run_ties(self) -> None: # if all things are equal we should just pick the first one # in the sorted list - self._add_sobol_trial( - experiment=self.experiment, - trial_type=Keys.LONG_RUN, - complete=False, - num_arms=1, - ) - self._add_sobol_trial( - experiment=self.experiment, - trial_type=Keys.LONG_RUN, - complete=False, - num_arms=1, - ) + for _ in range(2): + self._add_sobol_trial( + experiment=self.experiment, + trial_type=Keys.LONG_RUN, + complete=False, + num_arms=1, + ) + self.experiment.fetch_data() target_trial = NodeInputConstructors.TARGET_TRIAL_FIXED_FEATURES( previous_node=None, next_node=self.sobol_generation_node, @@ -409,18 +394,14 @@ def test_set_target_trial_long_run_ties(self) -> None: ) def test_set_target_trial_longest_duration_long_run_wins(self) -> None: - self._add_sobol_trial( - experiment=self.experiment, - trial_type=Keys.LONG_RUN, - complete=False, - num_arms=1, - ) - self._add_sobol_trial( - experiment=self.experiment, - trial_type=Keys.LONG_RUN, - complete=False, - num_arms=1, - ) + for _ in range(2): + self._add_sobol_trial( + experiment=self.experiment, + trial_type=Keys.LONG_RUN, + complete=False, + num_arms=1, + ) + self.experiment.fetch_data() self.experiment.trials[0]._time_run_started = datetime(2000, 1, 2) self.experiment.trials[1]._time_run_started = datetime(2000, 1, 1) target_trial = NodeInputConstructors.TARGET_TRIAL_FIXED_FEATURES( @@ -438,18 +419,14 @@ def test_set_target_trial_longest_duration_long_run_wins(self) -> None: ) def test_set_target_trial_running_short_trial_wins(self) -> None: - self._add_sobol_trial( - experiment=self.experiment, - trial_type=Keys.LONG_RUN, - complete=True, - num_arms=1, - ) - self._add_sobol_trial( - experiment=self.experiment, - trial_type=Keys.SHORT_RUN, - complete=False, - num_arms=1, - ) + for trial_type, complete in zip((Keys.LONG_RUN, Keys.SHORT_RUN), (True, False)): + self._add_sobol_trial( + experiment=self.experiment, + trial_type=trial_type, + complete=complete, + num_arms=1, + ) + self.experiment.fetch_data() target_trial = NodeInputConstructors.TARGET_TRIAL_FIXED_FEATURES( previous_node=None, next_node=self.sobol_generation_node, @@ -465,18 +442,14 @@ def test_set_target_trial_running_short_trial_wins(self) -> None: ) def test_set_target_trial_longest_short_wins(self) -> None: - self._add_sobol_trial( - experiment=self.experiment, - trial_type=Keys.SHORT_RUN, - complete=False, - num_arms=1, - ) - self._add_sobol_trial( - experiment=self.experiment, - trial_type=Keys.SHORT_RUN, - complete=False, - num_arms=1, - ) + for _ in range(2): + self._add_sobol_trial( + experiment=self.experiment, + trial_type=Keys.SHORT_RUN, + complete=False, + num_arms=1, + ) + self.experiment.fetch_data() self.experiment.trials[0]._time_run_started = datetime(2000, 1, 2) self.experiment.trials[1]._time_run_started = datetime(2000, 1, 1) target_trial = NodeInputConstructors.TARGET_TRIAL_FIXED_FEATURES( @@ -494,18 +467,14 @@ def test_set_target_trial_longest_short_wins(self) -> None: ) def test_set_target_trial_most_arms_short_running_wins(self) -> None: - self._add_sobol_trial( - experiment=self.experiment, - trial_type=Keys.SHORT_RUN, - complete=False, - num_arms=1, - ) - self._add_sobol_trial( - experiment=self.experiment, - trial_type=Keys.SHORT_RUN, - complete=False, - num_arms=3, - ) + for num_arms in (1, 3): + self._add_sobol_trial( + experiment=self.experiment, + trial_type=Keys.SHORT_RUN, + complete=False, + num_arms=num_arms, + ) + self.experiment.fetch_data() target_trial = NodeInputConstructors.TARGET_TRIAL_FIXED_FEATURES( previous_node=None, next_node=self.sobol_generation_node, @@ -521,18 +490,14 @@ def test_set_target_trial_most_arms_short_running_wins(self) -> None: ) def test_set_target_trial_most_arms_complete_short_wins(self) -> None: - self._add_sobol_trial( - experiment=self.experiment, - trial_type=Keys.SHORT_RUN, - complete=True, - num_arms=1, - ) - self._add_sobol_trial( - experiment=self.experiment, - trial_type=Keys.SHORT_RUN, - complete=True, - num_arms=3, - ) + for num_arms in (1, 3): + self._add_sobol_trial( + experiment=self.experiment, + trial_type=Keys.SHORT_RUN, + complete=False, + num_arms=num_arms, + ) + self.experiment.fetch_data() target_trial = NodeInputConstructors.TARGET_TRIAL_FIXED_FEATURES( previous_node=None, next_node=self.sobol_generation_node, @@ -548,18 +513,14 @@ def test_set_target_trial_most_arms_complete_short_wins(self) -> None: ) def test_set_target_trial_longest_short_complete_wins(self) -> None: - self._add_sobol_trial( - experiment=self.experiment, - trial_type=Keys.SHORT_RUN, - complete=True, - num_arms=1, - ) - self._add_sobol_trial( - experiment=self.experiment, - trial_type=Keys.SHORT_RUN, - complete=True, - num_arms=1, - ) + for _ in range(2): + self._add_sobol_trial( + experiment=self.experiment, + trial_type=Keys.SHORT_RUN, + complete=True, + num_arms=1, + ) + self.experiment.fetch_data() self.experiment.trials[0]._time_run_started = datetime(2000, 1, 2) self.experiment.trials[1]._time_run_started = datetime(2000, 1, 1) target_trial = NodeInputConstructors.TARGET_TRIAL_FIXED_FEATURES( @@ -588,6 +549,29 @@ def test_target_trial_raises_error_if_none_found(self) -> None: experiment=self.experiment, ) + def test_target_trial_latest_trial_no_data(self) -> None: + for _ in range(2): + self._add_sobol_trial( + experiment=self.experiment, + trial_type=Keys.SHORT_RUN, + complete=True, + num_arms=1, + ) + self.experiment.fetch_trials_data(trial_indices=[0]) + target_trial = NodeInputConstructors.TARGET_TRIAL_FIXED_FEATURES( + previous_node=None, + next_node=self.sobol_generation_node, + gs_gen_call_kwargs={}, + experiment=self.experiment, + ) + self.assertEqual( + target_trial, + ObservationFeatures( + parameters={}, + trial_index=0, + ), + ) + def _add_sobol_trial( self, experiment: Experiment, diff --git a/ax/generation_strategy/tests/test_generation_strategy.py b/ax/generation_strategy/tests/test_generation_strategy.py index 2acd7ba0158..9792e2bc6ab 100644 --- a/ax/generation_strategy/tests/test_generation_strategy.py +++ b/ax/generation_strategy/tests/test_generation_strategy.py @@ -1866,7 +1866,8 @@ def test_node_gs_with_auto_transitions_three_phase(self) -> None: self.assertEqual(trial.generator_runs[0]._generation_node_name, "sobol_4") def test_gs_with_fixed_features_constructor(self) -> None: - exp = get_branin_experiment() + exp = get_branin_experiment(with_completed_batch=True) + exp.fetch_data() sobol_criterion = [ MinTrials( threshold=1, diff --git a/ax/utils/common/mock.py b/ax/utils/common/mock.py index 47cd29266d6..e4aa750fbde 100644 --- a/ax/utils/common/mock.py +++ b/ax/utils/common/mock.py @@ -34,5 +34,10 @@ def side_effect(self: C, *args: Any, **kwargs: Any) -> T: return original_method(self, *args, **kwargs) patcher = patch(mock_path, autospec=True, side_effect=side_effect) - yield patcher.start() + try: + yield patcher.start() + except Exception as e: + # tear down the patch if the `original_method` fails + patcher.stop() + raise e patcher.stop()