Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interpolate merging digest percentiles by centroid weight #110

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions core/src/main/java/com/tdunning/math/stats/MergingDigest.java
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,9 @@ public double quantile(double q) {
// at the boundaries, we return min or max
if (index < weight[0] / 2) {
assert weight[0] > 0;
return min + 2 * index / weight[0] * (mean[0] - min);
double z1 = (index / totalWeight) * weight[0];
double z2 = ((weight[0] / 2 - index) / totalWeight) * getWeightForMinOrMax(weight[0]);
return weightedAverage(min, z2, mean[0], z1);
}

// in between we interpolate between centroids
Expand All @@ -687,8 +689,8 @@ public double quantile(double q) {
double dw = (weight[i] + weight[i + 1]) / 2;
if (weightSoFar + dw > index) {
// centroids i and i+1 bracket our current point
double z1 = index - weightSoFar;
double z2 = weightSoFar + dw - index;
double z1 = ((index - weightSoFar) / totalWeight) * weight[i + 1];
double z2 = ((weightSoFar + dw - index) / totalWeight) * weight[i];
return weightedAverage(mean[i], z2, mean[i + 1], z1);
}
weightSoFar += dw;
Expand All @@ -698,11 +700,28 @@ public double quantile(double q) {

// weightSoFar = totalWeight - weight[n-1]/2 (very nearly)
// so we interpolate out to max value ever seen
double z1 = index - totalWeight - weight[n - 1] / 2.0;
double z2 = weight[n - 1] / 2 - z1;
double z1 = ((totalWeight - index) / totalWeight) * weight[n - 1];
double z2 = (weight[n - 1] / 2 - z1) * getWeightForMinOrMax(weight[n - 1]);
return weightedAverage(mean[n - 1], z1, max, z2);
}

private double getWeightForMinOrMax(double maxWeight) {
return Math.min(maxWeight, centroidScaleToQuantile(1 / compression));
}

/**
* Takes a centroid scale (the centroid number) and returns the ideal
* starting percentile for it.
*
* @param centroidNumber the number of the centroid
* @return the ideal starting percentile for this centroid
*/
private double centroidScaleToQuantile(double centroidNumber) {
double k = centroidNumber / compression;
double sinePrecalc = Math.sin((k*Math.PI)/2);
return sinePrecalc * sinePrecalc;
}

@Override
public int centroidCount() {
return lastUsedCell;
Expand Down
26 changes: 23 additions & 3 deletions core/src/test/java/com/tdunning/math/stats/MergingDigestTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ public void printQuantiles() throws FileNotFoundException {
td.setMinMax(0, 10);
td.add(1);
td.add(2);
td.add(5, 2);
td.add(5);
td.add(6);
td.add(9);
td.add(10);
Expand All @@ -132,7 +132,7 @@ public void printQuantiles() throws FileNotFoundException {

quantiles.printf("x,q\n");
cdfs.printf("x,q\n");
for (double q = 0; q < 1; q += 1e-3) {
for (double q = 1.0 / 12.0; q < 1; q += 1e-3) {
double x = td.quantile(q);
quantiles.printf("%.3f,%.3f\n", x, q);

Expand All @@ -145,7 +145,27 @@ public void printQuantiles() throws FileNotFoundException {
}
}

assertEquals(2.0 / 7, td.cdf(3), 1e-9);
assertEquals(0.25 + 1.0/18.0, td.cdf(3), 1e-9);
}

@Test
public void testCloseTo1QuantileBigWeights() {
MergingDigest td = new MergingDigest(200);
td.add(1, 1);
td.add(2, 100);
td.setMinMax(0, 1000);
final double quantile = td.quantile(.999);
assertTrue("Quantile val incorrect: " + quantile, quantile < 101);
}

@Test
public void testCloseToZeroQuantileBigWeights() {
MergingDigest td = new MergingDigest(200);
td.add(1000, 100);
td.add(1001, 1);
td.setMinMax(0, 1002);
final double quantile = td.quantile(.001);
assertTrue("Quantile val incorrect: " + quantile, quantile > 999);
}

@Override
Expand Down
Binary file not shown.
153 changes: 153 additions & 0 deletions docs/t-digest-paper/weighted-vs-linear-interpolation.r
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
### Compare linear interpolation of centroids to weighted interpolation ###
fade = rgb(0,0,0,alpha=0.5)
dot.size = 0.7
set.seed(5)

pdf("weighted-vs-linear-interpolation.pdf", width=6, height=2.7, pointsize=10)
layout(matrix(c(1,2),byrow=T, ncol=2), widths=c(1.1,1))
x = sort(log(1-runif(10000))) # sorted exponential distribution
F = ((0:(length(x)-1))+0.5)/length(x) # the y points for an x point to its percentile
par(mar=c(2.5,3,1,1))
plot(x, F, cex=dot.size, pch=21, bg=fade, col=NA, type='b', xlim=c(x[1], x[110]), ylim=c(0,0.01), xaxt='n', ylab=NA, mgp=c(1,0.5,0), xlab=NA)

axis(side=1, at=-10:-1, labels=NA)
title(xlab='x', line=0.8, cex.lab=1.5)
title(ylab='q', line=1.5, cex.lab=1.5)

left.end = min(x)

lines(c(left.end, x[100]), c(0, 0.01), lwd=2)
lines(c(left.end, left.end), c(-0.0005, 0.0005), lt=1, col='black', lwd=0.5)
lines(c(x[100], x[100]), c(0.0085, 0.015), lt=1, col='black', lwd=0.5)
text(-7, 0.006, "100")

q.to.k = function(q) {
(asin(2*q-1)/pi + 1/2)
}

k.to.q = function(k,compression) {
sin(k/compression*pi - pi/2)/2 + 0.5
}

# This function makes a plot of the cdf of the sorted distribution in the global variable "x"
# The graph limits are defined by x[rangeMin]/x[rangeMax], which are values from [1,n]. By default, it
# with only graph ~ the first percentile.
# It will then calculate the positions of the centroids based on the weights passed in.
# Lastly, it will graph these positions with draw function which is passed in
makeChart = function(weights, titleToDisp, drawFunc, rangeMin=1, rangeMax=round(1.1*length(x)/100)) {
xLimits = c(x[rangeMin], x[rangeMax])
yLimits = c((rangeMin-1)/length(x), rangeMax/length(x))
F = ((0:(length(x)-1))+0.5)/length(x) # the y points for an x point to its percentile

#plot the points of the distribution
plot(x, F, cex=dot.size, pch=21, bg=fade, col=NA, type='b', xlim=xLimits, ylim=yLimits, xaxt='n')
title(main=titleToDisp)
axis(side=1, at=-10:-1, labels=NA)
axis(side=2, at=(0:6)/10, labels=NA)
title(xlab='x', line=0.8, cex.lab=1.5)
title(ylab='q', line=2, cex.lab=1.5)

xCoordinatesEnds = c(1, cumsum(weights), length(x)) # x[coordinateEnds[i]] are the end points of each centroid

# xCoordinates[i] is the mean value of all of the points that in centroid[i]
xCoordinates = numeric(length(xCoordinatesEnds))
xCoordinates[1]=x[1] # first centroid is always the min
for(i in 2:length(xCoordinates)) {
xCoordinates[i] = mean(x[xCoordinatesEnds[i-1]:xCoordinatesEnds[i]])
}

# each centroid's height is all of the weight up to the centroid minus half the centroid weight.
yCoordinates = c(0, cumsum(weights)- weights / 2, length(x)) / sum(weights)

# note - weight of min and max is specified as 1. It could be k.to.q(1/compression) for better results.
weightsAlignedByPoints = c(1,weights,1)

# draw chart
drawFunc(xCoordinates, yCoordinates, weightsAlignedByPoints)
# write the weight of each centroid above it.
text(xCoordinates, yCoordinates + (rangeMax-rangeMin)*.06/length(x), round(c(1, weights, 1)))
}

# Draw the centroids linearly interpolated
drawLinear = function(xCoordinates, yCoordinates, weights) {
lines(xCoordinates, yCoordinates, type='o', lwd=1, col='blue')
}
# Draw the centroids interpolated by weight
drawWeighted = function(xCoordinates, yCoordinates, weights) {
points(xCoordinates, yCoordinates, col='blue')
for(i in 1:(length(xCoordinates)-1)) {
w1 = weights[i]
w2 = weights[i+1]
x1 = xCoordinates[i]
x2 = xCoordinates[i+1]
mean1 = yCoordinates[i]
mean2 = yCoordinates[i+1]

# Weighted average with centroid weight as
weightFunc = function(q) {
(mean1*w1*(x2-q)+mean2*w2*(q-x1)) / (w1*(x2-q)+w2*(q-x1))
}
curve(weightFunc, x1, x2, add=TRUE, col="blue")
}
}

# This function calculates the ideal centroid weights for the specified compression / dataset size
getIdealBounds = function(n, compression) {
leftBounds = c(0, k.to.q(1:(compression-1), compression))
rightBounds = k.to.q(1:compression, compression)
(rightBounds - leftBounds) * length(x)
}

# these weights were taken from linear-interpolation.r
weights = c(2, 8, 19, 35, 56, 81, 111)

n=length(x)
makeChart(c(weights, n-sum(weights)), "Preset Weights", drawLinear)

# Show comparison for these preset weights
makeChart(c(weights, n-sum(weights)), "Preset Weights", drawLinear)
makeChart(c(weights, n-sum(weights)), "Preset Weights", drawWeighted)

drawExampleCharts = function(distName) {
n = length(x)

# make comparison charts for ideal centroid placements
for (i in c(50,100,200)) {
titleToDisp = paste(distName," c=", i)
makeChart(getIdealBounds(n, i), titleToDisp, drawLinear)
makeChart(getIdealBounds(n, i), titleToDisp, drawWeighted)
}

midLen = n/20 # 5%
titleToDisp = paste(distName," c=", 100)

# make chart of top 0.025-0.05 quantiles
makeChart(getIdealBounds(n, 100), titleToDisp, drawLinear, rangeMin=round(midLen/2), rangeMax=midLen)
makeChart(getIdealBounds(n, 100), titleToDisp, drawWeighted, rangeMin=round(midLen/2), rangeMax=midLen)

# make chart of top 0.45-0.55 quantiles
makeChart(getIdealBounds(n, 100), titleToDisp, drawLinear, rangeMin=round(n/2-midLen), rangeMax=round(n/2+midLen))
makeChart(getIdealBounds(n, 100), titleToDisp, drawWeighted, rangeMin=round(n/2-midLen), rangeMax=round(n/2+midLen))


# make chart of top 0.01 quantile
bottomPercent = (n*1.1)/100 # 1.1% of elements
makeChart(getIdealBounds(n, 100), titleToDisp, drawLinear, rangeMin=n-bottomPercent, rangeMax=n)
makeChart(getIdealBounds(n, 100), titleToDisp, drawWeighted, rangeMin=n-bottomPercent, rangeMax=n)
}

drawExampleCharts(paste("Exp n=", length(x)))

x = sort(log(1-runif(50000))) # sorted exponential distribution
drawExampleCharts(paste("Exp n=", length(x)))

x = sort(rnorm(50000)) # sorted normal distribution
drawExampleCharts(paste("Normal n=", length(x)))

x = sort(rnorm(50000)) # sorted normal distribution
drawExampleCharts(paste("Normal n=", length(x)))

x = sort(runif(50000)) # sorted uniform distribution
drawExampleCharts(paste("Uniform n=", length(x)))

dev.off()