From 235460f46da13827d39e12b9eea8d6e8b914ce8f Mon Sep 17 00:00:00 2001 From: Vijayan Balasubramanian Date: Thu, 14 Nov 2024 12:05:17 -0800 Subject: [PATCH] Segregate segments based on search type For exact search, it is not required to perform qunatization during rescore with oversamples. However, to avoid normalization between segments from approx search and exact search, we will first identify segments that needs approxsearch and will perform oversamples and, at end, after rescore, we will add scores from segments that will perform exact search. Signed-off-by: Vijayan Balasubramanian --- .../opensearch/knn/index/query/KNNWeight.java | 20 ++++++++++ .../nativelib/NativeEngineKnnVectorQuery.java | 38 ++++++++++++++++++- 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 04c2ce587..073737476 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -153,6 +153,26 @@ public Map searchLeaf(LeafReaderContext context, int k) throws I return docIdsToScoreMap; } + /** + * For given {@link LeafReaderContext}, this api will return will KNNWeight perform exact search or not + * always. This decision is based on two properties, 1) if there are no native engine files in segments, + * exact search will always be performed, 2) if number of docs after filter is less than 'k' + * @param context + * @return + * @throws IOException + */ + public boolean isExactSearchPreferred(LeafReaderContext context) throws IOException { + final BitSet filterBitSet = getFilteredDocsBitSet(context); + int cardinality = filterBitSet.cardinality(); + if (isFilteredExactSearchPreferred(cardinality)) { + return true; + } + if (isMissingNativeEngineFiles(context)) { + return true; + } + return false; + } + private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOException { if (this.filterWeight == null) { return new FixedBitSet(0); diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java index 53885850a..143f74fe1 100644 --- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -65,7 +65,17 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo boolean isShardLevelRescoringEnabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(knnQuery.getIndexName()); int dimension = knnQuery.getQueryVector().length; int firstPassK = rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension); - perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, firstPassK); + // split segments into whether exact search will be performed or not + List exactSearchSegments = new ArrayList<>(); + List approxSearchSegments = new ArrayList<>(); + for (LeafReaderContext leafReaderContext : leafReaderContexts) { + if (knnWeight.isExactSearchPreferred(leafReaderContext)) { + exactSearchSegments.add(leafReaderContext); + } else { + approxSearchSegments.add(leafReaderContext); + } + } + perLeafResults = doSearch(indexSearcher, approxSearchSegments, knnWeight, firstPassK); if (isShardLevelRescoringEnabled == true) { ResultUtil.reduceToTopK(perLeafResults, firstPassK); } @@ -73,7 +83,9 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo StopWatch stopWatch = new StopWatch().start(); perLeafResults = doRescore(indexSearcher, knnWeight, perLeafResults, finalK); long rescoreTime = stopWatch.stop().totalTime().millis(); - log.debug("Rescoring results took {} ms. oversampled k:{}, segments:{}", rescoreTime, firstPassK, leafReaderContexts.size()); + log.debug("Rescoring results took {} ms. oversampled k:{}, segments:{}", rescoreTime, firstPassK, perLeafResults.size()); + // do exact search on rest of segments and append to result lists + perLeafResults.addAll(doExactSearch(indexSearcher, knnWeight, exactSearchSegments)); } ResultUtil.reduceToTopK(perLeafResults, finalK); TopDocs[] topDocs = new TopDocs[perLeafResults.size()]; @@ -127,6 +139,28 @@ private List>> doRescore( return indexSearcher.getTaskExecutor().invokeAll(rescoreTasks); } + private List>> doExactSearch( + final IndexSearcher indexSearcher, + KNNWeight knnWeight, + List leafReaderContexts + ) throws IOException { + List>>> exactSearchTasks = new ArrayList<>(leafReaderContexts.size()); + for (LeafReaderContext context : leafReaderContexts) { + exactSearchTasks.add(() -> { + final ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder() + // setting to false because we want to do exact search on full precision vectors + .useQuantizedVectorsForSearch(false) + .k(knnQuery.getK()) + .knnQuery(knnQuery) + .isParentHits(true) + .build(); + final Map searchResults = knnWeight.exactSearch(context, exactSearcherContext); + return new AbstractMap.SimpleEntry<>(context, searchResults); + }); + } + return indexSearcher.getTaskExecutor().invokeAll(exactSearchTasks); + } + private Query createDocAndScoreQuery(IndexReader reader, TopDocs topK) { int len = topK.scoreDocs.length; Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));