Skip to content

Commit

Permalink
Expose compute_adhoc method for cross validation plot (#3428)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3428

There was a push last half from folks on the platform team to revamp our analysis code, and it is much crispier than before(yay!). As part of this migration effort we want to exclusively use these new plots moving forward

However, previously these plots did not expose a way to call them directly for adhoc computation. I will be adding this for the plots that need it, and starting with CV.

Reviewed By: sdaulton, mpolson64

Differential Revision: D69754828

fbshipit-source-id: bddb206c07a6c641a3e57d0ceccd894710c8ab98
  • Loading branch information
mgarrard authored and facebook-github-bot committed Feb 28, 2025
1 parent 5f11c70 commit 26080b7
Showing 1 changed file with 97 additions and 10 deletions.
107 changes: 97 additions & 10 deletions ax/analysis/plotly/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ax.core.experiment import Experiment
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.modelbridge.base import Adapter
from ax.modelbridge.cross_validation import cross_validate
from plotly import express as px, graph_objects as go
from pyre_extensions import none_throws
Expand Down Expand Up @@ -89,18 +90,108 @@ def compute(
metric_name = self.metric_name or select_metric(
experiment=generation_strategy.experiment
)
# If model is not fit already, fit it
if generation_strategy.model is None:
generation_strategy._fit_current_model(None)

df = _prepare_data(
generation_strategy=generation_strategy,
return self._construct_plot(
adapter=none_throws(generation_strategy.model),
metric_name=metric_name,
folds=self.folds,
untransform=self.untransform,
trial_index=self.trial_index,
experiment=experiment,
)
fig = _prepare_plot(df=df)

k_folds_substring = f"{self.folds}-fold" if self.folds > 0 else "leave-one-out"
def _compute_adhoc(
self,
adapter: Adapter,
metric_name: str,
experiment: Experiment | None = None,
folds: int = -1,
untransform: bool = True,
) -> PlotlyAnalysisCard:
"""
Helper method to expose adhoc cross validation plotting. This overrides the
default assumption that the adapter from the generation strategy should be
used. Only for advanced users in a notebook setting.
Args:
adapter: The adapter that will be assessed during cross validation.
metric_name: The name of the metric to plot. Must be provided for adhoc
plotting.
experiment: Experiment associated with this analysis. Used to determine
the priority of the analysis based on the metric importance in the
optimization config.
folds: Number of subsamples to partition observations into. Use -1 for
leave-one-out cross validation.
untransform: Whether to untransform the model predictions before cross
validating. Generators are trained on transformed data, and candidate
generation is performed in the transformed space. Computing the model
quality metric based on the cross-validation results in the
untransformed space may not be representative of the model that
is actually used for candidate generation in case of non-invertible
transforms, e.g., Winsorize or LogY. While the model in the
transformed space may not be representative of the original data in
regions where outliers have been removed, we have found it to better
reflect the how good the model used for candidate generation actually
is.
"""
return self._construct_plot(
adapter=adapter,
metric_name=metric_name,
folds=folds,
untransform=untransform,
# trial_index argument is used with generation strategy since this is an
# adhoc plot call, this will be None.
trial_index=None,
experiment=experiment,
)

def _construct_plot(
self,
adapter: Adapter,
metric_name: str,
folds: int,
untransform: bool,
trial_index: int | None,
experiment: Experiment | None = None,
) -> PlotlyAnalysisCard:
"""
Args:
adapter: The adapter that will be assessed during cross validation.
metric_name: The name of the metric to plot.
folds: Number of subsamples to partition observations into. Use -1 for
leave-one-out cross validation.
untransform: Whether to untransform the model predictions before cross
validating. Generators are trained on transformed data, and candidate
generation is performed in the transformed space. Computing the model
quality metric based on the cross-validation results in the
untransformed space may not be representative of the model that
is actually used for candidate generation in case of non-invertible
transforms, e.g., Winsorize or LogY. While the model in the
transformed space may not be representative of the original data in
regions where outliers have been removed, we have found it to better
reflect the how good the model used for candidate generation actually
is.
trial_index: Optional trial index that the model from generation_strategy
was used to generate. We should therefore only have observations from
trials prior to this trial index in our plot. If this is not True, we
should error out.
experiment: Optional Experiment associated with this analysis. Used to set
the priority of the analysis based on the metric importance in the
optimization config.
"""
df = _prepare_data(
adapter=adapter,
metric_name=metric_name,
folds=folds,
untransform=untransform,
trial_index=trial_index,
)

fig = _prepare_plot(df=df)
k_folds_substring = f"{folds}-fold" if folds > 0 else "leave-one-out"
# Nudge the priority if the metric is important to the experiment
if (
experiment is not None
Expand Down Expand Up @@ -129,18 +220,14 @@ def compute(


def _prepare_data(
generation_strategy: GenerationStrategy,
adapter: Adapter,
metric_name: str,
folds: int,
untransform: bool,
trial_index: int | None,
) -> pd.DataFrame:
# If model is not fit already, fit it
if generation_strategy.model is None:
generation_strategy._fit_current_model(None)

cv_results = cross_validate(
model=none_throws(generation_strategy.model),
model=adapter,
folds=folds,
untransform=untransform,
)
Expand Down

0 comments on commit 26080b7

Please sign in to comment.