Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add correlation parameter for KNN performance tests #330

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
100 changes: 100 additions & 0 deletions src/main/knn/CorrelatedFilterBuilder.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package knn;

import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.FixedBitSet;

import java.util.Random;

/**
* Builds a FixedBitSet filter over the input docs to achieve:
* 1. A filter cardinality = (selectivity)%
* 2. A normalized correlation ≈ correlation
*/
public class CorrelatedFilterBuilder {

final private float selectivity;
final private float correlation;
final private int n;

public CorrelatedFilterBuilder(int n, float selectivity, float correlation) {
if (selectivity <= 0 || selectivity >= 1) {
throw new IllegalArgumentException("selectivity must be in the range (0, 1)");
}
if (correlation < -1 || correlation > 1) {
throw new IllegalArgumentException("correlation must be in the range: [-1, 1]");
}
this.selectivity = selectivity;
this.correlation = correlation;
this.n = n;
}


/**
* The filter is built such that:
* - correlation = -1 means the lowest scoring (selectivity) of docs are set
* - correlation = 1 means the highest scoring (selectivity) of docs are set
* Correlation between -1 and 0 starts like -1, but a random (1 - |correlation|) of the set docs are cleared
* and set randomly over the entire range.
* Similarly, correlation between 0 and 1 starts like 1, and follows the same process.
*/
public FixedBitSet getCorrelatedFilter(TopDocs docs, Random random) {
FixedBitSet filter = new FixedBitSet(n);
final int filterCardinality = (int) (selectivity * n);

// Start with largest/smallest possible correlation by
// setting the highest (for corr > 0) / lowest (for corr < 0) scored vectors
if (correlation > 0) {
for (int i = 0; i < filterCardinality; i++) {
filter.set(docs.scoreDocs[i].doc);
}
} else {
for (int i = n - 1; i > n - 1 - filterCardinality; i--) {
filter.set(docs.scoreDocs[i].doc);
}
}

// Randomly flip (1 - |correlation|) of the set bits
final int amountToFlip = (int) ((1 - Math.abs(correlation)) * filterCardinality);
int flipped = 0;
while (flipped < amountToFlip) {
int i;
if (correlation > 0) {
i = random.nextInt(filterCardinality);
} else {
i = random.nextInt(n - filterCardinality, n);
}
if (filter.getAndClear(docs.scoreDocs[i].doc)) {
setRandomClearBit(filter, random);
flipped++;
}
}

return filter;
}

private void setRandomClearBit(FixedBitSet bitSet, Random random) {
int randomBit;
do {
randomBit = random.nextInt(bitSet.length());
} while (bitSet.get(randomBit));
bitSet.set(randomBit);
}

}
80 changes: 75 additions & 5 deletions src/main/knn/KnnGraphTester.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
import java.util.concurrent.Executors;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.function.BinaryOperator;

