diff --git a/gui/src/app/StanSampler/StanModelWorker.ts b/gui/src/app/StanSampler/StanModelWorker.ts index d1f463e7..38bb240d 100644 --- a/gui/src/app/StanSampler/StanModelWorker.ts +++ b/gui/src/app/StanSampler/StanModelWorker.ts @@ -30,12 +30,13 @@ const parseProgress = (msg: string): Progress => { if (msg.startsWith("Iteration:")) { msg = "Chain [1] " + msg; } + msg = msg.replace(/\[|\]/g, ""); const parts = msg.split(/\s+/); - const chain = parseInt(parts[1].slice(1, -1)); + const chain = parseInt(parts[1]); const iteration = parseInt(parts[3]); const totalIterations = parseInt(parts[5]); - const percent = parseInt(parts[7].slice(0, -2)); - const warmup = parts[8] === "(Warmup)"; + const percent = parseInt(parts[6].slice(0, -1)); + const warmup = parts[7] === "(Warmup)"; return { chain, iteration, totalIterations, percent, warmup }; }; @@ -64,7 +65,7 @@ self.onmessage = (e) => { m.stanVersion(), ); self.postMessage({ purpose: Replies.ModelLoaded }); - }); + }, console.error); break; } case Requests.Sample: { diff --git a/gui/src/app/StanSampler/StanSampler.ts b/gui/src/app/StanSampler/StanSampler.ts index 9f33d5a4..eedaf3af 100644 --- a/gui/src/app/StanSampler/StanSampler.ts +++ b/gui/src/app/StanSampler/StanSampler.ts @@ -33,6 +33,7 @@ class StanSampler { } { const sampler = new StanSampler(compiledUrl); const cleanup = () => { + console.log("terminating model worker"); sampler.#worker && sampler.#worker.terminate(); sampler.#worker = undefined; }; @@ -81,38 +82,23 @@ class StanSampler { sample(data: any, samplingOpts: SamplingOpts) { const refresh = calculateReasonableRefreshRate(samplingOpts); const sampleConfig: Partial = { + ...samplingOpts, data, - num_chains: samplingOpts.num_chains, - num_warmup: samplingOpts.num_warmup, - num_samples: samplingOpts.num_samples, - init_radius: samplingOpts.init_radius, seed: samplingOpts.seed !== undefined ? samplingOpts.seed : null, refresh, }; if (!this.#worker) return; - if (this.#status === "") { - console.warn("Model not loaded yet"); - return; - } - if (sampleConfig.num_chains === undefined) { - console.warn("Number of chains not specified"); - return; - } if (this.#status === "sampling") { console.warn("Already sampling"); return; } - if (this.#status === "loading") { - console.warn("Model not loaded yet"); - return; - } this.#samplingOpts = samplingOpts; this.#draws = []; this.#paramNames = []; - this.#worker.postMessage({ purpose: Requests.Sample, sampleConfig }); this.#samplingStartTimeSec = Date.now() / 1000; this.#status = "sampling"; this.#onStatusChangedCallbacks.forEach((cb) => cb()); + this.#worker.postMessage({ purpose: Requests.Sample, sampleConfig }); } onProgress(callback: (progress: Progress) => void) { this.#onProgressCallbacks.push(callback); diff --git a/gui/test/app/StanSampler/MockStanModel.ts b/gui/test/app/StanSampler/MockStanModel.ts new file mode 100644 index 00000000..81a449cf --- /dev/null +++ b/gui/test/app/StanSampler/MockStanModel.ts @@ -0,0 +1,70 @@ +import { vi } from "vitest"; + +import type StanModel from "tinystan"; +import type { PrintCallback } from "tinystan"; +import { defaultSamplingOpts } from "../../../src/app/Project/ProjectDataModel"; + +import fakeURL from "./empty.ts?url"; +import erroringURL from "./fail.ts?url"; +import failSentinel from "./fail.ts"; + +export const mockCompiledMainJsUrl = fakeURL; +export const erroringCompiledMainJsUrl = erroringURL; + +const erroring_num_chains = 999; + +export const erroringSamplingOpts = { + ...defaultSamplingOpts, + num_chains: erroring_num_chains, +}; + +export const mockedParamNames = ["a", "b"]; +export const mockedDraws = [ + [1, 2], + [3, 4], +]; + +export const mockedProgress = { + chain: 1, + iteration: 123, + totalIterations: 1000, + percent: 45, + warmup: false, +}; + +const mockedProgressString = `Chain [${mockedProgress.chain}] \ +Iteration: ${mockedProgress.iteration} \ +/ ${mockedProgress.totalIterations} \ +[${mockedProgress.percent.toString().padStart(3)}%] \ +(${mockedProgress.warmup ? "Warmup" : "Sampling"})`; + +const mockedLoad = async ( + _create: any, + printCallback: PrintCallback | null, +) => { + await new Promise((resolve) => setTimeout(resolve, 50)); + + if (_create === failSentinel) { + return Promise.reject("error for testing in load!"); + } + + const model = { + stanVersion: vi.fn(() => "1.2.3"), + sample: vi.fn(({ num_chains }) => { + if (num_chains === erroring_num_chains) { + throw new Error("error for testing in sample!"); + } + + printCallback && printCallback(mockedProgressString); + + return { + paramNames: mockedParamNames, + draws: mockedDraws, + }; + }), + } as unknown as StanModel; + + return model; +}; + +export default mockedLoad; diff --git a/gui/test/app/StanSampler/empty.ts b/gui/test/app/StanSampler/empty.ts new file mode 100644 index 00000000..f05309dd --- /dev/null +++ b/gui/test/app/StanSampler/empty.ts @@ -0,0 +1 @@ +// intentionally empty, used to create a URL vitest can resolve diff --git a/gui/test/app/StanSampler/fail.ts b/gui/test/app/StanSampler/fail.ts new file mode 100644 index 00000000..138f0f3f --- /dev/null +++ b/gui/test/app/StanSampler/fail.ts @@ -0,0 +1,3 @@ +// used to create a URL vitest can resolve +const failSentinel = "fail"; +export default failSentinel; diff --git a/gui/test/app/StanSampler/useStanSampler.test.ts b/gui/test/app/StanSampler/useStanSampler.test.ts new file mode 100644 index 00000000..75495a54 --- /dev/null +++ b/gui/test/app/StanSampler/useStanSampler.test.ts @@ -0,0 +1,248 @@ +// @vitest-environment jsdom + +import { expect, test, describe, vi, afterEach, onTestFinished } from "vitest"; +import "@vitest/web-worker"; +import { renderHook, waitFor, act } from "@testing-library/react"; +import mockedLoad, { + mockCompiledMainJsUrl, + erroringCompiledMainJsUrl, + erroringSamplingOpts, + mockedDraws, + mockedParamNames, + mockedProgress, +} from "./MockStanModel"; + +import useStanSampler, { + useSamplerOutput, + useSamplerProgress, + useSamplerStatus, +} from "../../../src/app/StanSampler/useStanSampler"; +import { defaultSamplingOpts } from "../../../src/app/Project/ProjectDataModel"; +import type StanSampler from "../../../src/app/StanSampler/StanSampler"; + +const mockedStdout = vi + .spyOn(console, "log") + .mockImplementation(() => undefined); +const mockedStderr = vi + .spyOn(console, "error") + .mockImplementation(() => undefined); + +vi.mock("tinystan", async (importOriginal) => { + const mod = await importOriginal(); + mod.default.load = mockedLoad; + return mod; +}); + +afterEach(() => { + vi.clearAllMocks(); +}); + +const loadedSampler = async () => { + const ret = renderHook(() => useStanSampler(mockCompiledMainJsUrl)); + const status = renderHook(() => useSamplerStatus(ret.result.current.sampler)); + await waitFor(() => { + expect(status.result.current.status).toBe("loaded"); + }); + + onTestFinished(() => { + expect(ret.result.current.sampler?.status).toEqual( + status.result.current.status, + ); + + expect(mockedStdout).not.toHaveBeenCalledWith("terminating model worker"); + ret.unmount(); + expect(mockedStdout).toHaveBeenCalledWith("terminating model worker"); + }); + + return [ret, status] as const; +}; + +const rerenderableSampler = async () => { + const ret = renderHook((url: string | undefined) => useStanSampler(url), { + initialProps: undefined, + }); + const status = renderHook( + (sampler: StanSampler | undefined) => useSamplerStatus(sampler), + { + initialProps: ret.result.current.sampler, + }, + ); + + return [ret, status] as const; +}; + +describe("useStanSampler", () => { + test("empty URL should return undefined", () => { + const { result } = renderHook(() => useStanSampler(undefined)); + + expect(result.current.sampler).toBeUndefined(); + }); + + test("other URLs are nonempty", async () => { + const { result } = renderHook(() => useStanSampler(mockCompiledMainJsUrl)); + + expect(result.current.sampler).toBeDefined(); + }); + + describe("useSamplerStatus", () => { + test("loading changes status", async () => { + const [ + { result, rerender }, + { result: statusResult, rerender: rerenderStatus }, + ] = await rerenderableSampler(); + + expect(statusResult.current.status).toBe(""); + + rerender(mockCompiledMainJsUrl); + rerenderStatus(result.current.sampler); + + expect(statusResult.current.status).toBe("loading"); + + await waitFor(() => { + expect(statusResult.current.status).toBe("loaded"); + }); + expect(mockedStderr).not.toHaveBeenCalled(); + }); + + test("failing to load changes status", async () => { + const [ + { result, rerender }, + { result: statusResult, rerender: rerenderStatus }, + ] = await rerenderableSampler(); + + expect(statusResult.current.status).toBe(""); + + rerender(erroringCompiledMainJsUrl); + rerenderStatus(result.current.sampler); + + await waitFor(() => { + expect(statusResult.current.status).toBe("loading"); + }); + + await waitFor(() => { + expect(mockedStderr).toHaveBeenCalledWith("error for testing in load!"); + }); + + act(() => { + result.current.sampler?.sample({}, defaultSamplingOpts); + }); + + await waitFor(() => { + expect(statusResult.current.status).toBe("failed"); + expect(statusResult.current.errorMessage).toBe("Model not loaded yet!"); + }); + }); + + test("sampling changes status", async () => { + const [{ result }, { result: statusResult }] = await loadedSampler(); + + act(() => { + result.current.sampler?.sample({}, defaultSamplingOpts); + }); + + await waitFor(() => { + expect(statusResult.current.status).toBe("completed"); + expect(result.current.sampler?.paramNames).toEqual(["a", "b"]); + }); + expect(mockedStderr).not.toHaveBeenCalled(); + }); + + test("error during sampling changes status", async () => { + const [{ result }, { result: statusResult }] = await loadedSampler(); + + act(() => { + result.current.sampler?.sample({}, erroringSamplingOpts); + }); + + await waitFor(() => { + expect(statusResult.current.status).toBe("failed"); + expect(statusResult.current.errorMessage).toBe( + "Error: error for testing in sample!", + ); + }); + expect(mockedStderr).not.toHaveBeenCalled(); + }); + + // NOTE: Because vitest-web-worker does not actually run anything concurrently, this test will not work + // test("cancelling reloads", async () => { + // const [{ result }, { result: statusResult }] = await loadedSampler(); + // act(() => { + // result.current.sampler?.sample({}, defaultSamplingOpts); + // }); + // act(() => { + // result.current.sampler?.cancel(); + // }); + // await waitFor(() => { + // expect(statusResult.current.status).toBe("loaded"); + // }); + // expect(mockedStderr).not.toHaveBeenCalled(); + // }); + }); + + describe("useSamplerProgress", () => { + test("sampling changes status", async () => { + const [{ result }] = await loadedSampler(); + + const { result: progress } = renderHook(() => + useSamplerProgress(result.current.sampler), + ); + + expect(progress.current).toBeUndefined(); + + act(() => { + result.current.sampler?.sample({}, defaultSamplingOpts); + }); + + await waitFor(() => { + expect(progress.current).toEqual(mockedProgress); + }); + + expect(mockedStderr).not.toHaveBeenCalled(); + }); + }); + + describe("useSamplerOutput", () => { + test("undefined sampler returns undefined", () => { + const { result } = renderHook(() => useSamplerOutput(undefined)); + expect(result.current.draws).toBeUndefined(); + expect(result.current.paramNames).toBeUndefined(); + expect(result.current.numChains).toBeUndefined(); + expect(result.current.computeTimeSec).toBeUndefined(); + }); + + test("sampling changes output", async () => { + const [{ result }] = await loadedSampler(); + + const { result: output } = renderHook(() => + useSamplerOutput(result.current.sampler), + ); + + expect(output.current.draws).toBeUndefined(); + expect(output.current.paramNames).toBeUndefined(); + expect(output.current.numChains).toBeUndefined(); + expect(output.current.computeTimeSec).toBeUndefined(); + + act(() => { + result.current.sampler?.sample({}, defaultSamplingOpts); + }); + + await waitFor(() => { + expect(output.current.draws).toEqual(mockedDraws); + expect(output.current.paramNames).toEqual(mockedParamNames); + expect(output.current.numChains).toBe(defaultSamplingOpts.num_chains); + expect(output.current.computeTimeSec).toBeDefined(); + }); + + expect(result.current.sampler?.status).toBe("completed"); + + expect(result.current.sampler?.draws).toBe(output.current.draws); + expect(result.current.sampler?.paramNames).toBe( + output.current.paramNames, + ); + expect(result.current.sampler?.samplingOpts).toBe(defaultSamplingOpts); + expect(result.current.sampler?.computeTimeSec).toBe( + output.current.computeTimeSec, + ); + }); + }); +});