Skip to content

Commit

Permalink
remove unnecessary data validation for trial index (facebook#3433)
Browse files Browse the repository at this point in the history
Summary:

previously in this diff stack we realized that this check is actually not that helpful, we remove it here.

Frankly, I don't fully understand how trial_index is being leveraged so could be helpful to get some UI ptrs for that

Differential Revision: D70301660
  • Loading branch information
mgarrard authored and facebook-github-bot committed Feb 27, 2025
1 parent f5a7a26 commit c886c6e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 24 deletions.
20 changes: 4 additions & 16 deletions ax/analysis/plotly/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,8 @@ def __init__(
reflect the how good the model used for candidate generation actually
is.
trial_index: Optional trial index that the model from generation_strategy
was used to generate. We should therefore only have observations from
trials prior to this trial index in our plot. If this is not True, we
should error out.
was used to generate. Useful card attribute to filter to only specific
trial.
"""

self.metric_name = metric_name
Expand Down Expand Up @@ -191,9 +190,8 @@ def _construct_plot(
reflect the how good the model used for candidate generation actually
is.
trial_index: Optional trial index that the model from generation_strategy
was used to generate. We should therefore only have observations from
trials prior to this trial index in our plot. If this is not True, we
should error out.
was used to generate. Useful card attribute to filter to only specific
trial.
experiment: Optional Experiment associated with this analysis. Used to set
the priority of the analysis based on the metric importance in the
optimization config.
Expand Down Expand Up @@ -255,16 +253,6 @@ def _prepare_data(

records = []
for observed, predicted in cv_results:
if trial_index is not None:
if (
observed.features.trial_index is not None
and observed.features.trial_index >= trial_index
):
raise UserInputError(
"CrossValidationPlot was specified to be for the generation of "
f"trial {trial_index}, but has observations from trial "
f"{observed.features.trial_index}."
)
# Find the index of the metric in observed and predicted
observed_i = next(
(
Expand Down
8 changes: 0 additions & 8 deletions ax/analysis/plotly/tests/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,6 @@ def test_compute(self) -> None:
card.df["arm_name"].unique(),
)

def test_it_can_only_contain_observation_prior_to_the_trial_index(self) -> None:
analysis = CrossValidationPlot(metric_name="bar", trial_index=7)
with self.assertRaisesRegex(
UserInputError,
"CrossValidationPlot was specified to be for the generation of trial 7",
):
analysis.compute(generation_strategy=self.client.generation_strategy)

def test_it_can_specify_trial_index_correctly(self) -> None:
analysis = CrossValidationPlot(metric_name="bar", trial_index=9)
card = analysis.compute(generation_strategy=self.client.generation_strategy)
Expand Down

0 comments on commit c886c6e

Please sign in to comment.