import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnVectorsFormat;
Expand Down Expand Up @@ -158,8 +157,10 @@ public class KnnGraphTester {
private ExecutorService exec;
private VectorSimilarityFunction similarityFunction;
private VectorEncoding vectorEncoding;
private Query filterQuery;
private Query sharedFilterQuery;
private Query[] filterQueries;
private float selectivity;
private float correlation;
private boolean prefilter;
private boolean randomCommits;
private boolean parentJoin;
Expand All @@ -182,6 +183,7 @@ private KnnGraphTester() {
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
vectorEncoding = VectorEncoding.FLOAT32;
selectivity = 1f;
correlation = 0f;
prefilter = false;
quantize = false;
randomCommits = false;
Expand Down Expand Up @@ -377,6 +379,15 @@ private void run(String... args) throws Exception {
throw new IllegalArgumentException("-filterSelectivity must be between 0 and 1");
}
break;
case "-filterCorrelation": // Setting this != 0 will make tests slow as correlated filters are built
if (iarg == args.length - 1) {
throw new IllegalArgumentException("-filterCorrelation requires a following float");
}
correlation = Float.parseFloat(args[++iarg]);
if (correlation < -1 || correlation > 1) {
throw new IllegalArgumentException("-filterCorrelation must be in the range [-1, 1]");
}
break;
case "-quiet":
quiet = true;
break;
Expand Down Expand Up @@ -425,6 +436,9 @@ private void run(String... args) throws Exception {
if (prefilter && selectivity == 1f) {
throw new IllegalArgumentException("-prefilter requires filterSelectivity between 0 and 1");
}
if (correlation != 0 && selectivity == 1f) {
throw new IllegalArgumentException("a nonzero -filterCorrelation requires filterSelectivity between 0 and 1");
}
if (indexPath == null) {
indexPath = Paths.get(formatIndexPath(docVectorsPath)); // derive index path
log("Index Path = %s", indexPath);
Expand Down Expand Up @@ -573,7 +587,11 @@ private void run(String... args) throws Exception {
if (docVectorsPath == null) {
throw new IllegalArgumentException("missing -docs arg");
}
filterQuery = selectivity == 1f ? new MatchAllDocsQuery() : generateRandomQuery(random, indexPath, numDocs, selectivity);
if (correlation != 0) {
filterQueries = generateRandomCorrelatedFilterQueries(random, queryPath, selectivity, correlation);
} else {
sharedFilterQuery = selectivity == 1f ? new MatchAllDocsQuery() : generateRandomFilterQuery(random, indexPath, numDocs, selectivity);
}
if (outputPath != null) {
testSearch(indexPath, queryPath, queryStartIndex, outputPath, null);
} else {
Expand All @@ -591,7 +609,54 @@ private void run(String... args) throws Exception {
}
}

private static Query generateRandomQuery(Random random, Path indexPath, int size, float selectivity) throws IOException {
// For each query vector, generate a filter query with the given selectivity and correlation
private Query[] generateRandomCorrelatedFilterQueries(Random random, Path queryPath, float selectivity, float correlation) throws IOException {
Query[] filterQueries = new Query[numQueryVectors];
log("computing correlated filters for " + numQueryVectors + " target vectors");
long startNS = System.nanoTime();
try (Directory dir = FSDirectory.open(indexPath);
DirectoryReader docReader = DirectoryReader.open(dir);
FileChannel qIn = getVectorFileChannel(queryPath, dim, vectorEncoding)) {
VectorReader queryReader = (VectorReader) VectorReader.create(qIn, dim, VectorEncoding.FLOAT32, queryStartIndex);
int indexNumDocs = docReader.numDocs();
if (numDocs > indexNumDocs) {
throw new IllegalArgumentException("-ndocs must be <= the number of docs in the index");
}
knn.CorrelatedFilterBuilder correlatedFilterBuilder = new knn.CorrelatedFilterBuilder(docReader.numDocs(), selectivity, correlation);

for (int i = 0; i < numQueryVectors; i++) {
if ((i + 1) % 10 == 0) {
log(" " + (i + 1));
}

// Get a score for every doc by doing an exact search for topK = numDocs
float[] queryVec = queryReader.next().clone();
IndexSearcher searcher = new IndexSearcher(docReader);
var queryVector = new ConstKnnFloatValueSource(queryVec);
var docVectors = new FloatKnnVectorFieldSource(KNN_FIELD);
var query = new BooleanQuery.Builder()
.add(new FunctionQuery(new FloatVectorSimilarityFunction(similarityFunction, queryVector, docVectors)), BooleanClause.Occur.SHOULD)
.build();
var topDocs = searcher.search(query, numDocs);

BitSet[] segmentDocs = new BitSet[docReader.leaves().size()];
for (var leafContext : docReader.leaves()) {
// Generate a filter for this query vector's scores
FixedBitSet segmentBitSet = correlatedFilterBuilder.getCorrelatedFilter(topDocs, random);
segmentDocs[leafContext.ord] = segmentBitSet;
}
// TODO: cache filters for subsequent runs (similar to readExactNN)
filterQueries[i] = new BitSetQuery(segmentDocs);
}
}

long elapsedMS = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNS); // ns -> ms
System.out.printf("took %.3f sec to compute correlated filters\n", elapsedMS / 1000.);

return filterQueries;
}

private static Query generateRandomFilterQuery(Random random, Path indexPath, int size, float selectivity) throws IOException {
FixedBitSet bitSet = new FixedBitSet(size);
for (int i = 0; i < size; i++) {
if (random.nextFloat() < selectivity) {
Expand Down Expand Up @@ -787,6 +852,7 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat
numDocs = reader.maxDoc();
// warm up
for (int i = 0; i < numQueryVectors; i++) {
Query filterQuery = correlation == 0 ? sharedFilterQuery : filterQueries[i];
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
byte[] target = targetReaderByte.nextBytes();
doKnnByteVectorQuery(searcher, KNN_FIELD, target, topK, fanout, prefilter, filterQuery);
Expand All @@ -799,6 +865,7 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat
startNS = System.nanoTime();
cpuTimeStartNs = bean.getCurrentThreadCpuTime();
for (int i = 0; i < numQueryVectors; i++) {
Query filterQuery = correlation == 0 ? sharedFilterQuery : filterQueries[i];
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
byte[] target = targetReaderByte.nextBytes();
results[i] = doKnnByteVectorQuery(searcher, KNN_FIELD, target, topK, fanout, prefilter, filterQuery);
Expand Down Expand Up @@ -857,7 +924,7 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat
double reindexSec = reindexTimeMsec / 1000.0;
System.out.printf(
Locale.ROOT,
"SUMMARY: %5.3f\t%5.3f\t%d\t%d\t%d\t%d\t%d\t%s\t%d\t%.2f\t%.2f\t%.2f\t%d\t%.2f\t%.2f\t%s\t%5.3f\t%5.3f\n",
"SUMMARY: %5.3f\t%5.3f\t%d\t%d\t%d\t%d\t%d\t%s\t%d\t%.2f\t%.2f\t%.2f\t%d\t%.2f\t%.2f\t%.2f\t%s\t%5.3f\t%5.3f\n",
recall,
totalCpuTimeMS / (float) numQueryVectors,
numDocs,
Expand All @@ -873,6 +940,7 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat
indexNumSegments,
indexSizeOnDiskMB,
selectivity,
correlation,
prefilter ? "pre-filter" : "post-filter",
vectorDiskSizeBytes / 1024. / 1024.,
vectorRAMSizeBytes / 1024. / 1024.);
Expand Down Expand Up @@ -1050,6 +1118,7 @@ public Void call() {
try {
var queryVector = new ConstKnnByteVectorValueSource(query);
var docVectors = new ByteKnnVectorFieldSource(KNN_FIELD);
Query filterQuery = correlation == 0 ? sharedFilterQuery : filterQueries[queryOrd];
var query = new BooleanQuery.Builder()
.add(new FunctionQuery(new ByteVectorSimilarityFunction(similarityFunction, queryVector, docVectors)), BooleanClause.Occur.SHOULD)
.add(filterQuery, BooleanClause.Occur.FILTER)
Expand Down Expand Up @@ -1119,6 +1188,7 @@ public Void call() {
try {
var queryVector = new ConstKnnFloatValueSource(query);
var docVectors = new FloatKnnVectorFieldSource(KNN_FIELD);
Query filterQuery = correlation == 0 ? sharedFilterQuery : filterQueries[queryOrd];
var query = new BooleanQuery.Builder()
.add(new FunctionQuery(new FloatVectorSimilarityFunction(similarityFunction, queryVector, docVectors)), BooleanClause.Occur.SHOULD)
.add(filterQuery, BooleanClause.Occur.FILTER)
Expand Down
2 changes: 1 addition & 1 deletion src/python/knnPerfTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def run_knn_benchmark(checkout, values):
print_fixed_width(all_results, skip_headers)

def print_fixed_width(all_results, columns_to_skip):
header = 'recall\tlatency (ms)\tnDoc\ttopK\tfanout\tmaxConn\tbeamWidth\tquantized\tvisited\tindex s\tindex docs/s\tforce merge s\tnum segments\tindex size (MB)\tselectivity\tfilterType\tvec disk (MB)\tvec RAM (MB)'
header = 'recall\tlatency (ms)\tnDoc\ttopK\tfanout\tmaxConn\tbeamWidth\tquantized\tvisited\tindex s\tindex docs/s\tforce merge s\tnum segments\tindex size (MB)\tselectivity\tcorrelation\tfilterType\tvec disk (MB)\tvec RAM (MB)'

# crazy logic to make everything fixed width so rendering in fixed width font "aligns":
headers = header.split('\t')
Expand Down