diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 04c2ce587..073737476 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -153,6 +153,26 @@ public Map searchLeaf(LeafReaderContext context, int k) throws I return docIdsToScoreMap; } + /** + * For given {@link LeafReaderContext}, this api will return will KNNWeight perform exact search or not + * always. This decision is based on two properties, 1) if there are no native engine files in segments, + * exact search will always be performed, 2) if number of docs after filter is less than 'k' + * @param context + * @return + * @throws IOException + */ + public boolean isExactSearchPreferred(LeafReaderContext context) throws IOException { + final BitSet filterBitSet = getFilteredDocsBitSet(context); + int cardinality = filterBitSet.cardinality(); + if (isFilteredExactSearchPreferred(cardinality)) { + return true; + } + if (isMissingNativeEngineFiles(context)) { + return true; + } + return false; + } + private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOException { if (this.filterWeight == null) { return new FixedBitSet(0); diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java index 53885850a..143f74fe1 100644 --- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -65,7 +65,17 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo boolean isShardLevelRescoringEnabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(knnQuery.getIndexName()); int dimension = knnQuery.getQueryVector().length; int firstPassK = rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension); - perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, firstPassK); + // split segments into whether exact search will be performed or not + List exactSearchSegments = new ArrayList<>(); + List approxSearchSegments = new ArrayList<>(); + for (LeafReaderContext leafReaderContext : leafReaderContexts) { + if (knnWeight.isExactSearchPreferred(leafReaderContext)) { + exactSearchSegments.add(leafReaderContext); + } else { + approxSearchSegments.add(leafReaderContext); + } + } + perLeafResults = doSearch(indexSearcher, approxSearchSegments, knnWeight, firstPassK); if (isShardLevelRescoringEnabled == true) { ResultUtil.reduceToTopK(perLeafResults, firstPassK); } @@ -73,7 +83,9 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo StopWatch stopWatch = new StopWatch().start(); perLeafResults = doRescore(indexSearcher, knnWeight, perLeafResults, finalK); long rescoreTime = stopWatch.stop().totalTime().millis(); - log.debug("Rescoring results took {} ms. oversampled k:{}, segments:{}", rescoreTime, firstPassK, leafReaderContexts.size()); + log.debug("Rescoring results took {} ms. oversampled k:{}, segments:{}", rescoreTime, firstPassK, perLeafResults.size()); + // do exact search on rest of segments and append to result lists + perLeafResults.addAll(doExactSearch(indexSearcher, knnWeight, exactSearchSegments)); } ResultUtil.reduceToTopK(perLeafResults, finalK); TopDocs[] topDocs = new TopDocs[perLeafResults.size()]; @@ -127,6 +139,28 @@ private List>> doRescore( return indexSearcher.getTaskExecutor().invokeAll(rescoreTasks); } + private List>> doExactSearch( + final IndexSearcher indexSearcher, + KNNWeight knnWeight, + List leafReaderContexts + ) throws IOException { + List>>> exactSearchTasks = new ArrayList<>(leafReaderContexts.size()); + for (LeafReaderContext context : leafReaderContexts) { + exactSearchTasks.add(() -> { + final ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder() + // setting to false because we want to do exact search on full precision vectors + .useQuantizedVectorsForSearch(false) + .k(knnQuery.getK()) + .knnQuery(knnQuery) + .isParentHits(true) + .build(); + final Map searchResults = knnWeight.exactSearch(context, exactSearcherContext); + return new AbstractMap.SimpleEntry<>(context, searchResults); + }); + } + return indexSearcher.getTaskExecutor().invokeAll(exactSearchTasks); + } + private Query createDocAndScoreQuery(IndexReader reader, TopDocs topK) { int len = topK.scoreDocs.length; Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));