Skip to content

Commit

Permalink
Merge pull request #892 from porink0424/fix/move-searchSpace-to-tslib
Browse files Browse the repository at this point in the history
Add `distribution` prop to `SearchSpaceItem` type and move `searchSpace` to `tslib/react`
  • Loading branch information
c-bata authored Jun 26, 2024
2 parents d9a62fa + 29c4e73 commit 0c733bc
Show file tree
Hide file tree
Showing 12 changed files with 8,045 additions and 7,942 deletions.
7 changes: 5 additions & 2 deletions optuna_dashboard/ts/components/GraphContour.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@ import {
useTheme,
} from "@mui/material"
import blue from "@mui/material/colors/blue"
import { GraphContainer, useGraphComponentState } from "@optuna/react"
import {
GraphContainer,
useGraphComponentState,
useMergedUnionSearchSpace,
} from "@optuna/react"
import * as plotly from "plotly.js-dist-min"
import React, { FC, useEffect, useMemo, useState } from "react"
import { SearchSpaceItem, StudyDetail, Trial } from "ts/types/optuna"
import { PlotType } from "../apiClient"
import { getAxisInfo } from "../graphUtil"
import { usePlot } from "../hooks/usePlot"
import { useMergedUnionSearchSpace } from "../searchSpace"
import { usePlotlyColorTheme } from "../state"
import { useBackendRender } from "../state"

