From 1f9760cd406bb99b95bbe1701ac6de42e2f4e2e7 Mon Sep 17 00:00:00 2001 From: Owais Date: Thu, 7 Nov 2024 13:49:16 -0800 Subject: [PATCH] Refactored HybridQueryPhaseSearcherTests to remove knn specific classes Signed-off-by: Owais --- .../query/HybridQueryPhaseSearcherTests.java | 26 +++++++------------ 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index 0a88324fb..ab49cd3b6 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -63,8 +63,6 @@ import org.opensearch.index.remote.RemoteStoreEnums; import org.opensearch.index.shard.IndexShard; import org.opensearch.index.shard.SearchOperationListener; -import org.opensearch.knn.index.mapper.KNNMappingConfig; -import org.opensearch.knn.index.mapper.KNNVectorFieldType; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; import org.opensearch.search.SearchShardTarget; @@ -94,12 +92,7 @@ public class HybridQueryPhaseSearcherTests extends OpenSearchQueryTestCase { public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { HybridQueryPhaseSearcher hybridQueryPhaseSearcher = spy(new HybridQueryPhaseSearcher()); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); - KNNMappingConfig mockKNNMappingConfig = mock(KNNMappingConfig.class); - when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(mockKNNMappingConfig); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getKnnMappingConfig().getDimension()).thenReturn(4); - when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); MapperService mapperService = createMapperService(); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); @@ -153,8 +146,11 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); - TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); - queryBuilder.add(termSubQuery); + // Add multiple term queries to simulate a complex hybrid query + TermQueryBuilder termSubQuery1 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); + TermQueryBuilder termSubQuery2 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2); + queryBuilder.add(termSubQuery1); + queryBuilder.add(termSubQuery2); Query query = queryBuilder.toQuery(mockQueryShardContext); when(searchContext.query()).thenReturn(query); @@ -826,12 +822,7 @@ public void testAggsProcessor_whenGettingAggsProcessor_thenSuccess() { public void testAggregations_whenMetricAggregation_thenSuccessful() { HybridQueryPhaseSearcher hybridQueryPhaseSearcher = spy(new HybridQueryPhaseSearcher()); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); - KNNMappingConfig mockKNNMappingConfig = mock(KNNMappingConfig.class); - when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(mockKNNMappingConfig); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getKnnMappingConfig().getDimension()).thenReturn(4); - when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); MapperService mapperService = createMapperService(); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); @@ -886,8 +877,11 @@ public void testAggregations_whenMetricAggregation_thenSuccessful() { HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); - TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); - queryBuilder.add(termSubQuery); + // Add multiple queries to simulate a complex hybrid query + TermQueryBuilder termSubQuery1 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); + TermQueryBuilder termSubQuery2 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2); + queryBuilder.add(termSubQuery1); + queryBuilder.add(termSubQuery2); Query query = queryBuilder.toQuery(mockQueryShardContext); when(searchContext.query()).thenReturn(query);