From 11afca546956ef4034ca7dd1d0ef8627dd978021 Mon Sep 17 00:00:00 2001 From: Jeremy Magland Date: Fri, 14 Jun 2024 06:59:45 -0400 Subject: [PATCH 1/4] export chain csvs --- gui/package.json | 1 + .../SamplerOutputView/SamplerOutputView.tsx | 54 +++++++++++++++++-- gui/src/app/SamplerOutputView/SummaryView.tsx | 3 +- gui/yarn.lock | 32 +++++++++++ 4 files changed, 84 insertions(+), 6 deletions(-) diff --git a/gui/package.json b/gui/package.json index 8b1ae602..4b581e0f 100644 --- a/gui/package.json +++ b/gui/package.json @@ -20,6 +20,7 @@ "@monaco-editor/react": "^4.6.0", "@mui/icons-material": "^5.15.17", "@mui/material": "^5.15.17", + "jszip": "^3.10.1", "monaco-editor": "^0.48.0", "plotly.js": "^2.33.0", "react": "^18.2.0", diff --git a/gui/src/app/SamplerOutputView/SamplerOutputView.tsx b/gui/src/app/SamplerOutputView/SamplerOutputView.tsx index e2ec6058..54e367cf 100644 --- a/gui/src/app/SamplerOutputView/SamplerOutputView.tsx +++ b/gui/src/app/SamplerOutputView/SamplerOutputView.tsx @@ -7,6 +7,7 @@ import TabWidget from "../TabWidget/TabWidget" import TracePlotsView from "./TracePlotsView" import SummaryView from "./SummaryView" import HistsView from "./HistsView" +import JSZip from 'jszip' type SamplerOutputViewProps = { width: number @@ -142,13 +143,33 @@ const DrawsView: FunctionComponent = ({ width, height, draws, pa const csvText = prepareCsvText(draws, paramNames, drawChainIds, drawNumbers); downloadTextFile(csvText, 'draws.csv'); }, [draws, paramNames, drawChainIds, drawNumbers]); + const handleExportToMultipleCsvs = useCallback(async () => { + const uniqueChainIds = Array.from(new Set(drawChainIds)); + const csvTexts = prepareMultipleCsvsText(draws, paramNames, drawChainIds, uniqueChainIds); + const blob = await createZipBlobForMultipleCsvs(csvTexts, uniqueChainIds); + const fileName = 'SP-draws.zip'; + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = fileName; + a.click(); + URL.revokeObjectURL(url); + }, [draws, paramNames, drawChainIds]); return (
- } - label="Export to .csv" - onClick={handleExportToCsv} - /> +
+ } + label="Export to single .csv" + onClick={handleExportToCsv} + /> +   + } + label="Export to multiple .csv" + onClick={handleExportToMultipleCsvs} + /> +
@@ -197,6 +218,29 @@ const prepareCsvText = (draws: number[][], paramNames: string[], drawChainIds: n return [['Chain', 'Draw', ...paramNames].join(','), ...lines].join('\n') } +const prepareMultipleCsvsText = (draws: number[][], paramNames: string[], drawChainIds: number[], uniqueChainIds: number[]) => { + return uniqueChainIds.map(chainId => { + const drawIndicesForChain = drawChainIds.map((id, i) => id === chainId ? i : -1).filter(i => i >= 0); + const lines = drawIndicesForChain.map(i => { + return paramNames.map((_, j) => draws[j][i]).join(',') + }) + + return [paramNames.join(','), ...lines].join('\n') + }) +} + +const createZipBlobForMultipleCsvs = async (csvTexts: string[], uniqueChainIds: number[]) => { + const zip = new JSZip(); + // put them all in a folder called 'draws' + const folder = zip.folder('draws'); + if (!folder) throw new Error('Failed to create folder'); + csvTexts.forEach((text, i) => { + folder.file(`chain-${uniqueChainIds[i]}.csv`, text); + }); + const blob = await zip.generateAsync({type: 'blob'}); + return blob; +} + const downloadTextFile = (text: string, filename: string) => { const blob = new Blob([text], {type: 'text/plain'}); const url = URL.createObjectURL(blob); diff --git a/gui/src/app/SamplerOutputView/SummaryView.tsx b/gui/src/app/SamplerOutputView/SummaryView.tsx index ee2e8f41..bf49ea58 100644 --- a/gui/src/app/SamplerOutputView/SummaryView.tsx +++ b/gui/src/app/SamplerOutputView/SummaryView.tsx @@ -1,8 +1,8 @@ import { FunctionComponent, useMemo } from "react" import { ess } from "./advanced/ess" -import { computeMean, computePercentile, computeStdDev } from "./util" import rhat from "./advanced/rhat" import compute_effective_sample_size from "./ess_computation_from_stan/compute_effective_sample_size" +import { computeMean, computePercentile, computeStdDev } from "./util" type SummaryViewProps = { width: number @@ -187,6 +187,7 @@ const computeEss2 = (x: number[], chainIds: number[]) => { draws[chainIndex].push(x[i]); } const ess = compute_effective_sample_size(draws); + // const ess = compute_split_effective_sample_size(draws); return ess; } diff --git a/gui/yarn.lock b/gui/yarn.lock index 6ef253f1..fd667054 100644 --- a/gui/yarn.lock +++ b/gui/yarn.lock @@ -3483,6 +3483,11 @@ ignore@^5.2.0, ignore@^5.2.4: resolved "https://registry.yarnpkg.com/ignore/-/ignore-5.3.1.tgz#5073e554cd42c5b33b394375f538b8593e34d4ef" integrity sha512-5Fytz/IraMjqpwfd34ke28PTVMjZjJG2MPn5t7OE4eUCUNf8BAa7b5WUS9/Qvr6mwOQS7Mk6vdsMno5he+T8Xw== +immediate@~3.0.5: + version "3.0.6" + resolved "https://registry.yarnpkg.com/immediate/-/immediate-3.0.6.tgz#9db1dbd0faf8de6fbe0f5dd5e56bb606280de69b" + integrity sha512-XXOFtyqDjNDAQxVfYxuF7g9Il/IbWmmlQg2MYKOH8ExIT1qg6xc4zyS3HaEEATgs1btfzxq15ciUiY7gjSXRGQ== + import-fresh@^3.2.1: version "3.3.0" resolved "https://registry.yarnpkg.com/import-fresh/-/import-fresh-3.3.0.tgz#37162c25fcb9ebaa2e6e53d5b4d88ce17d9e0c2b" @@ -3932,6 +3937,16 @@ json5@^2.2.3: object.assign "^4.1.4" object.values "^1.1.6" +jszip@^3.10.1: + version "3.10.1" + resolved "https://registry.yarnpkg.com/jszip/-/jszip-3.10.1.tgz#34aee70eb18ea1faec2f589208a157d1feb091c2" + integrity sha512-xXDvecyTpGLrqFrvkrUSoxxfJI5AH7U8zxxtVclpsUtMCq4JQ290LY8AW5c7Ggnr/Y/oK+bQMbqK2qmtk3pN4g== + dependencies: + lie "~3.3.0" + pako "~1.0.2" + readable-stream "~2.3.6" + setimmediate "^1.0.5" + katex@^0.16.0: version "0.16.10" resolved "https://registry.yarnpkg.com/katex/-/katex-0.16.10.tgz#6f81b71ac37ff4ec7556861160f53bc5f058b185" @@ -3964,6 +3979,13 @@ levn@^0.4.1: prelude-ls "^1.2.1" type-check "~0.4.0" +lie@~3.3.0: + version "3.3.0" + resolved "https://registry.yarnpkg.com/lie/-/lie-3.3.0.tgz#dcf82dee545f46074daf200c7c1c5a08e0f40f6a" + integrity sha512-UaiMJzeWRlEujzAuw5LokY1L5ecNQYZKfmyZ9L7wDHb/p5etKaxXhohBcrw0EYby+G/NA52vRSN4N39dxHAIwQ== + dependencies: + immediate "~3.0.5" + lines-and-columns@^1.1.6: version "1.2.4" resolved "https://registry.yarnpkg.com/lines-and-columns/-/lines-and-columns-1.2.4.tgz#eca284f75d2965079309dc0ad9255abb2ebc1632" @@ -5054,6 +5076,11 @@ p-locate@^5.0.0: dependencies: p-limit "^3.0.2" +pako@~1.0.2: + version "1.0.11" + resolved "https://registry.yarnpkg.com/pako/-/pako-1.0.11.tgz#6c9599d340d54dfd3946380252a35705a6b992bf" + integrity sha512-4hLB8Py4zZce5s4yd9XzopqwVv/yGNhV1Bl8NTmCq1763HeK2+EwVTv+leGeL13Dnh2wfbqowVPXCIO0z4taYw== + parent-module@^1.0.0: version "1.0.1" resolved "https://registry.yarnpkg.com/parent-module/-/parent-module-1.0.1.tgz#691d2709e78c79fae3a156622452d00762caaaa2" @@ -5897,6 +5924,11 @@ set-function-name@^2.0.1, set-function-name@^2.0.2: functions-have-names "^1.2.3" has-property-descriptors "^1.0.2" +setimmediate@^1.0.5: + version "1.0.5" + resolved "https://registry.yarnpkg.com/setimmediate/-/setimmediate-1.0.5.tgz#290cbb232e306942d7d7ea9b83732ab7856f8285" + integrity sha512-MATJdZp8sLqDl/68LfQmbP8zKPLQNV6BIZoIgrscFDQ+RsvK/BxeDQOgyxKKoh0y/8h3BqVFnCqQ/gd+reiIXA== + shallow-copy@0.0.1: version "0.0.1" resolved "https://registry.yarnpkg.com/shallow-copy/-/shallow-copy-0.0.1.tgz#415f42702d73d810330292cc5ee86eae1a11a170" From 0264e5e463afb6ff780f9209e40dd729b3c364dc Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 14 Jun 2024 14:14:15 +0000 Subject: [PATCH 2/4] Fix ESS calculation (sample vs population variance issue) --- .../compute_effective_sample_size.ts | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/gui/src/app/SamplerOutputView/ess_computation_from_stan/compute_effective_sample_size.ts b/gui/src/app/SamplerOutputView/ess_computation_from_stan/compute_effective_sample_size.ts index 9adfc673..155aa12c 100644 --- a/gui/src/app/SamplerOutputView/ess_computation_from_stan/compute_effective_sample_size.ts +++ b/gui/src/app/SamplerOutputView/ess_computation_from_stan/compute_effective_sample_size.ts @@ -85,7 +85,7 @@ function compute_effective_sample_size(draws: number[][]): number { let var_plus = mean_var * (num_draws - 1) / num_draws; if (num_chains > 1) { - var_plus += compute_variance(chain_mean); + var_plus += compute_sample_variance(chain_mean); } const rho_hat_s = new Array(num_draws).fill(0); @@ -150,12 +150,16 @@ function compute_mean(arr: number[]): number { return compute_sum(arr) / arr.length; } -function compute_variance(arr: number[]): number { - // QUESTION: is this the correct formula for variance? +function compute_population_variance(arr: number[]): number { const mean = compute_mean(arr); return compute_mean(arr.map(d => (d - mean) ** 2)); } +function compute_sample_variance(arr: number[]): number { + const mean = compute_mean(arr); + return compute_sum(arr.map(d => (d - mean) ** 2)) / (arr.length - 1); +} + function autocorrelation(y: number[]): number[] { const N = y.length; const M = fftNextGoodSize(N); @@ -190,7 +194,7 @@ function autocorrelation(y: number[]): number[] { function autocovariance(y: number[]): number[] { const acov = autocorrelation(y); - const variance = compute_variance(y); + const variance = compute_population_variance(y); return acov.map(v => v * variance); } @@ -257,4 +261,4 @@ export const compute_split_effective_sample_size = (draws: number[][]) => { return compute_effective_sample_size(split_draws); } -export default compute_effective_sample_size; \ No newline at end of file +export default compute_effective_sample_size; From 6ee44b44c75c86de89c065c6893b627ee241d8e5 Mon Sep 17 00:00:00 2001 From: Jeremy Magland Date: Fri, 14 Jun 2024 10:59:46 -0400 Subject: [PATCH 3/4] change "chain-*" to "chain_*" --- gui/src/app/SamplerOutputView/SamplerOutputView.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gui/src/app/SamplerOutputView/SamplerOutputView.tsx b/gui/src/app/SamplerOutputView/SamplerOutputView.tsx index 54e367cf..423c3df3 100644 --- a/gui/src/app/SamplerOutputView/SamplerOutputView.tsx +++ b/gui/src/app/SamplerOutputView/SamplerOutputView.tsx @@ -235,7 +235,7 @@ const createZipBlobForMultipleCsvs = async (csvTexts: string[], uniqueChainIds: const folder = zip.folder('draws'); if (!folder) throw new Error('Failed to create folder'); csvTexts.forEach((text, i) => { - folder.file(`chain-${uniqueChainIds[i]}.csv`, text); + folder.file(`chain_${uniqueChainIds[i]}.csv`, text); }); const blob = await zip.generateAsync({type: 'blob'}); return blob; From 5166803eaa8bddf34c99150f3f75eaf5cce2a563 Mon Sep 17 00:00:00 2001 From: Jeremy Magland Date: Fri, 14 Jun 2024 17:33:21 -0400 Subject: [PATCH 4/4] comments in SamplerOutputView --- gui/src/app/SamplerOutputView/SamplerOutputView.tsx | 7 +++++++ gui/src/app/SamplerOutputView/SummaryView.tsx | 1 - 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/gui/src/app/SamplerOutputView/SamplerOutputView.tsx b/gui/src/app/SamplerOutputView/SamplerOutputView.tsx index 423c3df3..192d9b15 100644 --- a/gui/src/app/SamplerOutputView/SamplerOutputView.tsx +++ b/gui/src/app/SamplerOutputView/SamplerOutputView.tsx @@ -212,6 +212,10 @@ const DrawsView: FunctionComponent = ({ width, height, draws, pa } const prepareCsvText = (draws: number[][], paramNames: string[], drawChainIds: number[], drawNumbers: number[]) => { + // draws: Each element of draws is a column corresponding to a parameter, across all chains + // paramNames: The paramNames array contains the names of the parameters in the same order that they appear in the draws array + // drawChainIds: The drawChainIds array contains the chain id for each row in the draws array + // uniqueChainIds: The uniqueChainIds array contains the unique chain ids const lines = draws[0].map((_, i) => { return [`${drawChainIds[i]}`, `${drawNumbers[i]}`, ...paramNames.map((_, j) => draws[j][i])].join(',') }) @@ -219,6 +223,9 @@ const prepareCsvText = (draws: number[][], paramNames: string[], drawChainIds: n } const prepareMultipleCsvsText = (draws: number[][], paramNames: string[], drawChainIds: number[], uniqueChainIds: number[]) => { + // See the comments in prepareCsvText for the meaning of the arguments. + // Whereas prepareCsvText returns a CSV that represents a long-form table, + // this function returns multiple CSVs, one for each chain. return uniqueChainIds.map(chainId => { const drawIndicesForChain = drawChainIds.map((id, i) => id === chainId ? i : -1).filter(i => i >= 0); const lines = drawIndicesForChain.map(i => { diff --git a/gui/src/app/SamplerOutputView/SummaryView.tsx b/gui/src/app/SamplerOutputView/SummaryView.tsx index 201b6aea..832fa128 100644 --- a/gui/src/app/SamplerOutputView/SummaryView.tsx +++ b/gui/src/app/SamplerOutputView/SummaryView.tsx @@ -161,7 +161,6 @@ const computeEss = (x: number[], chainIds: number[]) => { draws[chainIndex].push(x[i]); } const ess = compute_effective_sample_size(draws); - // const ess = compute_split_effective_sample_size(draws); return ess; }