Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

in KnnVectorsWriter reduce code duplication w.r.t. MergedVectorValues.merge(Float|Byte)VectorValues #13539

Merged
merged 5 commits into from
Jul 12, 2024
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 56 additions & 34 deletions lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.BiFunction;
import java.util.Objects;
cpoerschke marked this conversation as resolved.
Show resolved Hide resolved
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocIDMerger;
Expand All @@ -35,6 +36,7 @@
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.IOFunction;

/** Writes vectors to an index. */
public abstract class KnnVectorsWriter implements Accountable, Closeable {
Expand Down Expand Up @@ -111,11 +113,11 @@ public final void merge(MergeState mergeState) throws IOException {
}

/** Tracks state of one sub-reader that we are merging */
private static class VectorValuesSub extends DocIDMerger.Sub {
private static class FloatVectorValuesSub extends DocIDMerger.Sub {

final FloatVectorValues values;

VectorValuesSub(MergeState.DocMap docMap, FloatVectorValues values) {
FloatVectorValuesSub(MergeState.DocMap docMap, FloatVectorValues values) {
super(docMap);
this.values = values;
assert values.docID() == -1;
Expand Down Expand Up @@ -201,61 +203,81 @@ public static void mapOldOrdToNewOrd(
public static final class MergedVectorValues {
private MergedVectorValues() {}

/** Returns a merged view over all the segment's {@link FloatVectorValues}. */
public static FloatVectorValues mergeFloatVectorValues(
FieldInfo fieldInfo, MergeState mergeState) throws IOException {
private static void validateFieldEncoding(FieldInfo fieldInfo, VectorEncoding expected) {
assert fieldInfo != null && fieldInfo.hasVectorValues();
if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) {
VectorEncoding fieldEncoding = fieldInfo.getVectorEncoding();
if (fieldEncoding != expected) {
throw new UnsupportedOperationException(
"Cannot merge vectors encoded as [" + fieldInfo.getVectorEncoding() + "] as FLOAT32");
"Cannot merge vectors encoded as [" + fieldEncoding + "] as " + expected);
}
List<VectorValuesSub> subs = new ArrayList<>();
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
}

private static <V, S> List<S> mergeVectorValues(
KnnVectorsReader[] knnVectorsReaders,
MergeState.DocMap[] docMaps,
IOFunction<KnnVectorsReader, V> valuesSupplier,
BiFunction<MergeState.DocMap, V, S> newSub)
throws IOException {
List<S> subs = new ArrayList<>();
for (int i = 0; i < knnVectorsReaders.length; i++) {
KnnVectorsReader knnVectorsReader = knnVectorsReaders[i];
if (knnVectorsReader != null) {
FloatVectorValues values = knnVectorsReader.getFloatVectorValues(fieldInfo.name);
V values = valuesSupplier.apply(knnVectorsReader);
if (values != null) {
subs.add(new VectorValuesSub(mergeState.docMaps[i], values));
subs.add(newSub.apply(docMaps[i], values));
}
}
}
return new MergedFloat32VectorValues(subs, mergeState);
return subs;
}

/** Returns a merged view over all the segment's {@link FloatVectorValues}. */
public static FloatVectorValues mergeFloatVectorValues(
FieldInfo fieldInfo, MergeState mergeState) throws IOException {
validateFieldEncoding(fieldInfo, VectorEncoding.FLOAT32);
return new MergedFloat32VectorValues(
mergeVectorValues(
mergeState.knnVectorsReaders,
mergeState.docMaps,
knnVectorsReader -> {
return knnVectorsReader.getFloatVectorValues(fieldInfo.name);
},
(docMap, values) -> {
return new FloatVectorValuesSub(docMap, values);
}),
mergeState);
}

/** Returns a merged view over all the segment's {@link ByteVectorValues}. */
public static ByteVectorValues mergeByteVectorValues(FieldInfo fieldInfo, MergeState mergeState)
throws IOException {
assert fieldInfo != null && fieldInfo.hasVectorValues();
if (fieldInfo.getVectorEncoding() != VectorEncoding.BYTE) {
throw new UnsupportedOperationException(
"Cannot merge vectors encoded as [" + fieldInfo.getVectorEncoding() + "] as BYTE");
}
List<ByteVectorValuesSub> subs = new ArrayList<>();
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
if (knnVectorsReader != null) {
ByteVectorValues values = knnVectorsReader.getByteVectorValues(fieldInfo.name);
if (values != null) {
subs.add(new ByteVectorValuesSub(mergeState.docMaps[i], values));
}
}
}
return new MergedByteVectorValues(subs, mergeState);
validateFieldEncoding(fieldInfo, VectorEncoding.BYTE);
return new MergedByteVectorValues(
mergeVectorValues(
mergeState.knnVectorsReaders,
mergeState.docMaps,
knnVectorsReader -> {
return knnVectorsReader.getByteVectorValues(fieldInfo.name);
},
(docMap, values) -> {
return new ByteVectorValuesSub(docMap, values);
}),
mergeState);
}

static class MergedFloat32VectorValues extends FloatVectorValues {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And maybe this could drop the 32 for consistency with MergedByteVectorValues naming?

Suggested change
static class MergedFloat32VectorValues extends FloatVectorValues {
static class MergedFloatVectorValues extends FloatVectorValues {

private final List<VectorValuesSub> subs;
private final DocIDMerger<VectorValuesSub> docIdMerger;
private final List<FloatVectorValuesSub> subs;
private final DocIDMerger<FloatVectorValuesSub> docIdMerger;
private final int size;
private int docId;
VectorValuesSub current;
FloatVectorValuesSub current;

private MergedFloat32VectorValues(List<VectorValuesSub> subs, MergeState mergeState)
private MergedFloat32VectorValues(List<FloatVectorValuesSub> subs, MergeState mergeState)
throws IOException {
this.subs = subs;
docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort);
int totalSize = 0;
for (VectorValuesSub sub : subs) {
for (FloatVectorValuesSub sub : subs) {
totalSize += sub.values.size();
}
size = totalSize;
Expand Down
Loading