Skip to content

Commit

Permalink
only consider trials with data as target trial (facebook#3351)
Browse files Browse the repository at this point in the history
Summary:

See title. This ensures the target trial is one with data.

Reviewed By: mgarrard

Differential Revision: D69477011
  • Loading branch information
sdaulton authored and facebook-github-bot committed Feb 14, 2025
1 parent 5265946 commit 3ec31da
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 151 deletions.
66 changes: 41 additions & 25 deletions ax/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ax.core.trial import Trial
from ax.core.types import ComparisonOp
from ax.utils.common.constants import Keys
from pyre_extensions import assert_is_instance, none_throws
from pyre_extensions import none_throws

TArmTrial = tuple[str, int]

Expand Down Expand Up @@ -431,8 +431,11 @@ def extend_pending_observations(
def get_target_trial_index(experiment: Experiment) -> int | None:
"""Get the index of the target trial in the ``Experiment``.
Find the target trial giving priority in the following order:
1. a running long-run trial
Find the target trial (among those with data) giving priority in the following
order:
1. a running long-run trial. Note if there is a running long-run trial on the
experiment without data, or if there is no data on the experiment, then
this will return None.
2. Most recent trial expecting data with running trials be considered the most
recent
Expand All @@ -450,47 +453,60 @@ def get_target_trial_index(experiment: Experiment) -> int | None:
# TODO: @mgarrard improve logic to include trial_obsolete_threshold that
# takes into account the age of the trial, and consider more heavily weighting
# long run trials.
df = experiment.lookup_data().df
if df.empty:
return None
trial_indices_with_data = set(df.trial_index.unique())
# only consider running trials with data
running_trials = [
assert_is_instance(trial, BatchTrial)
trial
for trial in experiment.trials_by_status[TrialStatus.RUNNING]
if trial.index in trial_indices_with_data
]
sorted_running_trials = _sort_trials(trials=running_trials, trials_are_running=True)
# Priority 1: Any running long-run trial
target_trial_idx = next(
(
trial.index
for trial in sorted_running_trials
if trial.trial_type == Keys.LONG_RUN
),
None,
has_running_long_run_trial = any(
trial.trial_type == Keys.LONG_RUN
for trial in experiment.trials_by_status[TrialStatus.RUNNING]
)
if target_trial_idx is not None:
return target_trial_idx
if has_running_long_run_trial:
# This returns a running long-run trial with data or None
# if there are running long-run trials on the experiment, but
# no data for that trial
return next(
(
trial.index
for trial in sorted_running_trials
if trial.trial_type == Keys.LONG_RUN
),
None,
)

# Priority 2: longest running currently running trial
# Priority 2: longest running currently running trial with data
if len(sorted_running_trials) > 0:
return sorted_running_trials[0].index

# Priortiy 3: the longest running trial expecting data, discounting running trials
# Priortiy 3: the longest running trial with data, discounting running trials
# as we handled those above
trials_expecting_data = [
assert_is_instance(trial, BatchTrial)
for trial in experiment.trials_expecting_data
if trial.status != TrialStatus.RUNNING
non_running_trial_indices_with_data = trial_indices_with_data - {
t.index for t in running_trials
}
non_running_trials_with_data = [
experiment.trials[i] for i in non_running_trial_indices_with_data
]
sorted_trials_expecting_data = _sort_trials(
trials=trials_expecting_data, trials_are_running=False
sorted_non_running_trials_with_data = _sort_trials(
trials=non_running_trials_with_data, trials_are_running=False
)
if len(sorted_trials_expecting_data) > 0:
return sorted_trials_expecting_data[0].index
if len(sorted_non_running_trials_with_data) > 0:
return sorted_non_running_trials_with_data[0].index

return None


def _sort_trials(
trials: list[BatchTrial],
trials: list[BaseTrial],
trials_are_running: bool,
) -> list[BatchTrial]:
) -> list[BaseTrial]:
"""Sort a list of trials by (1) duration of trial, (2) number of arms in trial.
Args:
Expand Down
6 changes: 4 additions & 2 deletions ax/generation_strategy/generation_node_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,10 @@ def get_status_quo(
raise AxGenerationException(
f"Attempting to construct status quo input into {next_node} but couldn't "
"identify the target trial. Often this could be due to no trials on the "
f"experiment that are in status {STATUSES_EXPECTING_DATA}. The trials on "
f"this experiment are: {experiment.trials}."
f"experiment that are in status {STATUSES_EXPECTING_DATA} "
f"and have data. The trials on this experiment are: "
f"{experiment.trials} and trials with data are: "
f"{experiment.lookup_data().df.trial_index.unique()}."
)
if experiment.status_quo is None:
raise AxGenerationException(
Expand Down
Loading

0 comments on commit 3ec31da

Please sign in to comment.