Skip to content

Commit

Permalink
Add importance props in PlotImportance so that importance can be calc…
Browse files Browse the repository at this point in the history
…ulated outside
  • Loading branch information
porink0424 committed Apr 10, 2024
1 parent b37fa18 commit 55fe37e
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 2 deletions.
12 changes: 11 additions & 1 deletion standalone_app/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions standalone_app/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"build:vscode": "webpack"
},
"devDependencies": {
"@optuna/types": "../tslib/types",
"@types/plotly.js": "^2.29.2",
"@types/react": "^18.2.64",
"@types/react-dom": "^18.2.21",
Expand Down
68 changes: 67 additions & 1 deletion standalone_app/src/components/StudyDetail.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import {
PlotIntermediateValues,
TrialTable,
} from "@optuna/react"
import * as Optuna from "@optuna/types"
import init, { wasm_fanova_calculate } from "optuna"
import React, { FC, useContext, useState, useEffect } from "react"
import { Link, useParams } from "react-router-dom"
import { StorageContext } from "./StorageProvider"
Expand All @@ -43,6 +45,68 @@ export const StudyDetail: FC<{
fetchStudy()
}, [storage, idxNumber])

const [importance, setImportance] = useState<Optuna.ParamImportance[][]>([])
const filterFunc = (trial: Optuna.Trial, objectiveId: number): boolean => {
if (trial.state !== "Complete" && trial.state !== "Pruned") {
return false
}
if (trial.values === undefined) {
return false
}
return (
trial.values.length > objectiveId &&
trial.values[objectiveId] !== Infinity &&
trial.values[objectiveId] !== -Infinity
)
}
// biome-ignore lint/correctness/useExhaustiveDependencies: <explanation>
useEffect(() => {
async function run_wasm() {
if (study === null) {
return
}

await init()

const x: Optuna.ParamImportance[][] = study.directions.map(
(_d, objectiveId) => {
const filteredTrials = study.trials.filter((t) =>
filterFunc(t, objectiveId)
)
if (filteredTrials.length === 0) {
return study.union_search_space.map((s) => {
return {
name: s.name,
importance: 0.5,
}
})
}

const features = study.intersection_search_space.map((s) =>
filteredTrials
.map(
(t) =>
t.params.find((p) => p.name === s.name) as Optuna.TrialParam
)
.map((p) => p.param_internal_value)
)
const values = filteredTrials.map(
(t) => t.values?.[objectiveId] as number
)
// TODO: handle errors thrown by wasm_fanova_calculate
const importance = wasm_fanova_calculate(features, values)
return study.intersection_search_space.map((s, i) => ({
name: s.name,
importance: importance[i],
}))
}
)
setImportance(x)
}

run_wasm()
}, [study])

return (
<>
<AppBar position="static">
Expand Down Expand Up @@ -116,7 +180,9 @@ export const StudyDetail: FC<{
<Grid2 xs={6}>
<Card sx={{ margin: theme.spacing(2) }}>
<CardContent>
{!!study && <PlotImportance study={study} />}
{!!study && (
<PlotImportance study={study} importance={importance} />
)}
</CardContent>
</Card>
</Grid2>
Expand Down

0 comments on commit 55fe37e

Please sign in to comment.