diff --git a/optuna_dashboard/ts/components/GraphIntermediateValues.tsx b/optuna_dashboard/ts/components/GraphIntermediateValues.tsx index 3552f64a..eae2c43e 100644 --- a/optuna_dashboard/ts/components/GraphIntermediateValues.tsx +++ b/optuna_dashboard/ts/components/GraphIntermediateValues.tsx @@ -1,102 +1,22 @@ -import { Box, Card, CardContent, Typography, useTheme } from "@mui/material" -import * as plotly from "plotly.js-dist-min" -import React, { FC, useEffect } from "react" +import { Card, CardContent } from "@mui/material" +import { PlotIntermediateValues } from "@optuna/react" +import React, { FC } from "react" import { Trial } from "ts/types/optuna" -import { usePlotlyColorTheme } from "../state" - -const plotDomId = "graph-intermediate-values" export const GraphIntermediateValues: FC<{ trials: Trial[] includePruned: boolean logScale: boolean }> = ({ trials, includePruned, logScale }) => { - const theme = useTheme() - const colorTheme = usePlotlyColorTheme(theme.palette.mode) - - useEffect(() => { - plotIntermediateValue(trials, colorTheme, false, !includePruned, logScale) - }, [trials, colorTheme, includePruned, logScale]) - return ( - - Intermediate values - - + ) } - -const plotIntermediateValue = ( - trials: Trial[], - colorTheme: Partial, - filterCompleteTrial: boolean, - filterPrunedTrial: boolean, - logScale: boolean -) => { - if (document.getElementById(plotDomId) === null) { - return - } - - const layout: Partial = { - margin: { - l: 50, - t: 0, - r: 50, - b: 0, - }, - yaxis: { - title: "Objective Value", - type: logScale ? "log" : "linear", - }, - xaxis: { - title: "Step", - type: "linear", - }, - uirevision: "true", - template: colorTheme, - legend: { - x: 1.0, - y: 0.95, - }, - } - if (trials.length === 0) { - plotly.react(plotDomId, [], layout) - return - } - - const filteredTrials = trials.filter( - (t) => - (!filterCompleteTrial && t.state === "Complete") || - (!filterPrunedTrial && - t.state === "Pruned" && - t.values && - t.values.length > 0) || - t.state === "Running" - ) - const plotData: Partial[] = filteredTrials.map((trial) => { - const isFeasible = trial.constraints.every((c) => c <= 0) - return { - x: trial.intermediate_values.map((iv) => iv.step), - y: trial.intermediate_values.map((iv) => iv.value), - marker: { maxdisplayed: 10 }, - mode: "lines+markers", - type: "scatter", - name: `trial #${trial.number} ${ - trial.state === "Running" - ? "(running)" - : !isFeasible - ? "(infeasible)" - : "" - }`, - ...(!isFeasible && { line: { color: "#CCCCCC" } }), - } - }) - plotly.react(plotDomId, plotData, layout) -} diff --git a/tslib/react/src/components/PlotIntermediateValues.tsx b/tslib/react/src/components/PlotIntermediateValues.tsx index d91c4506..ea339be7 100644 --- a/tslib/react/src/components/PlotIntermediateValues.tsx +++ b/tslib/react/src/components/PlotIntermediateValues.tsx @@ -64,6 +64,10 @@ const plotIntermediateValue = ( }, uirevision: "true", template: mode === "dark" ? plotlyDarkTemplate : {}, + legend: { + x: 1.0, + y: 0.95, + }, } if (trials.length === 0) { plotly.react(plotDomId, [], layout) @@ -79,23 +83,23 @@ const plotIntermediateValue = ( t.values.length > 0) || t.state === "Running" ) + const plotData: Partial[] = filteredTrials.map((trial) => { - const values = trial.intermediate_values.filter( - (iv) => - iv.value !== Infinity && - iv.value !== -Infinity && - !Number.isNaN(iv.value) - ) + const isFeasible = trial.constraints.every((c) => c <= 0) return { - x: values.map((iv) => iv.step), - y: values.map((iv) => iv.value), + x: trial.intermediate_values.map((iv) => iv.step), + y: trial.intermediate_values.map((iv) => iv.value), marker: { maxdisplayed: 10 }, mode: "lines+markers", type: "scatter", - name: - trial.state !== "Running" - ? `trial #${trial.number}` - : `trial #${trial.number} (running)`, + name: `trial #${trial.number} ${ + trial.state === "Running" + ? "(running)" + : !isFeasible + ? "(infeasible)" + : "" + }`, + ...(!isFeasible && { line: { color: "#CCCCCC" } }), } }) plotly.react(plotDomId, plotData, layout)