Skip to content

Commit

Permalink
Support constraints for tslib/storage
Browse files Browse the repository at this point in the history
  • Loading branch information
porink0424 committed May 10, 2024
1 parent 774efb7 commit 8c1d9f8
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
24 changes: 21 additions & 3 deletions tslib/storage/src/journal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ interface JournalOpSetTrialUserAttr extends JournalOpBase {
user_attr: { [key: string]: any } // eslint-disable-line @typescript-eslint/no-explicit-any
}

interface JournalOpSetTrialSystemAttr extends JournalOpBase {
trial_id: number
system_attr: {
constraints: number[]
}
}

const trialStateNumToTrialState = (state: number): Optuna.TrialState => {
switch (state) {
case 0:
Expand Down Expand Up @@ -224,7 +231,7 @@ class JournalStorage {
}
})

const userAtter = log.user_attrs
const userAttrs = log.user_attrs
? Object.entries(log.user_attrs).map(([key, value]) => {
return {
key: key,
Expand All @@ -249,7 +256,8 @@ class JournalStorage {
})(),
params: params,
intermediate_values: [],
user_attrs: userAtter,
user_attrs: userAttrs,
constraints: [],
datetime_start: log.datetime_start
? new Date(log.datetime_start)
: undefined,
Expand Down Expand Up @@ -341,6 +349,14 @@ class JournalStorage {
}
}
}

public applySetTrialSystemAttr(log: JournalOpSetTrialSystemAttr) {
const [thisStudy, thisTrial] = this.getStudyAndTrial(log.trial_id)
if (thisStudy === undefined || thisTrial === undefined) {
return
}
thisTrial.constraints = log.system_attr.constraints
}
}

const loadJournalStorage = (
Expand Down Expand Up @@ -434,7 +450,9 @@ const loadJournalStorage = (
)
break
case JournalOperation.SET_TRIAL_SYSTEM_ATTR:
// Unsupported
journalStorage.applySetTrialSystemAttr(
parsedLog as JournalOpSetTrialSystemAttr
)
break
}
}
Expand Down
20 changes: 20 additions & 0 deletions tslib/storage/src/sqlite.ts
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,11 @@ const getStudy = (
}
}

const systemAttrs = getTrialSystemAttributes(db, trial.trial_id)
if (systemAttrs !== undefined) {
trial.constraints = systemAttrs.constraints
}

const params = getTrialParams(db, trial.trial_id)
const param_names = new Set<string>()
for (const param of params) {
Expand Down Expand Up @@ -222,6 +227,7 @@ const getTrials = (
),
params: [], // Set this column later
user_attrs: [], // Set this column later
constraints: [],
datetime_start: vals[3],
datetime_complete: vals[4],
}
Expand Down Expand Up @@ -404,6 +410,20 @@ const getTrialUserAttributes = (
return attrs
}

const getTrialSystemAttributes = (db: SQLite3DB, trialId: number) => {
let attrs: { constraints: number[] } | undefined
db.exec({
sql: `SELECT key, value_json FROM trial_system_attributes WHERE trial_id = ${trialId} AND key = 'constraints'`,
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
callback: (vals: any[]) => {
attrs = {
constraints: JSON.parse(vals[1]),
}
},
})
return attrs
}

const getTrialIntermediateValues = (
db: SQLite3DB,
trialId: number,
Expand Down

0 comments on commit 8c1d9f8

Please sign in to comment.