From 3d9ff391b92f15a6189f9be30707950bedac3cdc Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Sun, 17 Nov 2024 14:30:13 -0800 Subject: [PATCH] Fix NPE in ANN search when a segment doesn't contain vector field Signed-off-by: Navneet Verma --- CHANGELOG.md | 1 + build.gradle | 6 +- .../knn/index/query/ExactSearcher.java | 167 ++++++++++-------- .../knn/index/query/ResultUtil.java | 14 +- .../nativelib/NativeEngineKnnVectorQuery.java | 16 +- .../knn/index/query/ResultUtilTests.java | 5 +- .../knn/integ/ModeAndCompressionIT.java | 36 ++++ .../org/opensearch/knn/KNNRestTestCase.java | 13 ++ 8 files changed, 168 insertions(+), 90 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c57523c3..5695d48ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Enhancements - Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241] ### Bug Fixes +* Fix NPE in ANN search when a segment doesn't contain vector field (#2278)[https://github.com/opensearch-project/k-NN/pull/2278] ### Infrastructure * Updated C++ version in JNI from c++11 to c++17 [#2259](https://github.com/opensearch-project/k-NN/pull/2259) ### Documentation diff --git a/build.gradle b/build.gradle index 132fd7f43..2ad973676 100644 --- a/build.gradle +++ b/build.gradle @@ -295,9 +295,9 @@ dependencies { api group: 'com.google.guava', name: 'guava', version:'32.1.3-jre' api group: 'commons-lang', name: 'commons-lang', version: '2.6' testFixturesImplementation "org.opensearch.test:framework:${opensearch_version}" - testImplementation group: 'net.bytebuddy', name: 'byte-buddy', version: '1.15.4' - testImplementation group: 'org.objenesis', name: 'objenesis', version: '3.2' - testImplementation group: 'net.bytebuddy', name: 'byte-buddy-agent', version: '1.15.4' + testImplementation group: 'net.bytebuddy', name: 'byte-buddy', version: '1.15.10' + testImplementation group: 'org.objenesis', name: 'objenesis', version: '3.3' + testImplementation group: 'net.bytebuddy', name: 'byte-buddy-agent', version: '1.15.10' testFixturesImplementation "org.opensearch:common-utils:${version}" implementation 'com.github.oshi:oshi-core:6.4.13' api "net.java.dev.jna:jna:5.13.0" diff --git a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java index 77e993297..d44b48a5d 100644 --- a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java +++ b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java @@ -21,7 +21,6 @@ import org.opensearch.common.lucene.Lucene; import org.opensearch.knn.common.FieldInfoExtractor; import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.query.iterators.BinaryVectorIdsKNNIterator; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.query.iterators.ByteVectorIdsKNNIterator; @@ -38,9 +37,11 @@ import org.opensearch.knn.indices.ModelDao; import java.io.IOException; +import java.util.Collections; import java.util.HashMap; import java.util.Locale; import java.util.Map; +import java.util.Optional; import java.util.function.Predicate; @Log4j2 @@ -59,23 +60,27 @@ public class ExactSearcher { */ public Map searchLeaf(final LeafReaderContext leafReaderContext, final ExactSearcherContext exactSearcherContext) throws IOException { - KNNIterator iterator = getKNNIterator(leafReaderContext, exactSearcherContext); + final Optional iterator = getKNNIterator(leafReaderContext, exactSearcherContext); + // because of any reason if we are not able to get KNNIterator, return an empty map + if (iterator.isEmpty()) { + return Collections.emptyMap(); + } if (exactSearcherContext.getKnnQuery().getRadius() != null) { - return doRadialSearch(leafReaderContext, exactSearcherContext, iterator); + return doRadialSearch(leafReaderContext, exactSearcherContext, iterator.get()); } if (exactSearcherContext.getMatchedDocs() != null && exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) { - return scoreAllDocs(iterator); + return scoreAllDocs(iterator.get()); } - return searchTopCandidates(iterator, exactSearcherContext.getK(), Predicates.alwaysTrue()); + return searchTopCandidates(iterator.get(), exactSearcherContext.getK(), Predicates.alwaysTrue()); } /** * Perform radial search by comparing scores with min score. Currently, FAISS from native engine supports radial search. * Hence, we assume that Radius from knnQuery is always distance, and we convert it to score since we do exact search uses scores * to filter out the documents that does not have given min score. - * @param leafReaderContext - * @param exactSearcherContext + * @param leafReaderContext {@link LeafReaderContext} + * @param exactSearcherContext {@link ExactSearcherContext} * @param iterator {@link KNNIterator} * @return Map of docId and score * @throws IOException exception raised by iterator during traversal @@ -145,79 +150,99 @@ private Map filterDocsByMinScore(ExactSearcherContext context, K return searchTopCandidates(iterator, maxResultWindow, scoreGreaterThanOrEqualToMinScore); } - private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSearcherContext exactSearcherContext) throws IOException { + private Optional getKNNIterator(LeafReaderContext leafReaderContext, ExactSearcherContext exactSearcherContext) + throws IOException { final KNNQuery knnQuery = exactSearcherContext.getKnnQuery(); final BitSet matchedDocs = exactSearcherContext.getMatchedDocs(); final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader()); final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); + if (fieldInfo == null) { + log.debug("[KNN] Cannot get KNNIterator as Field info not found for {}:{}", knnQuery.getField(), reader.getSegmentName()); + return Optional.empty(); + } final SpaceType spaceType = FieldInfoExtractor.getSpaceType(modelDao, fieldInfo); boolean isNestedRequired = exactSearcherContext.isParentHits() && knnQuery.getParentsFilter() != null; - - if (VectorDataType.BINARY == knnQuery.getVectorDataType()) { - final KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader); - if (isNestedRequired) { - return new NestedBinaryVectorIdsKNNIterator( - matchedDocs, - knnQuery.getByteQueryVector(), - (KNNBinaryVectorValues) vectorValues, - spaceType, - knnQuery.getParentsFilter().getBitSet(leafReaderContext) - ); - } - return new BinaryVectorIdsKNNIterator( - matchedDocs, - knnQuery.getByteQueryVector(), - (KNNBinaryVectorValues) vectorValues, - spaceType - ); - } - - if (VectorDataType.BYTE == knnQuery.getVectorDataType()) { - final KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader); - if (isNestedRequired) { - return new NestedByteVectorIdsKNNIterator( - matchedDocs, - knnQuery.getQueryVector(), - (KNNByteVectorValues) vectorValues, - spaceType, - knnQuery.getParentsFilter().getBitSet(leafReaderContext) + final KNNIterator knnIterator; + KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader); + switch (knnQuery.getVectorDataType()) { + case BINARY: + if (isNestedRequired) { + knnIterator = new NestedBinaryVectorIdsKNNIterator( + matchedDocs, + knnQuery.getByteQueryVector(), + (KNNBinaryVectorValues) vectorValues, + spaceType, + knnQuery.getParentsFilter().getBitSet(leafReaderContext) + ); + } else { + knnIterator = new BinaryVectorIdsKNNIterator( + matchedDocs, + knnQuery.getByteQueryVector(), + (KNNBinaryVectorValues) vectorValues, + spaceType + ); + } + return Optional.of(knnIterator); + case BYTE: + if (isNestedRequired) { + knnIterator = new NestedByteVectorIdsKNNIterator( + matchedDocs, + knnQuery.getQueryVector(), + (KNNByteVectorValues) vectorValues, + spaceType, + knnQuery.getParentsFilter().getBitSet(leafReaderContext) + ); + } else { + knnIterator = new ByteVectorIdsKNNIterator( + matchedDocs, + knnQuery.getQueryVector(), + (KNNByteVectorValues) vectorValues, + spaceType + ); + } + return Optional.of(knnIterator); + case FLOAT: + final byte[] quantizedQueryVector; + final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo; + if (exactSearcherContext.isUseQuantizedVectorsForSearch()) { + // Build Segment Level Quantization info. + segmentLevelQuantizationInfo = SegmentLevelQuantizationInfo.build(reader, fieldInfo, knnQuery.getField()); + // Quantize the Query Vector Once. + quantizedQueryVector = SegmentLevelQuantizationUtil.quantizeVector( + knnQuery.getQueryVector(), + segmentLevelQuantizationInfo + ); + } else { + segmentLevelQuantizationInfo = null; + quantizedQueryVector = null; + } + if (isNestedRequired) { + knnIterator = new NestedVectorIdsKNNIterator( + matchedDocs, + knnQuery.getQueryVector(), + (KNNFloatVectorValues) vectorValues, + spaceType, + knnQuery.getParentsFilter().getBitSet(leafReaderContext), + quantizedQueryVector, + segmentLevelQuantizationInfo + ); + } else { + knnIterator = new VectorIdsKNNIterator( + matchedDocs, + knnQuery.getQueryVector(), + (KNNFloatVectorValues) vectorValues, + spaceType, + quantizedQueryVector, + segmentLevelQuantizationInfo + ); + } + return Optional.of(knnIterator); + default: + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Vector data type [%s] is not supported", knnQuery.getVectorDataType()) ); - } - return new ByteVectorIdsKNNIterator(matchedDocs, knnQuery.getQueryVector(), (KNNByteVectorValues) vectorValues, spaceType); - } - final byte[] quantizedQueryVector; - final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo; - if (exactSearcherContext.isUseQuantizedVectorsForSearch()) { - // Build Segment Level Quantization info. - segmentLevelQuantizationInfo = SegmentLevelQuantizationInfo.build(reader, fieldInfo, knnQuery.getField()); - // Quantize the Query Vector Once. - quantizedQueryVector = SegmentLevelQuantizationUtil.quantizeVector(knnQuery.getQueryVector(), segmentLevelQuantizationInfo); - } else { - segmentLevelQuantizationInfo = null; - quantizedQueryVector = null; - } - - final KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader); - if (isNestedRequired) { - return new NestedVectorIdsKNNIterator( - matchedDocs, - knnQuery.getQueryVector(), - (KNNFloatVectorValues) vectorValues, - spaceType, - knnQuery.getParentsFilter().getBitSet(leafReaderContext), - quantizedQueryVector, - segmentLevelQuantizationInfo - ); } - return new VectorIdsKNNIterator( - matchedDocs, - knnQuery.getQueryVector(), - (KNNFloatVectorValues) vectorValues, - spaceType, - quantizedQueryVector, - segmentLevelQuantizationInfo - ); } /** diff --git a/src/main/java/org/opensearch/knn/index/query/ResultUtil.java b/src/main/java/org/opensearch/knn/index/query/ResultUtil.java index f62c09cb0..487d747ca 100644 --- a/src/main/java/org/opensearch/knn/index/query/ResultUtil.java +++ b/src/main/java/org/opensearch/knn/index/query/ResultUtil.java @@ -17,6 +17,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.PriorityQueue; /** @@ -58,19 +59,20 @@ public static void reduceToTopK(List> perLeafResults, int k) } /** - * Convert map to bit set + * Convert map to bit set, if resultMap is empty or null then returns an Optional. Returning an optional here to + * ensure that the caller is aware that BitSet may not be present * * @param resultMap Map of results - * @return BitSet of results + * @return Optional BitSet of results * @throws IOException If an error occurs during the search. */ - public static BitSet resultMapToMatchBitSet(Map resultMap) throws IOException { - if (resultMap.isEmpty()) { - return BitSet.of(DocIdSetIterator.empty(), 0); + public static Optional resultMapToMatchBitSet(Map resultMap) throws IOException { + if (resultMap == null || resultMap.isEmpty()) { + return Optional.empty(); } final int maxDoc = Collections.max(resultMap.keySet()) + 1; - return BitSet.of(resultMapToDocIds(resultMap, maxDoc), maxDoc); + return Optional.of(BitSet.of(resultMapToDocIds(resultMap, maxDoc), maxDoc)); } /** diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java index a34a0f1ee..74a8ecaa9 100644 --- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -28,12 +28,7 @@ import org.opensearch.knn.index.query.rescore.RescoreContext; import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Comparator; -import java.util.List; -import java.util.Map; -import java.util.Objects; +import java.util.*; import java.util.concurrent.Callable; /** @@ -112,9 +107,14 @@ private List> doRescore( LeafReaderContext leafReaderContext = leafReaderContexts.get(i); int finalI = i; rescoreTasks.add(() -> { - BitSet convertedBitSet = ResultUtil.resultMapToMatchBitSet(perLeafResults.get(finalI)); + final Optional convertedBitSet = ResultUtil.resultMapToMatchBitSet(perLeafResults.get(finalI)); + // if there is no docIds to re-score from a segment we should return early to ensure that we are not + // wasting any computation + if (convertedBitSet.isEmpty()) { + return Collections.emptyMap(); + } final ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder() - .matchedDocs(convertedBitSet) + .matchedDocs(convertedBitSet.get()) // setting to false because in re-scoring we want to do exact search on full precision vectors .useQuantizedVectorsForSearch(false) .k(k) diff --git a/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java b/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java index 70cb86e02..3b1a8c708 100644 --- a/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java @@ -17,6 +17,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.stream.Collectors; public class ResultUtilTests extends KNNTestCase { @@ -44,8 +45,8 @@ public void testReduceToTopK() { public void testResultMapToMatchBitSet() throws IOException { int firstPassK = 35; Map perLeafResults = getRandomResults(firstPassK); - BitSet resultBitset = ResultUtil.resultMapToMatchBitSet(perLeafResults); - assertResultMapToMatchBitSet(perLeafResults, resultBitset); + Optional resultBitset = ResultUtil.resultMapToMatchBitSet(perLeafResults); + assertResultMapToMatchBitSet(perLeafResults, resultBitset.get()); } public void testResultMapToDocIds() throws IOException { diff --git a/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java b/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java index 0913d9b36..ad7f8b057 100644 --- a/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java +++ b/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java @@ -11,6 +11,7 @@ import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; +import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentBuilder; @@ -220,6 +221,41 @@ public void testDeletedDocsWithSegmentMerge_whenValid_ThenSucceed() { validateGreenIndex(indexName); } + @SneakyThrows + public void testCompressionIndexWithNonVectorFieldsSegment_whenValid_ThenSucceed() { + CompressionLevel compressionLevel = CompressionLevel.x32; + String indexName = INDEX_NAME + compressionLevel; + try ( + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", DIMENSION) + .field(COMPRESSION_LEVEL_PARAMETER, compressionLevel.getName()) + .field(MODE_PARAMETER, Mode.ON_DISK.getName()) + .endObject() + .endObject() + .endObject() + ) { + String mapping = builder.toString(); + Settings indexSettings = buildKNNIndexSettings(0); + createKnnIndex(indexName, indexSettings, mapping); + // since we are going to delete a document, so its better to have 1 more extra doc so that we can re-use some tests + addKNNDocs(indexName, FIELD_NAME, DIMENSION, 0, NUM_DOCS + 1); + addNonKNNDoc(indexName, String.valueOf(NUM_DOCS + 2), FIELD_NAME_NON_KNN, "Hello world"); + deleteKnnDoc(indexName, "0"); + validateGreenIndex(indexName); + validateSearch( + indexName, + METHOD_PARAMETER_EF_SEARCH, + KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH, + compressionLevel.getName(), + Mode.ON_DISK.getName() + ); + } + } + @SneakyThrows public void testTraining_whenInvalid_thenFail() { setupTrainingIndex(); diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 8a4885cfe..2afbd9639 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -116,6 +116,7 @@ public class KNNRestTestCase extends ODFERestTestCase { public static final String INDEX_NAME = "test_index"; public static final String FIELD_NAME = "test_field"; + public static final String FIELD_NAME_NON_KNN = "test_field_non_knn"; public static final String PROPERTIES_FIELD = "properties"; public static final String STORE_FIELD = "store"; public static final String STORED_QUERY_FIELD = "stored_fields"; @@ -607,6 +608,18 @@ protected void addKnnDoc(String index, String docId, String fieldName, T vec assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } + protected void addNonKNNDoc(String index, String docId, String fieldName, String text) throws IOException { + Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); + + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(fieldName, text).endObject(); + request.setJsonEntity(builder.toString()); + client().performRequest(request); + + request = new Request("POST", "/" + index + "/_refresh"); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + /** * Add a single KNN Doc to an index with a nested vector field *