Skip to content

Commit

Permalink
Merge pull request #123 from flatironinstitute/use-stan-sampler-tests
Browse files Browse the repository at this point in the history
Tests for useStanSampler hooks
  • Loading branch information
jsoules authored Jul 11, 2024
2 parents 9e543a3 + 878490d commit 7d46efa
Show file tree
Hide file tree
Showing 6 changed files with 330 additions and 21 deletions.
9 changes: 5 additions & 4 deletions gui/src/app/StanSampler/StanModelWorker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ const parseProgress = (msg: string): Progress => {
if (msg.startsWith("Iteration:")) {
msg = "Chain [1] " + msg;
}
msg = msg.replace(/\[|\]/g, "");
const parts = msg.split(/\s+/);
const chain = parseInt(parts[1].slice(1, -1));
const chain = parseInt(parts[1]);
const iteration = parseInt(parts[3]);
const totalIterations = parseInt(parts[5]);
const percent = parseInt(parts[7].slice(0, -2));
const warmup = parts[8] === "(Warmup)";
const percent = parseInt(parts[6].slice(0, -1));
const warmup = parts[7] === "(Warmup)";
return { chain, iteration, totalIterations, percent, warmup };
};

Expand Down Expand Up @@ -64,7 +65,7 @@ self.onmessage = (e) => {
m.stanVersion(),
);
self.postMessage({ purpose: Replies.ModelLoaded });
});
}, console.error);
break;
}
case Requests.Sample: {
Expand Down
20 changes: 3 additions & 17 deletions gui/src/app/StanSampler/StanSampler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class StanSampler {
} {
const sampler = new StanSampler(compiledUrl);
const cleanup = () => {
console.log("terminating model worker");
sampler.#worker && sampler.#worker.terminate();
sampler.#worker = undefined;
};
Expand Down Expand Up @@ -81,38 +82,23 @@ class StanSampler {
sample(data: any, samplingOpts: SamplingOpts) {
const refresh = calculateReasonableRefreshRate(samplingOpts);
const sampleConfig: Partial<SamplerParams> = {
...samplingOpts,
data,
num_chains: samplingOpts.num_chains,
num_warmup: samplingOpts.num_warmup,
num_samples: samplingOpts.num_samples,
init_radius: samplingOpts.init_radius,
seed: samplingOpts.seed !== undefined ? samplingOpts.seed : null,
refresh,
};
if (!this.#worker) return;
if (this.#status === "") {
console.warn("Model not loaded yet");
return;
}
if (sampleConfig.num_chains === undefined) {
console.warn("Number of chains not specified");
return;
}
if (this.#status === "sampling") {
console.warn("Already sampling");
return;
}
if (this.#status === "loading") {
console.warn("Model not loaded yet");
return;
}
this.#samplingOpts = samplingOpts;
this.#draws = [];
this.#paramNames = [];
this.#worker.postMessage({ purpose: Requests.Sample, sampleConfig });
this.#samplingStartTimeSec = Date.now() / 1000;
this.#status = "sampling";
this.#onStatusChangedCallbacks.forEach((cb) => cb());
this.#worker.postMessage({ purpose: Requests.Sample, sampleConfig });
}
onProgress(callback: (progress: Progress) => void) {
this.#onProgressCallbacks.push(callback);
Expand Down
70 changes: 70 additions & 0 deletions gui/test/app/StanSampler/MockStanModel.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import { vi } from "vitest";

import type StanModel from "tinystan";
import type { PrintCallback } from "tinystan";
import { defaultSamplingOpts } from "../../../src/app/Project/ProjectDataModel";

import fakeURL from "./empty.ts?url";
import erroringURL from "./fail.ts?url";
import failSentinel from "./fail.ts";

export const mockCompiledMainJsUrl = fakeURL;
export const erroringCompiledMainJsUrl = erroringURL;

const erroring_num_chains = 999;

export const erroringSamplingOpts = {
...defaultSamplingOpts,
num_chains: erroring_num_chains,
};

export const mockedParamNames = ["a", "b"];
export const mockedDraws = [
[1, 2],
[3, 4],
];

export const mockedProgress = {
chain: 1,
iteration: 123,
totalIterations: 1000,
percent: 45,
warmup: false,
};

const mockedProgressString = `Chain [${mockedProgress.chain}] \
Iteration: ${mockedProgress.iteration} \
/ ${mockedProgress.totalIterations} \
[${mockedProgress.percent.toString().padStart(3)}%] \
(${mockedProgress.warmup ? "Warmup" : "Sampling"})`;

const mockedLoad = async (
_create: any,
printCallback: PrintCallback | null,
) => {
await new Promise((resolve) => setTimeout(resolve, 50));

if (_create === failSentinel) {
return Promise.reject("error for testing in load!");
}

const model = {
stanVersion: vi.fn(() => "1.2.3"),
sample: vi.fn(({ num_chains }) => {
if (num_chains === erroring_num_chains) {
throw new Error("error for testing in sample!");
}

printCallback && printCallback(mockedProgressString);

return {
paramNames: mockedParamNames,
draws: mockedDraws,
};
}),
} as unknown as StanModel;

return model;
};

export default mockedLoad;
1 change: 1 addition & 0 deletions gui/test/app/StanSampler/empty.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
// intentionally empty, used to create a URL vitest can resolve
3 changes: 3 additions & 0 deletions gui/test/app/StanSampler/fail.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
// used to create a URL vitest can resolve
const failSentinel = "fail";
export default failSentinel;
Loading

0 comments on commit 7d46efa

Please sign in to comment.