Skip to content

Commit

Permalink
Merge pull request #134 from flatironinstitute/refactor/useStanSample…
Browse files Browse the repository at this point in the history
…r-reducer

Simplify useStanSampler and related hooks
  • Loading branch information
WardBrian authored Jul 18, 2024
2 parents f4822e3 + 0f6ff12 commit 6c38d9d
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 296 deletions.
10 changes: 4 additions & 6 deletions gui/src/app/RunPanel/RunPanel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@ import { FunctionComponent, useCallback } from "react";
import { SamplingOpts } from "@SpCore/ProjectDataModel";
import { Progress } from "@SpStanSampler/StanModelWorker";
import StanSampler from "@SpStanSampler/StanSampler";
import {
useSamplerProgress,
useSamplerStatus,
} from "@SpStanSampler/useStanSampler";
import { StanRun } from "@SpStanSampler/useStanSampler";

type RunPanelProps = {
width: number;
height: number;
sampler?: StanSampler;
latestRun: StanRun;
data: any | undefined;
dataIsSaved: boolean;
samplingOpts: SamplingOpts;
Expand All @@ -27,12 +25,12 @@ const RunPanel: FunctionComponent<RunPanelProps> = ({
width,
height,
sampler,
latestRun,
data,
dataIsSaved,
samplingOpts,
}) => {
const { status: runStatus, errorMessage } = useSamplerStatus(sampler);
const progress = useSamplerProgress(sampler);
const { status: runStatus, errorMessage, progress } = latestRun;

const handleRun = useCallback(async () => {
if (!sampler) return;
Expand Down
21 changes: 9 additions & 12 deletions gui/src/app/SamplerOutputView/SamplerOutputView.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,33 @@ import SummaryView from "@SpComponents/SummaryView";
import TabWidget from "@SpComponents/TabWidget";
import TracePlotsView from "@SpComponents/TracePlotsView";
import { SamplingOpts } from "@SpCore/ProjectDataModel";
import StanSampler from "@SpStanSampler/StanSampler";
import { useSamplerOutput } from "@SpStanSampler/useStanSampler";
import { StanRun } from "@SpStanSampler/useStanSampler";
import { triggerDownload } from "@SpUtil/triggerDownload";
import JSZip from "jszip";
import { FunctionComponent, useCallback, useMemo, useState } from "react";

type SamplerOutputViewProps = {
width: number;
height: number;
sampler: StanSampler;
latestRun: StanRun;
};

const SamplerOutputView: FunctionComponent<SamplerOutputViewProps> = ({
width,
height,
sampler,
latestRun,
}) => {
const { draws, paramNames, numChains, computeTimeSec } =
useSamplerOutput(sampler);
const { draws, paramNames, computeTimeSec, samplingOpts } = latestRun;

if (!draws || !paramNames || !numChains) return <span />;
if (!draws || !paramNames || !samplingOpts) return <span />;
return (
<DrawsDisplay
width={width}
height={height}
draws={draws}
paramNames={paramNames}
numChains={numChains}
computeTimeSec={computeTimeSec}
samplingOpts={sampler.samplingOpts}
samplingOpts={samplingOpts}
/>
);
};
Expand All @@ -43,10 +40,9 @@ type DrawsDisplayProps = {
width: number;
height: number;
draws: number[][];
numChains: number;
paramNames: string[];
computeTimeSec: number | undefined;
samplingOpts: SamplingOpts; // for including in exported zip
samplingOpts: SamplingOpts;
};

const tabs = [
Expand Down Expand Up @@ -81,12 +77,13 @@ const DrawsDisplay: FunctionComponent<DrawsDisplayProps> = ({
height,
draws,
paramNames,
numChains,
computeTimeSec,
samplingOpts,
}) => {
const [currentTabId, setCurrentTabId] = useState("summary");

const numChains = samplingOpts.num_chains;

const drawChainIds = useMemo(() => {
return [...new Array(draws[0].length).keys()].map(
(i) => 1 + Math.floor((i / draws[0].length) * numChains),
Expand Down
114 changes: 42 additions & 72 deletions gui/src/app/StanSampler/StanSampler.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { defaultSamplingOpts, SamplingOpts } from "@SpCore/ProjectDataModel";
import { Progress, Replies, Requests } from "@SpStanSampler/StanModelWorker";
import { SamplingOpts } from "@SpCore/ProjectDataModel";
import { Replies, Requests } from "@SpStanSampler/StanModelWorker";
import StanWorkerUrl from "@SpStanSampler/StanModelWorker?worker&url";
import type { SamplerParams } from "tinystan";
import { type StanRunAction } from "./useStanSampler";

export type StanSamplerStatus =
| ""
Expand All @@ -11,27 +12,27 @@ export type StanSamplerStatus =
| "completed"
| "failed";

type StanSamplerAndCleanup = {
sampler: StanSampler;
cleanup: () => void;
};

class StanSampler {
#worker: Worker | undefined;
#status: StanSamplerStatus = "";
#errorMessage: string = "";
#onProgressCallbacks: ((progress: Progress) => void)[] = [];
#onStatusChangedCallbacks: (() => void)[] = [];
#draws: number[][] = [];
#computeTimeSec: number | undefined = undefined;
#paramNames: string[] = [];
#samplingStartTimeSec: number = 0;
#samplingOpts: SamplingOpts = defaultSamplingOpts; // the sampling options used in the last sample call

private constructor(private compiledUrl: string) {
private constructor(
private compiledUrl: string,
private update: (action: StanRunAction) => void,
) {
this._initialize();
}

static __unsafe_create(compiledUrl: string): {
sampler: StanSampler;
cleanup: () => void;
} {
const sampler = new StanSampler(compiledUrl);
static __unsafe_create(
compiledUrl: string,
update: (action: StanRunAction) => void,
): StanSamplerAndCleanup {
const sampler = new StanSampler(compiledUrl, update);
const cleanup = () => {
console.log("terminating model worker");
sampler.#worker && sampler.#worker.terminate();
Expand All @@ -45,40 +46,43 @@ class StanSampler {
name: "tinystan worker",
type: "module",
});
this.#status = "loading";

this.update({ type: "clear" });

this.#worker.onmessage = (e) => {
const purpose: Replies = e.data.purpose;
switch (purpose) {
case Replies.Progress: {
this.#onProgressCallbacks.forEach((callback) =>
callback(e.data.report),
);
this.update({ type: "progressUpdate", progress: e.data.report });
break;
}
case Replies.ModelLoaded: {
this.#status = "loaded";
this.#onStatusChangedCallbacks.forEach((cb) => cb());
this.update({ type: "statusUpdate", status: "loaded" });
break;
}
case Replies.StanReturn: {
if (e.data.error) {
this.#errorMessage = e.data.error;
this.#status = "failed";
this.#onStatusChangedCallbacks.forEach((cb) => cb());
this.update({
type: "statusUpdate",
status: "failed",
errorMessage: e.data.error,
});
} else {
this.#draws = e.data.draws;
this.#paramNames = e.data.paramNames;
this.#computeTimeSec =
Date.now() / 1000 - this.#samplingStartTimeSec;
this.#status = "completed";
this.#onStatusChangedCallbacks.forEach((cb) => cb());
this.update({
type: "samplerReturn",
draws: e.data.draws,
paramNames: e.data.paramNames,
computeTimeSec: Date.now() / 1000 - this.#samplingStartTimeSec,
});
}
break;
}
}
};
this.update({ type: "statusUpdate", status: "loading" });
this.#worker.postMessage({ purpose: Requests.Load, url: this.compiledUrl });
}

sample(data: any, samplingOpts: SamplingOpts) {
const refresh = calculateReasonableRefreshRate(samplingOpts);
const sampleConfig: Partial<SamplerParams> = {
Expand All @@ -87,51 +91,17 @@ class StanSampler {
seed: samplingOpts.seed !== undefined ? samplingOpts.seed : null,
refresh,
};
if (!this.#worker) return;
if (this.#status === "sampling") {
console.warn("Already sampling");
return;
}
this.#samplingOpts = samplingOpts;
this.#draws = [];
this.#paramNames = [];
if (!this.#worker) throw new Error("model worker is undefined");

this.update({ type: "startSampling", samplingOpts });

this.#samplingStartTimeSec = Date.now() / 1000;
this.#status = "sampling";
this.#onStatusChangedCallbacks.forEach((cb) => cb());
this.#worker.postMessage({ purpose: Requests.Sample, sampleConfig });
}
onProgress(callback: (progress: Progress) => void) {
this.#onProgressCallbacks.push(callback);
}
onStatusChanged(callback: () => void) {
this.#onStatusChangedCallbacks.push(callback);
}

cancel() {
if (this.#status === "sampling") {
this.#worker && this.#worker.terminate();
this.#status = "";
this._initialize();
} else {
console.warn("Nothing to cancel");
}
}
get draws() {
return this.#draws;
}
get paramNames() {
return this.#paramNames;
}
get status() {
return this.#status;
}
get errorMessage() {
return this.#errorMessage;
}
get computeTimeSec() {
return this.#computeTimeSec;
}
get samplingOpts() {
return this.#samplingOpts;
this.#worker?.terminate();
this._initialize();
}
}

Expand Down
Loading

0 comments on commit 6c38d9d

Please sign in to comment.