Skip to content

Commit

Permalink
Add by_wallclock_time option to ProgressionPlot (#3383)
Browse files Browse the repository at this point in the history
Summary:

Allows a user to specify that they would like to plot with wallclock time on the x axis. Starts a t=0 as the earliest start time

Reviewed By: saitcakmak

Differential Revision: D69800761
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Feb 19, 2025
1 parent 2a0d7a5 commit 1aca64d
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 15 deletions.
123 changes: 110 additions & 13 deletions ax/analysis/plotly/progression.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
# LICENSE file in the root directory of this source tree.

# pyre-strict
from logging import Logger

import numpy as np
import plotly.express as px
from ax.analysis.analysis import AnalysisCardLevel

Expand All @@ -14,7 +17,11 @@
from ax.core.trial_status import TrialStatus
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.utils.common.logger import get_logger
from plotly import graph_objects as go
from pyre_extensions import assert_is_instance

logger: Logger = get_logger(__name__)


class ProgressionPlot(PlotlyAnalysis):
Expand All @@ -28,17 +35,24 @@ class ProgressionPlot(PlotlyAnalysis):
- arm_name: The name of the arm
- METRIC_NAME: The observed mean of the metric specified
- progression: The progression at which the metric was observed
- wallclock_time: The wallclock time at which the metric was observed, in
seconds and starting at 0 from the first trial's start time.
"""

def __init__(self, metric_name: str | None = None) -> None:
def __init__(
self, metric_name: str | None = None, by_wallclock_time: bool = False
) -> None:
"""
Args:
metric_name: The name of the metric to plot. If not specified the objective
will be used. Note that the metric cannot be inferred for
multi-objective or scalarized-objective experiments.
wallclock_time: If True, plot the relative wallclock time instead of the
progression on the x-axis.
"""

self._metric_name = metric_name
self._by_wallclock_time = by_wallclock_time

def compute(
self,
Expand Down Expand Up @@ -74,28 +88,111 @@ def compute(
for trial in experiment.trials_by_status[TrialStatus.EARLY_STOPPED]
]
),
["mean", map_key],
["trial_index", "mean", map_key],
].rename(columns={map_key: "progression", "mean": metric_name})

# Add the wallclock time column
wallclock_series = _calculate_wallclock_timeseries(
experiment=experiment, metric_name=metric_name
)

df["wallclock_time"] = df.apply(
lambda row: wallclock_series[row["trial_index"]][row["progression"]],
axis=1,
)
if len(terminal_points) > 0:
terminal_points["wallclock_time"] = terminal_points.apply(
lambda row: wallclock_series[row["trial_index"]][row["progression"]],
axis=1,
)

# Plot the progression lines with one curve for each arm.
fig = px.line(df, x="progression", y=metric_name, color="arm_name")
if self._by_wallclock_time:
x_axis_name = "wallclock_time"
else:
x_axis_name = "progression"

fig = px.line(df, x=x_axis_name, y=metric_name, color="arm_name")

# Add a marker for each terminal point on early stopped trials.
fig.add_trace(
go.Scatter(
x=terminal_points["progression"],
y=terminal_points[metric_name],
mode="markers",
showlegend=False,
line_color="red",
hoverinfo="none",
if len(terminal_points) > 0:
fig.add_trace(
go.Scatter(
x=terminal_points[x_axis_name],
y=terminal_points[metric_name],
mode="markers",
showlegend=False,
line_color="red",
hoverinfo="none",
)
)
)

return self._create_plotly_analysis_card(
title=f"{metric_name} by progression",
title=f"{metric_name} by {x_axis_name.replace('_', ' ')}",
subtitle="Observe how the metric changes as each trial progresses",
level=AnalysisCardLevel.MID,
df=df,
fig=fig,
)


def _calculate_wallclock_timeseries(
experiment: Experiment,
metric_name: str,
) -> dict[int, dict[float, float]]:
"""
Calculate a mapping from each trial index and progression to the time since the
first trial started, in seconds. Assume that the first trial started at t=0, and
that progressions are linearly spaced between the start and completion times of
each trial.
If a trial does not have either a start or completion time the wallclock time
cannot be calculated and the value will be nan (which will not be plotted).
Returns:
trial_index => (progression => timestamp)
"""
# Find the earliest start time.
start_time = min(
trial.time_run_started.timestamp()
for trial in experiment.trials.values()
if trial.time_run_started is not None
)
# Calculate all start and completion times relative to the earliest start time.
# Give nan for trials that don't have a start or completion time.
relative_timestamps = {
idx: (
trial.time_run_started.timestamp() - start_time
if trial.time_run_started is not None
else np.nan,
trial.time_completed.timestamp() - start_time
if trial.time_completed is not None
else np.nan,
)
for idx, trial in experiment.trials.items()
}

data = assert_is_instance(experiment.lookup_data(), MapData)
df = data.map_df[data.map_df["metric_name"] == metric_name]
map_key = data.map_key_infos[0].key

return {
trial_index: dict(
zip(
df[df["trial_index"] == trial_index][map_key].to_numpy(),
# Map the progressions to linspace if the start and completion times
# are both available, otherwise map to nans
np.linspace(
relative_timestamps[trial_index][0],
relative_timestamps[trial_index][1],
len(df[df["trial_index"] == trial_index]),
)
if (
relative_timestamps[trial_index][0] is not None
and relative_timestamps[trial_index][1] is not None
)
else np.full(len(df[df["trial_index"] == trial_index]), np.nan),
)
)
for trial_index in experiment.trials.keys()
}
25 changes: 23 additions & 2 deletions ax/analysis/plotly/tests/test_progression.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@

# pyre-strict

import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.plotly.progression import ProgressionPlot
from ax.analysis.plotly.progression import (
_calculate_wallclock_timeseries,
ProgressionPlot,
)
from ax.exceptions.core import UserInputError
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_test_map_data_experiment
Expand All @@ -33,8 +37,25 @@ def test_compute(self) -> None:
)
self.assertEqual(card.level, AnalysisCardLevel.MID)
self.assertEqual(
{*card.df.columns}, {"trial_index", "arm_name", "branin_map", "progression"}
{*card.df.columns},
{"trial_index", "arm_name", "branin_map", "progression", "wallclock_time"},
)

self.assertIsNotNone(card.blob)
self.assertEqual(card.blob_annotation, "plotly")

def test_calculate_wallclock_timeseries(self) -> None:
experiment = get_test_map_data_experiment(
num_trials=2, num_fetches=5, num_complete=2
)
wallclock_timeseries = _calculate_wallclock_timeseries(
experiment=experiment, metric_name="branin_map"
)

self.assertEqual(len(wallclock_timeseries), 2)
self.assertTrue(
all(len(timeseries) == 5 for timeseries in wallclock_timeseries.values())
)

for timeseries in wallclock_timeseries.values():
self.assertTrue(pd.Series(timeseries).is_monotonic_increasing)

0 comments on commit 1aca64d

Please sign in to comment.