Skip to content

Commit

Permalink
Improve readability of CV plot (#3426)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3426

In the previous diff we exposed an adhoc comput method for CV, but the delta between the previous plot and our new plot is a degragation in UX. This diff fixes that by:
- tightening the autozoom
- making the points more transparent so they are more visible individually
- improving the hover
- adding x and y axis titles

Thanks for pointing some of these out Sam!

Differential Revision: D70195847
  • Loading branch information
mgarrard authored and facebook-github-bot committed Feb 27, 2025
1 parent 51e0aa7 commit fdee121
Showing 1 changed file with 47 additions and 17 deletions.
64 changes: 47 additions & 17 deletions ax/analysis/plotly/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.modelbridge.base import Adapter
from ax.modelbridge.cross_validation import cross_validate
from plotly import express as px, graph_objects as go
from plotly import graph_objects as go
from pyre_extensions import none_throws


Expand Down Expand Up @@ -272,33 +272,60 @@ def _prepare_data(


def _prepare_plot(df: pd.DataFrame) -> go.Figure:
fig = px.scatter(
df,
x="observed",
y="predicted",
error_x="observed_sem",
error_y="predicted_sem",
hover_data=["arm_name", "observed", "predicted"],
# Create a scatter plot using Plotly Graph Objects for more control
fig = go.Figure()
# Add scatter trace with error bars
fig.add_trace(
go.Scatter(
x=df["observed"],
y=df["predicted"],
mode="markers",
marker={
"color": "rgba(0, 0, 255, 0.3)", # partially transparent blue
},
error_x={
"type": "data",
"array": df["observed_sem"],
"visible": True,
"color": "rgba(0, 0, 255, 0.2)", # partially transparent blue
},
error_y={
"type": "data",
"array": df["predicted_sem"],
"visible": True,
"color": "rgba(0, 0, 255, 0.2)", # partially transparent blue
},
text=df["arm_name"],
hovertemplate=(
# "<b>Details</b><br>"
"<b>Arm Name: %{text}</b><br>"
+ "Predicted: %{y}<br>"
+ "Observed: %{x}<br>"
+ "<extra></extra>" # Removes the trace name from the hover
),
hoverlabel={
"bgcolor": "rgba(0, 0, 255, 0.2)", # partially transparent blue
"font": {"color": "black"},
},
)
)

# Add a gray dashed line at y=x starting and ending just outside of the region of
# interest for reference. A well fit model should have points clustered around this
# line.
# interest for reference. A well fit model should have points clustered around
# this line.
lower_bound = (
min(
(df["observed"] - df["observed_sem"].fillna(0)).min(),
(df["predicted"] - df["predicted_sem"].fillna(0)).min(),
)
* 0.99
* 0.999 # tight autozoom
)
upper_bound = (
max(
(df["observed"] + df["observed_sem"].fillna(0)).max(),
(df["predicted"] + df["predicted_sem"].fillna(0)).max(),
)
* 1.01
* 1.001 # tight autozoom
)

fig.add_shape(
type="line",
x0=lower_bound,
Expand All @@ -308,11 +335,14 @@ def _prepare_plot(df: pd.DataFrame) -> go.Figure:
line={"color": "gray", "dash": "dot"},
)

# Force plot to display as a square
fig.update_xaxes(range=[lower_bound, upper_bound], constrain="domain")
# Update axes with tight autozoom that remains square
fig.update_xaxes(
range=[lower_bound, upper_bound], constrain="domain", title="Actual Outcome"
)
fig.update_yaxes(
range=[lower_bound, upper_bound],
scaleanchor="x",
scaleratio=1,
title="Predicted Outcome",
)

return fig

0 comments on commit fdee121

Please sign in to comment.