Skip to content

Commit

Permalink
Merge pull request #48 from flatironinstitute/ess-stats
Browse files Browse the repository at this point in the history
frontend: effective sample size and related stats
  • Loading branch information
magland authored Jun 14, 2024
2 parents 2c0e884 + e3d6340 commit a584857
Show file tree
Hide file tree
Showing 3 changed files with 585 additions and 23 deletions.
65 changes: 42 additions & 23 deletions gui/src/app/SamplerOutputView/SummaryView.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { FunctionComponent, useMemo } from "react"
import { computeMean, computePercentile, computeStdDev } from "./util"
import { compute_effective_sample_size, compute_split_potential_scale_reduction } from "./stan_stats/stan_stats"

type SummaryViewProps = {
width: number
Expand All @@ -16,11 +17,11 @@ const columns = [
label: 'Mean',
title: 'Mean value of the parameter'
},
/*future: {
{
key: 'mcse',
label: 'MCSE',
title: 'Monte Carlo Standard Error: Standard deviation of the parameter divided by the square root of the effective sample size'
},*/
},
{
key: 'stdDev',
label: 'StdDev',
Expand All @@ -41,46 +42,43 @@ const columns = [
label: '95%',
title: '95th percentile of the parameter'
},
/*future: {
{
key: 'nEff',
label: 'N_Eff',
title: 'Effective sample size: A crude measure of the effective sample size (uses ess_imse)'
},*/
/*future: {
title: 'Effective sample size: A crude measure of the effective sample size'
},
{
key: 'nEff/s',
label: 'N_Eff/s',
title: 'Effective sample size per second of compute time'
},*/
/*future: {
},
{
key: 'rHat',
label: 'R_hat',
title: 'Potential scale reduction factor on split chains (at convergence, R_hat=1)'
}*/
}
]

type TableRow = {
key: string
values: number[]
}

