From b70189819aa79b9272d108855007047f4c0e27a9 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Wed, 1 May 2024 10:05:51 -0400 Subject: [PATCH] Improve int4 compressed comparisons performance (#13321) This updates the int4 dot-product comparison to have an optimized one for when one of the vectors are compressed (the most common search case). This change actually makes the compressed search on ARM faster than the uncompressed. However, on AVX512/256, it still slightly slower than uncompressed, but it still much faster now with this optimization than before (eagerly decompressing). This optimized is tied tightly with how the vectors are actually compressed and stored, consequently, I added a new scorer that is within the lucene99 codec. So, this gives us 8x reduction over float32, well more than 2x faster queries than float32, and no need to rerank as the recall and accuracy are excellent. --- lucene/CHANGES.txt | 3 + .../hnsw/ScalarQuantizedVectorScorer.java | 101 +++++- .../lucene95/OffHeapByteVectorValues.java | 5 + .../lucene95/OffHeapFloatVectorValues.java | 5 + .../Lucene99ScalarQuantizedVectorScorer.java | 303 ++++++++++++++++++ .../Lucene99ScalarQuantizedVectorsFormat.java | 5 +- .../OffHeapQuantizedByteVectorValues.java | 20 ++ .../DefaultVectorUtilSupport.java | 16 +- .../vectorization/VectorUtilSupport.java | 2 +- .../org/apache/lucene/util/VectorUtil.java | 25 +- .../util/hnsw/RandomAccessVectorValues.java | 32 +- ...RandomAccessQuantizedByteVectorValues.java | 2 +- .../ScalarQuantizedRandomVectorScorer.java | 71 ---- ...arQuantizedRandomVectorScorerSupplier.java | 68 ---- .../PanamaVectorUtilSupport.java | 151 ++++++++- ...stLucene99ScalarQuantizedVectorScorer.java | 156 +++++++++ .../vectorization/TestVectorUtilSupport.java | 52 ++- 17 files changed, 847 insertions(+), 170 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java delete mode 100644 lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorer.java delete mode 100644 lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorerSupplier.java create mode 100644 lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 35a63855ba7c..4bb287acf578 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -97,6 +97,9 @@ Optimizations * GITHUB#13284: Per-field doc values and knn vectors readers now use a HashMap internally instead of a TreeMap. (Adrien Grand) +* GITHUB#13321: Improve compressed int4 quantized vector search by utilizing SIMD inline with the decompression + process. (Ben Trent) + Bug Fixes --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer.java index 9abc1bcab587..a4f339dda4b8 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer.java @@ -19,12 +19,12 @@ import java.io.IOException; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; -import org.apache.lucene.util.quantization.ScalarQuantizedRandomVectorScorer; -import org.apache.lucene.util.quantization.ScalarQuantizedRandomVectorScorerSupplier; import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity; import org.apache.lucene.util.quantization.ScalarQuantizer; @@ -35,6 +35,30 @@ */ public class ScalarQuantizedVectorScorer implements FlatVectorsScorer { + public static float quantizeQuery( + float[] query, + byte[] quantizedQuery, + VectorSimilarityFunction similarityFunction, + ScalarQuantizer scalarQuantizer) { + final float[] processedQuery; + switch (similarityFunction) { + case EUCLIDEAN: + case DOT_PRODUCT: + case MAXIMUM_INNER_PRODUCT: + processedQuery = query; + break; + case COSINE: + float[] queryCopy = ArrayUtil.copyOfSubArray(query, 0, query.length); + VectorUtil.l2normalize(queryCopy); + processedQuery = queryCopy; + break; + default: + throw new IllegalArgumentException( + "Unsupported similarity function: " + similarityFunction); + } + return scalarQuantizer.quantize(processedQuery, quantizedQuery, similarityFunction); + } + private final FlatVectorsScorer nonQuantizedDelegate; public ScalarQuantizedVectorScorer(FlatVectorsScorer flatVectorsScorer) { @@ -69,18 +93,21 @@ public RandomVectorScorer getRandomVectorScorer( ScalarQuantizer scalarQuantizer = quantizedByteVectorValues.getScalarQuantizer(); byte[] targetBytes = new byte[target.length]; float offsetCorrection = - ScalarQuantizedRandomVectorScorer.quantizeQuery( - target, targetBytes, similarityFunction, scalarQuantizer); + quantizeQuery(target, targetBytes, similarityFunction, scalarQuantizer); ScalarQuantizedVectorSimilarity scalarQuantizedVectorSimilarity = ScalarQuantizedVectorSimilarity.fromVectorSimilarity( similarityFunction, scalarQuantizer.getConstantMultiplier(), scalarQuantizer.getBits()); - return new ScalarQuantizedRandomVectorScorer( - scalarQuantizedVectorSimilarity, - quantizedByteVectorValues, - targetBytes, - offsetCorrection); + return new RandomVectorScorer.AbstractRandomVectorScorer(quantizedByteVectorValues) { + @Override + public float score(int node) throws IOException { + byte[] nodeVector = quantizedByteVectorValues.vectorValue(node); + float nodeOffset = quantizedByteVectorValues.getScoreCorrectionConstant(node); + return scalarQuantizedVectorSimilarity.score( + targetBytes, offsetCorrection, nodeVector, nodeOffset); + } + }; } // It is possible to get to this branch during initial indexing and flush return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); @@ -99,4 +126,60 @@ public RandomVectorScorer getRandomVectorScorer( public String toString() { return "ScalarQuantizedVectorScorer(" + "nonQuantizedDelegate=" + nonQuantizedDelegate + ')'; } + + /** + * Quantized vector scorer supplier + * + * @lucene.experimental + */ + public static class ScalarQuantizedRandomVectorScorerSupplier + implements RandomVectorScorerSupplier { + + private final RandomAccessQuantizedByteVectorValues values; + private final ScalarQuantizedVectorSimilarity similarity; + private final VectorSimilarityFunction vectorSimilarityFunction; + + public ScalarQuantizedRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, + ScalarQuantizer scalarQuantizer, + RandomAccessQuantizedByteVectorValues values) { + this.similarity = + ScalarQuantizedVectorSimilarity.fromVectorSimilarity( + similarityFunction, + scalarQuantizer.getConstantMultiplier(), + scalarQuantizer.getBits()); + this.values = values; + this.vectorSimilarityFunction = similarityFunction; + } + + private ScalarQuantizedRandomVectorScorerSupplier( + ScalarQuantizedVectorSimilarity similarity, + VectorSimilarityFunction vectorSimilarityFunction, + RandomAccessQuantizedByteVectorValues values) { + this.similarity = similarity; + this.values = values; + this.vectorSimilarityFunction = vectorSimilarityFunction; + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + final RandomAccessQuantizedByteVectorValues vectorsCopy = values.copy(); + final byte[] queryVector = values.vectorValue(ord); + final float queryOffset = values.getScoreCorrectionConstant(ord); + return new RandomVectorScorer.AbstractRandomVectorScorer(vectorsCopy) { + @Override + public float score(int node) throws IOException { + byte[] nodeVector = vectorsCopy.vectorValue(node); + float nodeOffset = vectorsCopy.getScoreCorrectionConstant(node); + return similarity.score(queryVector, queryOffset, nodeVector, nodeOffset); + } + }; + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new ScalarQuantizedRandomVectorScorerSupplier( + similarity, vectorSimilarityFunction, values.copy()); + } + } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java index da11df1e2518..8d98a9cd1c8a 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java @@ -68,6 +68,11 @@ public byte[] vectorValue(int targetOrd) throws IOException { return binaryValue; } + @Override + public IndexInput getSlice() { + return slice; + } + private void readValue(int targetOrd) throws IOException { slice.seek((long) targetOrd * byteSize); slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java index a13c9f55bb12..0aeddaf15362 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java @@ -56,6 +56,11 @@ public int size() { return size; } + @Override + public IndexInput getSlice() { + return slice; + } + @Override public float[] vectorValue(int targetOrd) throws IOException { if (lastOrd == targetOrd) { diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java new file mode 100644 index 000000000000..b10a3730b6a8 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.codecs.lucene99; + +import static org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer.quantizeQuery; + +import java.io.IOException; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; +import org.apache.lucene.util.quantization.ScalarQuantizer; + +/** + * Optimized scalar quantized implementation of {@link FlatVectorsScorer} for quantized vectors + * stored in the Lucene99 format. + * + * @lucene.experimental + */ +public class Lucene99ScalarQuantizedVectorScorer implements FlatVectorsScorer { + + private final FlatVectorsScorer nonQuantizedDelegate; + + public Lucene99ScalarQuantizedVectorScorer(FlatVectorsScorer flatVectorsScorer) { + nonQuantizedDelegate = flatVectorsScorer; + } + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues) + throws IOException { + if (vectorValues instanceof RandomAccessQuantizedByteVectorValues) { + return new ScalarQuantizedRandomVectorScorerSupplier( + (RandomAccessQuantizedByteVectorValues) vectorValues, similarityFunction); + } + // It is possible to get to this branch during initial indexing and flush + return nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + RandomAccessVectorValues vectorValues, + float[] target) + throws IOException { + if (vectorValues instanceof RandomAccessQuantizedByteVectorValues) { + RandomAccessQuantizedByteVectorValues quantizedByteVectorValues = + (RandomAccessQuantizedByteVectorValues) vectorValues; + ScalarQuantizer scalarQuantizer = quantizedByteVectorValues.getScalarQuantizer(); + byte[] targetBytes = new byte[target.length]; + float offsetCorrection = + quantizeQuery(target, targetBytes, similarityFunction, scalarQuantizer); + return fromVectorSimilarity( + targetBytes, + offsetCorrection, + similarityFunction, + scalarQuantizer.getConstantMultiplier(), + quantizedByteVectorValues); + } + // It is possible to get to this branch during initial indexing and flush + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + RandomAccessVectorValues vectorValues, + byte[] target) + throws IOException { + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + @Override + public String toString() { + return "ScalarQuantizedVectorScorer(" + "nonQuantizedDelegate=" + nonQuantizedDelegate + ')'; + } + + static RandomVectorScorer fromVectorSimilarity( + byte[] targetBytes, + float offsetCorrection, + VectorSimilarityFunction sim, + float constMultiplier, + RandomAccessQuantizedByteVectorValues values) { + switch (sim) { + case EUCLIDEAN: + return new Euclidean(values, constMultiplier, targetBytes); + case COSINE: + case DOT_PRODUCT: + return dotProductFactory( + targetBytes, offsetCorrection, sim, constMultiplier, values, f -> (1 + f) / 2); + case MAXIMUM_INNER_PRODUCT: + return dotProductFactory( + targetBytes, + offsetCorrection, + sim, + constMultiplier, + values, + VectorUtil::scaleMaxInnerProductScore); + default: + throw new IllegalArgumentException("Unsupported similarity function: " + sim); + } + } + + private static RandomVectorScorer.AbstractRandomVectorScorer dotProductFactory( + byte[] targetBytes, + float offsetCorrection, + VectorSimilarityFunction sim, + float constMultiplier, + RandomAccessQuantizedByteVectorValues values, + FloatToFloatFunction scoreAdjustmentFunction) { + if (values.getScalarQuantizer().getBits() <= 4) { + if (values.getVectorByteLength() != values.dimension() && values.getSlice() != null) { + return new CompressedInt4DotProduct( + values, constMultiplier, targetBytes, offsetCorrection, scoreAdjustmentFunction); + } + return new Int4DotProduct( + values, constMultiplier, targetBytes, offsetCorrection, scoreAdjustmentFunction); + } + return new DotProduct( + values, constMultiplier, targetBytes, offsetCorrection, scoreAdjustmentFunction); + } + + private static class Euclidean extends RandomVectorScorer.AbstractRandomVectorScorer { + private final float constMultiplier; + private final byte[] targetBytes; + private final RandomAccessQuantizedByteVectorValues values; + + private Euclidean( + RandomAccessQuantizedByteVectorValues values, float constMultiplier, byte[] targetBytes) { + super(values); + this.values = values; + this.constMultiplier = constMultiplier; + this.targetBytes = targetBytes; + } + + @Override + public float score(int node) throws IOException { + byte[] nodeVector = values.vectorValue(node); + int squareDistance = VectorUtil.squareDistance(nodeVector, targetBytes); + float adjustedDistance = squareDistance * constMultiplier; + return 1 / (1f + adjustedDistance); + } + } + + /** Calculates dot product on quantized vectors, applying the appropriate corrections */ + private static class DotProduct extends RandomVectorScorer.AbstractRandomVectorScorer { + private final float constMultiplier; + private final RandomAccessQuantizedByteVectorValues values; + private final byte[] targetBytes; + private final float offsetCorrection; + private final FloatToFloatFunction scoreAdjustmentFunction; + + public DotProduct( + RandomAccessQuantizedByteVectorValues values, + float constMultiplier, + byte[] targetBytes, + float offsetCorrection, + FloatToFloatFunction scoreAdjustmentFunction) { + super(values); + this.constMultiplier = constMultiplier; + this.values = values; + this.targetBytes = targetBytes; + this.offsetCorrection = offsetCorrection; + this.scoreAdjustmentFunction = scoreAdjustmentFunction; + } + + @Override + public float score(int vectorOrdinal) throws IOException { + byte[] storedVector = values.vectorValue(vectorOrdinal); + float vectorOffset = values.getScoreCorrectionConstant(vectorOrdinal); + int dotProduct = VectorUtil.dotProduct(storedVector, targetBytes); + float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset; + return scoreAdjustmentFunction.apply(adjustedDistance); + } + } + + private static class CompressedInt4DotProduct + extends RandomVectorScorer.AbstractRandomVectorScorer { + private final float constMultiplier; + private final RandomAccessQuantizedByteVectorValues values; + private final byte[] compressedVector; + private final byte[] targetBytes; + private final float offsetCorrection; + private final FloatToFloatFunction scoreAdjustmentFunction; + + private CompressedInt4DotProduct( + RandomAccessQuantizedByteVectorValues values, + float constMultiplier, + byte[] targetBytes, + float offsetCorrection, + FloatToFloatFunction scoreAdjustmentFunction) { + super(values); + this.constMultiplier = constMultiplier; + this.values = values; + this.compressedVector = new byte[values.getVectorByteLength()]; + this.targetBytes = targetBytes; + this.offsetCorrection = offsetCorrection; + this.scoreAdjustmentFunction = scoreAdjustmentFunction; + } + + @Override + public float score(int vectorOrdinal) throws IOException { + // get compressed vector, in Lucene99, vector values are stored and have a single value for + // offset correction + values.getSlice().seek((long) vectorOrdinal * (values.getVectorByteLength() + Float.BYTES)); + values.getSlice().readBytes(compressedVector, 0, compressedVector.length); + float vectorOffset = values.getScoreCorrectionConstant(vectorOrdinal); + int dotProduct = VectorUtil.int4DotProductPacked(targetBytes, compressedVector); + float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset; + return scoreAdjustmentFunction.apply(adjustedDistance); + } + } + + private static class Int4DotProduct extends RandomVectorScorer.AbstractRandomVectorScorer { + private final float constMultiplier; + private final RandomAccessQuantizedByteVectorValues values; + private final byte[] targetBytes; + private final float offsetCorrection; + private final FloatToFloatFunction scoreAdjustmentFunction; + + public Int4DotProduct( + RandomAccessQuantizedByteVectorValues values, + float constMultiplier, + byte[] targetBytes, + float offsetCorrection, + FloatToFloatFunction scoreAdjustmentFunction) { + super(values); + this.constMultiplier = constMultiplier; + this.values = values; + this.targetBytes = targetBytes; + this.offsetCorrection = offsetCorrection; + this.scoreAdjustmentFunction = scoreAdjustmentFunction; + } + + @Override + public float score(int vectorOrdinal) throws IOException { + byte[] storedVector = values.vectorValue(vectorOrdinal); + float vectorOffset = values.getScoreCorrectionConstant(vectorOrdinal); + int dotProduct = VectorUtil.int4DotProduct(storedVector, targetBytes); + float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset; + return scoreAdjustmentFunction.apply(adjustedDistance); + } + } + + @FunctionalInterface + private interface FloatToFloatFunction { + float apply(float f); + } + + private static final class ScalarQuantizedRandomVectorScorerSupplier + implements RandomVectorScorerSupplier { + + private final VectorSimilarityFunction vectorSimilarityFunction; + private final RandomAccessQuantizedByteVectorValues values; + private final RandomAccessQuantizedByteVectorValues values1; + private final RandomAccessQuantizedByteVectorValues values2; + + public ScalarQuantizedRandomVectorScorerSupplier( + RandomAccessQuantizedByteVectorValues values, + VectorSimilarityFunction vectorSimilarityFunction) + throws IOException { + this.values = values; + this.values1 = values.copy(); + this.values2 = values.copy(); + this.vectorSimilarityFunction = vectorSimilarityFunction; + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + byte[] vectorValue = values1.vectorValue(ord); + float offsetCorrection = values1.getScoreCorrectionConstant(ord); + return fromVectorSimilarity( + vectorValue, + offsetCorrection, + vectorSimilarityFunction, + values.getScalarQuantizer().getConstantMultiplier(), + values2); + } + + @Override + public ScalarQuantizedRandomVectorScorerSupplier copy() throws IOException { + return new ScalarQuantizedRandomVectorScorerSupplier(values.copy(), vectorSimilarityFunction); + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java index a3d894e64e99..c10f87da2a65 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java @@ -22,7 +22,6 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; -import org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; @@ -65,7 +64,7 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat { final byte bits; final boolean compress; - final ScalarQuantizedVectorScorer flatVectorScorer; + final Lucene99ScalarQuantizedVectorScorer flatVectorScorer; /** Constructs a format using default graph construction parameters */ public Lucene99ScalarQuantizedVectorsFormat() { @@ -102,7 +101,7 @@ public Lucene99ScalarQuantizedVectorsFormat( this.bits = (byte) bits; this.confidenceInterval = confidenceInterval; this.compress = compress; - this.flatVectorScorer = new ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer()); + this.flatVectorScorer = new Lucene99ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer()); } public static float calculateDefaultConfidenceInterval(int vectorDimension) { diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java index 08e666d51ff8..9659eb131872 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java @@ -135,6 +135,26 @@ public float getScoreCorrectionConstant() { return scoreCorrectionConstant[0]; } + @Override + public float getScoreCorrectionConstant(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return scoreCorrectionConstant[0]; + } + slice.seek(((long) targetOrd * byteSize) + numBytes); + slice.readFloats(scoreCorrectionConstant, 0, 1); + return scoreCorrectionConstant[0]; + } + + @Override + public IndexInput getSlice() { + return slice; + } + + @Override + public int getVectorByteLength() { + return numBytes; + } + public static OffHeapQuantizedByteVectorValues load( OrdToDocDISIReaderConfiguration configuration, int dimension, diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java index e56d6b97f314..eb5160a0f0dd 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java @@ -152,7 +152,21 @@ public int dotProduct(byte[] a, byte[] b) { } @Override - public int int4DotProduct(byte[] a, byte[] b) { + public int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean bpacked) { + assert (apacked && bpacked) == false; + if (apacked || bpacked) { + byte[] packed = apacked ? a : b; + byte[] unpacked = apacked ? b : a; + int total = 0; + for (int i = 0; i < packed.length; i++) { + byte packedByte = packed[i]; + byte unpacked1 = unpacked[i]; + byte unpacked2 = unpacked[i + packed.length]; + total += (packedByte & 0x0F) * unpacked2; + total += ((packedByte & 0xFF) >> 4) * unpacked1; + } + return total; + } return dotProduct(a, b); } diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java index 246cbdf95bc7..22e5e96aa256 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java @@ -37,7 +37,7 @@ public interface VectorUtilSupport { int dotProduct(byte[] a, byte[] b); /** Returns the dot product over the computed bytes, assuming the values are int4 encoded. */ - int int4DotProduct(byte[] a, byte[] b); + int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean bpacked); /** Returns the cosine similarity between the two byte vectors. */ float cosine(byte[] a, byte[] b); diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java index 7409a1de4cff..a9acca64d69a 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java @@ -179,7 +179,30 @@ public static int int4DotProduct(byte[] a, byte[] b) { if (a.length != b.length) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - return IMPL.int4DotProduct(a, b); + return IMPL.int4DotProduct(a, false, b, false); + } + + /** + * Dot product computed over int4 (values between [0,15]) bytes. The second vector is considered + * "packed" (i.e. every byte representing two values). The following packing is assumed: + * + *
+   *   packed[0] = (raw[0] * 16) | raw[packed.length];
+   *   packed[1] = (raw[1] * 16) | raw[packed.length + 1];
+   *   ...
+   *   packed[packed.length - 1] = (raw[packed.length - 1] * 16) | raw[2 * packed.length - 1];
+   * 
+ * + * @param unpacked the unpacked vector, of even length + * @param packed the packed vector, of length {@code (unpacked.length + 1) / 2} + * @return the value of the dot product of the two vectors + */ + public static int int4DotProductPacked(byte[] unpacked, byte[] packed) { + if (packed.length != ((unpacked.length + 1) >> 1)) { + throw new IllegalArgumentException( + "vector dimensions differ: " + unpacked.length + "!= 2 * " + packed.length); + } + return IMPL.int4DotProduct(unpacked, false, packed, true); } /** diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java index ecf5339cd21a..e2c7372b667a 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.List; +import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.Bits; /** @@ -41,6 +42,17 @@ public interface RandomAccessVectorValues { */ RandomAccessVectorValues copy() throws IOException; + /** + * Returns a slice of the underlying {@link IndexInput} that contains the vector values if + * available + */ + default IndexInput getSlice() { + return null; + } + + /** Returns the byte length of the vector values. */ + int getVectorByteLength(); + /** * Translates vector ordinal to the correct document ID. By default, this is an identity function. * @@ -72,6 +84,12 @@ interface Floats extends RandomAccessVectorValues { * @param targetOrd a valid ordinal, ≥ 0 and < {@link #size()}. */ float[] vectorValue(int targetOrd) throws IOException; + + /** Returns the vector byte length, defaults to dimension multiplied by float byte size */ + @Override + default int getVectorByteLength() { + return dimension() * Float.BYTES; + } } /** Byte vector values. */ @@ -85,6 +103,12 @@ interface Bytes extends RandomAccessVectorValues { * @param targetOrd a valid ordinal, ≥ 0 and < {@link #size()}. */ byte[] vectorValue(int targetOrd) throws IOException; + + /** Returns the vector byte length, defaults to dimension multiplied by byte size */ + @Override + default int getVectorByteLength() { + return dimension() * Byte.BYTES; + } } /** @@ -107,12 +131,12 @@ public int dimension() { } @Override - public float[] vectorValue(int targetOrd) throws IOException { + public float[] vectorValue(int targetOrd) { return vectors.get(targetOrd); } @Override - public RandomAccessVectorValues.Floats copy() throws IOException { + public RandomAccessVectorValues.Floats copy() { return this; } }; @@ -138,12 +162,12 @@ public int dimension() { } @Override - public byte[] vectorValue(int targetOrd) throws IOException { + public byte[] vectorValue(int targetOrd) { return vectors.get(targetOrd); } @Override - public RandomAccessVectorValues.Bytes copy() throws IOException { + public RandomAccessVectorValues.Bytes copy() { return this; } }; diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/RandomAccessQuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/util/quantization/RandomAccessQuantizedByteVectorValues.java index 08b0b6e5a7ae..b86009a690e1 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/RandomAccessQuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/RandomAccessQuantizedByteVectorValues.java @@ -29,7 +29,7 @@ public interface RandomAccessQuantizedByteVectorValues extends RandomAccessVecto ScalarQuantizer getScalarQuantizer(); - float getScoreCorrectionConstant(); + float getScoreCorrectionConstant(int vectorOrd) throws IOException; @Override RandomAccessQuantizedByteVectorValues copy() throws IOException; diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorer.java b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorer.java deleted file mode 100644 index a88534e0b043..000000000000 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorer.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.lucene.util.quantization; - -import java.io.IOException; -import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.util.ArrayUtil; -import org.apache.lucene.util.VectorUtil; -import org.apache.lucene.util.hnsw.RandomVectorScorer; - -/** - * Quantized vector scorer - * - * @lucene.experimental - */ -public class ScalarQuantizedRandomVectorScorer - extends RandomVectorScorer.AbstractRandomVectorScorer { - - public static float quantizeQuery( - float[] query, - byte[] quantizedQuery, - VectorSimilarityFunction similarityFunction, - ScalarQuantizer scalarQuantizer) { - float[] processedQuery = query; - if (similarityFunction.equals(VectorSimilarityFunction.COSINE)) { - float[] queryCopy = ArrayUtil.copyOfSubArray(query, 0, query.length); - VectorUtil.l2normalize(queryCopy); - processedQuery = queryCopy; - } - return scalarQuantizer.quantize(processedQuery, quantizedQuery, similarityFunction); - } - - private final byte[] quantizedQuery; - private final float queryOffset; - private final RandomAccessQuantizedByteVectorValues values; - private final ScalarQuantizedVectorSimilarity similarity; - - public ScalarQuantizedRandomVectorScorer( - ScalarQuantizedVectorSimilarity similarityFunction, - RandomAccessQuantizedByteVectorValues values, - byte[] query, - float queryOffset) { - super(values); - this.quantizedQuery = query; - this.queryOffset = queryOffset; - this.similarity = similarityFunction; - this.values = values; - } - - @Override - public float score(int node) throws IOException { - byte[] storedVectorValue = values.vectorValue(node); - float storedVectorCorrection = values.getScoreCorrectionConstant(); - return similarity.score( - quantizedQuery, this.queryOffset, storedVectorValue, storedVectorCorrection); - } -} diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorerSupplier.java b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorerSupplier.java deleted file mode 100644 index baf89df326db..000000000000 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorerSupplier.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.lucene.util.quantization; - -import java.io.IOException; -import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.util.hnsw.RandomVectorScorer; -import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; - -/** - * Quantized vector scorer supplier - * - * @lucene.experimental - */ -public class ScalarQuantizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier { - - private final RandomAccessQuantizedByteVectorValues values; - private final ScalarQuantizedVectorSimilarity similarity; - private final VectorSimilarityFunction vectorSimilarityFunction; - - public ScalarQuantizedRandomVectorScorerSupplier( - VectorSimilarityFunction similarityFunction, - ScalarQuantizer scalarQuantizer, - RandomAccessQuantizedByteVectorValues values) { - this.similarity = - ScalarQuantizedVectorSimilarity.fromVectorSimilarity( - similarityFunction, scalarQuantizer.getConstantMultiplier(), scalarQuantizer.getBits()); - this.values = values; - this.vectorSimilarityFunction = similarityFunction; - } - - private ScalarQuantizedRandomVectorScorerSupplier( - ScalarQuantizedVectorSimilarity similarity, - VectorSimilarityFunction vectorSimilarityFunction, - RandomAccessQuantizedByteVectorValues values) { - this.similarity = similarity; - this.values = values; - this.vectorSimilarityFunction = vectorSimilarityFunction; - } - - @Override - public RandomVectorScorer scorer(int ord) throws IOException { - final RandomAccessQuantizedByteVectorValues vectorsCopy = values.copy(); - final byte[] queryVector = values.vectorValue(ord); - final float queryOffset = values.getScoreCorrectionConstant(); - return new ScalarQuantizedRandomVectorScorer(similarity, vectorsCopy, queryVector, queryOffset); - } - - @Override - public RandomVectorScorerSupplier copy() throws IOException { - return new ScalarQuantizedRandomVectorScorerSupplier( - similarity, vectorSimilarityFunction, values.copy()); - } -} diff --git a/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java b/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java index 96dafcf2c1af..9e4476122159 100644 --- a/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java +++ b/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java @@ -19,7 +19,9 @@ import static jdk.incubator.vector.VectorOperators.ADD; import static jdk.incubator.vector.VectorOperators.B2I; import static jdk.incubator.vector.VectorOperators.B2S; +import static jdk.incubator.vector.VectorOperators.LSHR; import static jdk.incubator.vector.VectorOperators.S2I; +import static jdk.incubator.vector.VectorOperators.ZERO_EXTEND_B2S; import jdk.incubator.vector.ByteVector; import jdk.incubator.vector.FloatVector; @@ -390,22 +392,151 @@ private int dotProductBody128(byte[] a, byte[] b, int limit) { } @Override - public int int4DotProduct(byte[] a, byte[] b) { + public int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean bpacked) { + assert (apacked && bpacked) == false; int i = 0; int res = 0; - if (VECTOR_BITSIZE >= 512 || VECTOR_BITSIZE == 256) { - return dotProduct(a, b); - } else if (a.length >= 32 && HAS_FAST_INTEGER_VECTORS) { - i += ByteVector.SPECIES_128.loopBound(a.length); - res += int4DotProductBody128(a, b, i); - } - // scalar tail - for (; i < a.length; i++) { - res += b[i] * a[i]; + if (apacked || bpacked) { + byte[] packed = apacked ? a : b; + byte[] unpacked = apacked ? b : a; + if (packed.length >= 32) { + if (VECTOR_BITSIZE >= 512) { + i += ByteVector.SPECIES_256.loopBound(packed.length); + res += dotProductBody512Int4Packed(unpacked, packed, i); + } else if (VECTOR_BITSIZE == 256) { + i += ByteVector.SPECIES_128.loopBound(packed.length); + res += dotProductBody256Int4Packed(unpacked, packed, i); + } else if (HAS_FAST_INTEGER_VECTORS) { + i += ByteVector.SPECIES_64.loopBound(packed.length); + res += dotProductBody128Int4Packed(unpacked, packed, i); + } + } + // scalar tail + for (; i < packed.length; i++) { + byte packedByte = packed[i]; + byte unpacked1 = unpacked[i]; + byte unpacked2 = unpacked[i + packed.length]; + res += (packedByte & 0x0F) * unpacked2; + res += ((packedByte & 0xFF) >> 4) * unpacked1; + } + } else { + if (VECTOR_BITSIZE >= 512 || VECTOR_BITSIZE == 256) { + return dotProduct(a, b); + } else if (a.length >= 32 && HAS_FAST_INTEGER_VECTORS) { + i += ByteVector.SPECIES_128.loopBound(a.length); + res += int4DotProductBody128(a, b, i); + } + // scalar tail + for (; i < a.length; i++) { + res += b[i] * a[i]; + } } + return res; } + private int dotProductBody512Int4Packed(byte[] unpacked, byte[] packed, int limit) { + int sum = 0; + // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator + for (int i = 0; i < limit; i += 4096) { + ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_512); + ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_512); + int innerLimit = Math.min(limit - i, 4096); + for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_256.length()) { + // packed + var vb8 = ByteVector.fromArray(ByteVector.SPECIES_256, packed, i + j); + // unpacked + var va8 = ByteVector.fromArray(ByteVector.SPECIES_256, unpacked, i + j + packed.length); + + // upper + ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); + Vector prod16 = prod8.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0); + acc0 = acc0.add(prod16); + + // lower + ByteVector vc8 = ByteVector.fromArray(ByteVector.SPECIES_256, unpacked, i + j); + ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8); + Vector prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0); + acc1 = acc1.add(prod16a); + } + IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_512, 0).reinterpretAsInts(); + IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_512, 1).reinterpretAsInts(); + IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_512, 0).reinterpretAsInts(); + IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_512, 1).reinterpretAsInts(); + sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD); + } + return sum; + } + + private int dotProductBody256Int4Packed(byte[] unpacked, byte[] packed, int limit) { + int sum = 0; + // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator + for (int i = 0; i < limit; i += 2048) { + ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_256); + ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_256); + int innerLimit = Math.min(limit - i, 2048); + for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) { + // packed + var vb8 = ByteVector.fromArray(ByteVector.SPECIES_128, packed, i + j); + // unpacked + var va8 = ByteVector.fromArray(ByteVector.SPECIES_128, unpacked, i + j + packed.length); + + // upper + ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); + Vector prod16 = prod8.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0); + acc0 = acc0.add(prod16); + + // lower + ByteVector vc8 = ByteVector.fromArray(ByteVector.SPECIES_128, unpacked, i + j); + ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8); + Vector prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0); + acc1 = acc1.add(prod16a); + } + IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_256, 0).reinterpretAsInts(); + IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_256, 1).reinterpretAsInts(); + IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_256, 0).reinterpretAsInts(); + IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_256, 1).reinterpretAsInts(); + sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD); + } + return sum; + } + + /** vectorized dot product body (128 bit vectors) */ + private int dotProductBody128Int4Packed(byte[] unpacked, byte[] packed, int limit) { + int sum = 0; + // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator + for (int i = 0; i < limit; i += 1024) { + ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_128); + ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_128); + int innerLimit = Math.min(limit - i, 1024); + for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_64.length()) { + // packed + ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, packed, i + j); + // unpacked + ByteVector va8 = + ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j + packed.length); + + // upper + ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); + ShortVector prod16 = + prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); + acc0 = acc0.add(prod16.and((short) 0xFF)); + + // lower + va8 = ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j); + prod8 = vb8.lanewise(LSHR, 4).mul(va8); + prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); + acc1 = acc1.add(prod16.and((short) 0xFF)); + } + IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_128, 0).reinterpretAsInts(); + IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_128, 1).reinterpretAsInts(); + IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_128, 0).reinterpretAsInts(); + IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_128, 1).reinterpretAsInts(); + sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD); + } + return sum; + } + private int int4DotProductBody128(byte[] a, byte[] b, int limit) { int sum = 0; // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java new file mode 100644 index 000000000000..dc6be04e4965 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.codecs.lucene99; + +import java.io.IOException; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.CodecReader; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomVectorScorer; + +public class TestLucene99ScalarQuantizedVectorScorer extends LuceneTestCase { + + private static Codec getCodec(int bits, boolean compress) { + return new Lucene99Codec() { + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { + return new Lucene99HnswScalarQuantizedVectorsFormat( + Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, + Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, + 1, + bits, + compress, + null, + null); + } + }; + } + + public void testScoringCompressedInt4() throws Exception { + vectorScoringTest(4, true); + } + + public void testScoringUncompressedInt4() throws Exception { + vectorScoringTest(4, false); + } + + public void testScoringInt7() throws Exception { + vectorScoringTest(7, random().nextBoolean()); + } + + private void vectorScoringTest(int bits, boolean compress) throws IOException { + float[][] storedVectors = new float[10][]; + int numVectors = 10; + int vectorDimensions = random().nextInt(10) + 4; + if (bits == 4 && vectorDimensions % 2 == 1) { + vectorDimensions++; + } + for (int i = 0; i < numVectors; i++) { + float[] vector = new float[vectorDimensions]; + for (int j = 0; j < vectorDimensions; j++) { + vector[j] = i + j; + } + VectorUtil.l2normalize(vector); + storedVectors[i] = vector; + } + + // create lucene directory with codec + for (VectorSimilarityFunction similarityFunction : + new VectorSimilarityFunction[] { + VectorSimilarityFunction.DOT_PRODUCT, + VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT, + VectorSimilarityFunction.EUCLIDEAN + }) { + try (Directory dir = newDirectory()) { + indexVectors(dir, storedVectors, similarityFunction, bits, compress); + try (DirectoryReader reader = DirectoryReader.open(dir)) { + LeafReader leafReader = reader.leaves().get(0).reader(); + float[] vector = new float[vectorDimensions]; + for (int i = 0; i < vectorDimensions; i++) { + vector[i] = i + 1; + } + VectorUtil.l2normalize(vector); + RandomVectorScorer randomScorer = + getRandomVectorScorer(similarityFunction, leafReader, vector); + float[] rawScores = new float[10]; + for (int i = 0; i < 10; i++) { + rawScores[i] = similarityFunction.compare(vector, storedVectors[i]); + } + for (int i = 0; i < 10; i++) { + assertEquals(similarityFunction.toString(), rawScores[i], randomScorer.score(i), 0.05f); + } + } + } + } + } + + private RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction function, LeafReader leafReader, float[] vector) throws IOException { + if (leafReader instanceof CodecReader) { + KnnVectorsReader format = ((CodecReader) leafReader).getVectorReader(); + if (format instanceof PerFieldKnnVectorsFormat.FieldsReader) { + format = ((PerFieldKnnVectorsFormat.FieldsReader) format).getFieldReader("field"); + } + if (format instanceof Lucene99HnswVectorsReader) { + OffHeapQuantizedByteVectorValues quantizedByteVectorReader = + (OffHeapQuantizedByteVectorValues) + ((Lucene99HnswVectorsReader) format).getQuantizedVectorValues("field"); + return new Lucene99ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer()) + .getRandomVectorScorer(function, quantizedByteVectorReader, vector); + } + } + throw new IllegalArgumentException("Unsupported reader"); + } + + private static void indexVectors( + Directory dir, + float[][] vectors, + VectorSimilarityFunction function, + int bits, + boolean compress) + throws IOException { + try (IndexWriter writer = + new IndexWriter(dir, new IndexWriterConfig().setCodec(getCodec(bits, compress)))) { + for (int i = 0; i < vectors.length; i++) { + Document doc = new Document(); + if (random().nextBoolean()) { + // index a document without a vector + writer.addDocument(doc); + } + writer.addDocument(doc); + doc.add(new KnnFloatVectorField("field", vectors[i], function)); + writer.addDocument(doc); + } + writer.commit(); + writer.forceMerge(1); + } + } +} diff --git a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java index 9fe5ddd0e2b5..7064955cb5f3 100644 --- a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java +++ b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java @@ -27,7 +27,8 @@ public class TestVectorUtilSupport extends BaseVectorizationTestCase { private static final double DELTA = 1e-3; private static final int[] VECTOR_SIZES = { - 1, 4, 6, 8, 13, 16, 25, 32, 64, 100, 128, 207, 256, 300, 512, 702, 1024 + 1, 4, 6, 8, 13, 16, 25, 32, 64, 100, 128, 207, 256, 300, 512, 702, 1024, 1536, 2046, 2048, 4096, + 4098 }; private final int size; @@ -92,6 +93,55 @@ public void testBinaryVectorsBoundaries() { assertFloatReturningProviders(p -> p.cosine(a, b)); } + public void testInt4DotProduct() { + assumeTrue("even sizes only", size % 2 == 0); + var a = new byte[size]; + var b = new byte[size]; + for (int i = 0; i < size; ++i) { + a[i] = (byte) random().nextInt(16); + b[i] = (byte) random().nextInt(16); + } + + assertIntReturningProviders(p -> p.int4DotProduct(a, false, pack(b), true)); + assertIntReturningProviders(p -> p.int4DotProduct(pack(a), true, b, false)); + assertEquals( + LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + PANAMA_PROVIDER.getVectorUtilSupport().int4DotProduct(a, false, pack(b), true)); + } + + public void testInt4DotProductBoundaries() { + assumeTrue("even sizes only", size % 2 == 0); + byte MAX_VALUE = 15; + var a = new byte[size]; + var b = new byte[size]; + + Arrays.fill(a, MAX_VALUE); + Arrays.fill(b, MAX_VALUE); + assertIntReturningProviders(p -> p.int4DotProduct(a, false, pack(b), true)); + assertIntReturningProviders(p -> p.int4DotProduct(pack(a), true, b, false)); + assertEquals( + LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + PANAMA_PROVIDER.getVectorUtilSupport().int4DotProduct(a, false, pack(b), true)); + + byte MIN_VALUE = 0; + Arrays.fill(a, MIN_VALUE); + Arrays.fill(b, MIN_VALUE); + assertIntReturningProviders(p -> p.int4DotProduct(a, false, pack(b), true)); + assertIntReturningProviders(p -> p.int4DotProduct(pack(a), true, b, false)); + assertEquals( + LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + PANAMA_PROVIDER.getVectorUtilSupport().int4DotProduct(a, false, pack(b), true)); + } + + static byte[] pack(byte[] unpacked) { + int len = (unpacked.length + 1) / 2; + var packed = new byte[len]; + for (int i = 0; i < len; i++) { + packed[i] = (byte) (unpacked[i] << 4 | unpacked[packed.length + i]); + } + return packed; + } + private void assertFloatReturningProviders(ToDoubleFunction func) { assertEquals( func.applyAsDouble(LUCENE_PROVIDER.getVectorUtilSupport()),