Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve ergonomics around early stopping #3446

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions ax/core/base_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,8 +621,16 @@ def mark_early_stopped(self, unsafe: bool = False) -> BaseTrial:
Returns:
The trial instance.
"""
if self._status != TrialStatus.RUNNING:
raise ValueError("Can only early stop trial that is currently running.")
if not unsafe:
if self._status != TrialStatus.RUNNING:
raise ValueError("Can only early stop trial that is currently running.")

if self.lookup_data().df.empty:
raise UnsupportedError(
"Cannot mark trial early stopped without data. Please mark trial "
"abandoned instead."
)

self._status = TrialStatus.EARLY_STOPPED
self._time_completed = datetime.now()
return self
Expand Down Expand Up @@ -822,4 +830,4 @@ def _update_trial_attrs_on_clone(
if self.status == TrialStatus.FAILED:
new_trial.mark_failed(reason=self.failed_reason)
return
new_trial.mark_as(self.status)
new_trial.mark_as(self.status, unsafe=True)
6 changes: 6 additions & 0 deletions ax/core/tests/test_batch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,12 @@ def test_FailedBatchTrial(self) -> None:
def test_EarlyStoppedBatchTrial(self) -> None:
self.batch.runner = SyntheticRunner()
self.batch.run()
self.batch.attach_batch_trial_data(
raw_data={
self.batch.arms[0].name: {"m1": 1.0, "m2": 2.0},
self.batch.arms[1].name: {"m1": 3.0, "m2": 4.0},
}
)
self.batch.mark_early_stopped()

self.assertEqual(self.batch.status, TrialStatus.EARLY_STOPPED)
Expand Down
4 changes: 3 additions & 1 deletion ax/core/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1298,7 +1298,9 @@ def test_trial_indices(self) -> None:
self.assertEqual(experiment.trial_indices_expecting_data, {2, 4})
experiment.trials[4].mark_failed()
self.assertEqual(experiment.trial_indices_expecting_data, {2})
experiment.trials[5].mark_running(no_runner_required=True).mark_early_stopped()
experiment.trials[5].mark_running(no_runner_required=True).mark_early_stopped(
unsafe=True
)
self.assertEqual(experiment.trial_indices_expecting_data, {2, 5})

def test_stop_trial(self) -> None:
Expand Down
6 changes: 6 additions & 0 deletions ax/core/tests/test_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,11 @@ def test_mark_as(self) -> None:
kwargs["reason"] = "test_reason_abandon"
if status == TrialStatus.FAILED:
kwargs["reason"] = "test_reason_failed"

# Trial must have data before it can be marked EARLY_STOPPED
if status == TrialStatus.EARLY_STOPPED:
self.trial.update_trial_data(raw_data={"m1": 1.0, "m2": 2.0})

self.trial.mark_as(status=status, **kwargs)
self.assertTrue(self.trial.status == status)

Expand Down Expand Up @@ -254,6 +259,7 @@ def stop(self, trial, reason):
self.trial._runner = DummyStopRunner()
self.trial.mark_running()
self.assertEqual(self.trial.status, TrialStatus.RUNNING)
self.trial.update_trial_data(raw_data={"m1": 1.0, "m2": 2.0})
self.trial.stop(new_status=new_status, reason=reason)
self.assertEqual(self.trial.status, new_status)
self.assertEqual(
Expand Down
19 changes: 6 additions & 13 deletions ax/preview/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,23 +560,16 @@ def mark_trial_abandoned(self, trial_index: int) -> None:
experiment=self._experiment, trial=self._experiment.trials[trial_index]
)

def mark_trial_early_stopped(
self, trial_index: int, raw_data: TOutcome, progression: int | None = None
) -> None:
def mark_trial_early_stopped(self, trial_index: int) -> None:
"""
Manually mark a trial as EARLY_STOPPED while attaching the most recent data.
This is used when the user has decided (with or without Ax's recommendation) to
stop the trial early. EARLY_STOPPED trials will not be re-suggested by
get_next_trials.
Manually mark a trial as EARLY_STOPPED. This is used when the user has decided
(with or without Ax's recommendation) to stop the trial after some data has
been attached but before the trial is completed. Note that if data has not been
attached for the trial yet users should instead call ``mark_trial_abandoned``.
EARLY_STOPPED trials will not be re-suggested by ``get_next_trials``.

Saves to database on completion if storage_config is present.
"""

# First attach the new data
self.attach_data(
trial_index=trial_index, raw_data=raw_data, progression=progression
)

self._experiment.trials[trial_index].mark_early_stopped()

self._save_or_update_trial_in_db_if_possible(
Expand Down
9 changes: 8 additions & 1 deletion ax/preview/api/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,9 +687,16 @@ def test_mark_trial_early_stopped(self) -> None:
client.configure_optimization(objective="foo")

trial_index = [*client.get_next_trials(maximum_trials=1).keys()][0]
client.mark_trial_early_stopped(

with self.assertRaisesRegex(
UnsupportedError, "Cannot mark trial early stopped"
):
client.mark_trial_early_stopped(trial_index=trial_index)

client.attach_data(
trial_index=trial_index, raw_data={"foo": 0.0}, progression=1
)
client.mark_trial_early_stopped(trial_index=trial_index)
self.assertEqual(
client._experiment.trials[trial_index].status,
TrialStatus.EARLY_STOPPED,
Expand Down
57 changes: 45 additions & 12 deletions ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def get_branin_currin_optimization_with_N_sobol_trials(
def get_branin_optimization(
generation_strategy: GenerationStrategy | None = None,
torch_device: torch.device | None = None,
support_intermediate_data: bool = False,
) -> AxClient:
ax_client = AxClient(
generation_strategy=generation_strategy, torch_device=torch_device
Expand All @@ -199,6 +200,7 @@ def get_branin_optimization(
{"name": "y", "type": "range", "bounds": [0.0, 15.0]},
],
objectives={"branin": ObjectiveProperties(minimize=True)},
support_intermediate_data=support_intermediate_data,
)
return ax_client

Expand Down Expand Up @@ -1700,28 +1702,44 @@ def test_trial_completion(self) -> None:
self.assertTrue(math.isnan(best_trial_values[1]["branin"]["branin"]))

def test_update_trial_data(self) -> None:
ax_client = get_branin_optimization()
ax_client = get_branin_optimization(support_intermediate_data=True)
params, idx = ax_client.get_next_trial()
# Can't update before completing.
with self.assertRaisesRegex(ValueError, ".* not in a terminal state"):
ax_client.update_trial_data(trial_index=idx, raw_data={"branin": (0, 0.0)})
ax_client.complete_trial(trial_index=idx, raw_data={"branin": (0, 0.0)})
ax_client.update_trial_data(
trial_index=idx, raw_data=[({"t": 0}, {"branin": (0, 0.0)})]
)
ax_client.complete_trial(
trial_index=idx, raw_data=[({"t": 0}, {"branin": (0, 0.0)})]
)
# Cannot complete a trial twice, should use `update_trial_data`.
with self.assertRaisesRegex(UnsupportedError, ".* already been completed"):
ax_client.complete_trial(trial_index=idx, raw_data={"branin": (0, 0.0)})
ax_client.complete_trial(
trial_index=idx, raw_data=[({"t": 0}, {"branin": (0, 0.0)})]
)
# Check that the update changes the data.
ax_client.update_trial_data(trial_index=idx, raw_data={"branin": (1, 0.0)})
ax_client.update_trial_data(
trial_index=idx, raw_data=[({"t": 0}, {"branin": (1, 0.0)})]
)
df = ax_client.experiment.lookup_data_for_trial(idx)[0].df
self.assertEqual(len(df), 1)
self.assertEqual(df["mean"].item(), 1.0)
self.assertTrue(ax_client.get_trial(idx).status.is_completed)

# With early stopped trial.
params, idx = ax_client.get_next_trial()
ax_client.update_running_trial_with_intermediate_data(
idx,
# pyre-fixme[6]: For 2nd argument expected `Union[floating[typing...
raw_data=[({"t": 0}, {"branin": (branin(*params.values()), 0.0)})],
)

ax_client.stop_trial_early(trial_index=idx)
df = ax_client.experiment.lookup_data_for_trial(idx)[0].df
self.assertEqual(len(df), 0)
ax_client.update_trial_data(trial_index=idx, raw_data={"branin": (2, 0.0)})
self.assertEqual(len(df), 1)
ax_client.update_trial_data(
trial_index=idx, raw_data=[({"t": 0}, {"branin": (2, 0.0)})]
)
df = ax_client.experiment.lookup_data_for_trial(idx)[0].df
self.assertEqual(len(df), 1)
self.assertEqual(df["mean"].item(), 2.0)
Expand All @@ -1730,13 +1748,17 @@ def test_update_trial_data(self) -> None:
# Failed trial.
params, idx = ax_client.get_next_trial()
ax_client.log_trial_failure(trial_index=idx)
ax_client.update_trial_data(trial_index=idx, raw_data={"branin": (3, 0.0)})
ax_client.update_trial_data(
trial_index=idx, raw_data=[({"t": 0}, {"branin": (3, 0.0)})]
)
df = ax_client.experiment.lookup_data_for_trial(idx)[0].df
self.assertEqual(df["mean"].item(), 3.0)

# Incomplete trial fails
params, idx = ax_client.get_next_trial()
ax_client.complete_trial(trial_index=idx, raw_data={"missing_metric": (1, 0.0)})
ax_client.complete_trial(
trial_index=idx, raw_data=[({"t": 0}, {"missing_metric": (0, 0.0)})]
)
self.assertTrue(ax_client.get_trial(idx).status.is_failed)

def test_incomplete_multi_fidelity_trial(self) -> None:
Expand Down Expand Up @@ -2104,13 +2126,14 @@ def test_sqa_storage(self) -> None:
{"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
{"name": "y", "type": "range", "bounds": [0.0, 15.0]},
],
support_intermediate_data=True,
)
for _ in range(5):
parameters, trial_index = ax_client.get_next_trial()
ax_client.complete_trial(
trial_index=trial_index,
# pyre-fixme[6]: For 2nd param expected `Union[List[Tuple[Dict[str, U...
raw_data=branin(*parameters.values()),
raw_data=[({"t": 0}, {"branin": (branin(*parameters.values()), 0.0)})],
)
gs = ax_client.generation_strategy
ax_client = AxClient(db_settings=db_settings)
Expand Down Expand Up @@ -2144,6 +2167,11 @@ def test_sqa_storage(self) -> None:

# Attach an early stopped trial.
parameters, trial_index = ax_client.get_next_trial()
ax_client.update_running_trial_with_intermediate_data(
trial_index=trial_index,
# pyre-fixme[6]: For 2nd param expected `Union[List[Tuple[Dict[str, U...
raw_data=[({"t": 0}, {"branin": (branin(*parameters.values()), 0.0)})],
)
ax_client.stop_trial_early(trial_index=trial_index)

# Reload experiment and check that trial status is accurate.
Expand Down Expand Up @@ -2910,7 +2938,12 @@ def test_stop_trial_early(self) -> None:
],
support_intermediate_data=True,
)
_, idx = ax_client.get_next_trial()
parameters, idx = ax_client.get_next_trial()
ax_client.update_running_trial_with_intermediate_data(
idx,
# pyre-fixme[6]: For 2nd argument expected `Union[floating[typing...
raw_data=[({"t": 0}, {"branin": (branin(*parameters.values()), 0.0)})],
)
ax_client.stop_trial_early(idx)
trial = ax_client.get_trial(idx)
self.assertTrue(trial.status.is_early_stopped)
Expand All @@ -2925,7 +2958,7 @@ def test_estimate_early_stopping_savings(self) -> None:
support_intermediate_data=True,
)
_, idx = ax_client.get_next_trial()
ax_client.stop_trial_early(idx)
ax_client.experiment.trials[idx].mark_early_stopped(unsafe=True)

self.assertEqual(ax_client.estimate_early_stopping_savings(), 0)

Expand Down
4 changes: 2 additions & 2 deletions ax/service/tests/test_best_point_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_best_raw_objective_point(self) -> None:
with self.subTest("Only early-stopped trials"):
exp = get_experiment_with_map_data()
exp.trials[0].mark_running(no_runner_required=True)
exp.trials[0].mark_early_stopped()
exp.trials[0].mark_early_stopped(unsafe=True)
with self.assertRaisesRegex(
ValueError, "Cannot identify best point if no trials are completed."
):
Expand Down Expand Up @@ -414,7 +414,7 @@ def test_extract_Y_from_data(self) -> None:
no_runner_required=True
)
if i in [3, 8, 10]:
trial.mark_early_stopped()
trial.mark_early_stopped(unsafe=True)
else:
trial.mark_completed()

Expand Down