const SummaryView: FunctionComponent<SummaryViewProps> = ({ width, height, draws, paramNames }) => {
// will be used in the future:
// const uniqueChainIds = useMemo(() => (Array.from(new Set(drawChainIds)).sort()), [drawChainIds]);
// note: computeTimeSec will be used in the future

const SummaryView: FunctionComponent<SummaryViewProps> = ({ width, height, draws, paramNames, drawChainIds, computeTimeSec }) => {
const rows = useMemo(() => {
const rows: TableRow[] = [];
for (const pname of paramNames) {
const pDraws = draws[paramNames.indexOf(pname)];
const pDrawsSorted = [...pDraws].sort((a, b) => a - b);
const ess = computeEss(pDraws, drawChainIds);
const rhat = computeRhat(pDraws, drawChainIds);
const stdDev = computeStdDev(pDraws);
const values = columns.map((column) => {
if (column.key === 'mean') {
return computeMean(pDraws);
}
else if (column.key === 'mcse') {
// placeholder for mcse
throw new Error('Not implemented');
return stdDev / Math.sqrt(ess);
}
else if (column.key === 'stdDev') {
return stdDev;
Expand All @@ -95,16 +93,13 @@ const SummaryView: FunctionComponent<SummaryViewProps> = ({ width, height, draws
return computePercentile(pDrawsSorted, 0.95);
}
else if (column.key === 'nEff') {
// placeholder for nEff
throw new Error('Not implemented');
return ess;
}
else if (column.key === 'nEff/s') {
// placeholder for nEff/s
throw new Error('Not implemented');
return computeTimeSec ? ess / computeTimeSec : NaN;
}
else if (column.key === 'rHat') {
// placeholder for rHat
throw new Error('Not implemented');
return rhat;
}
else {
return NaN;
Expand All @@ -116,7 +111,7 @@ const SummaryView: FunctionComponent<SummaryViewProps> = ({ width, height, draws
})
}
return rows;
}, [paramNames, draws]);
}, [draws, paramNames, drawChainIds, computeTimeSec]);

return (
<div style={{position: 'absolute', width, height, overflowY: 'auto'}}>
Expand Down Expand Up @@ -157,6 +152,30 @@ const SummaryView: FunctionComponent<SummaryViewProps> = ({ width, height, draws
)
}

const drawsByChain = (draws: number[], chainIds: number[]): number[][] => {
// Group draws by chain for use in computing ESS and Rhat
const uniqueChainIds = Array.from(new Set(chainIds)).sort();
const drawsByChain: number[][] = new Array(uniqueChainIds.length).fill(0).map(() => []);
for (let i = 0; i < draws.length; i++) {
const chainId = chainIds[i];
const chainIndex = uniqueChainIds.indexOf(chainId);
drawsByChain[chainIndex].push(draws[i]);
}
return drawsByChain;
}

const computeEss = (x: number[], chainIds: number[]) => {
const draws = drawsByChain(x, chainIds);
const ess = compute_effective_sample_size(draws);
return ess;
}

const computeRhat = (x: number[], chainIds: number[]) => {
const draws = drawsByChain(x, chainIds);
const rhat = compute_split_potential_scale_reduction(draws);
return rhat;
}

// Example of Stan output...
// Inference for Stan model: bernoulli_model
// 1 chains: each with iter=(1000); warmup=(0); thin=(1); 1000 iterations saved.
Expand Down
225 changes: 225 additions & 0 deletions gui/src/app/SamplerOutputView/stan_stats/fft.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
/* eslint-disable @typescript-eslint/no-inferrable-types */
/* eslint-disable @typescript-eslint/no-unused-vars */
/* eslint-disable prefer-const */
/*
* Free FFT and convolution (TypeScript)
*
* Copyright (c) 2022 Project Nayuki. (MIT License)
* https://www.nayuki.io/page/free-small-fft-in-multiple-languages
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of
* this software and associated documentation files (the "Software"), to deal in
* the Software without restriction, including without limitation the rights to
* use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
* the Software, and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
* - The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
* - The Software is provided "as is", without warranty of any kind, express or
* implied, including but not limited to the warranties of merchantability,
* fitness for a particular purpose and noninfringement. In no event shall the
* authors or copyright holders be liable for any claim, damages or other
* liability, whether in an action of contract, tort or otherwise, arising from,
* out of or in connection with the Software or the use or other dealings in the
* Software.
*/


/*
* Computes the discrete Fourier transform (DFT) of the given complex vector, storing the result back into the vector.
* The vector can have any length. This is a wrapper function.
*/
export function transform(real: Array<number>|Float64Array, imag: Array<number>|Float64Array): void {
const n: number = real.length;
if (n != imag.length)
throw new RangeError("Mismatched lengths");
if (n == 0)
return;
else if ((n & (n - 1)) == 0) // Is power of 2
transformRadix2(real, imag);
else // More complicated algorithm for arbitrary sizes
transformBluestein(real, imag);
}


/*
* Computes the inverse discrete Fourier transform (IDFT) of the given complex vector, storing the result back into the vector.
* The vector can have any length. This is a wrapper function. This transform does not perform scaling, so the inverse is not a true inverse.
*/
export function inverseTransform(real: Array<number>|Float64Array, imag: Array<number>|Float64Array): void {
transform(imag, real);
}


/*
* Computes the discrete Fourier transform (DFT) of the given complex vector, storing the result back into the vector.
* The vector's length must be a power of 2. Uses the Cooley-Tukey decimation-in-time radix-2 algorithm.
*/
function transformRadix2(real: Array<number>|Float64Array, imag: Array<number>|Float64Array): void {
// Length variables
const n: number = real.length;
if (n != imag.length)
throw new RangeError("Mismatched lengths");
if (n == 1) // Trivial transform
return;
let levels: number = -1;
for (let i = 0; i < 32; i++) {
if (1 << i == n)
levels = i; // Equal to log2(n)
}
if (levels == -1)
throw new RangeError("Length is not a power of 2");

// Trigonometric tables
let cosTable = new Array<number>(n / 2);
let sinTable = new Array<number>(n / 2);
for (let i = 0; i < n / 2; i++) {
cosTable[i] = Math.cos(2 * Math.PI * i / n);
sinTable[i] = Math.sin(2 * Math.PI * i / n);
}

// Bit-reversed addressing permutation
for (let i = 0; i < n; i++) {
const j: number = reverseBits(i, levels);
if (j > i) {
let temp: number = real[i];
real[i] = real[j];
real[j] = temp;
temp = imag[i];
imag[i] = imag[j];
imag[j] = temp;
}
}

// Cooley-Tukey decimation-in-time radix-2 FFT
for (let size = 2; size <= n; size *= 2) {
const halfsize: number = size / 2;
const tablestep: number = n / size;
for (let i = 0; i < n; i += size) {
for (let j = i, k = 0; j < i + halfsize; j++, k += tablestep) {
const l: number = j + halfsize;
const tpre: number = real[l] * cosTable[k] + imag[l] * sinTable[k];
const tpim: number = -real[l] * sinTable[k] + imag[l] * cosTable[k];
real[l] = real[j] - tpre;
imag[l] = imag[j] - tpim;
real[j] += tpre;
imag[j] += tpim;
}
}
}

// Returns the integer whose value is the reverse of the lowest 'width' bits of the integer 'val'.
function reverseBits(val: number, width: number): number {
let result: number = 0;
for (let i = 0; i < width; i++) {
result = (result << 1) | (val & 1);
val >>>= 1;
}
return result;
}
}


/*
* Computes the discrete Fourier transform (DFT) of the given complex vector, storing the result back into the vector.
* The vector can have any length. This requires the convolution function, which in turn requires the radix-2 FFT function.
* Uses Bluestein's chirp z-transform algorithm.
*/
function transformBluestein(real: Array<number>|Float64Array, imag: Array<number>|Float64Array): void {
// Find a power-of-2 convolution length m such that m >= n * 2 + 1
const n: number = real.length;
if (n != imag.length)
throw new RangeError("Mismatched lengths");
let m: number = 1;
while (m < n * 2 + 1)
m *= 2;

// Trigonometric tables
let cosTable = new Array<number>(n);
let sinTable = new Array<number>(n);
for (let i = 0; i < n; i++) {
const j: number = i * i % (n * 2); // This is more accurate than j = i * i
cosTable[i] = Math.cos(Math.PI * j / n);
sinTable[i] = Math.sin(Math.PI * j / n);
}

// Temporary vectors and preprocessing
let areal: Array<number> = newArrayOfZeros(m);
let aimag: Array<number> = newArrayOfZeros(m);
for (let i = 0; i < n; i++) {
areal[i] = real[i] * cosTable[i] + imag[i] * sinTable[i];
aimag[i] = -real[i] * sinTable[i] + imag[i] * cosTable[i];
}
let breal: Array<number> = newArrayOfZeros(m);
let bimag: Array<number> = newArrayOfZeros(m);
breal[0] = cosTable[0];
bimag[0] = sinTable[0];
for (let i = 1; i < n; i++) {
breal[i] = breal[m - i] = cosTable[i];
bimag[i] = bimag[m - i] = sinTable[i];
}

// Convolution
let creal = new Array<number>(m);
let cimag = new Array<number>(m);
convolveComplex(areal, aimag, breal, bimag, creal, cimag);

// Postprocessing
for (let i = 0; i < n; i++) {
real[i] = creal[i] * cosTable[i] + cimag[i] * sinTable[i];
imag[i] = -creal[i] * sinTable[i] + cimag[i] * cosTable[i];
}
}


/*
* Computes the circular convolution of the given real vectors. Each vector's length must be the same.
*/
// function convolveReal(xvec: Array<number>|Float64Array, yvec: Array<number>|Float64Array, outvec: Array<number>|Float64Array): void {
// const n: number = xvec.length;
// if (n != yvec.length || n != outvec.length)
// throw new RangeError("Mismatched lengths");
// convolveComplex(xvec, newArrayOfZeros(n), yvec, newArrayOfZeros(n), outvec, newArrayOfZeros(n));
// }


/*
* Computes the circular convolution of the given complex vectors. Each vector's length must be the same.
*/
function convolveComplex(
xreal: Array<number>|Float64Array, ximag: Array<number>|Float64Array,
yreal: Array<number>|Float64Array, yimag: Array<number>|Float64Array,
outreal: Array<number>|Float64Array, outimag: Array<number>|Float64Array): void {

const n: number = xreal.length;
if (n != ximag.length || n != yreal.length || n != yimag.length
|| n != outreal.length || n != outimag.length)
throw new RangeError("Mismatched lengths");

xreal = xreal.slice();
ximag = ximag.slice();
yreal = yreal.slice();
yimag = yimag.slice();
transform(xreal, ximag);
transform(yreal, yimag);

for (let i = 0; i < n; i++) {
const temp: number = xreal[i] * yreal[i] - ximag[i] * yimag[i];
ximag[i] = ximag[i] * yreal[i] + xreal[i] * yimag[i];
xreal[i] = temp;
}
inverseTransform(xreal, ximag);

for (let i = 0; i < n; i++) { // Scaling (because this FFT implementation omits it)
outreal[i] = xreal[i] / n;
outimag[i] = ximag[i] / n;
}
}


function newArrayOfZeros(n: number): Array<number> {
let result: Array<number> = [];
for (let i = 0; i < n; i++)
result.push(0);
return result;
}
Loading

0 comments on commit a584857

Please sign in to comment.