diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 35a63855ba7c..4bb287acf578 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -97,6 +97,9 @@ Optimizations * GITHUB#13284: Per-field doc values and knn vectors readers now use a HashMap internally instead of a TreeMap. (Adrien Grand) +* GITHUB#13321: Improve compressed int4 quantized vector search by utilizing SIMD inline with the decompression + process. (Ben Trent) + Bug Fixes --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer.java index 9abc1bcab587..a4f339dda4b8 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer.java @@ -19,12 +19,12 @@ import java.io.IOException; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; -import org.apache.lucene.util.quantization.ScalarQuantizedRandomVectorScorer; -import org.apache.lucene.util.quantization.ScalarQuantizedRandomVectorScorerSupplier; import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity; import org.apache.lucene.util.quantization.ScalarQuantizer; @@ -35,6 +35,30 @@ */ public class ScalarQuantizedVectorScorer implements FlatVectorsScorer { + public static float quantizeQuery( + float[] query, + byte[] quantizedQuery, + VectorSimilarityFunction similarityFunction, + ScalarQuantizer scalarQuantizer) { + final float[] processedQuery; + switch (similarityFunction) { + case EUCLIDEAN: + case DOT_PRODUCT: + case MAXIMUM_INNER_PRODUCT: + processedQuery = query; + break; + case COSINE: + float[] queryCopy = ArrayUtil.copyOfSubArray(query, 0, query.length); + VectorUtil.l2normalize(queryCopy); + processedQuery = queryCopy; + break; + default: + throw new IllegalArgumentException( + "Unsupported similarity function: " + similarityFunction); + } + return scalarQuantizer.quantize(processedQuery, quantizedQuery, similarityFunction); + } + private final FlatVectorsScorer nonQuantizedDelegate; public ScalarQuantizedVectorScorer(FlatVectorsScorer flatVectorsScorer) { @@ -69,18 +93,21 @@ public RandomVectorScorer getRandomVectorScorer( ScalarQuantizer scalarQuantizer = quantizedByteVectorValues.getScalarQuantizer(); byte[] targetBytes = new byte[target.length]; float offsetCorrection = - ScalarQuantizedRandomVectorScorer.quantizeQuery( - target, targetBytes, similarityFunction, scalarQuantizer); + quantizeQuery(target, targetBytes, similarityFunction, scalarQuantizer); ScalarQuantizedVectorSimilarity scalarQuantizedVectorSimilarity = ScalarQuantizedVectorSimilarity.fromVectorSimilarity( similarityFunction, scalarQuantizer.getConstantMultiplier(), scalarQuantizer.getBits()); - return new ScalarQuantizedRandomVectorScorer( - scalarQuantizedVectorSimilarity, - quantizedByteVectorValues, - targetBytes, - offsetCorrection); + return new RandomVectorScorer.AbstractRandomVectorScorer(quantizedByteVectorValues) { + @Override + public float score(int node) throws IOException { + byte[] nodeVector = quantizedByteVectorValues.vectorValue(node); + float nodeOffset = quantizedByteVectorValues.getScoreCorrectionConstant(node); + return scalarQuantizedVectorSimilarity.score( + targetBytes, offsetCorrection, nodeVector, nodeOffset); + } + }; } // It is possible to get to this branch during initial indexing and flush return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); @@ -99,4 +126,60 @@ public RandomVectorScorer getRandomVectorScorer( public String toString() { return "ScalarQuantizedVectorScorer(" + "nonQuantizedDelegate=" + nonQuantizedDelegate + ')'; } + + /** + * Quantized vector scorer supplier + * + * @lucene.experimental + */ + public static class ScalarQuantizedRandomVectorScorerSupplier + implements RandomVectorScorerSupplier { + + private final RandomAccessQuantizedByteVectorValues values; + private final ScalarQuantizedVectorSimilarity similarity; + private final VectorSimilarityFunction vectorSimilarityFunction; + + public ScalarQuantizedRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, + ScalarQuantizer scalarQuantizer, + RandomAccessQuantizedByteVectorValues values) { + this.similarity = + ScalarQuantizedVectorSimilarity.fromVectorSimilarity( + similarityFunction, + scalarQuantizer.getConstantMultiplier(), + scalarQuantizer.getBits()); + this.values = values; + this.vectorSimilarityFunction = similarityFunction; + } + + private ScalarQuantizedRandomVectorScorerSupplier( + ScalarQuantizedVectorSimilarity similarity, + VectorSimilarityFunction vectorSimilarityFunction, + RandomAccessQuantizedByteVectorValues values) { + this.similarity = similarity; + this.values = values; + this.vectorSimilarityFunction = vectorSimilarityFunction; + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + final RandomAccessQuantizedByteVectorValues vectorsCopy = values.copy(); + final byte[] queryVector = values.vectorValue(ord); + final float queryOffset = values.getScoreCorrectionConstant(ord); + return new RandomVectorScorer.AbstractRandomVectorScorer(vectorsCopy) { + @Override + public float score(int node) throws IOException { + byte[] nodeVector = vectorsCopy.vectorValue(node); + float nodeOffset = vectorsCopy.getScoreCorrectionConstant(node); + return similarity.score(queryVector, queryOffset, nodeVector, nodeOffset); + } + }; + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new ScalarQuantizedRandomVectorScorerSupplier( + similarity, vectorSimilarityFunction, values.copy()); + } + } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java index da11df1e2518..8d98a9cd1c8a 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java @@ -68,6 +68,11 @@ public byte[] vectorValue(int targetOrd) throws IOException { return binaryValue; } + @Override + public IndexInput getSlice() { + return slice; + } + private void readValue(int targetOrd) throws IOException { slice.seek((long) targetOrd * byteSize); slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java index a13c9f55bb12..0aeddaf15362 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java @@ -56,6 +56,11 @@ public int size() { return size; } + @Override + public IndexInput getSlice() { + return slice; + } + @Override public float[] vectorValue(int targetOrd) throws IOException { if (lastOrd == targetOrd) { diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java new file mode 100644 index 000000000000..b10a3730b6a8 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.codecs.lucene99; + +import static org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer.quantizeQuery; + +import java.io.IOException; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; +import org.apache.lucene.util.quantization.ScalarQuantizer; + +/** + * Optimized scalar quantized implementation of {@link FlatVectorsScorer} for quantized vectors + * stored in the Lucene99 format. + * + * @lucene.experimental + */ +public class Lucene99ScalarQuantizedVectorScorer implements FlatVectorsScorer { + + private final FlatVectorsScorer nonQuantizedDelegate; + + public Lucene99ScalarQuantizedVectorScorer(FlatVectorsScorer flatVectorsScorer) { + nonQuantizedDelegate = flatVectorsScorer; + } + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues) + throws IOException { + if (vectorValues instanceof RandomAccessQuantizedByteVectorValues) { + return new ScalarQuantizedRandomVectorScorerSupplier( + (RandomAccessQuantizedByteVectorValues) vectorValues, similarityFunction); + } + // It is possible to get to this branch during initial indexing and flush + return nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + RandomAccessVectorValues vectorValues, + float[] target) + throws IOException { + if (vectorValues instanceof RandomAccessQuantizedByteVectorValues) { + RandomAccessQuantizedByteVectorValues quantizedByteVectorValues = + (RandomAccessQuantizedByteVectorValues) vectorValues; + ScalarQuantizer scalarQuantizer = quantizedByteVectorValues.getScalarQuantizer(); + byte[] targetBytes = new byte[target.length]; + float offsetCorrection = + quantizeQuery(target, targetBytes, similarityFunction, scalarQuantizer); + return fromVectorSimilarity( + targetBytes, + offsetCorrection, + similarityFunction, + scalarQuantizer.getConstantMultiplier(), + quantizedByteVectorValues); + } + // It is possible to get to this branch during initial indexing and flush + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + RandomAccessVectorValues vectorValues, + byte[] target) + throws IOException { + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + @Override + public String toString() { + return "ScalarQuantizedVectorScorer(" + "nonQuantizedDelegate=" + nonQuantizedDelegate + ')'; + } + + static RandomVectorScorer fromVectorSimilarity( + byte[] targetBytes, + float offsetCorrection, + VectorSimilarityFunction sim, + float constMultiplier, + RandomAccessQuantizedByteVectorValues values) { + switch (sim) { + case EUCLIDEAN: + return new Euclidean(values, constMultiplier, targetBytes); + case COSINE: + case DOT_PRODUCT: + return dotProductFactory( + targetBytes, offsetCorrection, sim, constMultiplier, values, f -> (1 + f) / 2); + case MAXIMUM_INNER_PRODUCT: + return dotProductFactory( + targetBytes, + offsetCorrection, + sim, + constMultiplier, + values, + VectorUtil::scaleMaxInnerProductScore); + default: + throw new IllegalArgumentException("Unsupported similarity function: " + sim); + } + } + + private static RandomVectorScorer.AbstractRandomVectorScorer dotProductFactory( + byte[] targetBytes, + float offsetCorrection, + VectorSimilarityFunction sim, + float constMultiplier, + RandomAccessQuantizedByteVectorValues values, + FloatToFloatFunction scoreAdjustmentFunction) { + if (values.getScalarQuantizer().getBits() <= 4) { + if (values.getVectorByteLength() != values.dimension() && values.getSlice() != null) { + return new CompressedInt4DotProduct( + values, constMultiplier, targetBytes, offsetCorrection, scoreAdjustmentFunction); + } + return new Int4DotProduct( + values, constMultiplier, targetBytes, offsetCorrection, scoreAdjustmentFunction); + } + return new DotProduct( + values, constMultiplier, targetBytes, offsetCorrection, scoreAdjustmentFunction); + } + + private static class Euclidean extends RandomVectorScorer.AbstractRandomVectorScorer { + private final float constMultiplier; + private final byte[] targetBytes; + private final RandomAccessQuantizedByteVectorValues values; + + private Euclidean( + RandomAccessQuantizedByteVectorValues values, float constMultiplier, byte[] targetBytes) { + super(values); + this.values = values; + this.constMultiplier = constMultiplier; + this.targetBytes = targetBytes; + } + + @Override + public float score(int node) throws IOException { + byte[] nodeVector = values.vectorValue(node); + int squareDistance = VectorUtil.squareDistance(nodeVector, targetBytes); + float adjustedDistance = squareDistance * constMultiplier; + return 1 / (1f + adjustedDistance); + } + } + + /** Calculates dot product on quantized vectors, applying the appropriate corrections */ + private static class DotProduct extends RandomVectorScorer.AbstractRandomVectorScorer { + private final float constMultiplier; + private final RandomAccessQuantizedByteVectorValues values; + private final byte[] targetBytes; + private final float offsetCorrection; + private final FloatToFloatFunction scoreAdjustmentFunction; + + public DotProduct( + RandomAccessQuantizedByteVectorValues values, + float constMultiplier, + byte[] targetBytes, + float offsetCorrection, + FloatToFloatFunction scoreAdjustmentFunction) { + super(values); + this.constMultiplier = constMultiplier; + this.values = values; + this.targetBytes = targetBytes; + this.offsetCorrection = offsetCorrection; + this.scoreAdjustmentFunction = scoreAdjustmentFunction; + } + + @Override + public float score(int vectorOrdinal) throws IOException { + byte[] storedVector = values.vectorValue(vectorOrdinal); + float vectorOffset = values.getScoreCorrectionConstant(vectorOrdinal); + int dotProduct = VectorUtil.dotProduct(storedVector, targetBytes); + float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset; + return scoreAdjustmentFunction.apply(adjustedDistance); + } + } + + private static class CompressedInt4DotProduct + extends RandomVectorScorer.AbstractRandomVectorScorer { + private final float constMultiplier; + private final RandomAccessQuantizedByteVectorValues values; + private final byte[] compressedVector; + private final byte[] targetBytes; + private final float offsetCorrection; + private final FloatToFloatFunction scoreAdjustmentFunction; + + private CompressedInt4DotProduct( + RandomAccessQuantizedByteVectorValues values, + float constMultiplier, + byte[] targetBytes, + float offsetCorrection, + FloatToFloatFunction scoreAdjustmentFunction) { + super(values); + this.constMultiplier = constMultiplier; + this.values = values; + this.compressedVector = new byte[values.getVectorByteLength()]; + this.targetBytes = targetBytes; + this.offsetCorrection = offsetCorrection; + this.scoreAdjustmentFunction = scoreAdjustmentFunction; + } + + @Override + public float score(int vectorOrdinal) throws IOException { + // get compressed vector, in Lucene99, vector values are stored and have a single value for + // offset correction + values.getSlice().seek((long) vectorOrdinal * (values.getVectorByteLength() + Float.BYTES)); + values.getSlice().readBytes(compressedVector, 0, compressedVector.length); + float vectorOffset = values.getScoreCorrectionConstant(vectorOrdinal); + int dotProduct = VectorUtil.int4DotProductPacked(targetBytes, compressedVector); + float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset; + return scoreAdjustmentFunction.apply(adjustedDistance); + } + } + + private static class Int4DotProduct extends RandomVectorScorer.AbstractRandomVectorScorer { + private final float constMultiplier; + private final RandomAccessQuantizedByteVectorValues values; + private final byte[] targetBytes; + private final float offsetCorrection; + private final FloatToFloatFunction scoreAdjustmentFunction; + + public Int4DotProduct( + RandomAccessQuantizedByteVectorValues values, + float constMultiplier, + byte[] targetBytes, + float offsetCorrection, + FloatToFloatFunction scoreAdjustmentFunction) { + super(values); + this.constMultiplier = constMultiplier; + this.values = values; + this.targetBytes = targetBytes; + this.offsetCorrection = offsetCorrection; + this.scoreAdjustmentFunction = scoreAdjustmentFunction; + } + + @Override + public float score(int vectorOrdinal) throws IOException { + byte[] storedVector = values.vectorValue(vectorOrdinal); + float vectorOffset = values.getScoreCorrectionConstant(vectorOrdinal); + int dotProduct = VectorUtil.int4DotProduct(storedVector, targetBytes); + float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset; + return scoreAdjustmentFunction.apply(adjustedDistance); + } + } + + @FunctionalInterface + private interface FloatToFloatFunction { + float apply(float f); + } + + private static final class ScalarQuantizedRandomVectorScorerSupplier + implements RandomVectorScorerSupplier { + + private final VectorSimilarityFunction vectorSimilarityFunction; + private final RandomAccessQuantizedByteVectorValues values; + private final RandomAccessQuantizedByteVectorValues values1; + private final RandomAccessQuantizedByteVectorValues values2; + + public ScalarQuantizedRandomVectorScorerSupplier( + RandomAccessQuantizedByteVectorValues values, + VectorSimilarityFunction vectorSimilarityFunction) + throws IOException { + this.values = values; + this.values1 = values.copy(); + this.values2 = values.copy(); + this.vectorSimilarityFunction = vectorSimilarityFunction; + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + byte[] vectorValue = values1.vectorValue(ord); + float offsetCorrection = values1.getScoreCorrectionConstant(ord); + return fromVectorSimilarity( + vectorValue, + offsetCorrection, + vectorSimilarityFunction, + values.getScalarQuantizer().getConstantMultiplier(), + values2); + } + + @Override + public ScalarQuantizedRandomVectorScorerSupplier copy() throws IOException { + return new ScalarQuantizedRandomVectorScorerSupplier(values.copy(), vectorSimilarityFunction); + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java index a3d894e64e99..c10f87da2a65 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java @@ -22,7 +22,6 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; -import org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; @@ -65,7 +64,7 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat { final byte bits; final boolean compress; - final ScalarQuantizedVectorScorer flatVectorScorer; + final Lucene99ScalarQuantizedVectorScorer flatVectorScorer; /** Constructs a format using default graph construction parameters */ public Lucene99ScalarQuantizedVectorsFormat() { @@ -102,7 +101,7 @@ public Lucene99ScalarQuantizedVectorsFormat( this.bits = (byte) bits; this.confidenceInterval = confidenceInterval; this.compress = compress; - this.flatVectorScorer = new ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer()); + this.flatVectorScorer = new Lucene99ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer()); } public static float calculateDefaultConfidenceInterval(int vectorDimension) { diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java index 08e666d51ff8..9659eb131872 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java @@ -135,6 +135,26 @@ public float getScoreCorrectionConstant() { return scoreCorrectionConstant[0]; } + @Override + public float getScoreCorrectionConstant(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return scoreCorrectionConstant[0]; + } + slice.seek(((long) targetOrd * byteSize) + numBytes); + slice.readFloats(scoreCorrectionConstant, 0, 1); + return scoreCorrectionConstant[0]; + } + + @Override + public IndexInput getSlice() { + return slice; + } + + @Override + public int getVectorByteLength() { + return numBytes; + } + public static OffHeapQuantizedByteVectorValues load( OrdToDocDISIReaderConfiguration configuration, int dimension, diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java index e56d6b97f314..eb5160a0f0dd 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java @@ -152,7 +152,21 @@ public int dotProduct(byte[] a, byte[] b) { } @Override - public int int4DotProduct(byte[] a, byte[] b) { + public int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean bpacked) { + assert (apacked && bpacked) == false; + if (apacked || bpacked) { + byte[] packed = apacked ? a : b; + byte[] unpacked = apacked ? b : a; + int total = 0; + for (int i = 0; i < packed.length; i++) { + byte packedByte = packed[i]; + byte unpacked1 = unpacked[i]; + byte unpacked2 = unpacked[i + packed.length]; + total += (packedByte & 0x0F) * unpacked2; + total += ((packedByte & 0xFF) >> 4) * unpacked1; + } + return total; + } return dotProduct(a, b); } diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java index 246cbdf95bc7..22e5e96aa256 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java @@ -37,7 +37,7 @@ public interface VectorUtilSupport { int dotProduct(byte[] a, byte[] b); /** Returns the dot product over the computed bytes, assuming the values are int4 encoded. */ - int int4DotProduct(byte[] a, byte[] b); + int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean bpacked); /** Returns the cosine similarity between the two byte vectors. */ float cosine(byte[] a, byte[] b); diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java index 7409a1de4cff..a9acca64d69a 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java @@ -179,7 +179,30 @@ public static int int4DotProduct(byte[] a, byte[] b) { if (a.length != b.length) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - return IMPL.int4DotProduct(a, b); + return IMPL.int4DotProduct(a, false, b, false); + } + + /** + * Dot product computed over int4 (values between [0,15]) bytes. The second vector is considered + * "packed" (i.e. every byte representing two values). The following packing is assumed: + * + *
+   *   packed[0] = (raw[0] * 16) | raw[packed.length];
+   *   packed[1] = (raw[1] * 16) | raw[packed.length + 1];
+   *   ...
+   *   packed[packed.length - 1] = (raw[packed.length - 1] * 16) | raw[2 * packed.length - 1];
+   * 
+ * + * @param unpacked the unpacked vector, of even length + * @param packed the packed vector, of length {@code (unpacked.length + 1) / 2} + * @return the value of the dot product of the two vectors + */ + public static int int4DotProductPacked(byte[] unpacked, byte[] packed) { + if (packed.length != ((unpacked.length + 1) >> 1)) { + throw new IllegalArgumentException( + "vector dimensions differ: " + unpacked.length + "!= 2 * " + packed.length); + } + return IMPL.int4DotProduct(unpacked, false, packed, true); } /** diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java index ecf5339cd21a..e2c7372b667a 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.List; +import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.Bits; /** @@ -41,6 +42,17 @@ public interface RandomAccessVectorValues { */ RandomAccessVectorValues copy() throws IOException; + /** + * Returns a slice of the underlying {@link IndexInput} that contains the vector values if + * available + */ + default IndexInput getSlice() { + return null; + } + + /** Returns the byte length of the vector values. */ + int getVectorByteLength(); + /** * Translates vector ordinal to the correct document ID. By default, this is an identity function. * @@ -72,6 +84,12 @@ interface Floats extends RandomAccessVectorValues { * @param targetOrd a valid ordinal, ≥ 0 and < {@link #size()}. */ float[] vectorValue(int targetOrd) throws IOException; + + /** Returns the vector byte length, defaults to dimension multiplied by float byte size */ + @Override + default int getVectorByteLength() { + return dimension() * Float.BYTES; + } } /** Byte vector values. */ @@ -85,6 +103,12 @@ interface Bytes extends RandomAccessVectorValues { * @param targetOrd a valid ordinal, ≥ 0 and < {@link #size()}. */ byte[] vectorValue(int targetOrd) throws IOException; + + /** Returns the vector byte length, defaults to dimension multiplied by byte size */ + @Override + default int getVectorByteLength() { + return dimension() * Byte.BYTES; + } } /** @@ -107,12 +131,12 @@ public int dimension() { } @Override - public float[] vectorValue(int targetOrd) throws IOException { + public float[] vectorValue(int targetOrd) { return vectors.get(targetOrd); } @Override - public RandomAccessVectorValues.Floats copy() throws IOException { + public RandomAccessVectorValues.Floats copy() { return this; } }; @@ -138,12 +162,12 @@ public int dimension() { } @Override - public byte[] vectorValue(int targetOrd) throws IOException { + public byte[] vectorValue(int targetOrd) { return vectors.get(targetOrd); } @Override - public RandomAccessVectorValues.Bytes copy() throws IOException { + public RandomAccessVectorValues.Bytes copy() { return this; } }; diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/RandomAccessQuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/util/quantization/RandomAccessQuantizedByteVectorValues.java index 08b0b6e5a7ae..b86009a690e1 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/RandomAccessQuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/RandomAccessQuantizedByteVectorValues.java @@ -29,7 +29,7 @@ public interface RandomAccessQuantizedByteVectorValues extends RandomAccessVecto ScalarQuantizer getScalarQuantizer(); - float getScoreCorrectionConstant(); + float getScoreCorrectionConstant(int vectorOrd) throws IOException; @Override RandomAccessQuantizedByteVectorValues copy() throws IOException; diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorer.java b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorer.java deleted file mode 100644 index a88534e0b043..000000000000 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorer.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.lucene.util.quantization; - -import java.io.IOException; -import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.util.ArrayUtil; -import org.apache.lucene.util.VectorUtil; -import org.apache.lucene.util.hnsw.RandomVectorScorer; - -/** - * Quantized vector scorer - * - * @lucene.experimental - */ -public class ScalarQuantizedRandomVectorScorer - extends RandomVectorScorer.AbstractRandomVectorScorer { - - public static float quantizeQuery( - float[] query, - byte[] quantizedQuery, - VectorSimilarityFunction similarityFunction, - ScalarQuantizer scalarQuantizer) { - float[] processedQuery = query; - if (similarityFunction.equals(VectorSimilarityFunction.COSINE)) { - float[] queryCopy = ArrayUtil.copyOfSubArray(query, 0, query.length); - VectorUtil.l2normalize(queryCopy); - processedQuery = queryCopy; - } - return scalarQuantizer.quantize(processedQuery, quantizedQuery, similarityFunction); - } - - private final byte[] quantizedQuery; - private final float queryOffset; - private final RandomAccessQuantizedByteVectorValues values; - private final ScalarQuantizedVectorSimilarity similarity; - - public ScalarQuantizedRandomVectorScorer( - ScalarQuantizedVectorSimilarity similarityFunction, - RandomAccessQuantizedByteVectorValues values, - byte[] query, - float queryOffset) { - super(values); - this.quantizedQuery = query; - this.queryOffset = queryOffset; - this.similarity = similarityFunction; - this.values = values; - } - - @Override - public float score(int node) throws IOException { - byte[] storedVectorValue = values.vectorValue(node); - float storedVectorCorrection = values.getScoreCorrectionConstant(); - return similarity.score( - quantizedQuery, this.queryOffset, storedVectorValue, storedVectorCorrection); - } -} diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorerSupplier.java b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorerSupplier.java deleted file mode 100644 index baf89df326db..000000000000 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorerSupplier.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.lucene.util.quantization; - -import java.io.IOException; -import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.util.hnsw.RandomVectorScorer; -import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; - -/** - * Quantized vector scorer supplier - * - * @lucene.experimental - */ -public class ScalarQuantizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier { - - private final RandomAccessQuantizedByteVectorValues values; - private final ScalarQuantizedVectorSimilarity similarity; - private final VectorSimilarityFunction vectorSimilarityFunction; - - public ScalarQuantizedRandomVectorScorerSupplier( - VectorSimilarityFunction similarityFunction, - ScalarQuantizer scalarQuantizer, - RandomAccessQuantizedByteVectorValues values) { - this.similarity = - ScalarQuantizedVectorSimilarity.fromVectorSimilarity( - similarityFunction, scalarQuantizer.getConstantMultiplier(), scalarQuantizer.getBits()); - this.values = values; - this.vectorSimilarityFunction = similarityFunction; - } - - private ScalarQuantizedRandomVectorScorerSupplier( - ScalarQuantizedVectorSimilarity similarity, - VectorSimilarityFunction vectorSimilarityFunction, - RandomAccessQuantizedByteVectorValues values) { - this.similarity = similarity; - this.values = values; - this.vectorSimilarityFunction = vectorSimilarityFunction; - } - - @Override - public RandomVectorScorer scorer(int ord) throws IOException { - final RandomAccessQuantizedByteVectorValues vectorsCopy = values.copy(); - final byte[] queryVector = values.vectorValue(ord); - final float queryOffset = values.getScoreCorrectionConstant(); - return new ScalarQuantizedRandomVectorScorer(similarity, vectorsCopy, queryVector, queryOffset); - } - - @Override - public RandomVectorScorerSupplier copy() throws IOException { - return new ScalarQuantizedRandomVectorScorerSupplier( - similarity, vectorSimilarityFunction, values.copy()); - } -} diff --git a/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java b/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java index 96dafcf2c1af..9e4476122159 100644 --- a/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java +++ b/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java @@ -19,7 +19,9 @@ import static jdk.incubator.vector.VectorOperators.ADD; import static jdk.incubator.vector.VectorOperators.B2I; import static jdk.incubator.vector.VectorOperators.B2S; +import static jdk.incubator.vector.VectorOperators.LSHR; import static jdk.incubator.vector.VectorOperators.S2I; +import static jdk.incubator.vector.VectorOperators.ZERO_EXTEND_B2S; import jdk.incubator.vector.ByteVector; import jdk.incubator.vector.FloatVector; @@ -390,22 +392,151 @@ private int dotProductBody128(byte[] a, byte[] b, int limit) { } @Override - public int int4DotProduct(byte[] a, byte[] b) { + public int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean bpacked) { + assert (apacked && bpacked) == false; int i = 0; int res = 0; - if (VECTOR_BITSIZE >= 512 || VECTOR_BITSIZE == 256) { - return dotProduct(a, b); - } else if (a.length >= 32 && HAS_FAST_INTEGER_VECTORS) { - i += ByteVector.SPECIES_128.loopBound(a.length); - res += int4DotProductBody128(a, b, i); - } - // scalar tail - for (; i < a.length; i++) { - res += b[i] * a[i]; + if (apacked || bpacked) { + byte[] packed = apacked ? a : b; + byte[] unpacked = apacked ? b : a; + if (packed.length >= 32) { + if (VECTOR_BITSIZE >= 512) { + i += ByteVector.SPECIES_256.loopBound(packed.length); + res += dotProductBody512Int4Packed(unpacked, packed, i); + } else if (VECTOR_BITSIZE == 256) { + i += ByteVector.SPECIES_128.loopBound(packed.length); + res += dotProductBody256Int4Packed(unpacked, packed, i); + } else if (HAS_FAST_INTEGER_VECTORS) { + i += ByteVector.SPECIES_64.loopBound(packed.length); + res += dotProductBody128Int4Packed(unpacked, packed, i); + } + } + // scalar tail + for (; i < packed.length; i++) { + byte packedByte = packed[i]; + byte unpacked1 = unpacked[i]; + byte unpacked2 = unpacked[i + packed.length]; + res += (packedByte & 0x0F) * unpacked2; + res += ((packedByte & 0xFF) >> 4) * unpacked1; + } + } else { + if (VECTOR_BITSIZE >= 512 || VECTOR_BITSIZE == 256) { + return dotProduct(a, b); + } else if (a.length >= 32 && HAS_FAST_INTEGER_VECTORS) { + i += ByteVector.SPECIES_128.loopBound(a.length); + res += int4DotProductBody128(a, b, i); + } + // scalar tail + for (; i < a.length; i++) { + res += b[i] * a[i]; + } } + return res; } + private int dotProductBody512Int4Packed(byte[] unpacked, byte[] packed, int limit) { + int sum = 0; + // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator + for (int i = 0; i < limit; i += 4096) { + ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_512); + ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_512); + int innerLimit = Math.min(limit - i, 4096); + for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_256.length()) { + // packed + var vb8 = ByteVector.fromArray(ByteVector.SPECIES_256, packed, i + j); + // unpacked + var va8 = ByteVector.fromArray(ByteVector.SPECIES_256, unpacked, i + j + packed.length); + + // upper + ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); + Vector prod16 = prod8.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0); + acc0 = acc0.add(prod16); + + // lower + ByteVector vc8 = ByteVector.fromArray(ByteVector.SPECIES_256, unpacked, i + j); + ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8); + Vector prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0); + acc1 = acc1.add(prod16a); + } + IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_512, 0).reinterpretAsInts(); + IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_512, 1).reinterpretAsInts(); + IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_512, 0).reinterpretAsInts(); + IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_512, 1).reinterpretAsInts(); + sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD); + } + return sum; + } + + private int dotProductBody256Int4Packed(byte[] unpacked, byte[] packed, int limit) { + int sum = 0; + // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator + for (int i = 0; i < limit; i += 2048) { + ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_256); + ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_256); + int innerLimit = Math.min(limit - i, 2048); + for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) { + // packed + var vb8 = ByteVector.fromArray(ByteVector.SPECIES_128, packed, i + j); + // unpacked + var va8 = ByteVector.fromArray(ByteVector.SPECIES_128, unpacked, i + j + packed.length); + + // upper + ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); + Vector prod16 = prod8.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0); + acc0 = acc0.add(prod16); + + // lower + ByteVector vc8 = ByteVector.fromArray(ByteVector.SPECIES_128, unpacked, i + j); + ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8); + Vector prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0); + acc1 = acc1.add(prod16a); + } + IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_256, 0).reinterpretAsInts(); + IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_256, 1).reinterpretAsInts(); + IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_256, 0).reinterpretAsInts(); + IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_256, 1).reinterpretAsInts(); + sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD); + } + return sum; + } + + /** vectorized dot product body (128 bit vectors) */ + private int dotProductBody128Int4Packed(byte[] unpacked, byte[] packed, int limit) { + int sum = 0; + // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator + for (int i = 0; i < limit; i += 1024) { + ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_128); + ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_128); + int innerLimit = Math.min(limit - i, 1024); + for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_64.length()) { + // packed + ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, packed, i + j); + // unpacked + ByteVector va8 = + ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j + packed.length); + + // upper + ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); + ShortVector prod16 = + prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); + acc0 = acc0.add(prod16.and((short) 0xFF)); + + // lower + va8 = ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j); + prod8 = vb8.lanewise(LSHR, 4).mul(va8); + prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); + acc1 = acc1.add(prod16.and((short) 0xFF)); + } + IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_128, 0).reinterpretAsInts(); + IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_128, 1).reinterpretAsInts(); + IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_128, 0).reinterpretAsInts(); + IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_128, 1).reinterpretAsInts(); + sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD); + } + return sum; + } + private int int4DotProductBody128(byte[] a, byte[] b, int limit) { int sum = 0; // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java new file mode 100644 index 000000000000..dc6be04e4965 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.codecs.lucene99; + +import java.io.IOException; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.CodecReader; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomVectorScorer; + +public class TestLucene99ScalarQuantizedVectorScorer extends LuceneTestCase { + + private static Codec getCodec(int bits, boolean compress) { + return new Lucene99Codec() { + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { + return new Lucene99HnswScalarQuantizedVectorsFormat( + Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, + Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, + 1, + bits, + compress, + null, + null); + } + }; + } + + public void testScoringCompressedInt4() throws Exception { + vectorScoringTest(4, true); + } + + public void testScoringUncompressedInt4() throws Exception { + vectorScoringTest(4, false); + } + + public void testScoringInt7() throws Exception { + vectorScoringTest(7, random().nextBoolean()); + } + + private void vectorScoringTest(int bits, boolean compress) throws IOException { + float[][] storedVectors = new float[10][]; + int numVectors = 10; + int vectorDimensions = random().nextInt(10) + 4; + if (bits == 4 && vectorDimensions % 2 == 1) { + vectorDimensions++; + } + for (int i = 0; i < numVectors; i++) { + float[] vector = new float[vectorDimensions]; + for (int j = 0; j < vectorDimensions; j++) { + vector[j] = i + j; + } + VectorUtil.l2normalize(vector); + storedVectors[i] = vector; + } + + // create lucene directory with codec + for (VectorSimilarityFunction similarityFunction : + new VectorSimilarityFunction[] { + VectorSimilarityFunction.DOT_PRODUCT, + VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT, + VectorSimilarityFunction.EUCLIDEAN + }) { + try (Directory dir = newDirectory()) { + indexVectors(dir, storedVectors, similarityFunction, bits, compress); + try (DirectoryReader reader = DirectoryReader.open(dir)) { + LeafReader leafReader = reader.leaves().get(0).reader(); + float[] vector = new float[vectorDimensions]; + for (int i = 0; i < vectorDimensions; i++) { + vector[i] = i + 1; + } + VectorUtil.l2normalize(vector); + RandomVectorScorer randomScorer = + getRandomVectorScorer(similarityFunction, leafReader, vector); + float[] rawScores = new float[10]; + for (int i = 0; i < 10; i++) { + rawScores[i] = similarityFunction.compare(vector, storedVectors[i]); + } + for (int i = 0; i < 10; i++) { + assertEquals(similarityFunction.toString(), rawScores[i], randomScorer.score(i), 0.05f); + } + } + } + } + } + + private RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction function, LeafReader leafReader, float[] vector) throws IOException { + if (leafReader instanceof CodecReader) { + KnnVectorsReader format = ((CodecReader) leafReader).getVectorReader(); + if (format instanceof PerFieldKnnVectorsFormat.FieldsReader) { + format = ((PerFieldKnnVectorsFormat.FieldsReader) format).getFieldReader("field"); + } + if (format instanceof Lucene99HnswVectorsReader) { + OffHeapQuantizedByteVectorValues quantizedByteVectorReader = + (OffHeapQuantizedByteVectorValues) + ((Lucene99HnswVectorsReader) format).getQuantizedVectorValues("field"); + return new Lucene99ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer()) + .getRandomVectorScorer(function, quantizedByteVectorReader, vector); + } + } + throw new IllegalArgumentException("Unsupported reader"); + } + + private static void indexVectors( + Directory dir, + float[][] vectors, + VectorSimilarityFunction function, + int bits, + boolean compress) + throws IOException { + try (IndexWriter writer = + new IndexWriter(dir, new IndexWriterConfig().setCodec(getCodec(bits, compress)))) { + for (int i = 0; i < vectors.length; i++) { + Document doc = new Document(); + if (random().nextBoolean()) { + // index a document without a vector + writer.addDocument(doc); + } + writer.addDocument(doc); + doc.add(new KnnFloatVectorField("field", vectors[i], function)); + writer.addDocument(doc); + } + writer.commit(); + writer.forceMerge(1); + } + } +} diff --git a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java index 9fe5ddd0e2b5..7064955cb5f3 100644 --- a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java +++ b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java @@ -27,7 +27,8 @@ public class TestVectorUtilSupport extends BaseVectorizationTestCase { private static final double DELTA = 1e-3; private static final int[] VECTOR_SIZES = { - 1, 4, 6, 8, 13, 16, 25, 32, 64, 100, 128, 207, 256, 300, 512, 702, 1024 + 1, 4, 6, 8, 13, 16, 25, 32, 64, 100, 128, 207, 256, 300, 512, 702, 1024, 1536, 2046, 2048, 4096, + 4098 }; private final int size; @@ -92,6 +93,55 @@ public void testBinaryVectorsBoundaries() { assertFloatReturningProviders(p -> p.cosine(a, b)); } + public void testInt4DotProduct() { + assumeTrue("even sizes only", size % 2 == 0); + var a = new byte[size]; + var b = new byte[size]; + for (int i = 0; i < size; ++i) { + a[i] = (byte) random().nextInt(16); + b[i] = (byte) random().nextInt(16); + } + + assertIntReturningProviders(p -> p.int4DotProduct(a, false, pack(b), true)); + assertIntReturningProviders(p -> p.int4DotProduct(pack(a), true, b, false)); + assertEquals( + LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + PANAMA_PROVIDER.getVectorUtilSupport().int4DotProduct(a, false, pack(b), true)); + } + + public void testInt4DotProductBoundaries() { + assumeTrue("even sizes only", size % 2 == 0); + byte MAX_VALUE = 15; + var a = new byte[size]; + var b = new byte[size]; + + Arrays.fill(a, MAX_VALUE); + Arrays.fill(b, MAX_VALUE); + assertIntReturningProviders(p -> p.int4DotProduct(a, false, pack(b), true)); + assertIntReturningProviders(p -> p.int4DotProduct(pack(a), true, b, false)); + assertEquals( + LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + PANAMA_PROVIDER.getVectorUtilSupport().int4DotProduct(a, false, pack(b), true)); + + byte MIN_VALUE = 0; + Arrays.fill(a, MIN_VALUE); + Arrays.fill(b, MIN_VALUE); + assertIntReturningProviders(p -> p.int4DotProduct(a, false, pack(b), true)); + assertIntReturningProviders(p -> p.int4DotProduct(pack(a), true, b, false)); + assertEquals( + LUCENE_PROVIDER.getVectorUtilSupport().dotProduct(a, b), + PANAMA_PROVIDER.getVectorUtilSupport().int4DotProduct(a, false, pack(b), true)); + } + + static byte[] pack(byte[] unpacked) { + int len = (unpacked.length + 1) / 2; + var packed = new byte[len]; + for (int i = 0; i < len; i++) { + packed[i] = (byte) (unpacked[i] << 4 | unpacked[packed.length + i]); + } + return packed; + } + private void assertFloatReturningProviders(ToDoubleFunction func) { assertEquals( func.applyAsDouble(LUCENE_PROVIDER.getVectorUtilSupport()),