Skip to content

Commit

Permalink
Improve int4 compressed comparisons performance (#13321)
Browse files Browse the repository at this point in the history
This updates the int4 dot-product comparison to have an optimized one for when one of the vectors are compressed (the most common search case). This change actually makes the compressed search on ARM faster than the uncompressed. However, on AVX512/256, it still slightly slower than uncompressed, but it still much faster now with this optimization than before (eagerly decompressing).

This optimized is tied tightly with how the vectors are actually compressed and stored, consequently, I added a new scorer that is within the lucene99 codec.

So, this gives us 8x reduction over float32, well more than 2x faster queries than float32, and no need to rerank as the recall and accuracy are excellent.
  • Loading branch information
benwtrent committed May 1, 2024
1 parent 9287167 commit b701898
Show file tree
Hide file tree
Showing 17 changed files with 847 additions and 170 deletions.
3 changes: 3 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ Optimizations
* GITHUB#13284: Per-field doc values and knn vectors readers now use a HashMap internally instead of
a TreeMap. (Adrien Grand)

* GITHUB#13321: Improve compressed int4 quantized vector search by utilizing SIMD inline with the decompression
process. (Ben Trent)

Bug Fixes
---------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@

import java.io.IOException;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizedRandomVectorScorer;
import org.apache.lucene.util.quantization.ScalarQuantizedRandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity;
import org.apache.lucene.util.quantization.ScalarQuantizer;

Expand All @@ -35,6 +35,30 @@
*/
public class ScalarQuantizedVectorScorer implements FlatVectorsScorer {

public static float quantizeQuery(
float[] query,
byte[] quantizedQuery,
VectorSimilarityFunction similarityFunction,
ScalarQuantizer scalarQuantizer) {
final float[] processedQuery;
switch (similarityFunction) {
case EUCLIDEAN:
case DOT_PRODUCT:
case MAXIMUM_INNER_PRODUCT:
processedQuery = query;
break;
case COSINE:
float[] queryCopy = ArrayUtil.copyOfSubArray(query, 0, query.length);
VectorUtil.l2normalize(queryCopy);
processedQuery = queryCopy;
break;
default:
throw new IllegalArgumentException(
"Unsupported similarity function: " + similarityFunction);
}
return scalarQuantizer.quantize(processedQuery, quantizedQuery, similarityFunction);
}

private final FlatVectorsScorer nonQuantizedDelegate;

public ScalarQuantizedVectorScorer(FlatVectorsScorer flatVectorsScorer) {
Expand Down Expand Up @@ -69,18 +93,21 @@ public RandomVectorScorer getRandomVectorScorer(
ScalarQuantizer scalarQuantizer = quantizedByteVectorValues.getScalarQuantizer();
byte[] targetBytes = new byte[target.length];
float offsetCorrection =
ScalarQuantizedRandomVectorScorer.quantizeQuery(
target, targetBytes, similarityFunction, scalarQuantizer);
quantizeQuery(target, targetBytes, similarityFunction, scalarQuantizer);
ScalarQuantizedVectorSimilarity scalarQuantizedVectorSimilarity =
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
similarityFunction,
scalarQuantizer.getConstantMultiplier(),
scalarQuantizer.getBits());
return new ScalarQuantizedRandomVectorScorer(
scalarQuantizedVectorSimilarity,
quantizedByteVectorValues,
targetBytes,
offsetCorrection);
return new RandomVectorScorer.AbstractRandomVectorScorer(quantizedByteVectorValues) {
@Override
public float score(int node) throws IOException {
byte[] nodeVector = quantizedByteVectorValues.vectorValue(node);
float nodeOffset = quantizedByteVectorValues.getScoreCorrectionConstant(node);
return scalarQuantizedVectorSimilarity.score(
targetBytes, offsetCorrection, nodeVector, nodeOffset);
}
};
}
// It is possible to get to this branch during initial indexing and flush
return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
Expand All @@ -99,4 +126,60 @@ public RandomVectorScorer getRandomVectorScorer(
public String toString() {
return "ScalarQuantizedVectorScorer(" + "nonQuantizedDelegate=" + nonQuantizedDelegate + ')';
}

/**
* Quantized vector scorer supplier
*
* @lucene.experimental
*/
public static class ScalarQuantizedRandomVectorScorerSupplier
implements RandomVectorScorerSupplier {

private final RandomAccessQuantizedByteVectorValues values;
private final ScalarQuantizedVectorSimilarity similarity;
private final VectorSimilarityFunction vectorSimilarityFunction;

public ScalarQuantizedRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction,
ScalarQuantizer scalarQuantizer,
RandomAccessQuantizedByteVectorValues values) {
this.similarity =
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
similarityFunction,
scalarQuantizer.getConstantMultiplier(),
scalarQuantizer.getBits());
this.values = values;
this.vectorSimilarityFunction = similarityFunction;
}

private ScalarQuantizedRandomVectorScorerSupplier(
ScalarQuantizedVectorSimilarity similarity,
VectorSimilarityFunction vectorSimilarityFunction,
RandomAccessQuantizedByteVectorValues values) {
this.similarity = similarity;
this.values = values;
this.vectorSimilarityFunction = vectorSimilarityFunction;
}

@Override
public RandomVectorScorer scorer(int ord) throws IOException {
final RandomAccessQuantizedByteVectorValues vectorsCopy = values.copy();
final byte[] queryVector = values.vectorValue(ord);
final float queryOffset = values.getScoreCorrectionConstant(ord);
return new RandomVectorScorer.AbstractRandomVectorScorer(vectorsCopy) {
@Override
public float score(int node) throws IOException {
byte[] nodeVector = vectorsCopy.vectorValue(node);
float nodeOffset = vectorsCopy.getScoreCorrectionConstant(node);
return similarity.score(queryVector, queryOffset, nodeVector, nodeOffset);
}
};
}

@Override
public RandomVectorScorerSupplier copy() throws IOException {
return new ScalarQuantizedRandomVectorScorerSupplier(
similarity, vectorSimilarityFunction, values.copy());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ public byte[] vectorValue(int targetOrd) throws IOException {
return binaryValue;
}

@Override
public IndexInput getSlice() {
return slice;
}

private void readValue(int targetOrd) throws IOException {
slice.seek((long) targetOrd * byteSize);
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ public int size() {
return size;
}

@Override
public IndexInput getSlice() {
return slice;
}

@Override
public float[] vectorValue(int targetOrd) throws IOException {
if (lastOrd == targetOrd) {
Expand Down
Loading

0 comments on commit b701898

Please sign in to comment.