Skip to content

Commit

Permalink
Merge pull request #861 from porink0424/feat/add-journal-storage-tests
Browse files Browse the repository at this point in the history
Improve tests for `JournalStorage`
  • Loading branch information
c-bata authored Apr 10, 2024
2 parents 7e9db28 + 2d0bcf0 commit d0536a1
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 26 deletions.
69 changes: 49 additions & 20 deletions tslib/storage/src/journal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,17 @@ class JournalStorage {
}
}

const loadJournalStorage = (arrayBuffer: ArrayBuffer): Optuna.Study[] => {
const loadJournalStorage = (
arrayBuffer: ArrayBuffer
): {
studies: Optuna.Study[]
errors: { log: string; message: string }[]
} => {
const decoder = new TextDecoder("utf-8")
const logs = decoder.decode(arrayBuffer).split("\n")

const journalStorage = new JournalStorage()
const errors: { log: string; message: string }[] = []

for (const log of logs) {
if (log === "") {
Expand All @@ -343,28 +349,41 @@ const loadJournalStorage = (arrayBuffer: ArrayBuffer): Optuna.Study[] => {

const parsedLog: JournalOpBase = (() => {
try {
return JSON.parse(log)
try {
return JSON.parse(log)
} catch (error) {
if (error instanceof SyntaxError) {
let escapedLog: string = log.replace(/NaN/g, '"***nan***"')
escapedLog = escapedLog.replace(/-Infinity/g, '"***-inf***"')
escapedLog = escapedLog.replace(/Infinity/g, '"***inf***"')
return JSON.parse(escapedLog, (_key, value) => {
switch (value) {
case "***nan***":
return NaN
case "***-inf***":
return -Infinity
case "***inf***":
return Infinity
default:
return value
}
})
}
throw error
}
} catch (error) {
if (error instanceof SyntaxError) {
let escapedLog: string = log.replace(/NaN/g, '"***nan***"')
escapedLog = escapedLog.replace(/-Infinity/g, '"***-inf***"')
escapedLog = escapedLog.replace(/Infinity/g, '"***inf***"')
return JSON.parse(escapedLog, (_key, value) => {
switch (value) {
case "***nan***":
return NaN
case "***-inf***":
return -Infinity
case "***inf***":
return Infinity
default:
return value
}
})
if (error instanceof Error) {
errors.push({ log: log, message: error.message })
} else {
errors.push({ log: log, message: "Unknown error" })
}
}
})()

if (parsedLog === undefined) {
continue
}

switch (parsedLog.op_code) {
case JournalOperation.CREATE_STUDY:
journalStorage.applyCreateStudy(parsedLog as JournalOpCreateStudy)
Expand Down Expand Up @@ -405,18 +424,28 @@ const loadJournalStorage = (arrayBuffer: ArrayBuffer): Optuna.Study[] => {
}
}

return journalStorage.getStudies()
return {
studies: journalStorage.getStudies(),
errors,
}
}

export class JournalFileStorage implements OptunaStorage {
studies: Optuna.Study[]
errors: { log: string; message: string }[]
constructor(arrayBuffer: ArrayBuffer) {
this.studies = loadJournalStorage(arrayBuffer)
const { studies: studiesFromStorage, errors: errorsFromStorage } =
loadJournalStorage(arrayBuffer)
this.studies = studiesFromStorage
this.errors = errorsFromStorage
}
getStudies = async (): Promise<Optuna.StudySummary[]> => {
return this.studies
}
getStudy = async (idx: number): Promise<Optuna.Study | null> => {
return this.studies[idx] || null
}
getErrors = (): { log: string; message: string }[] => {
return this.errors
}
}
45 changes: 45 additions & 0 deletions tslib/storage/test/generate_assets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import math
import os.path
import random
import shutil

import optuna
Expand Down Expand Up @@ -48,6 +50,34 @@ def objective_single_dynamic(trial: optuna.Trial) -> float:

study.optimize(objective_single_dynamic, n_trials=50)

# Single objective study with 'inf', '-inf' value
study = optuna.create_study(study_name="single-inf", storage=storage)
print(f"Generating {study.study_name} for {type(storage).__name__}...")

def objective_single_inf(trial: optuna.Trial) -> float:
x = trial.suggest_float("x", -10, 10)
if trial.number % 3 == 0:
return float("inf")
elif trial.number % 3 == 1:
return float("-inf")
else:
return x**2

study.optimize(objective_single_inf, n_trials=50)

# Single objective with reported nan value
study = optuna.create_study(study_name="single-nan-report", storage=storage)
print(f"Generating {study.study_name} for {type(storage).__name__}...")

def objective_single_nan_report(trial: optuna.Trial) -> float:
x1 = trial.suggest_float("x1", 0, 10)
x2 = trial.suggest_float("x2", 0, 10)
trial.report(0.5, step=0)
trial.report(math.nan, step=1)
return (x1 - 2) ** 2 + (x2 - 5) ** 2

study.optimize(objective_single_nan_report, n_trials=100)


if __name__ == "__main__":
remove_assets()
Expand All @@ -56,3 +86,18 @@ def objective_single_dynamic(trial: optuna.Trial) -> float:
RDBStorage("sqlite:///" + os.path.join(BASE_DIR, "db.sqlite3")),
]:
create_optuna_storage(storage)

# Make a file including a broken line to the random position to test error handling
shutil.copyfile(
os.path.join(BASE_DIR, "journal.log"), os.path.join(BASE_DIR, "journal-broken.log")
)
broken_line = (
'{"op_code": ..., "worker_id": "0000", "study_id": 0,'
'"datetime_start": "2024-04-01T12:00:00.000000"}\n'
)
with open(os.path.join(BASE_DIR, "journal-broken.log"), "r+") as f:
lines = f.readlines()
lines.insert(random.randint(0, len(lines)), broken_line)
f.truncate(0)
f.seek(0, os.SEEK_SET)
f.writelines(lines)
52 changes: 46 additions & 6 deletions tslib/storage/test/journal.test.mjs
Original file line number Diff line number Diff line change
@@ -1,19 +1,59 @@
import assert from "node:assert"
import { openAsBlob } from "node:fs"
import path from "node:path"
import test from "node:test"
import { describe, it } from "node:test"

import * as mut from "../pkg/journal.js"

const n_studies = 2

test("Test Journal File Storage", async () => {
describe("Test Journal File Storage", async () => {
const blob = await openAsBlob(
path.resolve(".", "test", "asset", "journal.log")
)
const buf = await blob.arrayBuffer()
const storage = new mut.JournalFileStorage(buf)
const studies = await storage.getStudies()
const studySummaries = await storage.getStudies()
const studies = await Promise.all(
studySummaries.map((_summary, index) => storage.getStudy(index))
)

it("Check the study including Infinities", () => {
const study = studies.find((s) => s.study_name === "single-inf")
study.trials.forEach((trial, index) => {
if (index % 3 === 0) {
assert.strictEqual(trial.values[0], Infinity)
} else if (index % 3 === 1) {
assert.strictEqual(trial.values[0], -Infinity)
}
})
})

it("Check the study including NaNs", () => {
const study = studies.find((s) => s.study_name === "single-nan-report")
for (const trial of study.trials) {
assert.strictEqual(
trial.intermediate_values.find((v) => v.step === 1).value,
NaN
)
}
})

it("Check the parsing errors", async () => {
const blob = await openAsBlob(
path.resolve(".", "test", "asset", "journal-broken.log")
)
const buf = await blob.arrayBuffer()
const storage = new mut.JournalFileStorage(buf)
const errors = storage.getErrors()

assert.strictEqual(errors.length, 1)
assert.strictEqual(
errors[0].message,
`Unexpected token '.', ..."op_code": ..., "work"... is not valid JSON`
)
})

assert.strictEqual(studies.length, n_studies)
it("Check the number of studies", () => {
const N_STUDIES = 4
assert.strictEqual(studies.length, N_STUDIES)
})
})

0 comments on commit d0536a1

Please sign in to comment.