Skip to content

Commit

Permalink
Refactored HybridQueryPhaseSearcherTests to remove knn specific classes
Browse files Browse the repository at this point in the history
Signed-off-by: Owais <[email protected]>
  • Loading branch information
owaiskazi19 committed Nov 7, 2024
1 parent 0316dc4 commit 1f9760c
Showing 1 changed file with 10 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@
import org.opensearch.index.remote.RemoteStoreEnums;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.index.shard.SearchOperationListener;
import org.opensearch.knn.index.mapper.KNNMappingConfig;
import org.opensearch.knn.index.mapper.KNNVectorFieldType;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase;
import org.opensearch.search.SearchShardTarget;
Expand Down Expand Up @@ -94,12 +92,7 @@ public class HybridQueryPhaseSearcherTests extends OpenSearchQueryTestCase {
public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() {
HybridQueryPhaseSearcher hybridQueryPhaseSearcher = spy(new HybridQueryPhaseSearcher());
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class);
KNNMappingConfig mockKNNMappingConfig = mock(KNNMappingConfig.class);
when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(mockKNNMappingConfig);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
when(mockKNNVectorField.getKnnMappingConfig().getDimension()).thenReturn(4);
when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField);
MapperService mapperService = createMapperService();
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);
Expand Down Expand Up @@ -153,8 +146,11 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() {

HybridQueryBuilder queryBuilder = new HybridQueryBuilder();

TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1);
queryBuilder.add(termSubQuery);
// Add multiple term queries to simulate a complex hybrid query
TermQueryBuilder termSubQuery1 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1);
TermQueryBuilder termSubQuery2 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2);
queryBuilder.add(termSubQuery1);
queryBuilder.add(termSubQuery2);

Query query = queryBuilder.toQuery(mockQueryShardContext);
when(searchContext.query()).thenReturn(query);
Expand Down Expand Up @@ -826,12 +822,7 @@ public void testAggsProcessor_whenGettingAggsProcessor_thenSuccess() {
public void testAggregations_whenMetricAggregation_thenSuccessful() {
HybridQueryPhaseSearcher hybridQueryPhaseSearcher = spy(new HybridQueryPhaseSearcher());
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class);
KNNMappingConfig mockKNNMappingConfig = mock(KNNMappingConfig.class);
when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(mockKNNMappingConfig);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
when(mockKNNVectorField.getKnnMappingConfig().getDimension()).thenReturn(4);
when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField);
MapperService mapperService = createMapperService();
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);
Expand Down Expand Up @@ -886,8 +877,11 @@ public void testAggregations_whenMetricAggregation_thenSuccessful() {

HybridQueryBuilder queryBuilder = new HybridQueryBuilder();

TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1);
queryBuilder.add(termSubQuery);
// Add multiple queries to simulate a complex hybrid query
TermQueryBuilder termSubQuery1 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1);
TermQueryBuilder termSubQuery2 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2);
queryBuilder.add(termSubQuery1);
queryBuilder.add(termSubQuery2);

Query query = queryBuilder.toQuery(mockQueryShardContext);
when(searchContext.query()).thenReturn(query);
Expand Down

0 comments on commit 1f9760c

Please sign in to comment.