Skip to content

Commit

Permalink
Add by_wallclock_time option to ProgressionPlot (#3383)
Browse files Browse the repository at this point in the history
Summary:

Allows a user to specify that they would like to plot with wallclock time on the x axis. Starts a t=0 as the earliest start time

Differential Revision: D69800761
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Feb 18, 2025
1 parent ec07e1c commit dbb9660
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 14 deletions.
130 changes: 117 additions & 13 deletions ax/analysis/plotly/progression.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
# LICENSE file in the root directory of this source tree.

# pyre-strict
from logging import Logger

import numpy as np
import plotly.express as px
from ax.analysis.analysis import AnalysisCardLevel

Expand All @@ -14,7 +17,11 @@
from ax.core.trial_status import TrialStatus
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.utils.common.logger import get_logger
from plotly import graph_objects as go
from pyre_extensions import assert_is_instance, none_throws

logger: Logger = get_logger(__name__)


class ProgressionPlot(PlotlyAnalysis):
Expand All @@ -28,17 +35,24 @@ class ProgressionPlot(PlotlyAnalysis):
- arm_name: The name of the arm
- METRIC_NAME: The observed mean of the metric specified
- progression: The progression at which the metric was observed
- wallclock_time: The wallclock time at which the metric was observed, in
seconds and starting at 0 from the first trial's start time.
"""

def __init__(self, metric_name: str | None = None) -> None:
def __init__(
self, metric_name: str | None = None, by_wallclock_time: bool = False
) -> None:
"""
Args:
metric_name: The name of the metric to plot. If not specified the objective
will be used. Note that the metric cannot be inferred for
multi-objective or scalarized-objective experiments.
wallclock_time: If True, plot the relative wallclock time instead of the
progression on the x-axis.
"""

self._metric_name = metric_name
self._by_wallclock_time = by_wallclock_time

def compute(
self,
Expand Down Expand Up @@ -74,28 +88,118 @@ def compute(
for trial in experiment.trials_by_status[TrialStatus.EARLY_STOPPED]
]
),
["mean", map_key],
["trial_index", "mean", map_key],
].rename(columns={map_key: "progression", "mean": metric_name})

# Add the wallclock time column
try:
wallclock_series = _calculate_wallclock_timeseries(experiment=experiment)

df["wallclock_time"] = df.apply(
lambda row: wallclock_series[row["trial_index"]][row["progression"]],
axis=1,
)
if len(terminal_points) > 0:
terminal_points["wallclock_time"] = terminal_points.apply(
lambda row: wallclock_series[row["trial_index"]][
row["progression"]
],
axis=1,
)

# This can happen if a trial's start time or completion time is None
except AssertionError:
df["wallclock_time"] = np.nan
if len(terminal_points) > 0:
terminal_points["wallclock_time"] = np.nan

# Plot the progression lines with one curve for each arm.
fig = px.line(df, x="progression", y=metric_name, color="arm_name")
if self._by_wallclock_time:
if df["wallclock_time"].isnull().any():
logger.warning(
"Some trials do not have a start time or completion time. Falling "
"back to plotting progression on the x-axis."
)
x_axis_name = "progression"
else:
x_axis_name = "wallclock_time"
else:
x_axis_name = "progression"

fig = px.line(df, x=x_axis_name, y=metric_name, color="arm_name")

# Add a marker for each terminal point on early stopped trials.
fig.add_trace(
go.Scatter(
x=terminal_points["progression"],
y=terminal_points[metric_name],
mode="markers",
showlegend=False,
line_color="red",
hoverinfo="none",
if len(terminal_points) > 0:
fig.add_trace(
go.Scatter(
x=terminal_points[x_axis_name],
y=terminal_points[metric_name],
mode="markers",
showlegend=False,
line_color="red",
hoverinfo="none",
)
)
)

return self._create_plotly_analysis_card(
title=f"{metric_name} by progression",
title=f"{metric_name} by {x_axis_name.replace('_', ' ')}",
subtitle="Observe how the metric changes as each trial progresses",
level=AnalysisCardLevel.MID,
df=df,
fig=fig,
)


def _calculate_wallclock_timeseries(
experiment: Experiment,
) -> dict[int, dict[float, float]]:
"""
Calculate a mapping from each trial index and progression to the time since the
first trial started, in seconds. Assume that the first trial started at t=0, and
that progressions are linearly spaced between the start and completion times of
each trial.
Returns:
trial_index => (progression => timestamp)
"""

# Collect the start and completion times of each trial.
# trial_index => (time_run_started, time_completed)
timestamps = {
idx: (
none_throws(trial.time_run_started).timestamp(),
none_throws(trial.time_completed).timestamp(),
)
for idx, trial in experiment.trials.items()
}

# Find the earliest start time.
start_time = min(
time_run_started for time_run_started, _time_completed in timestamps.values()
)
# Calculate all start and completion times relative to the earliest start time.
relative_timestamps = {
idx: (
none_throws(trial.time_run_started).timestamp() - start_time,
none_throws(trial.time_completed).timestamp() - start_time,
)
for idx, trial in experiment.trials.items()
}

data = assert_is_instance(experiment.lookup_data(), MapData)
df = data.map_df
map_key = data.map_key_infos[0].key

return {
trial_index: dict(
zip(
df[df["trial_index"] == trial_index][map_key].to_numpy(),
np.linspace(
relative_timestamps[trial_index][0],
relative_timestamps[trial_index][1],
len(df[df["trial_index"] == trial_index]),
),
)
)
for trial_index in experiment.trials.keys()
}
3 changes: 2 additions & 1 deletion ax/analysis/plotly/tests/test_progression.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def test_compute(self) -> None:
)
self.assertEqual(card.level, AnalysisCardLevel.MID)
self.assertEqual(
{*card.df.columns}, {"trial_index", "arm_name", "branin_map", "progression"}
{*card.df.columns},
{"trial_index", "arm_name", "branin_map", "progression", "wallclock_time"},
)

self.assertIsNotNone(card.blob)
Expand Down

0 comments on commit dbb9660

Please sign in to comment.