Skip to content

Commit

Permalink
Fix NPE in ANN search when a segment doesn't contain vector field
Browse files Browse the repository at this point in the history
Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v committed Nov 17, 2024
1 parent a07bad1 commit 3d457ad
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 100 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Enhancements
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
### Bug Fixes
* Fix NPE in ANN search when a segment doesn't contain vector field (#)[]
### Infrastructure
* Updated C++ version in JNI from c++11 to c++17 [#2259](https://github.com/opensearch-project/k-NN/pull/2259)
### Documentation
Expand Down
169 changes: 95 additions & 74 deletions src/main/java/org/opensearch/knn/index/query/ExactSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.opensearch.common.lucene.Lucene;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.query.iterators.BinaryVectorIdsKNNIterator;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.query.iterators.ByteVectorIdsKNNIterator;
Expand All @@ -38,9 +37,7 @@
import org.opensearch.knn.indices.ModelDao;

import java.io.IOException;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.*;
import java.util.function.Predicate;

@Log4j2
Expand All @@ -59,23 +56,27 @@ public class ExactSearcher {
*/
public Map<Integer, Float> searchLeaf(final LeafReaderContext leafReaderContext, final ExactSearcherContext exactSearcherContext)
throws IOException {
KNNIterator iterator = getKNNIterator(leafReaderContext, exactSearcherContext);
final Optional<KNNIterator> iterator = getKNNIterator(leafReaderContext, exactSearcherContext);
// if because of any reason if we are not able to get KNNIterator returning an empty map
if (iterator.isEmpty()) {
return Collections.emptyMap();
}
if (exactSearcherContext.getKnnQuery().getRadius() != null) {
return doRadialSearch(leafReaderContext, exactSearcherContext, iterator);
return doRadialSearch(leafReaderContext, exactSearcherContext, iterator.get());
}
if (exactSearcherContext.getMatchedDocs() != null
&& exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) {
return scoreAllDocs(iterator);
return scoreAllDocs(iterator.get());
}
return searchTopCandidates(iterator, exactSearcherContext.getK(), Predicates.alwaysTrue());
return searchTopCandidates(iterator.get(), exactSearcherContext.getK(), Predicates.alwaysTrue());
}

/**
* Perform radial search by comparing scores with min score. Currently, FAISS from native engine supports radial search.
* Hence, we assume that Radius from knnQuery is always distance, and we convert it to score since we do exact search uses scores
* to filter out the documents that does not have given min score.
* @param leafReaderContext
* @param exactSearcherContext
* @param leafReaderContext {@link LeafReaderContext}
* @param exactSearcherContext {@link ExactSearcherContext}
* @param iterator {@link KNNIterator}
* @return Map of docId and score
* @throws IOException exception raised by iterator during traversal
Expand Down Expand Up @@ -145,79 +146,99 @@ private Map<Integer, Float> filterDocsByMinScore(ExactSearcherContext context, K
return searchTopCandidates(iterator, maxResultWindow, scoreGreaterThanOrEqualToMinScore);
}

private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSearcherContext exactSearcherContext) throws IOException {
private Optional<KNNIterator> getKNNIterator(LeafReaderContext leafReaderContext, ExactSearcherContext exactSearcherContext)
throws IOException {
final KNNQuery knnQuery = exactSearcherContext.getKnnQuery();
final BitSet matchedDocs = exactSearcherContext.getMatchedDocs();
final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader());
final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
if (fieldInfo == null) {
log.debug("[KNN] Cannot get KNNIterator as Field info not found for {}:{}", knnQuery.getField(), reader.getSegmentName());
return Optional.empty();
}
final SpaceType spaceType = FieldInfoExtractor.getSpaceType(modelDao, fieldInfo);

boolean isNestedRequired = exactSearcherContext.isParentHits() && knnQuery.getParentsFilter() != null;

if (VectorDataType.BINARY == knnQuery.getVectorDataType()) {
final KNNVectorValues<byte[]> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader);
if (isNestedRequired) {
return new NestedBinaryVectorIdsKNNIterator(
matchedDocs,
knnQuery.getByteQueryVector(),
(KNNBinaryVectorValues) vectorValues,
spaceType,
knnQuery.getParentsFilter().getBitSet(leafReaderContext)
);
}
return new BinaryVectorIdsKNNIterator(
matchedDocs,
knnQuery.getByteQueryVector(),
(KNNBinaryVectorValues) vectorValues,
spaceType
);
}

if (VectorDataType.BYTE == knnQuery.getVectorDataType()) {
final KNNVectorValues<byte[]> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader);
if (isNestedRequired) {
return new NestedByteVectorIdsKNNIterator(
matchedDocs,
knnQuery.getQueryVector(),
(KNNByteVectorValues) vectorValues,
spaceType,
knnQuery.getParentsFilter().getBitSet(leafReaderContext)
final KNNIterator knnIterator;
KNNVectorValues<?> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader);
switch (knnQuery.getVectorDataType()) {
case BINARY:
if (isNestedRequired) {
knnIterator = new NestedBinaryVectorIdsKNNIterator(
matchedDocs,
knnQuery.getByteQueryVector(),
(KNNBinaryVectorValues) vectorValues,
spaceType,
knnQuery.getParentsFilter().getBitSet(leafReaderContext)
);
} else {
knnIterator = new BinaryVectorIdsKNNIterator(
matchedDocs,
knnQuery.getByteQueryVector(),
(KNNBinaryVectorValues) vectorValues,
spaceType
);
}
return Optional.of(knnIterator);
case BYTE:
if (isNestedRequired) {
knnIterator = new NestedByteVectorIdsKNNIterator(
matchedDocs,
knnQuery.getQueryVector(),
(KNNByteVectorValues) vectorValues,
spaceType,
knnQuery.getParentsFilter().getBitSet(leafReaderContext)
);
} else {
knnIterator = new ByteVectorIdsKNNIterator(
matchedDocs,
knnQuery.getQueryVector(),
(KNNByteVectorValues) vectorValues,
spaceType
);
}
return Optional.of(knnIterator);
case FLOAT:
final byte[] quantizedQueryVector;
final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo;
if (exactSearcherContext.isUseQuantizedVectorsForSearch()) {
// Build Segment Level Quantization info.
segmentLevelQuantizationInfo = SegmentLevelQuantizationInfo.build(reader, fieldInfo, knnQuery.getField());
// Quantize the Query Vector Once.
quantizedQueryVector = SegmentLevelQuantizationUtil.quantizeVector(
knnQuery.getQueryVector(),
segmentLevelQuantizationInfo
);
} else {
segmentLevelQuantizationInfo = null;
quantizedQueryVector = null;
}
if (isNestedRequired) {
knnIterator = new NestedVectorIdsKNNIterator(
matchedDocs,
knnQuery.getQueryVector(),
(KNNFloatVectorValues) vectorValues,
spaceType,
knnQuery.getParentsFilter().getBitSet(leafReaderContext),
quantizedQueryVector,
segmentLevelQuantizationInfo
);
} else {
knnIterator = new VectorIdsKNNIterator(
matchedDocs,
knnQuery.getQueryVector(),
(KNNFloatVectorValues) vectorValues,
spaceType,
quantizedQueryVector,
segmentLevelQuantizationInfo
);
}
return Optional.of(knnIterator);
default:
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Vector data type [%s] is not supported", knnQuery.getVectorDataType())
);
}
return new ByteVectorIdsKNNIterator(matchedDocs, knnQuery.getQueryVector(), (KNNByteVectorValues) vectorValues, spaceType);
}
final byte[] quantizedQueryVector;
final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo;
if (exactSearcherContext.isUseQuantizedVectorsForSearch()) {
// Build Segment Level Quantization info.
segmentLevelQuantizationInfo = SegmentLevelQuantizationInfo.build(reader, fieldInfo, knnQuery.getField());
// Quantize the Query Vector Once.
quantizedQueryVector = SegmentLevelQuantizationUtil.quantizeVector(knnQuery.getQueryVector(), segmentLevelQuantizationInfo);
} else {
segmentLevelQuantizationInfo = null;
quantizedQueryVector = null;
}

final KNNVectorValues<float[]> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader);
if (isNestedRequired) {
return new NestedVectorIdsKNNIterator(
matchedDocs,
knnQuery.getQueryVector(),
(KNNFloatVectorValues) vectorValues,
spaceType,
knnQuery.getParentsFilter().getBitSet(leafReaderContext),
quantizedQueryVector,
segmentLevelQuantizationInfo
);
}
return new VectorIdsKNNIterator(
matchedDocs,
knnQuery.getQueryVector(),
(KNNFloatVectorValues) vectorValues,
spaceType,
quantizedQueryVector,
segmentLevelQuantizationInfo
);
}

/**
Expand Down
19 changes: 8 additions & 11 deletions src/main/java/org/opensearch/knn/index/query/ResultUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@
import org.apache.lucene.util.DocIdSetBuilder;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.*;

/**
* Utility class used for processing results
Expand Down Expand Up @@ -58,19 +54,20 @@ public static void reduceToTopK(List<Map<Integer, Float>> perLeafResults, int k)
}

/**
* Convert map to bit set
* Convert map to bit set, if resultMap is empty or null then returns an Optional. Returning an optional here to
* ensure that the caller is aware that BitSet may not be present
*
* @param resultMap Map of results
* @return BitSet of results
* @return Optional BitSet of results
* @throws IOException If an error occurs during the search.
*/
public static BitSet resultMapToMatchBitSet(Map<Integer, Float> resultMap) throws IOException {
if (resultMap.isEmpty()) {
return BitSet.of(DocIdSetIterator.empty(), 0);
public static Optional<BitSet> resultMapToMatchBitSet(Map<Integer, Float> resultMap) throws IOException {
if (resultMap == null || resultMap.isEmpty()) {
return Optional.empty();
}

final int maxDoc = Collections.max(resultMap.keySet()) + 1;
return BitSet.of(resultMapToDocIds(resultMap, maxDoc), maxDoc);
return Optional.of(BitSet.of(resultMapToDocIds(resultMap, maxDoc), maxDoc));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,7 @@
import org.opensearch.knn.index.query.rescore.RescoreContext;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.*;
import java.util.concurrent.Callable;

/**
Expand Down Expand Up @@ -112,9 +107,14 @@ private List<Map<Integer, Float>> doRescore(
LeafReaderContext leafReaderContext = leafReaderContexts.get(i);
int finalI = i;
rescoreTasks.add(() -> {
BitSet convertedBitSet = ResultUtil.resultMapToMatchBitSet(perLeafResults.get(finalI));
final Optional<BitSet> convertedBitSet = ResultUtil.resultMapToMatchBitSet(perLeafResults.get(finalI));
// if there is no docIds to re-score from a segment we should return early to ensure that we are not
// wasting any computation
if (convertedBitSet.isEmpty()) {
return Collections.emptyMap();
}
final ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder()
.matchedDocs(convertedBitSet)
.matchedDocs(convertedBitSet.get())
// setting to false because in re-scoring we want to do exact search on full precision vectors
.useQuantizedVectorsForSearch(false)
.k(k)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@
import org.opensearch.knn.KNNTestCase;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.*;
import java.util.stream.Collectors;

public class ResultUtilTests extends KNNTestCase {
Expand Down Expand Up @@ -44,8 +40,8 @@ public void testReduceToTopK() {
public void testResultMapToMatchBitSet() throws IOException {
int firstPassK = 35;
Map<Integer, Float> perLeafResults = getRandomResults(firstPassK);
BitSet resultBitset = ResultUtil.resultMapToMatchBitSet(perLeafResults);
assertResultMapToMatchBitSet(perLeafResults, resultBitset);
Optional<BitSet> resultBitset = ResultUtil.resultMapToMatchBitSet(perLeafResults);
assertResultMapToMatchBitSet(perLeafResults, resultBitset.get());
}

public void testResultMapToDocIds() throws IOException {
Expand Down

0 comments on commit 3d457ad

Please sign in to comment.