Skip to content

Commit

Permalink
in KnnVectorsWriter reduce code duplication w.r.t. MergedVectorValues…
Browse files Browse the repository at this point in the history
….merge(Float|Byte)VectorValues (#13539)

Co-authored-by: Vigya Sharma <[email protected]>
  • Loading branch information
cpoerschke and vigyasharma authored Jul 12, 2024
1 parent cc14555 commit c55d664
Showing 1 changed file with 56 additions and 34 deletions.
90 changes: 56 additions & 34 deletions lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.function.BiFunction;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocIDMerger;
import org.apache.lucene.index.DocsWithFieldSet;
Expand All @@ -35,6 +36,7 @@
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.IOFunction;

/** Writes vectors to an index. */
public abstract class KnnVectorsWriter implements Accountable, Closeable {
Expand Down Expand Up @@ -111,11 +113,11 @@ public final void merge(MergeState mergeState) throws IOException {
}

/** Tracks state of one sub-reader that we are merging */
private static class VectorValuesSub extends DocIDMerger.Sub {
private static class FloatVectorValuesSub extends DocIDMerger.Sub {

final FloatVectorValues values;

VectorValuesSub(MergeState.DocMap docMap, FloatVectorValues values) {
FloatVectorValuesSub(MergeState.DocMap docMap, FloatVectorValues values) {
super(docMap);
this.values = values;
assert values.docID() == -1;
Expand Down Expand Up @@ -201,61 +203,81 @@ public static void mapOldOrdToNewOrd(
public static final class MergedVectorValues {
private MergedVectorValues() {}

/** Returns a merged view over all the segment's {@link FloatVectorValues}. */
public static FloatVectorValues mergeFloatVectorValues(
FieldInfo fieldInfo, MergeState mergeState) throws IOException {
private static void validateFieldEncoding(FieldInfo fieldInfo, VectorEncoding expected) {
assert fieldInfo != null && fieldInfo.hasVectorValues();
if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) {
VectorEncoding fieldEncoding = fieldInfo.getVectorEncoding();
if (fieldEncoding != expected) {
throw new UnsupportedOperationException(
"Cannot merge vectors encoded as [" + fieldInfo.getVectorEncoding() + "] as FLOAT32");
"Cannot merge vectors encoded as [" + fieldEncoding + "] as " + expected);
}
List<VectorValuesSub> subs = new ArrayList<>();
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
}

private static <V, S> List<S> mergeVectorValues(
KnnVectorsReader[] knnVectorsReaders,
MergeState.DocMap[] docMaps,
IOFunction<KnnVectorsReader, V> valuesSupplier,
BiFunction<MergeState.DocMap, V, S> newSub)
throws IOException {
List<S> subs = new ArrayList<>();
for (int i = 0; i < knnVectorsReaders.length; i++) {
KnnVectorsReader knnVectorsReader = knnVectorsReaders[i];
if (knnVectorsReader != null) {
FloatVectorValues values = knnVectorsReader.getFloatVectorValues(fieldInfo.name);
V values = valuesSupplier.apply(knnVectorsReader);
if (values != null) {
subs.add(new VectorValuesSub(mergeState.docMaps[i], values));
subs.add(newSub.apply(docMaps[i], values));
}
}
}
return new MergedFloat32VectorValues(subs, mergeState);
return subs;
}

/** Returns a merged view over all the segment's {@link FloatVectorValues}. */
public static FloatVectorValues mergeFloatVectorValues(
FieldInfo fieldInfo, MergeState mergeState) throws IOException {
validateFieldEncoding(fieldInfo, VectorEncoding.FLOAT32);
return new MergedFloat32VectorValues(
mergeVectorValues(
mergeState.knnVectorsReaders,
mergeState.docMaps,
knnVectorsReader -> {
return knnVectorsReader.getFloatVectorValues(fieldInfo.name);
},
(docMap, values) -> {
return new FloatVectorValuesSub(docMap, values);
}),
mergeState);
}

/** Returns a merged view over all the segment's {@link ByteVectorValues}. */
public static ByteVectorValues mergeByteVectorValues(FieldInfo fieldInfo, MergeState mergeState)
throws IOException {
assert fieldInfo != null && fieldInfo.hasVectorValues();
if (fieldInfo.getVectorEncoding() != VectorEncoding.BYTE) {
throw new UnsupportedOperationException(
"Cannot merge vectors encoded as [" + fieldInfo.getVectorEncoding() + "] as BYTE");
}
List<ByteVectorValuesSub> subs = new ArrayList<>();
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
if (knnVectorsReader != null) {
ByteVectorValues values = knnVectorsReader.getByteVectorValues(fieldInfo.name);
if (values != null) {
subs.add(new ByteVectorValuesSub(mergeState.docMaps[i], values));
}
}
}
return new MergedByteVectorValues(subs, mergeState);
validateFieldEncoding(fieldInfo, VectorEncoding.BYTE);
return new MergedByteVectorValues(
mergeVectorValues(
mergeState.knnVectorsReaders,
mergeState.docMaps,
knnVectorsReader -> {
return knnVectorsReader.getByteVectorValues(fieldInfo.name);
},
(docMap, values) -> {
return new ByteVectorValuesSub(docMap, values);
}),
mergeState);
}

static class MergedFloat32VectorValues extends FloatVectorValues {
private final List<VectorValuesSub> subs;
private final DocIDMerger<VectorValuesSub> docIdMerger;
private final List<FloatVectorValuesSub> subs;
private final DocIDMerger<FloatVectorValuesSub> docIdMerger;
private final int size;
private int docId;
VectorValuesSub current;
FloatVectorValuesSub current;

private MergedFloat32VectorValues(List<VectorValuesSub> subs, MergeState mergeState)
private MergedFloat32VectorValues(List<FloatVectorValuesSub> subs, MergeState mergeState)
throws IOException {
this.subs = subs;
docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort);
int totalSize = 0;
for (VectorValuesSub sub : subs) {
for (FloatVectorValuesSub sub : subs) {
totalSize += sub.values.size();
}
size = totalSize;
Expand Down

0 comments on commit c55d664

Please sign in to comment.