diff --git a/ax/analysis/plotly/cross_validation.py b/ax/analysis/plotly/cross_validation.py index 25fd5483455..c93445d3f48 100644 --- a/ax/analysis/plotly/cross_validation.py +++ b/ax/analysis/plotly/cross_validation.py @@ -14,6 +14,7 @@ from ax.core.experiment import Experiment from ax.exceptions.core import UserInputError from ax.generation_strategy.generation_strategy import GenerationStrategy +from ax.modelbridge.base import Adapter from ax.modelbridge.cross_validation import cross_validate from plotly import express as px, graph_objects as go from pyre_extensions import none_throws @@ -89,18 +90,108 @@ def compute( metric_name = self.metric_name or select_metric( experiment=generation_strategy.experiment ) + # If model is not fit already, fit it + if generation_strategy.model is None: + generation_strategy._fit_current_model(None) - df = _prepare_data( - generation_strategy=generation_strategy, + return self._construct_plot( + adapter=none_throws(generation_strategy.model), metric_name=metric_name, folds=self.folds, untransform=self.untransform, trial_index=self.trial_index, + experiment=experiment, ) - fig = _prepare_plot(df=df) - k_folds_substring = f"{self.folds}-fold" if self.folds > 0 else "leave-one-out" + def _compute_adhoc( + self, + adapter: Adapter, + metric_name: str, + experiment: Experiment | None = None, + folds: int = -1, + untransform: bool = True, + ) -> PlotlyAnalysisCard: + """ + Helper method to expose adhoc cross validation plotting. This overrides the + default assumption that the adapter from the generation strategy should be + used. Only for advanced users in a notebook setting. + + Args: + adapter: The adapter that will be assessed during cross validation. + metric_name: The name of the metric to plot. Must be provided for adhoc + plotting. + experiment: Experiment associated with this analysis. Used to determine + the priority of the analysis based on the metric importance in the + optimization config. + folds: Number of subsamples to partition observations into. Use -1 for + leave-one-out cross validation. + untransform: Whether to untransform the model predictions before cross + validating. Generators are trained on transformed data, and candidate + generation is performed in the transformed space. Computing the model + quality metric based on the cross-validation results in the + untransformed space may not be representative of the model that + is actually used for candidate generation in case of non-invertible + transforms, e.g., Winsorize or LogY. While the model in the + transformed space may not be representative of the original data in + regions where outliers have been removed, we have found it to better + reflect the how good the model used for candidate generation actually + is. + """ + return self._construct_plot( + adapter=adapter, + metric_name=metric_name, + folds=folds, + untransform=untransform, + # trial_index argument is used with generation strategy since this is an + # adhoc plot call, this will be None. + trial_index=None, + experiment=experiment, + ) + def _construct_plot( + self, + adapter: Adapter, + metric_name: str, + folds: int, + untransform: bool, + trial_index: int | None, + experiment: Experiment | None = None, + ) -> PlotlyAnalysisCard: + """ + Args: + adapter: The adapter that will be assessed during cross validation. + metric_name: The name of the metric to plot. + folds: Number of subsamples to partition observations into. Use -1 for + leave-one-out cross validation. + untransform: Whether to untransform the model predictions before cross + validating. Generators are trained on transformed data, and candidate + generation is performed in the transformed space. Computing the model + quality metric based on the cross-validation results in the + untransformed space may not be representative of the model that + is actually used for candidate generation in case of non-invertible + transforms, e.g., Winsorize or LogY. While the model in the + transformed space may not be representative of the original data in + regions where outliers have been removed, we have found it to better + reflect the how good the model used for candidate generation actually + is. + trial_index: Optional trial index that the model from generation_strategy + was used to generate. We should therefore only have observations from + trials prior to this trial index in our plot. If this is not True, we + should error out. + experiment: Optional Experiment associated with this analysis. Used to set + the priority of the analysis based on the metric importance in the + optimization config. + """ + df = _prepare_data( + adapter=adapter, + metric_name=metric_name, + folds=folds, + untransform=untransform, + trial_index=trial_index, + ) + + fig = _prepare_plot(df=df) + k_folds_substring = f"{folds}-fold" if folds > 0 else "leave-one-out" # Nudge the priority if the metric is important to the experiment if ( experiment is not None @@ -129,18 +220,14 @@ def compute( def _prepare_data( - generation_strategy: GenerationStrategy, + adapter: Adapter, metric_name: str, folds: int, untransform: bool, trial_index: int | None, ) -> pd.DataFrame: - # If model is not fit already, fit it - if generation_strategy.model is None: - generation_strategy._fit_current_model(None) - cv_results = cross_validate( - model=none_throws(generation_strategy.model), + model=adapter, folds=folds, untransform=untransform, )