Skip to content

Commit

Permalink
Minor code fixes which includes:
Browse files Browse the repository at this point in the history
* Passing correct score mode in NativeEngineKNNVectorQuery
* Ensuring visitor is called in KnnQuery

Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v committed Oct 13, 2024
1 parent 0749175 commit 9c415e1
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/main/java/org/opensearch/knn/index/query/KNNQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ private Weight getFilterWeight(IndexSearcher searcher) throws IOException {

@Override
public void visit(QueryVisitor visitor) {

visitor.visitLeaf(this);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public class NativeEngineKnnVectorQuery extends Query {
@Override
public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, float boost) throws IOException {
final IndexReader reader = indexSearcher.getIndexReader();
final KNNWeight knnWeight = (KNNWeight) knnQuery.createWeight(indexSearcher, ScoreMode.COMPLETE, 1);
final KNNWeight knnWeight = (KNNWeight) knnQuery.createWeight(indexSearcher, scoreMode, 1);
List<LeafReaderContext> leafReaderContexts = reader.leaves();
List<Map<Integer, Float>> perLeafResults;
RescoreContext rescoreContext = knnQuery.getRescoreContext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ public class NativeEngineKNNVectorQueryTests extends OpenSearchTestCase {
@InjectMocks
private NativeEngineKnnVectorQuery objectUnderTest;

private static ScoreMode scoreMode = ScoreMode.TOP_SCORES;

@Override
public void setUp() throws Exception {
super.setUp();
Expand All @@ -85,7 +87,7 @@ public void setUp() throws Exception {
when(leaf2.reader()).thenReturn(leafReader2);

when(searcher.getIndexReader()).thenReturn(reader);
when(knnQuery.createWeight(searcher, ScoreMode.COMPLETE, 1)).thenReturn(knnWeight);
when(knnQuery.createWeight(searcher, scoreMode, 1)).thenReturn(knnWeight);

when(searcher.getTaskExecutor()).thenReturn(taskExecutor);
when(taskExecutor.invokeAll(any())).thenAnswer(invocationOnMock -> {
Expand Down Expand Up @@ -135,7 +137,7 @@ public void testMultiLeaf() {
Query expected = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1);

// When
Weight actual = objectUnderTest.createWeight(searcher, ScoreMode.COMPLETE, 1);
Weight actual = objectUnderTest.createWeight(searcher, scoreMode, 1);

// Then
assertEquals(expected, actual.getQuery());
Expand Down Expand Up @@ -176,7 +178,7 @@ public void testRescoreWhenShardLevelRescoringEnabled() {
mockedResultUtil.when(() -> ResultUtil.reduceToTopK(any(), anyInt())).thenCallRealMethod();

// When
Weight actual = objectUnderTest.createWeight(searcher, ScoreMode.COMPLETE, 1);
Weight actual = objectUnderTest.createWeight(searcher, scoreMode, 1);

// Then
mockedResultUtil.verify(() -> ResultUtil.reduceToTopK(any(), anyInt()), times(2));
Expand All @@ -199,7 +201,7 @@ public void testSingleLeaf() {
Query expected = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1);

// When
Weight actual = objectUnderTest.createWeight(searcher, ScoreMode.COMPLETE, 1);
Weight actual = objectUnderTest.createWeight(searcher, scoreMode, 1);

// Then
assertEquals(expected, actual.getQuery());
Expand All @@ -214,7 +216,7 @@ public void testNoMatch() {
when(knnQuery.getK()).thenReturn(4);

// When
Weight actual = objectUnderTest.createWeight(searcher, ScoreMode.COMPLETE, 1);
Weight actual = objectUnderTest.createWeight(searcher, scoreMode, 1);

// Then
assertEquals(new MatchNoDocsQuery(), actual.getQuery());
Expand Down Expand Up @@ -260,7 +262,7 @@ public void testRescore() {
try (MockedStatic<NativeEngineKnnVectorQuery> mockedStaticNativeKnnVectorQuery = mockStatic(NativeEngineKnnVectorQuery.class)) {
mockedStaticNativeKnnVectorQuery.when(() -> NativeEngineKnnVectorQuery.findSegmentStarts(any(), any()))
.thenReturn(new int[] { 0, 4, 2 });
Weight actual = objectUnderTest.createWeight(searcher, ScoreMode.COMPLETE, 1);
Weight actual = objectUnderTest.createWeight(searcher, scoreMode, 1);
assertEquals(expected, actual.getQuery());
}
}
Expand Down

0 comments on commit 9c415e1

Please sign in to comment.