Skip to content

Commit

Permalink
Refactored HybridQueryPhaseSearcherTests to remove knn specific class…
Browse files Browse the repository at this point in the history
…es (#977)

* Refactored HybridQueryPhaseSearcherTests to remove knn specific classes

Signed-off-by: Owais <[email protected]>

* Refactored HybridQueryTests

Signed-off-by: Owais <[email protected]>

---------

Signed-off-by: Owais <[email protected]>
(cherry picked from commit 71017e6)
  • Loading branch information
owaiskazi19 authored and github-actions[bot] committed Nov 8, 2024
1 parent 1b16682 commit d9e322e
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 72 deletions.
2 changes: 1 addition & 1 deletion codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ coverage:
changes: yes

# disable comments in PRs
comment: yes
comment: yes
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,8 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import org.apache.lucene.document.FieldType;
Expand All @@ -39,35 +36,19 @@
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.index.Index;
import org.opensearch.index.mapper.TextFieldMapper;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.index.mapper.KNNMappingConfig;
import org.opensearch.knn.index.mapper.KNNVectorFieldType;
import org.opensearch.knn.index.query.KNNQueryBuilder;

import com.carrotsearch.randomizedtesting.RandomizedTest;

import lombok.SneakyThrows;

public class HybridQueryTests extends OpenSearchQueryTestCase {

static final String VECTOR_FIELD_NAME = "vectorField";
static final String TERM_QUERY_TEXT = "keyword";
static final String TERM_ANOTHER_QUERY_TEXT = "anotherkeyword";
static final float[] VECTOR_QUERY = new float[] { 1.0f, 2.0f, 2.1f, 0.6f };
static final int K = 2;
@Mock
protected ClusterService clusterService;
private AutoCloseable openMocks;
Expand All @@ -76,11 +57,6 @@ public class HybridQueryTests extends OpenSearchQueryTestCase {
public void setUp() throws Exception {
super.setUp();
openMocks = MockitoAnnotations.openMocks(this);
// This is required to make sure that before every test we are initializing the KNNSettings. Not doing this
// leads to failures of unit tests cases when a unit test is run separately. Try running this test:
// ./gradlew ':test' --tests "org.opensearch.knn.training.TrainingJobTests.testRun_success" and see it fails
// but if run along with other tests this test passes.
initKNNSettings();
}

@Override
Expand Down Expand Up @@ -141,30 +117,16 @@ public void testRewrite_whenRewriteQuery_thenSuccessful() {
w.commit();

IndexReader reader = DirectoryReader.open(w);

// Test with TermQuery
HybridQuery hybridQueryWithTerm = new HybridQuery(
List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext))
);
Query rewritten = hybridQueryWithTerm.rewrite(reader);
// term query is the same after we rewrite it
assertSame(hybridQueryWithTerm, rewritten);

Index dummyIndex = new Index("dummy", "dummy");
KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class);
KNNMappingConfig mockKNNMappingConfig = mock(KNNMappingConfig.class);
KNNMethodContext mockKNNMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, MethodComponentContext.EMPTY);
when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(mockKNNMappingConfig);
when(mockKNNMappingConfig.getKnnMethodContext()).thenReturn(Optional.ofNullable(mockKNNMethodContext));
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
when(mockKNNVectorField.getKnnMappingConfig().getDimension()).thenReturn(4);
when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField);
when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT);
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(VECTOR_FIELD_NAME, VECTOR_QUERY, K);
Query knnQuery = knnQueryBuilder.toQuery(mockQueryShardContext);

HybridQuery hybridQueryWithKnn = new HybridQuery(List.of(knnQuery));
rewritten = hybridQueryWithKnn.rewrite(reader);
assertSame(hybridQueryWithKnn, rewritten);

// Test empty query list
IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new HybridQuery(List.of()));
assertThat(exception.getMessage(), containsString("collection of queries must not be empty"));

Expand Down Expand Up @@ -353,17 +315,4 @@ public void testFilter_whenSubQueriesWithFilterPassed_thenSuccessful() {
}
assertEquals(2, countOfQueries);
}

private void initKNNSettings() {
Set<Setting<?>> defaultClusterSettings = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
defaultClusterSettings.addAll(
KNNSettings.state()
.getSettings()
.stream()
.filter(s -> s.getProperties().contains(Setting.Property.NodeScope))
.collect(Collectors.toList())
);
when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, defaultClusterSettings));
KNNSettings.state().setClusterService(clusterService);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@
import org.opensearch.index.remote.RemoteStoreEnums;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.index.shard.SearchOperationListener;
import org.opensearch.knn.index.mapper.KNNMappingConfig;
import org.opensearch.knn.index.mapper.KNNVectorFieldType;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase;
import org.opensearch.search.SearchShardTarget;
Expand All @@ -80,7 +78,6 @@
import org.opensearch.search.query.ReduceableSearchResult;

public class HybridQueryPhaseSearcherTests extends OpenSearchQueryTestCase {
private static final String VECTOR_FIELD_NAME = "vectorField";
private static final String TEXT_FIELD_NAME = "field";
private static final String TEST_DOC_TEXT1 = "Hello world";
private static final String TEST_DOC_TEXT2 = "Hi to this place";
Expand All @@ -94,12 +91,7 @@ public class HybridQueryPhaseSearcherTests extends OpenSearchQueryTestCase {
public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() {
HybridQueryPhaseSearcher hybridQueryPhaseSearcher = spy(new HybridQueryPhaseSearcher());
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class);
KNNMappingConfig mockKNNMappingConfig = mock(KNNMappingConfig.class);
when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(mockKNNMappingConfig);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
when(mockKNNVectorField.getKnnMappingConfig().getDimension()).thenReturn(4);
when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField);
MapperService mapperService = createMapperService();
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);
Expand Down Expand Up @@ -153,8 +145,11 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() {

HybridQueryBuilder queryBuilder = new HybridQueryBuilder();

TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1);
queryBuilder.add(termSubQuery);
// Add multiple term queries to simulate a complex hybrid query
TermQueryBuilder termSubQuery1 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1);
TermQueryBuilder termSubQuery2 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2);
queryBuilder.add(termSubQuery1);
queryBuilder.add(termSubQuery2);

Query query = queryBuilder.toQuery(mockQueryShardContext);
when(searchContext.query()).thenReturn(query);
Expand Down Expand Up @@ -826,12 +821,7 @@ public void testAggsProcessor_whenGettingAggsProcessor_thenSuccess() {
public void testAggregations_whenMetricAggregation_thenSuccessful() {
HybridQueryPhaseSearcher hybridQueryPhaseSearcher = spy(new HybridQueryPhaseSearcher());
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class);
KNNMappingConfig mockKNNMappingConfig = mock(KNNMappingConfig.class);
when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(mockKNNMappingConfig);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
when(mockKNNVectorField.getKnnMappingConfig().getDimension()).thenReturn(4);
when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField);
MapperService mapperService = createMapperService();
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);
Expand Down Expand Up @@ -886,8 +876,11 @@ public void testAggregations_whenMetricAggregation_thenSuccessful() {

HybridQueryBuilder queryBuilder = new HybridQueryBuilder();

TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1);
queryBuilder.add(termSubQuery);
// Add multiple queries to simulate a complex hybrid query
TermQueryBuilder termSubQuery1 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1);
TermQueryBuilder termSubQuery2 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2);
queryBuilder.add(termSubQuery1);
queryBuilder.add(termSubQuery2);

Query query = queryBuilder.toQuery(mockQueryShardContext);
when(searchContext.query()).thenReturn(query);
Expand Down

0 comments on commit d9e322e

Please sign in to comment.