Skip to content

Commit

Permalink
Fix vector type check for diversified knn search (#13235)
Browse files Browse the repository at this point in the history
I repeatably saw some test failures related to `TestParentBlockJoin[Byte|Float]KnnVectorQuery#testVectorEncodingMismatch`. This commit fixes those test failures and actually checks the field type.
  • Loading branch information
benwtrent committed Mar 29, 2024
1 parent 55ca9f7 commit 50a4475
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ protected TopDocs approximateSearch(
int visitedLimit,
KnnCollectorManager knnCollectorManager)
throws IOException {
ByteVectorValues.checkField(context.reader(), field);
KnnCollector collector = knnCollectorManager.newCollector(visitedLimit, context);
if (collector == null) {
return NO_RESULTS;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ protected TopDocs approximateSearch(
int visitedLimit,
KnnCollectorManager knnCollectorManager)
throws IOException {
FloatVectorValues.checkField(context.reader(), field);
KnnCollector collector = knnCollectorManager.newCollector(visitedLimit, context);
if (collector == null) {
return NO_RESULTS;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,21 @@

package org.apache.lucene.search.join;

import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.store.Directory;

public class TestParentBlockJoinByteKnnVectorQuery extends ParentBlockJoinKnnVectorQueryTestCase {
Expand Down Expand Up @@ -54,16 +59,25 @@ Field getKnnVectorField(
}

public void testVectorEncodingMismatch() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
Query filter = new TermQuery(new Term("other", "value"));
BitSetProducer parentFilter = parentFilter(reader);
Query kvq =
new DiversifyingChildrenFloatKnnVectorQuery(
"field", new float[] {1, 2}, filter, 2, parentFilter);
assertThrows(IllegalStateException.class, () -> searcher.search(kvq, 3));
try (Directory d = newDirectory()) {
try (IndexWriter w =
new IndexWriter(
d, new IndexWriterConfig().setMergePolicy(newMergePolicy(random(), false)))) {
List<Document> toAdd = new ArrayList<>();
Document doc = new Document();
doc.add(getKnnVectorField("field", new float[] {1, 1}, COSINE));
toAdd.add(doc);
toAdd.add(makeParent(new int[] {1}));
w.addDocuments(toAdd);
}
try (IndexReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = newSearcher(reader);
BitSetProducer parentFilter = parentFilter(reader);
Query kvq =
new DiversifyingChildrenFloatKnnVectorQuery(
"field", new float[] {1, 2}, null, 2, parentFilter);
assertThrows(IllegalStateException.class, () -> searcher.search(kvq, 3));
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,9 @@
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.store.Directory;

public class TestParentBlockJoinFloatKnnVectorQuery extends ParentBlockJoinKnnVectorQueryTestCase {
Expand All @@ -50,16 +48,25 @@ Query getParentJoinKnnQuery(
}

public void testVectorEncodingMismatch() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
Query filter = new TermQuery(new Term("other", "value"));
BitSetProducer parentFilter = parentFilter(reader);
Query kvq =
new DiversifyingChildrenByteKnnVectorQuery(
"field", new byte[] {1, 2}, filter, 2, parentFilter);
assertThrows(IllegalStateException.class, () -> searcher.search(kvq, 3));
try (Directory d = newDirectory()) {
try (IndexWriter w =
new IndexWriter(
d, new IndexWriterConfig().setMergePolicy(newMergePolicy(random(), false)))) {
List<Document> toAdd = new ArrayList<>();
Document doc = new Document();
doc.add(getKnnVectorField("field", new float[] {1, 1}, COSINE));
toAdd.add(doc);
toAdd.add(makeParent(new int[] {1}));
w.addDocuments(toAdd);
}
try (IndexReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = newSearcher(reader);
BitSetProducer parentFilter = parentFilter(reader);
Query kvq =
new DiversifyingChildrenByteKnnVectorQuery(
"field", new byte[] {1, 2}, null, 2, parentFilter);
assertThrows(IllegalStateException.class, () -> searcher.search(kvq, 3));
}
}
}

Expand Down

0 comments on commit 50a4475

Please sign in to comment.