Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove unnecessary data validation for trial index #3433

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 170 additions & 42 deletions ax/analysis/plotly/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from ax.core.experiment import Experiment
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.modelbridge.base import Adapter
from ax.modelbridge.cross_validation import cross_validate
from plotly import express as px, graph_objects as go
from plotly import graph_objects as go
from pyre_extensions import none_throws


Expand Down Expand Up @@ -68,9 +69,8 @@ def __init__(
reflect the how good the model used for candidate generation actually
is.
trial_index: Optional trial index that the model from generation_strategy
was used to generate. We should therefore only have observations from
trials prior to this trial index in our plot. If this is not True, we
should error out.
was used to generate. Useful card attribute to filter to only specific
trial.
"""

self.metric_name = metric_name
Expand All @@ -89,18 +89,125 @@ def compute(
metric_name = self.metric_name or select_metric(
experiment=generation_strategy.experiment
)
# If model is not fit already, fit it
if generation_strategy.model is None:
generation_strategy._fit_current_model(None)

df = _prepare_data(
generation_strategy=generation_strategy,
return self._construct_plot(
adapter=none_throws(generation_strategy.model),
metric_name=metric_name,
folds=self.folds,
untransform=self.untransform,
trial_index=self.trial_index,
experiment=experiment,
)
fig = _prepare_plot(df=df)

k_folds_substring = f"{self.folds}-fold" if self.folds > 0 else "leave-one-out"
def _compute_adhoc(
self,
adapter: Adapter,
metric_names: list[str],
experiment: Experiment | None = None,
folds: int = -1,
untransform: bool = True,
metric_name_mapping: dict[str, str] | None = None,
) -> list[PlotlyAnalysisCard]:
"""
Helper method to expose adhoc cross validation plotting. This overrides the
default assumption that the adapter from the generation strategy should be
used. Only for advanced users in a notebook setting.

Args:
adapter: The adapter that will be assessed during cross validation.
metric_names: A list of all the metrics to perform cross validation on.
Must be provided for adhoc plotting.
experiment: Experiment associated with this analysis. Used to determine
the priority of the analysis based on the metric importance in the
optimization config.
folds: Number of subsamples to partition observations into. Use -1 for
leave-one-out cross validation.
untransform: Whether to untransform the model predictions before cross
validating. Generators are trained on transformed data, and candidate
generation is performed in the transformed space. Computing the model
quality metric based on the cross-validation results in the
untransformed space may not be representative of the model that
is actually used for candidate generation in case of non-invertible
transforms, e.g., Winsorize or LogY. While the model in the
transformed space may not be representative of the original data in
regions where outliers have been removed, we have found it to better
reflect the how good the model used for candidate generation actually
is.
metric_name_mapping: Optional mapping from default metric names to more
readable metric names.
"""
plots = []
for metric_name in metric_names:
# replace metric name with human readable name if mapping is provided
refined_metric_name = (
metric_name_mapping.get(metric_name, metric_name)
if metric_name_mapping
else metric_name
)
plots.append(
self._construct_plot(
adapter=adapter,
metric_name=metric_name,
folds=folds,
untransform=untransform,
# trial_index argument is used with generation strategy since this
# is an adhoc plot call, this will be None.
trial_index=None,
experiment=experiment,
refined_metric_name=refined_metric_name,
)
)
return plots

def _construct_plot(
self,
adapter: Adapter,
metric_name: str,
folds: int,
untransform: bool,
trial_index: int | None,
experiment: Experiment | None = None,
refined_metric_name: str | None = None,
) -> PlotlyAnalysisCard:
"""
Args:
adapter: The adapter that will be assessed during cross validation.
metric_name: The name of the metric to plot.
folds: Number of subsamples to partition observations into. Use -1 for
leave-one-out cross validation.
untransform: Whether to untransform the model predictions before cross
validating. Generators are trained on transformed data, and candidate
generation is performed in the transformed space. Computing the model
quality metric based on the cross-validation results in the
untransformed space may not be representative of the model that
is actually used for candidate generation in case of non-invertible
transforms, e.g., Winsorize or LogY. While the model in the
transformed space may not be representative of the original data in
regions where outliers have been removed, we have found it to better
reflect the how good the model used for candidate generation actually
is.
trial_index: Optional trial index that the model from generation_strategy
was used to generate. Useful card attribute to filter to only specific
trial.
experiment: Optional Experiment associated with this analysis. Used to set
the priority of the analysis based on the metric importance in the
optimization config.
metric_name_mapping: Optional mapping from default metric names to more
readable metric names.
"""
df = _prepare_data(
adapter=adapter,
metric_name=metric_name,
folds=folds,
untransform=untransform,
trial_index=trial_index,
)

fig = _prepare_plot(df=df)
k_folds_substring = f"{folds}-fold" if folds > 0 else "leave-one-out"
# Nudge the priority if the metric is important to the experiment
if (
experiment is not None
Expand All @@ -118,8 +225,11 @@ def compute(
else:
nudge = 0

# If a human readable metric name is provided, use it in the title
metric_title = refined_metric_name if refined_metric_name else metric_name

return self._create_plotly_analysis_card(
title=f"Cross Validation for {metric_name}",
title=f"Cross Validation for {metric_title}",
subtitle=f"Out-of-sample predictions using {k_folds_substring} CV",
level=AnalysisCardLevel.LOW.value + nudge,
df=df,
Expand All @@ -129,34 +239,20 @@ def compute(


def _prepare_data(
generation_strategy: GenerationStrategy,
adapter: Adapter,
metric_name: str,
folds: int,
untransform: bool,
trial_index: int | None,
) -> pd.DataFrame:
# If model is not fit already, fit it
if generation_strategy.model is None:
generation_strategy._fit_current_model(None)

cv_results = cross_validate(
model=none_throws(generation_strategy.model),
model=adapter,
folds=folds,
untransform=untransform,
)

records = []
for observed, predicted in cv_results:
if trial_index is not None:
if (
observed.features.trial_index is not None
and observed.features.trial_index >= trial_index
):
raise UserInputError(
"CrossValidationPlot was specified to be for the generation of "
f"trial {trial_index}, but has observations from trial "
f"{observed.features.trial_index}."
)
# Find the index of the metric in observed and predicted
observed_i = next(
(
Expand Down Expand Up @@ -184,34 +280,63 @@ def _prepare_data(
return pd.DataFrame.from_records(records)


def _prepare_plot(df: pd.DataFrame) -> go.Figure:
fig = px.scatter(
df,
x="observed",
y="predicted",
error_x="observed_sem",
error_y="predicted_sem",
hover_data=["arm_name", "observed", "predicted"],
def _prepare_plot(
df: pd.DataFrame,
) -> go.Figure:
# Create a scatter plot using Plotly Graph Objects for more control
fig = go.Figure()
# Add scatter trace with error bars
fig.add_trace(
go.Scatter(
x=df["observed"],
y=df["predicted"],
mode="markers",
marker={
"color": "rgba(0, 0, 255, 0.3)", # partially transparent blue
},
error_x={
"type": "data",
"array": df["observed_sem"],
"visible": True,
"color": "rgba(0, 0, 255, 0.2)", # partially transparent blue
},
error_y={
"type": "data",
"array": df["predicted_sem"],
"visible": True,
"color": "rgba(0, 0, 255, 0.2)", # partially transparent blue
},
text=df["arm_name"],
hovertemplate=(
# "<b>Details</b><br>"
"<b>Arm Name: %{text}</b><br>"
+ "Predicted: %{y}<br>"
+ "Observed: %{x}<br>"
+ "<extra></extra>" # Removes the trace name from the hover
),
hoverlabel={
"bgcolor": "rgba(0, 0, 255, 0.2)", # partially transparent blue
"font": {"color": "black"},
},
)
)

# Add a gray dashed line at y=x starting and ending just outside of the region of
# interest for reference. A well fit model should have points clustered around this
# line.
# interest for reference. A well fit model should have points clustered around
# this line.
lower_bound = (
min(
(df["observed"] - df["observed_sem"].fillna(0)).min(),
(df["predicted"] - df["predicted_sem"].fillna(0)).min(),
)
* 0.99
* 0.999 # tight autozoom
)
upper_bound = (
max(
(df["observed"] + df["observed_sem"].fillna(0)).max(),
(df["predicted"] + df["predicted_sem"].fillna(0)).max(),
)
* 1.01
* 1.001 # tight autozoom
)

fig.add_shape(
type="line",
x0=lower_bound,
Expand All @@ -221,11 +346,14 @@ def _prepare_plot(df: pd.DataFrame) -> go.Figure:
line={"color": "gray", "dash": "dot"},
)

# Force plot to display as a square
fig.update_xaxes(range=[lower_bound, upper_bound], constrain="domain")
# Update axes with tight autozoom that remains square
fig.update_xaxes(
range=[lower_bound, upper_bound], constrain="domain", title="Actual Outcome"
)
fig.update_yaxes(
range=[lower_bound, upper_bound],
scaleanchor="x",
scaleratio=1,
title="Predicted Outcome",
)

return fig
26 changes: 18 additions & 8 deletions ax/analysis/plotly/tests/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ax.analysis.plotly.cross_validation import CrossValidationPlot
from ax.core.trial import Trial
from ax.exceptions.core import UserInputError
from ax.modelbridge.registry import Generators
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.utils.common.testutils import TestCase
from ax.utils.testing.mock import mock_botorch_optimize
Expand Down Expand Up @@ -77,14 +78,6 @@ def test_compute(self) -> None:
card.df["arm_name"].unique(),
)

def test_it_can_only_contain_observation_prior_to_the_trial_index(self) -> None:
analysis = CrossValidationPlot(metric_name="bar", trial_index=7)
with self.assertRaisesRegex(
UserInputError,
"CrossValidationPlot was specified to be for the generation of trial 7",
):
analysis.compute(generation_strategy=self.client.generation_strategy)

def test_it_can_specify_trial_index_correctly(self) -> None:
analysis = CrossValidationPlot(metric_name="bar", trial_index=9)
card = analysis.compute(generation_strategy=self.client.generation_strategy)
Expand All @@ -98,3 +91,20 @@ def test_it_can_specify_trial_index_correctly(self) -> None:
arm_name,
card.df["arm_name"].unique(),
)

@mock_botorch_optimize
def test_compute_adhoc(self) -> None:
metrics = ["bar"]
metric_mapping = {"bar": "spunky"}
data = self.client.experiment.lookup_data()
adapter = Generators.BOTORCH_MODULAR(
experiment=self.client.experiment, data=data
)
analysis = CrossValidationPlot()._compute_adhoc(
adapter=adapter, metric_names=metrics, metric_name_mapping=metric_mapping
)
self.assertEqual(len(analysis), 1)
card = analysis[0]
self.assertEqual(card.name, "CrossValidationPlot")
# validate that the metric name replacement occured
self.assertEqual(card.title, "Cross Validation for spunky")
Loading