-
-
Notifications
You must be signed in to change notification settings - Fork 86
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #600 from c-bata/plotly-user-defined-graph-objects
Support user-defined plotly figures
- Loading branch information
Showing
12 changed files
with
274 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
from __future__ import annotations | ||
|
||
import math | ||
from typing import TYPE_CHECKING | ||
import uuid | ||
|
||
from optuna import Study | ||
|
||
|
||
if TYPE_CHECKING: | ||
from typing import Any | ||
|
||
from optuna.storages import BaseStorage | ||
import plotly.graph_objs as go | ||
|
||
|
||
SYSTEM_ATTR_PLOT_DATA = "dashboard:plot_data:" | ||
SYSTEM_ATTR_MAX_LENGTH = 2045 | ||
|
||
|
||
def save_plotly_graph_object( | ||
study: Study, figure: go.Figure, *, graph_object_id: str | None = None | ||
) -> str: | ||
"""Save the user-defined plotly's graph object to the study. | ||
Example: | ||
.. code-block:: python | ||
import optuna | ||
from optuna_dashboard import save_plotly_graph_object | ||
def objective(trial): | ||
x = trial.suggest_float("x", -100, 100) | ||
y = trial.suggest_categorical("y", [-1, 0, 1]) | ||
return x**2 + y | ||
study = optuna.create_study() | ||
study.optimize(objective, n_trials=100) | ||
figure = optuna.visualization.plot_optimization_history(study) | ||
save_plotly_graph_object(study, figure) | ||
Args: | ||
study: | ||
Target study object. | ||
plot_data: | ||
The plotly's graph object to save. | ||
graph_object_id: | ||
Unique identifier of the graph object. If specified, the graph object is overwritten. | ||
This must be a valid HTML id attribute value. | ||
Returns: | ||
The graph object ID. | ||
""" | ||
if graph_object_id is not None and not is_valid_graph_object_id(graph_object_id): | ||
raise ValueError("graph_object_id must be a valid HTML id attribute value.") | ||
|
||
storage = study._storage | ||
study_id = study._study_id | ||
|
||
graph_object_id = graph_object_id or str(uuid.uuid4()) | ||
key = SYSTEM_ATTR_PLOT_DATA + graph_object_id + ":" | ||
plot_data_json_str = figure.to_json() | ||
save_graph_object_json(storage, study_id, key, plot_data_json_str) | ||
return graph_object_id | ||
|
||
|
||
def save_graph_object_json( | ||
storage: BaseStorage, study_id: int, key_prefix: str, plot_data_json_str: str | ||
) -> None: | ||
plot_data_system_attrs = split_plot_data(plot_data_json_str, key_prefix) | ||
for k, v in plot_data_system_attrs.items(): | ||
storage.set_study_system_attr(study_id, k, v) | ||
|
||
# Clear previous graph object attributes | ||
study_system_attrs = storage.get_study_system_attrs(study_id) | ||
all_plot_data_system_attrs = [k for k in study_system_attrs if k.startswith(key_prefix)] | ||
if len(all_plot_data_system_attrs) > len(plot_data_system_attrs): | ||
for i in range(len(plot_data_system_attrs), len(all_plot_data_system_attrs)): | ||
storage.set_study_system_attr(study_id, f"{key_prefix}{i}", "") | ||
|
||
|
||
def list_graph_object_ids(system_attrs: dict[str, Any]) -> list[str]: | ||
titles = set() | ||
for key in system_attrs: | ||
if not key.startswith(SYSTEM_ATTR_PLOT_DATA): | ||
continue | ||
|
||
s = key.split(":", maxsplit=2) # e.g. ["dashboard", "plot_data", "Optimization History:1"] | ||
if len(s) != 3: | ||
continue | ||
# Please note that title may contain ":". | ||
title = s[2].rsplit(":", maxsplit=1)[0] | ||
titles.add(title) | ||
return list(titles) | ||
|
||
|
||
def get_plotly_graph_objects(system_attrs: dict[str, Any]) -> dict[str, str]: | ||
graph_objects = {} | ||
for title in list_graph_object_ids(system_attrs): | ||
key_prefix = SYSTEM_ATTR_PLOT_DATA + title + ":" | ||
plot_data_attrs = {k: v for k, v in system_attrs.items() if k.startswith(key_prefix)} | ||
graph_objects[title] = concat_plot_data(plot_data_attrs, key_prefix) | ||
return graph_objects | ||
|
||
|
||
def split_plot_data(plot_data_str: str, key_prefix: str) -> dict[str, str]: | ||
plot_data_len = len(plot_data_str) | ||
attrs = {} | ||
for i in range(math.ceil(plot_data_len / SYSTEM_ATTR_MAX_LENGTH)): | ||
start = i * SYSTEM_ATTR_MAX_LENGTH | ||
end = min((i + 1) * SYSTEM_ATTR_MAX_LENGTH, plot_data_len) | ||
attrs[f"{key_prefix}{i}"] = plot_data_str[start:end] | ||
return attrs | ||
|
||
|
||
def concat_plot_data(plot_data_attrs: dict[str, str], key_prefix: str) -> str: | ||
return "".join(plot_data_attrs[f"{key_prefix}{i}"] for i in range(len(plot_data_attrs))) | ||
|
||
|
||
def is_valid_graph_object_id(graph_object_id: str) -> bool: | ||
if len(graph_object_id) == 0: | ||
return False | ||
|
||
# Can only contain letters [A-Za-z], numbers [0-9], hyphens ("-"), underscores ("_"), | ||
# colons, and periods. | ||
if not all( | ||
"a" <= c <= "z" or "A" <= c <= "Z" or "0" <= c <= "9" or c in ("-", "_", ":", ".") | ||
for c in graph_object_id[1:] | ||
): | ||
return False | ||
# Unlike HTML id attribute, graph object id can begin with a letter [A-Za-z] | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import * as plotly from "plotly.js-dist-min" | ||
import React, { FC, useEffect } from "react" | ||
import { Box } from "@mui/material" | ||
|
||
export const UserDefinedPlot: FC<{ | ||
graphObject: PlotlyGraphObject | ||
}> = ({ graphObject }) => { | ||
const plotDomId = `user-defined-plot:${graphObject.id}` | ||
|
||
useEffect(() => { | ||
try { | ||
const parsed = JSON.parse(graphObject.graph_object) | ||
plotly.react(plotDomId, parsed.data, parsed.layout) | ||
} catch (e) { | ||
// Avoid to crash the whole page when given invalid grpah objects. | ||
console.error(e) | ||
} | ||
}, [graphObject]) | ||
|
||
return <Box id={plotDomId} sx={{ height: "450px" }} /> | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,6 +43,7 @@ docs = [ | |
|
||
test = [ | ||
"coverage", | ||
"plotly", | ||
"pytest", | ||
"moto[s3]", | ||
] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from __future__ import annotations | ||
|
||
import optuna | ||
from optuna_dashboard import _custom_plot_data as custom_plot_data | ||
from optuna_dashboard import save_plotly_graph_object | ||
import pytest | ||
|
||
|
||
def get_dummy_study() -> optuna.Study: | ||
def objective(trial: optuna.Trial) -> float: | ||
x = trial.suggest_float("x", -100, 100) | ||
y = trial.suggest_categorical("y", [-1, 0, 1]) | ||
return x**2 + y | ||
|
||
study = optuna.create_study() | ||
optuna.logging.set_verbosity(optuna.logging.ERROR) | ||
study.optimize(objective, n_trials=100) | ||
return study | ||
|
||
|
||
def test_save_plotly_graph_object() -> None: | ||
# Save history plot | ||
dummy_study = get_dummy_study() | ||
plot_data = optuna.visualization.plot_optimization_history(dummy_study) | ||
graph_object_id = save_plotly_graph_object(dummy_study, plot_data) | ||
|
||
study_system_attrs = dummy_study._storage.get_study_system_attrs(dummy_study._study_id) | ||
plot_data_dict = custom_plot_data.get_plotly_graph_objects(study_system_attrs) | ||
assert len(plot_data_dict) == 1 | ||
assert plot_data_dict[graph_object_id] == plot_data.to_json() | ||
|
||
# Save parallel coordinate plot | ||
plot_data = optuna.visualization.plot_parallel_coordinate(dummy_study) | ||
graph_object_id = save_plotly_graph_object(dummy_study, plot_data) | ||
|
||
study_system_attrs = dummy_study._storage.get_study_system_attrs(dummy_study._study_id) | ||
plot_data_dict = custom_plot_data.get_plotly_graph_objects(study_system_attrs) | ||
assert len(plot_data_dict) == 2 | ||
assert plot_data_dict[graph_object_id] == plot_data.to_json() | ||
|
||
|
||
def test_update_plotly_graph_object() -> None: | ||
# Save history plot | ||
dummy_study = get_dummy_study() | ||
plot_data = optuna.visualization.plot_optimization_history(dummy_study) | ||
graph_object_id = save_plotly_graph_object(dummy_study, plot_data) | ||
|
||
study_system_attrs = dummy_study._storage.get_study_system_attrs(dummy_study._study_id) | ||
plot_data_dict = custom_plot_data.get_plotly_graph_objects(study_system_attrs) | ||
assert len(plot_data_dict) == 1 | ||
assert plot_data_dict[graph_object_id] == plot_data.to_json() | ||
|
||
# Save parallel coordinate plot | ||
plot_data = optuna.visualization.plot_parallel_coordinate(dummy_study) | ||
graph_object_id = save_plotly_graph_object( | ||
dummy_study, plot_data, graph_object_id=graph_object_id | ||
) | ||
|
||
study_system_attrs = dummy_study._storage.get_study_system_attrs(dummy_study._study_id) | ||
plot_data_dict = custom_plot_data.get_plotly_graph_objects(study_system_attrs) | ||
assert len(plot_data_dict) == 1 | ||
assert plot_data_dict[graph_object_id] == plot_data.to_json() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"name", | ||
[ | ||
"0", | ||
"a", | ||
"a1-:_.", | ||
], | ||
) | ||
def test_is_valid_graph_object_id(name: str) -> None: | ||
assert custom_plot_data.is_valid_graph_object_id(name) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"name", | ||
[ | ||
"a,", | ||
"a b", | ||
"aあいうえお", | ||
], | ||
) | ||
def test_is_invalid_graph_object_id(name: str) -> None: | ||
assert not custom_plot_data.is_valid_graph_object_id(name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters