Skip to content

Commit

Permalink
only consider trials with data as target trial
Browse files Browse the repository at this point in the history
Summary: See title. This ensures the target trial is one with data.

Differential Revision: D69477011
  • Loading branch information
sdaulton authored and facebook-github-bot committed Feb 12, 2025
1 parent b0d30c2 commit 0b4f437
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 141 deletions.
39 changes: 23 additions & 16 deletions ax/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ax.core.trial import Trial
from ax.core.types import ComparisonOp
from ax.utils.common.constants import Keys
from pyre_extensions import assert_is_instance, none_throws
from pyre_extensions import none_throws

TArmTrial = tuple[str, int]

Expand Down Expand Up @@ -431,7 +431,8 @@ def extend_pending_observations(
def get_target_trial_index(experiment: Experiment) -> int | None:
"""Get the index of the target trial in the ``Experiment``.
Find the target trial giving priority in the following order:
Find the target trial (among those with data) giving priority in the following
order:
1. a running long-run trial
2. Most recent trial expecting data with running trials be considered the most
recent
Expand All @@ -450,12 +451,17 @@ def get_target_trial_index(experiment: Experiment) -> int | None:
# TODO: @mgarrard improve logic to include trial_obsolete_threshold that
# takes into account the age of the trial, and consider more heavily weighting
# long run trials.
df = experiment.lookup_data().df
if df.empty:
return None
trial_indices_with_data = set(df.trial_index.unique())
running_trials = [
assert_is_instance(trial, BatchTrial)
trial
for trial in experiment.trials_by_status[TrialStatus.RUNNING]
if trial.index in trial_indices_with_data
]
sorted_running_trials = _sort_trials(trials=running_trials, trials_are_running=True)
# Priority 1: Any running long-run trial
# Priority 1: Any running long-run trial with data
target_trial_idx = next(
(
trial.index
Expand All @@ -467,30 +473,31 @@ def get_target_trial_index(experiment: Experiment) -> int | None:
if target_trial_idx is not None:
return target_trial_idx

# Priority 2: longest running currently running trial
# Priority 2: longest running currently running trial with data
if len(sorted_running_trials) > 0:
return sorted_running_trials[0].index

# Priortiy 3: the longest running trial expecting data, discounting running trials
# Priortiy 3: the longest running trial with data, discounting running trials
# as we handled those above
trials_expecting_data = [
assert_is_instance(trial, BatchTrial)
for trial in experiment.trials_expecting_data
if trial.status != TrialStatus.RUNNING
non_running_trial_indices_with_data = trial_indices_with_data - {
t.index for t in running_trials
}
non_running_trials_with_data = [
experiment.trials[i] for i in non_running_trial_indices_with_data
]
sorted_trials_expecting_data = _sort_trials(
trials=trials_expecting_data, trials_are_running=False
sorted_non_running_trials_with_data = _sort_trials(
trials=non_running_trials_with_data, trials_are_running=False
)
if len(sorted_trials_expecting_data) > 0:
return sorted_trials_expecting_data[0].index
if len(sorted_non_running_trials_with_data) > 0:
return sorted_non_running_trials_with_data[0].index

return None


def _sort_trials(
trials: list[BatchTrial],
trials: list[BaseTrial],
trials_are_running: bool,
) -> list[BatchTrial]:
) -> list[BaseTrial]:
"""Sort a list of trials by (1) duration of trial, (2) number of arms in trial.
Args:
Expand Down
3 changes: 2 additions & 1 deletion ax/generation_strategy/generation_node_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def get_status_quo(
f"Attempting to construct status quo input into {next_node} but couldn't "
"identify the target trial. Often this could be due to no trials on the "
f"experiment that are in status {STATUSES_EXPECTING_DATA} on the "
f"experiment. The trials on this experiment are: {experiment.trials}."
f"experiment and have data. The trials on this experiment are: "
f"{experiment.trials}."
)
if experiment.status_quo is None:
raise AxGenerationException(
Expand Down
Loading

0 comments on commit 0b4f437

Please sign in to comment.