diff --git a/ax/core/base_trial.py b/ax/core/base_trial.py index e37988c7819..594c06a4288 100644 --- a/ax/core/base_trial.py +++ b/ax/core/base_trial.py @@ -621,8 +621,16 @@ def mark_early_stopped(self, unsafe: bool = False) -> BaseTrial: Returns: The trial instance. """ - if self._status != TrialStatus.RUNNING: - raise ValueError("Can only early stop trial that is currently running.") + if not unsafe: + if self._status != TrialStatus.RUNNING: + raise ValueError("Can only early stop trial that is currently running.") + + if self.lookup_data().df.empty: + raise UnsupportedError( + "Cannot mark trial early stopped without data. Please mark trial " + "abandoned instead." + ) + self._status = TrialStatus.EARLY_STOPPED self._time_completed = datetime.now() return self @@ -822,4 +830,4 @@ def _update_trial_attrs_on_clone( if self.status == TrialStatus.FAILED: new_trial.mark_failed(reason=self.failed_reason) return - new_trial.mark_as(self.status) + new_trial.mark_as(self.status, unsafe=True) diff --git a/ax/core/tests/test_batch_trial.py b/ax/core/tests/test_batch_trial.py index 951f49e0326..e028a53bf56 100644 --- a/ax/core/tests/test_batch_trial.py +++ b/ax/core/tests/test_batch_trial.py @@ -404,6 +404,12 @@ def test_FailedBatchTrial(self) -> None: def test_EarlyStoppedBatchTrial(self) -> None: self.batch.runner = SyntheticRunner() self.batch.run() + self.batch.attach_batch_trial_data( + raw_data={ + self.batch.arms[0].name: {"m1": 1.0, "m2": 2.0}, + self.batch.arms[1].name: {"m1": 3.0, "m2": 4.0}, + } + ) self.batch.mark_early_stopped() self.assertEqual(self.batch.status, TrialStatus.EARLY_STOPPED) diff --git a/ax/core/tests/test_experiment.py b/ax/core/tests/test_experiment.py index 3fb517b06aa..6abdfa4817d 100644 --- a/ax/core/tests/test_experiment.py +++ b/ax/core/tests/test_experiment.py @@ -1298,7 +1298,9 @@ def test_trial_indices(self) -> None: self.assertEqual(experiment.trial_indices_expecting_data, {2, 4}) experiment.trials[4].mark_failed() self.assertEqual(experiment.trial_indices_expecting_data, {2}) - experiment.trials[5].mark_running(no_runner_required=True).mark_early_stopped() + experiment.trials[5].mark_running(no_runner_required=True).mark_early_stopped( + unsafe=True + ) self.assertEqual(experiment.trial_indices_expecting_data, {2, 5}) def test_stop_trial(self) -> None: diff --git a/ax/core/tests/test_trial.py b/ax/core/tests/test_trial.py index 5786e28395f..508c4815169 100644 --- a/ax/core/tests/test_trial.py +++ b/ax/core/tests/test_trial.py @@ -205,6 +205,11 @@ def test_mark_as(self) -> None: kwargs["reason"] = "test_reason_abandon" if status == TrialStatus.FAILED: kwargs["reason"] = "test_reason_failed" + + # Trial must have data before it can be marked EARLY_STOPPED + if status == TrialStatus.EARLY_STOPPED: + self.trial.update_trial_data(raw_data={"m1": 1.0, "m2": 2.0}) + self.trial.mark_as(status=status, **kwargs) self.assertTrue(self.trial.status == status) @@ -254,6 +259,7 @@ def stop(self, trial, reason): self.trial._runner = DummyStopRunner() self.trial.mark_running() self.assertEqual(self.trial.status, TrialStatus.RUNNING) + self.trial.update_trial_data(raw_data={"m1": 1.0, "m2": 2.0}) self.trial.stop(new_status=new_status, reason=reason) self.assertEqual(self.trial.status, new_status) self.assertEqual( diff --git a/ax/preview/api/client.py b/ax/preview/api/client.py index 99140177d05..207f6e38853 100644 --- a/ax/preview/api/client.py +++ b/ax/preview/api/client.py @@ -560,23 +560,16 @@ def mark_trial_abandoned(self, trial_index: int) -> None: experiment=self._experiment, trial=self._experiment.trials[trial_index] ) - def mark_trial_early_stopped( - self, trial_index: int, raw_data: TOutcome, progression: int | None = None - ) -> None: + def mark_trial_early_stopped(self, trial_index: int) -> None: """ - Manually mark a trial as EARLY_STOPPED while attaching the most recent data. - This is used when the user has decided (with or without Ax's recommendation) to - stop the trial early. EARLY_STOPPED trials will not be re-suggested by - get_next_trials. + Manually mark a trial as EARLY_STOPPED. This is used when the user has decided + (with or without Ax's recommendation) to stop the trial after some data has + been attached but before the trial is completed. Note that if data has not been + attached for the trial yet users should instead call ``mark_trial_abandoned``. + EARLY_STOPPED trials will not be re-suggested by ``get_next_trials``. Saves to database on completion if storage_config is present. """ - - # First attach the new data - self.attach_data( - trial_index=trial_index, raw_data=raw_data, progression=progression - ) - self._experiment.trials[trial_index].mark_early_stopped() self._save_or_update_trial_in_db_if_possible( diff --git a/ax/preview/api/tests/test_client.py b/ax/preview/api/tests/test_client.py index 4a378dd549a..77af5ab5998 100644 --- a/ax/preview/api/tests/test_client.py +++ b/ax/preview/api/tests/test_client.py @@ -687,9 +687,16 @@ def test_mark_trial_early_stopped(self) -> None: client.configure_optimization(objective="foo") trial_index = [*client.get_next_trials(maximum_trials=1).keys()][0] - client.mark_trial_early_stopped( + + with self.assertRaisesRegex( + UnsupportedError, "Cannot mark trial early stopped" + ): + client.mark_trial_early_stopped(trial_index=trial_index) + + client.attach_data( trial_index=trial_index, raw_data={"foo": 0.0}, progression=1 ) + client.mark_trial_early_stopped(trial_index=trial_index) self.assertEqual( client._experiment.trials[trial_index].status, TrialStatus.EARLY_STOPPED, diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index 7988563b6b8..0a86a9db0d1 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -188,6 +188,7 @@ def get_branin_currin_optimization_with_N_sobol_trials( def get_branin_optimization( generation_strategy: GenerationStrategy | None = None, torch_device: torch.device | None = None, + support_intermediate_data: bool = False, ) -> AxClient: ax_client = AxClient( generation_strategy=generation_strategy, torch_device=torch_device @@ -199,6 +200,7 @@ def get_branin_optimization( {"name": "y", "type": "range", "bounds": [0.0, 15.0]}, ], objectives={"branin": ObjectiveProperties(minimize=True)}, + support_intermediate_data=support_intermediate_data, ) return ax_client @@ -1700,17 +1702,25 @@ def test_trial_completion(self) -> None: self.assertTrue(math.isnan(best_trial_values[1]["branin"]["branin"])) def test_update_trial_data(self) -> None: - ax_client = get_branin_optimization() + ax_client = get_branin_optimization(support_intermediate_data=True) params, idx = ax_client.get_next_trial() # Can't update before completing. with self.assertRaisesRegex(ValueError, ".* not in a terminal state"): - ax_client.update_trial_data(trial_index=idx, raw_data={"branin": (0, 0.0)}) - ax_client.complete_trial(trial_index=idx, raw_data={"branin": (0, 0.0)}) + ax_client.update_trial_data( + trial_index=idx, raw_data=[({"t": 0}, {"branin": (0, 0.0)})] + ) + ax_client.complete_trial( + trial_index=idx, raw_data=[({"t": 0}, {"branin": (0, 0.0)})] + ) # Cannot complete a trial twice, should use `update_trial_data`. with self.assertRaisesRegex(UnsupportedError, ".* already been completed"): - ax_client.complete_trial(trial_index=idx, raw_data={"branin": (0, 0.0)}) + ax_client.complete_trial( + trial_index=idx, raw_data=[({"t": 0}, {"branin": (0, 0.0)})] + ) # Check that the update changes the data. - ax_client.update_trial_data(trial_index=idx, raw_data={"branin": (1, 0.0)}) + ax_client.update_trial_data( + trial_index=idx, raw_data=[({"t": 0}, {"branin": (1, 0.0)})] + ) df = ax_client.experiment.lookup_data_for_trial(idx)[0].df self.assertEqual(len(df), 1) self.assertEqual(df["mean"].item(), 1.0) @@ -1718,10 +1728,18 @@ def test_update_trial_data(self) -> None: # With early stopped trial. params, idx = ax_client.get_next_trial() + ax_client.update_running_trial_with_intermediate_data( + idx, + # pyre-fixme[6]: For 2nd argument expected `Union[floating[typing... + raw_data=[({"t": 0}, {"branin": (branin(*params.values()), 0.0)})], + ) + ax_client.stop_trial_early(trial_index=idx) df = ax_client.experiment.lookup_data_for_trial(idx)[0].df - self.assertEqual(len(df), 0) - ax_client.update_trial_data(trial_index=idx, raw_data={"branin": (2, 0.0)}) + self.assertEqual(len(df), 1) + ax_client.update_trial_data( + trial_index=idx, raw_data=[({"t": 0}, {"branin": (2, 0.0)})] + ) df = ax_client.experiment.lookup_data_for_trial(idx)[0].df self.assertEqual(len(df), 1) self.assertEqual(df["mean"].item(), 2.0) @@ -1730,13 +1748,17 @@ def test_update_trial_data(self) -> None: # Failed trial. params, idx = ax_client.get_next_trial() ax_client.log_trial_failure(trial_index=idx) - ax_client.update_trial_data(trial_index=idx, raw_data={"branin": (3, 0.0)}) + ax_client.update_trial_data( + trial_index=idx, raw_data=[({"t": 0}, {"branin": (3, 0.0)})] + ) df = ax_client.experiment.lookup_data_for_trial(idx)[0].df self.assertEqual(df["mean"].item(), 3.0) # Incomplete trial fails params, idx = ax_client.get_next_trial() - ax_client.complete_trial(trial_index=idx, raw_data={"missing_metric": (1, 0.0)}) + ax_client.complete_trial( + trial_index=idx, raw_data=[({"t": 0}, {"missing_metric": (0, 0.0)})] + ) self.assertTrue(ax_client.get_trial(idx).status.is_failed) def test_incomplete_multi_fidelity_trial(self) -> None: @@ -2104,13 +2126,14 @@ def test_sqa_storage(self) -> None: {"name": "x", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "y", "type": "range", "bounds": [0.0, 15.0]}, ], + support_intermediate_data=True, ) for _ in range(5): parameters, trial_index = ax_client.get_next_trial() ax_client.complete_trial( trial_index=trial_index, # pyre-fixme[6]: For 2nd param expected `Union[List[Tuple[Dict[str, U... - raw_data=branin(*parameters.values()), + raw_data=[({"t": 0}, {"branin": (branin(*parameters.values()), 0.0)})], ) gs = ax_client.generation_strategy ax_client = AxClient(db_settings=db_settings) @@ -2144,6 +2167,11 @@ def test_sqa_storage(self) -> None: # Attach an early stopped trial. parameters, trial_index = ax_client.get_next_trial() + ax_client.update_running_trial_with_intermediate_data( + trial_index=trial_index, + # pyre-fixme[6]: For 2nd param expected `Union[List[Tuple[Dict[str, U... + raw_data=[({"t": 0}, {"branin": (branin(*parameters.values()), 0.0)})], + ) ax_client.stop_trial_early(trial_index=trial_index) # Reload experiment and check that trial status is accurate. @@ -2910,7 +2938,12 @@ def test_stop_trial_early(self) -> None: ], support_intermediate_data=True, ) - _, idx = ax_client.get_next_trial() + parameters, idx = ax_client.get_next_trial() + ax_client.update_running_trial_with_intermediate_data( + idx, + # pyre-fixme[6]: For 2nd argument expected `Union[floating[typing... + raw_data=[({"t": 0}, {"branin": (branin(*parameters.values()), 0.0)})], + ) ax_client.stop_trial_early(idx) trial = ax_client.get_trial(idx) self.assertTrue(trial.status.is_early_stopped) @@ -2925,7 +2958,7 @@ def test_estimate_early_stopping_savings(self) -> None: support_intermediate_data=True, ) _, idx = ax_client.get_next_trial() - ax_client.stop_trial_early(idx) + ax_client.experiment.trials[idx].mark_early_stopped(unsafe=True) self.assertEqual(ax_client.estimate_early_stopping_savings(), 0) diff --git a/ax/service/tests/test_best_point_utils.py b/ax/service/tests/test_best_point_utils.py index 1527cdb894c..14dbf1e2dde 100644 --- a/ax/service/tests/test_best_point_utils.py +++ b/ax/service/tests/test_best_point_utils.py @@ -157,7 +157,7 @@ def test_best_raw_objective_point(self) -> None: with self.subTest("Only early-stopped trials"): exp = get_experiment_with_map_data() exp.trials[0].mark_running(no_runner_required=True) - exp.trials[0].mark_early_stopped() + exp.trials[0].mark_early_stopped(unsafe=True) with self.assertRaisesRegex( ValueError, "Cannot identify best point if no trials are completed." ): @@ -414,7 +414,7 @@ def test_extract_Y_from_data(self) -> None: no_runner_required=True ) if i in [3, 8, 10]: - trial.mark_early_stopped() + trial.mark_early_stopped(unsafe=True) else: trial.mark_completed()