Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split segment by search type #2273

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,26 @@ public Map<Integer, Float> searchLeaf(LeafReaderContext context, int k) throws I
return docIdsToScoreMap;
}

/**
* For given {@link LeafReaderContext}, this api will return will KNNWeight perform exact search or not
* always. This decision is based on two properties, 1) if there are no native engine files in segments,
* exact search will always be performed, 2) if number of docs after filter is less than 'k'
* @param context
* @return
* @throws IOException
*/
public boolean isExactSearchPreferred(LeafReaderContext context) throws IOException {
final BitSet filterBitSet = getFilteredDocsBitSet(context);
int cardinality = filterBitSet.cardinality();
if (isFilteredExactSearchPreferred(cardinality)) {
return true;
}
if (isMissingNativeEngineFiles(context)) {
return true;
}
return false;
}

private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOException {
if (this.filterWeight == null) {
return new FixedBitSet(0);
Expand Down
12 changes: 7 additions & 5 deletions src/main/java/org/opensearch/knn/index/query/ResultUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.index.query;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
Expand All @@ -30,14 +31,15 @@ public final class ResultUtil {
* @param perLeafResults Results from the list
* @param k the number of results across all leaf results to return
*/
public static void reduceToTopK(List<Map<Integer, Float>> perLeafResults, int k) {
public static void reduceToTopK(List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> perLeafResults, int k) {
// Iterate over all scores to get min competitive score
PriorityQueue<Float> topKMinQueue = new PriorityQueue<>(k);

int count = 0;
for (Map<Integer, Float> perLeafResult : perLeafResults) {
count += perLeafResult.size();
for (Float score : perLeafResult.values()) {
for (Map.Entry<LeafReaderContext, Map<Integer, Float>> perLeafResult : perLeafResults) {
Map<Integer, Float> docIdScoreMap = perLeafResult.getValue();
count += docIdScoreMap.size();
for (Float score : docIdScoreMap.values()) {
if (topKMinQueue.size() < k) {
topKMinQueue.add(score);
} else if (topKMinQueue.peek() != null && score > topKMinQueue.peek()) {
Expand All @@ -54,7 +56,7 @@ public static void reduceToTopK(List<Map<Integer, Float>> perLeafResults, int k)

// Reduce the results based on min competitive score
float minScore = topKMinQueue.peek() == null ? -Float.MAX_VALUE : topKMinQueue.peek();
perLeafResults.forEach(results -> results.entrySet().removeIf(entry -> entry.getValue() < minScore));
perLeafResults.forEach(results -> results.getValue().entrySet().removeIf(entry -> entry.getValue() < minScore));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.opensearch.knn.index.query.rescore.RescoreContext;

import java.io.IOException;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
Expand Down Expand Up @@ -55,7 +56,7 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo
final IndexReader reader = indexSearcher.getIndexReader();
final KNNWeight knnWeight = (KNNWeight) knnQuery.createWeight(indexSearcher, scoreMode, 1);
List<LeafReaderContext> leafReaderContexts = reader.leaves();
List<Map<Integer, Float>> perLeafResults;
List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> perLeafResults;
RescoreContext rescoreContext = knnQuery.getRescoreContext();
final int finalK = knnQuery.getK();
if (rescoreContext == null) {
Expand All @@ -64,20 +65,33 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo
boolean isShardLevelRescoringEnabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(knnQuery.getIndexName());
int dimension = knnQuery.getQueryVector().length;
int firstPassK = rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension);
perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, firstPassK);
// split segments into whether exact search will be performed or not
List<LeafReaderContext> exactSearchSegments = new ArrayList<>();
List<LeafReaderContext> approxSearchSegments = new ArrayList<>();
for (LeafReaderContext leafReaderContext : leafReaderContexts) {
if (knnWeight.isExactSearchPreferred(leafReaderContext)) {
exactSearchSegments.add(leafReaderContext);
} else {
approxSearchSegments.add(leafReaderContext);
}
}
perLeafResults = doSearch(indexSearcher, approxSearchSegments, knnWeight, firstPassK);
if (isShardLevelRescoringEnabled == true) {
ResultUtil.reduceToTopK(perLeafResults, firstPassK);
}

StopWatch stopWatch = new StopWatch().start();
perLeafResults = doRescore(indexSearcher, leafReaderContexts, knnWeight, perLeafResults, finalK);
perLeafResults = doRescore(indexSearcher, knnWeight, perLeafResults, finalK);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, we did exact search on finalK. After this change, we still does exact search on finalK. Could you tell me how will this improve the latency?

Copy link
Member Author

@VijayanB VijayanB Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doSearch can call either ApproxSearch or Exact Search based on conditions like whether engine files exists or not, number of docs after filter is less than k. In those cases, we will quantize query vector, and every vector from segments, and, then perform distance computation using Hamming distance for firstPassK. With this approach, we only call doSearch for those segments which we know will always call approxsearch, and, for other segments we will call exact search without quantization with finalK. The optimization is at https://github.com/opensearch-project/k-NN/pull/2273/files#diff-9cfe412357ba56b3ef216427d491fc653535686a760e8ba19ea1aa00fc0e0338R68-R78

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you assuming that an exact search on full precision vectors will be faster than an exact search with quantized vectors due to the slower quantization process? It would be interesting to see the benchmark results for this.

If that’s the case, an alternative could be to retrieve quantized values directly from the Faiss file instead of performing on-the-fly quantization.

Copy link
Member Author

@VijayanB VijayanB Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exact search on full precision for k is less than, exact search on quantization for first pass K + rescore matched docs on full precision . The linked GitHub issues actually shows how performance got impacted 10x when there are segments with no faiss engine files. In my POC, I saw improvements but recall was poor because of using order as link between results and leaf reader context. I am rerunning experiments with my change to collect metrics with latency and recall

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is one case where we are running exact search; when the returned result is less than k. Are we going to handle that case as well?

long rescoreTime = stopWatch.stop().totalTime().millis();
log.debug("Rescoring results took {} ms. oversampled k:{}, segments:{}", rescoreTime, firstPassK, leafReaderContexts.size());
log.debug("Rescoring results took {} ms. oversampled k:{}, segments:{}", rescoreTime, firstPassK, perLeafResults.size());
// do exact search on rest of segments and append to result lists
perLeafResults.addAll(doExactSearch(indexSearcher, knnWeight, exactSearchSegments));
}
ResultUtil.reduceToTopK(perLeafResults, finalK);
TopDocs[] topDocs = new TopDocs[perLeafResults.size()];
for (int i = 0; i < perLeafResults.size(); i++) {
topDocs[i] = ResultUtil.resultMapToTopDocs(perLeafResults.get(i), leafReaderContexts.get(i).docBase);
int i = 0;
for (Map.Entry<LeafReaderContext, Map<Integer, Float>> entry : perLeafResults) {
topDocs[i++] = ResultUtil.resultMapToTopDocs(entry.getValue(), entry.getKey().docBase);
}

TopDocs topK = TopDocs.merge(knnQuery.getK(), topDocs);
Expand All @@ -87,32 +101,29 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo
return createDocAndScoreQuery(reader, topK).createWeight(indexSearcher, scoreMode, boost);
}

private List<Map<Integer, Float>> doSearch(
private List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> doSearch(
final IndexSearcher indexSearcher,
List<LeafReaderContext> leafReaderContexts,
KNNWeight knnWeight,
int k
) throws IOException {
List<Callable<Map<Integer, Float>>> tasks = new ArrayList<>(leafReaderContexts.size());
List<Callable<Map.Entry<LeafReaderContext, Map<Integer, Float>>>> tasks = new ArrayList<>(leafReaderContexts.size());
for (LeafReaderContext leafReaderContext : leafReaderContexts) {
tasks.add(() -> searchLeaf(leafReaderContext, knnWeight, k));
}
return indexSearcher.getTaskExecutor().invokeAll(tasks);
}

private List<Map<Integer, Float>> doRescore(
private List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> doRescore(
final IndexSearcher indexSearcher,
List<LeafReaderContext> leafReaderContexts,
KNNWeight knnWeight,
List<Map<Integer, Float>> perLeafResults,
List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> perLeafResults,
int k
) throws IOException {
List<Callable<Map<Integer, Float>>> rescoreTasks = new ArrayList<>(leafReaderContexts.size());
for (int i = 0; i < perLeafResults.size(); i++) {
LeafReaderContext leafReaderContext = leafReaderContexts.get(i);
int finalI = i;
List<Callable<Map.Entry<LeafReaderContext, Map<Integer, Float>>>> rescoreTasks = new ArrayList<>(perLeafResults.size());
for (Map.Entry<LeafReaderContext, Map<Integer, Float>> entry : perLeafResults) {
rescoreTasks.add(() -> {
BitSet convertedBitSet = ResultUtil.resultMapToMatchBitSet(perLeafResults.get(finalI));
BitSet convertedBitSet = ResultUtil.resultMapToMatchBitSet(entry.getValue());
final ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder()
.matchedDocs(convertedBitSet)
// setting to false because in re-scoring we want to do exact search on full precision vectors
Expand All @@ -121,12 +132,35 @@ private List<Map<Integer, Float>> doRescore(
.isParentHits(false)
.knnQuery(knnQuery)
.build();
return knnWeight.exactSearch(leafReaderContext, exactSearcherContext);
final Map<Integer, Float> docIdScoreMap = knnWeight.exactSearch(entry.getKey(), exactSearcherContext);
return new AbstractMap.SimpleEntry<>(entry.getKey(), docIdScoreMap);
});
}
return indexSearcher.getTaskExecutor().invokeAll(rescoreTasks);
}

private List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> doExactSearch(
final IndexSearcher indexSearcher,
KNNWeight knnWeight,
List<LeafReaderContext> leafReaderContexts
) throws IOException {
List<Callable<Map.Entry<LeafReaderContext, Map<Integer, Float>>>> exactSearchTasks = new ArrayList<>(leafReaderContexts.size());
for (LeafReaderContext context : leafReaderContexts) {
exactSearchTasks.add(() -> {
final ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder()
// setting to false because we want to do exact search on full precision vectors
.useQuantizedVectorsForSearch(false)
.k(knnQuery.getK())
.knnQuery(knnQuery)
.isParentHits(true)
.build();
final Map<Integer, Float> searchResults = knnWeight.exactSearch(context, exactSearcherContext);
return new AbstractMap.SimpleEntry<>(context, searchResults);
});
}
return indexSearcher.getTaskExecutor().invokeAll(exactSearchTasks);
}

private Query createDocAndScoreQuery(IndexReader reader, TopDocs topK) {
int len = topK.scoreDocs.length;
Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
Expand Down Expand Up @@ -158,13 +192,14 @@ static int[] findSegmentStarts(IndexReader reader, int[] docs) {
return starts;
}

private Map<Integer, Float> searchLeaf(LeafReaderContext ctx, KNNWeight queryWeight, int k) throws IOException {
private Map.Entry<LeafReaderContext, Map<Integer, Float>> searchLeaf(LeafReaderContext ctx, KNNWeight queryWeight, int k)
throws IOException {
final Map<Integer, Float> leafDocScores = queryWeight.searchLeaf(ctx, k);
final Bits liveDocs = ctx.reader().getLiveDocs();
if (liveDocs != null) {
leafDocScores.entrySet().removeIf(entry -> liveDocs.get(entry.getKey()) == false);
}
return leafDocScores;
return new AbstractMap.SimpleEntry<>(ctx, leafDocScores);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,24 @@

package org.opensearch.knn.index.query;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.BitSet;
import org.opensearch.knn.KNNTestCase;

import java.io.IOException;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static org.mockito.Mockito.mock;

public class ResultUtilTests extends KNNTestCase {

public void testReduceToTopK() {
Expand All @@ -27,7 +31,9 @@ public void testReduceToTopK() {
int segmentCount = 5;

List<Map<Integer, Float>> initialLeafResults = getRandomListOfResults(firstPassK, segmentCount);
List<Map<Integer, Float>> reducedLeafResults = initialLeafResults.stream().map(HashMap::new).collect(Collectors.toList());
List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> reducedLeafResults = initialLeafResults.stream()
.map(item -> new AbstractMap.SimpleEntry<>(mock(LeafReaderContext.class), item))
.collect(Collectors.toList());
ResultUtil.reduceToTopK(reducedLeafResults, finalK);
assertTopK(initialLeafResults, reducedLeafResults, finalK);

Expand All @@ -36,7 +42,9 @@ public void testReduceToTopK() {
segmentCount = 1;

initialLeafResults = getRandomListOfResults(firstPassK, segmentCount);
reducedLeafResults = initialLeafResults.stream().map(HashMap::new).collect(Collectors.toList());
reducedLeafResults = initialLeafResults.stream()
.map(item -> new AbstractMap.SimpleEntry<>(mock(LeafReaderContext.class), item))
.collect(Collectors.toList());
ResultUtil.reduceToTopK(reducedLeafResults, finalK);
assertTopK(initialLeafResults, reducedLeafResults, firstPassK);
}
Expand Down Expand Up @@ -75,9 +83,13 @@ private void assertResultMapToTopDocs(Map<Integer, Float> perLeafResults, TopDoc
}
}

private void assertTopK(List<Map<Integer, Float>> beforeResults, List<Map<Integer, Float>> reducedResults, int expectedK) {
private void assertTopK(
List<Map<Integer, Float>> beforeResults,
List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> reducedResults,
int expectedK
) {
assertEquals(beforeResults.size(), reducedResults.size());
assertEquals(expectedK, reducedResults.stream().map(Map::size).reduce(Integer::sum).orElse(-1).intValue());
assertEquals(expectedK, reducedResults.stream().map(row -> row.getValue().size()).reduce(Integer::sum).orElse(-1).intValue());
float minScore = getMinScore(reducedResults);
int count = 0;
for (Map<Integer, Float> result : beforeResults) {
Expand Down Expand Up @@ -126,10 +138,10 @@ private Map<Integer, Float> getRandomResults(int k) {
return results;
}

private float getMinScore(List<Map<Integer, Float>> perLeafResults) {
private float getMinScore(List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> perLeafResults) {
float minScore = Float.MAX_VALUE;
for (Map<Integer, Float> result : perLeafResults) {
for (float score : result.values()) {
for (Map.Entry<LeafReaderContext, Map<Integer, Float>> result : perLeafResults) {
for (float score : result.getValue().values()) {
if (score < minScore) {
minScore = score;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ public void setUp() throws Exception {

when(searcher.getTaskExecutor()).thenReturn(taskExecutor);
when(taskExecutor.invokeAll(any())).thenAnswer(invocationOnMock -> {
List<Callable<Map<Integer, Float>>> callables = invocationOnMock.getArgument(0);
List<Map<Integer, Float>> results = new ArrayList<>();
for (Callable<Map<Integer, Float>> callable : callables) {
List<Callable<Map.Entry<LeafReaderContext, Map<Integer, Float>>>> callables = invocationOnMock.getArgument(0);
List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> results = new ArrayList<>();
for (Callable<Map.Entry<LeafReaderContext, Map<Integer, Float>>> callable : callables) {
results.add(callable.call());
}
return results;
Expand Down
Loading