Expand Down
7 changes: 5 additions & 2 deletions optuna_dashboard/ts/components/GraphParallelCoordinate.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ import {
Typography,
useTheme,
} from "@mui/material"
import { GraphContainer, useGraphComponentState } from "@optuna/react"
import {
GraphContainer,
useGraphComponentState,
useMergedUnionSearchSpace,
} from "@optuna/react"
import {
Target,
useFilteredTrials,
Expand All @@ -19,7 +23,6 @@ import React, { FC, ReactNode, useEffect, useState } from "react"
import { SearchSpaceItem, StudyDetail } from "ts/types/optuna"
import { PlotType } from "../apiClient"
import { usePlot } from "../hooks/usePlot"
import { useMergedUnionSearchSpace } from "../searchSpace"
import { usePlotlyColorTheme } from "../state"
import { useBackendRender } from "../state"

Expand Down
7 changes: 5 additions & 2 deletions optuna_dashboard/ts/components/GraphRank.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@ import {
Typography,
useTheme,
} from "@mui/material"
import { GraphContainer, useGraphComponentState } from "@optuna/react"
import {
GraphContainer,
useGraphComponentState,
useMergedUnionSearchSpace,
} from "@optuna/react"
import * as plotly from "plotly.js-dist-min"
import React, { FC, useEffect, useState } from "react"
import { SearchSpaceItem, StudyDetail, Trial } from "ts/types/optuna"
import { PlotType } from "../apiClient"
import { getAxisInfo, makeHovertext } from "../graphUtil"
import { usePlot } from "../hooks/usePlot"
import { useMergedUnionSearchSpace } from "../searchSpace"
import { useBackendRender, usePlotlyColorTheme } from "../state"

const plotDomId = "graph-rank"
Expand Down
7 changes: 5 additions & 2 deletions optuna_dashboard/ts/components/GraphSlice.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ import {
Typography,
useTheme,
} from "@mui/material"
import { GraphContainer, useGraphComponentState } from "@optuna/react"
import {
GraphContainer,
useGraphComponentState,
useMergedUnionSearchSpace,
} from "@optuna/react"
import {
Target,
useFilteredTrials,
Expand All @@ -22,7 +26,6 @@ import React, { FC, useEffect, useState } from "react"
import { SearchSpaceItem, StudyDetail } from "ts/types/optuna"
import { PlotType } from "../apiClient"
import { usePlot } from "../hooks/usePlot"
import { useMergedUnionSearchSpace } from "../searchSpace"
import { useBackendRender, usePlotlyColorTheme } from "../state"

const plotDomId = "graph-slice"
Expand Down
1 change: 1 addition & 0 deletions tslib/react/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ export {
useObjectiveAndUserAttrTargets,
useObjectiveAndUserAttrTargetsFromStudies,
} from "./utils/trialFilter"
export { useMergedUnionSearchSpace } from "./utils/searchSpace"
export type { GraphComponentState } from "./types"
Original file line number Diff line number Diff line change
@@ -1,31 +1,30 @@
import * as Optuna from "@optuna/types"
import { useMemo } from "react"
import { SearchSpaceItem } from "./types/optuna"

export const mergeUnionSearchSpace = (
unionSearchSpace: SearchSpaceItem[]
): SearchSpaceItem[] => {
const mergeUnionSearchSpace = (
unionSearchSpace: Optuna.SearchSpaceItem[]
): Optuna.SearchSpaceItem[] => {
const knownElements = new Map<string, Optuna.Distribution>()
unionSearchSpace.forEach((s) => {
for (const s of unionSearchSpace) {
const d = knownElements.get(s.name)
if (d === undefined) {
knownElements.set(s.name, s.distribution)
return
continue
}
if (
d.type === "CategoricalDistribution" ||
s.distribution.type === "CategoricalDistribution"
) {
// CategoricalDistribution.choices will never be changed
return
continue
}
const updated: Optuna.Distribution = {
...d,
low: Math.min(d.low, s.distribution.low),
high: Math.max(d.high, s.distribution.high),
}
knownElements.set(s.name, updated)
})
}
return Array.from(knownElements.keys())
.sort((a, b) => (a > b ? 1 : a < b ? -1 : 0))
.map((name) => ({
Expand All @@ -35,8 +34,8 @@ export const mergeUnionSearchSpace = (
}

export const useMergedUnionSearchSpace = (
unionSearchSpaces?: SearchSpaceItem[]
): SearchSpaceItem[] =>
unionSearchSpaces?: Optuna.SearchSpaceItem[]
): Optuna.SearchSpaceItem[] =>
useMemo(() => {
return mergeUnionSearchSpace(unionSearchSpaces || [])
}, [unionSearchSpaces])
76 changes: 59 additions & 17 deletions tslib/storage/src/journal.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,36 @@
import * as Optuna from "@optuna/types"
import { OptunaStorage } from "./storage"

// TODO(porink0424): Refactor to common function with sqlite.ts (current workaround duplicates code due to missing file extensions in tsc build output).
const isDistributionEqual = (
a: Optuna.Distribution,
b: Optuna.Distribution
) => {
if (a.type !== b.type) {
return false
}

if (a.type === "IntDistribution" || a.type === "FloatDistribution") {
if (b.type !== "IntDistribution" && b.type !== "FloatDistribution") {
throw new Error("Invalid distribution type")
}
return (
a.low === b.low &&
a.high === b.high &&
a.step === b.step &&
a.log === b.log
)
}
if (a.type === "CategoricalDistribution") {
if (b.type !== "CategoricalDistribution") {
throw new Error("Invalid distribution type")
}
return JSON.stringify(a.choices) === JSON.stringify(b.choices)
}

throw new Error("Invalid distribution type")
}

// JournalStorage
enum JournalOperation {
CREATE_STUDY = 0,
Expand Down Expand Up @@ -137,22 +167,42 @@ class JournalStorage {
public getStudies(): Optuna.Study[] {
for (const study of this.studies) {
const unionUserAttrs: Set<string> = new Set()
const unionSearchSpace: Set<string> = new Set()
let intersectionSearchSpace: string[] = []
const unionSearchSpace: Optuna.SearchSpaceItem[] = []
let intersectionSearchSpace: Optuna.SearchSpaceItem[] = []

study.trials.forEach((trial, index) => {
for (const userAttr of trial.user_attrs) {
unionUserAttrs.add(userAttr.key)
}
for (const param of trial.params) {
unionSearchSpace.add(param.name)
if (
!unionSearchSpace.some(
(item) =>
item.name === param.name &&
isDistributionEqual(item.distribution, param.distribution)
)
) {
unionSearchSpace.push({
name: param.name,
distribution: param.distribution,
})
}
}
if (index === 0) {
intersectionSearchSpace = Array.from(unionSearchSpace)
intersectionSearchSpace = [...unionSearchSpace]
} else {
intersectionSearchSpace = intersectionSearchSpace.filter((name) => {
return trial.params.some((param) => param.name === name)
})
intersectionSearchSpace = intersectionSearchSpace.filter(
(searchSpaceItem) => {
return trial.params.some(
(param) =>
param.name === searchSpaceItem.name &&
isDistributionEqual(
param.distribution,
searchSpaceItem.distribution
)
)
}
)
}
})
study.union_user_attrs = Array.from(unionUserAttrs).map((key) => {
Expand All @@ -161,16 +211,8 @@ class JournalStorage {
sortable: false,
}
})
study.union_search_space = Array.from(unionSearchSpace).map((name) => {
return {
name: name,
}
})
study.intersection_search_space = intersectionSearchSpace.map((name) => {
return {
name: name,
}
})
study.union_search_space = unionSearchSpace
study.intersection_search_space = intersectionSearchSpace
}

return this.studies
Expand Down
61 changes: 47 additions & 14 deletions tslib/storage/src/sqlite.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,36 @@ import * as Optuna from "@optuna/types"
import sqlite3InitModule from "@sqlite.org/sqlite-wasm"
import { OptunaStorage } from "./storage"

// TODO(porink0424): Refactor to common function with journal.ts (current workaround duplicates code due to missing file extensions in tsc build output).
const isDistributionEqual = (
a: Optuna.Distribution,
b: Optuna.Distribution
) => {
if (a.type !== b.type) {
return false
}

if (a.type === "IntDistribution" || a.type === "FloatDistribution") {
if (b.type !== "IntDistribution" && b.type !== "FloatDistribution") {
throw new Error("Invalid distribution type")
}
return (
a.low === b.low &&
a.high === b.high &&
a.step === b.step &&
a.log === b.log
)
}
if (a.type === "CategoricalDistribution") {
if (b.type !== "CategoricalDistribution") {
throw new Error("Invalid distribution type")
}
return JSON.stringify(a.choices) === JSON.stringify(b.choices)
}

throw new Error("Invalid distribution type")
}

type SQLite3DB = {
exec(options: {
sql: string
Expand Down Expand Up @@ -149,7 +179,7 @@ const getStudy = (
study.metric_names = studySystemAttrs.metric_names
}

let intersection_search_space: Set<Optuna.SearchSpaceItem> = new Set()
let intersectionSearchSpace: Optuna.SearchSpaceItem[] = []
study.trials = getTrials(db, summary.id, schemaVersion)
for (const trial of study.trials) {
const userAttrs = getTrialUserAttributes(db, trial.trial_id)
Expand All @@ -165,31 +195,34 @@ const getStudy = (
}

const params = getTrialParams(db, trial.trial_id)
const param_names = new Set<string>()
for (const param of params) {
param_names.add(param.name)
if (
study.union_search_space.findIndex((s) => s.name === param.name) === -1
) {
study.union_search_space.push({ name: param.name })
study.union_search_space.push({
name: param.name,
distribution: param.distribution,
})
}
}
if (intersection_search_space.size === 0) {
// biome-ignore lint/complexity/noForEach: <explanation>
param_names.forEach((s) => {
intersection_search_space.add({ name: s })
})
if (intersectionSearchSpace.length === 0) {
intersectionSearchSpace = params.map((param) => ({
name: param.name,
distribution: param.distribution,
}))
} else {
intersection_search_space = new Set(
Array.from(intersection_search_space).filter((s) =>
param_names.has(s.name)
intersectionSearchSpace = intersectionSearchSpace.filter((item) => {
return params.some(
(param) =>
item.name === param.name &&
isDistributionEqual(item.distribution, param.distribution)
)
)
})
}
trial.params = params
trial.user_attrs = userAttrs
}
study.intersection_search_space = Array.from(intersection_search_space)
study.intersection_search_space = intersectionSearchSpace
return study
}

Expand Down
4 changes: 2 additions & 2 deletions tslib/storage/test/generate_assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def objective_single(trial: optuna.Trial) -> float:
def objective_single_dynamic(trial: optuna.Trial) -> float:
category = trial.suggest_categorical("category", ["foo", "bar"])
if category == "foo":
return (trial.suggest_float("x1", 0, 10) - 2) ** 2
return (trial.suggest_float("x", 0, 10) - 2) ** 2
else:
return -((trial.suggest_float("x2", -10, 0) + 5) ** 2)
return -((trial.suggest_float("x", -10, 0) + 5) ** 2)

study.optimize(objective_single_dynamic, n_trials=50)

Expand Down
Loading

0 comments on commit 0c733bc

Please sign in to comment.