diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 32a192325..1ac6fbcf7 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,2 +1,2 @@ # This should match the owning team set up in https://github.com/orgs/opensearch-project/teams -* @heemin32 @navneet1v @VijayanB @vamshin @jmazanec15 @naveentatikonda @junqiu-lei @martin-gaievski @sean-zheng-amazon @model-collapse @zane-neo @ylwu-amzn @jngz-es @vibrantvarun @zhichao-aws @yuye-aws +* @heemin32 @navneet1v @VijayanB @vamshin @jmazanec15 @naveentatikonda @junqiu-lei @martin-gaievski @sean-zheng-amazon @model-collapse @zane-neo @vibrantvarun @zhichao-aws @yuye-aws @minalsha diff --git a/CHANGELOG.md b/CHANGELOG.md index 595ea7dd4..2cd1278f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x) ### Features +- Pagination in Hybrid query ([#963](https://github.com/opensearch-project/neural-search/pull/963)) ### Enhancements ### Bug Fixes ### Infrastructure diff --git a/MAINTAINERS.md b/MAINTAINERS.md index d2e3cbd82..a1415e777 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -12,8 +12,6 @@ This document contains a list of maintainers in this repo. See [opensearch-proje | Charlie Yang | [model-collapse](https://github.com/model-collapse) | Amazon | | Navneet Verma | [navneet1v](https://github.com/navneet1v) | Amazon | | Zan Niu | [zane-neo](https://github.com/zane-neo) | Amazon | -| Yaliang Wu | [ylwu-amzn](https://github.com/ylwu-amzn) | Amazon | -| Jing Zhang | [jngz-es](https://github.com/jngz-es) | Amazon | | Heemin Kim | [heemin32](https://github.com/heemin32) | Amazon | | Junqiu Lei | [junqiu-lei](https://github.com/junqiu-lei) | Amazon | | Martin Gaievski | [martin-gaievski](https://github.com/martin-gaievski) | Amazon | @@ -22,9 +20,13 @@ This document contains a list of maintainers in this repo. See [opensearch-proje | Varun Jain | [vibrantvarun](https://github.com/vibrantvarun) | Amazon | | Zhichao Geng | [zhichao-aws](https://github.com/zhichao-aws) | Amazon | | Yuye Zhu | [yuye-aws](https://github.com/yuye-aws) | Amazon | +| Minal Shah | [minalsha](https://github.com/minalsha) | Amazon | + ## Emeritus | Maintainer | GitHub ID | Affiliation | |-------------------------|---------------------------------------------|-------------| | Junshen Wu | [wujunshen](https://github.com/wujunshen) | Independent | +| Yaliang Wu | [ylwu-amzn](https://github.com/ylwu-amzn) | Amazon | +| Jing Zhang | [jngz-es](https://github.com/jngz-es) | Amazon | diff --git a/src/main/java/org/opensearch/neuralsearch/common/MinClusterVersionUtil.java b/src/main/java/org/opensearch/neuralsearch/common/MinClusterVersionUtil.java index 0f5cbefcf..05e04e84a 100644 --- a/src/main/java/org/opensearch/neuralsearch/common/MinClusterVersionUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/common/MinClusterVersionUtil.java @@ -22,6 +22,7 @@ public final class MinClusterVersionUtil { private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0; private static final Version MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH = Version.V_2_14_0; + private static final Version MINIMAL_SUPPORTED_VERSION_PAGINATION_IN_HYBRID_QUERY = Version.V_2_19_0; // Note this minimal version will act as a override private static final Map MINIMAL_VERSION_NEURAL = ImmutableMap.builder() @@ -38,6 +39,10 @@ public static boolean isClusterOnOrAfterMinReqVersionForRadialSearch() { return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH); } + public static boolean isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery() { + return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_PAGINATION_IN_HYBRID_QUERY); + } + public static boolean isClusterOnOrAfterMinReqVersion(String key) { Version version; if (MINIMAL_VERSION_NEURAL.containsKey(key)) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java index 0563c92a0..a30bd7f56 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java @@ -16,6 +16,7 @@ import org.opensearch.action.search.SearchPhaseName; import org.opensearch.action.search.SearchPhaseResults; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.dto.NormalizationExecuteDto; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.fetch.FetchSearchResult; @@ -58,7 +59,16 @@ public void process( } List querySearchResults = getQueryPhaseSearchResults(searchPhaseResult); Optional fetchSearchResult = getFetchSearchResults(searchPhaseResult); - normalizationWorkflow.execute(querySearchResults, fetchSearchResult, normalizationTechnique, combinationTechnique); + // Builds data transfer object to pass into execute + NormalizationExecuteDto normalizationExecuteDto = NormalizationExecuteDto.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(fetchSearchResult) + .normalizationTechnique(normalizationTechnique) + .combinationTechnique(combinationTechnique) + .searchPhaseContext(searchPhaseContext) + .build(); + + normalizationWorkflow.execute(normalizationExecuteDto); } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index c64f1c1f4..1e3c8fc0d 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -18,10 +18,12 @@ import org.apache.lucene.search.Sort; import org.apache.lucene.search.TopFieldDocs; import org.apache.lucene.search.FieldDoc; +import org.opensearch.action.search.SearchPhaseContext; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; -import org.opensearch.neuralsearch.processor.combination.CombineScoresDto; +import org.opensearch.neuralsearch.processor.dto.CombineScoresDto; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; +import org.opensearch.neuralsearch.processor.dto.NormalizationExecuteDto; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; import org.opensearch.search.SearchHit; @@ -47,16 +49,17 @@ public class NormalizationProcessorWorkflow { /** * Start execution of this workflow - * @param querySearchResults input data with QuerySearchResult from multiple shards - * @param normalizationTechnique technique for score normalization - * @param combinationTechnique technique for score combination + * @param normalizationExecuteDto contains querySearchResults input data with QuerySearchResult + * from multiple shards, fetchSearchResultOptional, normalizationTechnique technique for score normalization + * combinationTechnique technique for score combination, searchPhaseContext. */ - public void execute( - final List querySearchResults, - final Optional fetchSearchResultOptional, - final ScoreNormalizationTechnique normalizationTechnique, - final ScoreCombinationTechnique combinationTechnique - ) { + public void execute(final NormalizationExecuteDto normalizationExecuteDto) { + final List querySearchResults = normalizationExecuteDto.getQuerySearchResults(); + final Optional fetchSearchResultOptional = normalizationExecuteDto.getFetchSearchResultOptional(); + final ScoreNormalizationTechnique normalizationTechnique = normalizationExecuteDto.getNormalizationTechnique(); + final ScoreCombinationTechnique combinationTechnique = normalizationExecuteDto.getCombinationTechnique(); + final SearchPhaseContext searchPhaseContext = normalizationExecuteDto.getSearchPhaseContext(); + // save original state List unprocessedDocIds = unprocessedDocIds(querySearchResults); @@ -73,6 +76,8 @@ public void execute( .scoreCombinationTechnique(combinationTechnique) .querySearchResults(querySearchResults) .sort(evaluateSortCriteria(querySearchResults, queryTopDocs)) + .fromValueForSingleShard(searchPhaseContext.getRequest().source().from()) + .isFetchResultsPresent(fetchSearchResultOptional.isPresent()) .build(); // combine @@ -82,7 +87,12 @@ public void execute( // post-process data log.debug("Post-process query results after score normalization and combination"); updateOriginalQueryResults(combineScoresDTO); - updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional, unprocessedDocIds); + updateOriginalFetchResults( + querySearchResults, + fetchSearchResultOptional, + unprocessedDocIds, + combineScoresDTO.getFromValueForSingleShard() + ); } /** @@ -113,15 +123,29 @@ private void updateOriginalQueryResults(final CombineScoresDto combineScoresDTO) final List querySearchResults = combineScoresDTO.getQuerySearchResults(); final List queryTopDocs = getCompoundTopDocs(combineScoresDTO, querySearchResults); final Sort sort = combineScoresDTO.getSort(); + int totalScoreDocsCount = 0; for (int index = 0; index < querySearchResults.size(); index++) { QuerySearchResult querySearchResult = querySearchResults.get(index); CompoundTopDocs updatedTopDocs = queryTopDocs.get(index); + totalScoreDocsCount += updatedTopDocs.getScoreDocs().size(); TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore( buildTopDocs(updatedTopDocs, sort), maxScoreForShard(updatedTopDocs, sort != null) ); + // Fetch Phase had ran before the normalization phase, therefore update the from value in result of each shard. + // This will ensure the trimming of the results. + if (combineScoresDTO.isFetchResultsPresent()) { + querySearchResult.from(combineScoresDTO.getFromValueForSingleShard()); + } querySearchResult.topDocs(updatedTopDocsAndMaxScore, querySearchResult.sortValueFormats()); } + + final int from = querySearchResults.get(0).from(); + if (from > 0 && from > totalScoreDocsCount) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Reached end of search result, increase pagination_depth value to see more results") + ); + } } private List getCompoundTopDocs(CombineScoresDto combineScoresDTO, List querySearchResults) { @@ -180,7 +204,8 @@ private TopDocs buildTopDocs(CompoundTopDocs updatedTopDocs, Sort sort) { private void updateOriginalFetchResults( final List querySearchResults, final Optional fetchSearchResultOptional, - final List docIds + final List docIds, + final int fromValueForSingleShard ) { if (fetchSearchResultOptional.isEmpty()) { return; @@ -212,14 +237,21 @@ private void updateOriginalFetchResults( QuerySearchResult querySearchResult = querySearchResults.get(0); TopDocs topDocs = querySearchResult.topDocs().topDocs; + + // When normalization process will execute before the fetch phase, then from =0 is applicable. + // When normalization process runs after fetch phase, then search hits already fetched. Therefore, use the from value sent in the + // search request. // iterate over the normalized/combined scores, that solves (1) and (3) - SearchHit[] updatedSearchHitArray = Arrays.stream(topDocs.scoreDocs).map(scoreDoc -> { + SearchHit[] updatedSearchHitArray = new SearchHit[topDocs.scoreDocs.length - fromValueForSingleShard]; + for (int i = fromValueForSingleShard; i < topDocs.scoreDocs.length; i++) { + ScoreDoc scoreDoc = topDocs.scoreDocs[i]; // get fetched hit content by doc_id SearchHit searchHit = docIdToSearchHit.get(scoreDoc.doc); // update score to normalized/combined value (3) searchHit.score(scoreDoc.score); - return searchHit; - }).toArray(SearchHit[]::new); + updatedSearchHitArray[i - fromValueForSingleShard] = searchHit; + } + SearchHits updatedSearchHits = new SearchHits( updatedSearchHitArray, querySearchResult.getTotalHits(), diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java index a4e39f448..c70ab0b78 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java @@ -26,6 +26,7 @@ import org.opensearch.neuralsearch.processor.CompoundTopDocs; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.processor.dto.CombineScoresDto; /** * Abstracts combination of scores in query search results. @@ -65,14 +66,10 @@ public class ScoreCombiner { public void combineScores(final CombineScoresDto combineScoresDTO) { // iterate over results from each shard. Every CompoundTopDocs object has results from // multiple sub queries, doc ids may repeat for each sub query results + ScoreCombinationTechnique scoreCombinationTechnique = combineScoresDTO.getScoreCombinationTechnique(); + Sort sort = combineScoresDTO.getSort(); combineScoresDTO.getQueryTopDocs() - .forEach( - compoundQueryTopDocs -> combineShardScores( - combineScoresDTO.getScoreCombinationTechnique(), - compoundQueryTopDocs, - combineScoresDTO.getSort() - ) - ); + .forEach(compoundQueryTopDocs -> combineShardScores(scoreCombinationTechnique, compoundQueryTopDocs, sort)); } private void combineShardScores( diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/CombineScoresDto.java b/src/main/java/org/opensearch/neuralsearch/processor/dto/CombineScoresDto.java similarity index 78% rename from src/main/java/org/opensearch/neuralsearch/processor/combination/CombineScoresDto.java rename to src/main/java/org/opensearch/neuralsearch/processor/dto/CombineScoresDto.java index c4783969b..77444f383 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/CombineScoresDto.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/dto/CombineScoresDto.java @@ -2,9 +2,10 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.neuralsearch.processor.combination; +package org.opensearch.neuralsearch.processor.dto; import java.util.List; + import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Getter; @@ -12,6 +13,7 @@ import org.apache.lucene.search.Sort; import org.opensearch.common.Nullable; import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import org.opensearch.search.query.QuerySearchResult; /** @@ -29,4 +31,6 @@ public class CombineScoresDto { private List querySearchResults; @Nullable private Sort sort; + private int fromValueForSingleShard; + private boolean isFetchResultsPresent; } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/dto/NormalizationExecuteDto.java b/src/main/java/org/opensearch/neuralsearch/processor/dto/NormalizationExecuteDto.java new file mode 100644 index 000000000..1ddda83d2 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/dto/NormalizationExecuteDto.java @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.dto; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; +import org.opensearch.search.fetch.FetchSearchResult; +import org.opensearch.search.query.QuerySearchResult; + +import java.util.List; +import java.util.Optional; + +/** + * DTO object to hold data in NormalizationProcessorWorkflow class + * in NormalizationProcessorWorkflow. + */ +@AllArgsConstructor +@Builder +@Getter +public class NormalizationExecuteDto { + @NonNull + private List querySearchResults; + @NonNull + private Optional fetchSearchResultOptional; + @NonNull + private ScoreNormalizationTechnique normalizationTechnique; + @NonNull + private ScoreCombinationTechnique combinationTechnique; + @NonNull + private SearchPhaseContext searchPhaseContext; +} diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java index 60d5870da..14514df60 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java @@ -15,6 +15,7 @@ import java.util.Objects; import java.util.concurrent.Callable; +import lombok.Getter; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.IndexSearcher; @@ -31,20 +32,25 @@ * Implementation of Query interface for type "hybrid". It allows execution of multiple sub-queries and collect individual * scores for each sub-query. */ +@Getter public final class HybridQuery extends Query implements Iterable { private final List subQueries; + private Integer paginationDepth; /** * Create new instance of hybrid query object based on collection of sub queries and filter query * @param subQueries collection of queries that are executed individually and contribute to a final list of combined scores * @param filterQueries list of filters that will be applied to each sub query. Each filter from the list is added as bool "filter" clause. If this is null sub queries will be executed as is */ - public HybridQuery(final Collection subQueries, final List filterQueries) { + public HybridQuery(final Collection subQueries, final List filterQueries, Integer paginationDepth) { Objects.requireNonNull(subQueries, "collection of queries must not be null"); if (subQueries.isEmpty()) { throw new IllegalArgumentException("collection of queries must not be empty"); } + if (paginationDepth != null && paginationDepth == 0) { + throw new IllegalArgumentException("pagination depth must not be zero"); + } if (Objects.isNull(filterQueries) || filterQueries.isEmpty()) { this.subQueries = new ArrayList<>(subQueries); } else { @@ -57,10 +63,11 @@ public HybridQuery(final Collection subQueries, final List filterQ } this.subQueries = modifiedSubQueries; } + this.paginationDepth = paginationDepth; } - public HybridQuery(final Collection subQueries) { - this(subQueries, List.of()); + public HybridQuery(final Collection subQueries, final Integer paginationDepth) { + this(subQueries, List.of(), paginationDepth); } /** @@ -128,7 +135,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { return super.rewrite(indexSearcher); } final List rewrittenSubQueries = manager.getQueriesAfterRewrite(collectors); - return new HybridQuery(rewrittenSubQueries); + return new HybridQuery(rewrittenSubQueries, paginationDepth); } private Void rewriteQuery(Query query, HybridQueryExecutorCollector> collector) { diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java index 60d9fd639..0b52b90e6 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java @@ -35,6 +35,8 @@ import lombok.experimental.Accessors; import lombok.extern.log4j.Log4j2; +import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery; + /** * Class abstract creation of a Query type "hybrid". Hybrid query will allow execution of multiple sub-queries and * collects score for each of those sub-query. @@ -48,16 +50,23 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder queries = new ArrayList<>(); private String fieldName; - + private Integer paginationDepth = null; static final int MAX_NUMBER_OF_SUB_QUERIES = 5; + private final static int DEFAULT_PAGINATION_DEPTH = 10; + private static final int LOWER_BOUND_OF_PAGINATION_DEPTH = 1; + private static final int UPPER_BOUND_OF_PAGINATION_DEPTH = 10000; public HybridQueryBuilder(StreamInput in) throws IOException { super(in); queries.addAll(readQueries(in)); + if (isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery()) { + paginationDepth = in.readOptionalInt(); + } } /** @@ -68,6 +77,9 @@ public HybridQueryBuilder(StreamInput in) throws IOException { @Override protected void doWriteTo(StreamOutput out) throws IOException { writeQueries(out, queries); + if (isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery()) { + out.writeOptionalInt(paginationDepth); + } } /** @@ -97,6 +109,9 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep queryBuilder.toXContent(builder, params); } builder.endArray(); + if (isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery()) { + builder.field(DEPTH_FIELD.getPreferredName(), paginationDepth == null ? DEFAULT_PAGINATION_DEPTH : paginationDepth); + } printBoostAndQueryName(builder); builder.endObject(); } @@ -113,7 +128,7 @@ protected Query doToQuery(QueryShardContext queryShardContext) throws IOExceptio if (queryCollection.isEmpty()) { return Queries.newMatchNoDocsQuery(String.format(Locale.ROOT, "no clauses for %s query", NAME)); } - return new HybridQuery(queryCollection); + return new HybridQuery(queryCollection, paginationDepth); } /** @@ -147,8 +162,10 @@ protected Query doToQuery(QueryShardContext queryShardContext) throws IOExceptio * @throws IOException */ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOException { + log.info("fromXContent called"); float boost = AbstractQueryBuilder.DEFAULT_BOOST; + Integer paginationDepth = null; final List queries = new ArrayList<>(); String queryName = null; @@ -196,6 +213,8 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx } } else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { queryName = parser.text(); + } else if (DEPTH_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + paginationDepth = parser.intValue(); } else { log.error(String.format(Locale.ROOT, "[%s] query does not support [%s]", NAME, currentFieldName)); throw new ParsingException( @@ -216,6 +235,10 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx HybridQueryBuilder compoundQueryBuilder = new HybridQueryBuilder(); compoundQueryBuilder.queryName(queryName); compoundQueryBuilder.boost(boost); + if (isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery()) { + validatePaginationDepth(paginationDepth); + compoundQueryBuilder.paginationDepth(paginationDepth); + } for (QueryBuilder query : queries) { compoundQueryBuilder.add(query); } @@ -235,6 +258,9 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryShardContext) throws I if (changed) { newBuilder.queryName(queryName); newBuilder.boost(boost); + if (isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery()) { + newBuilder.paginationDepth(paginationDepth); + } return newBuilder; } else { return this; @@ -257,6 +283,9 @@ protected boolean doEquals(HybridQueryBuilder obj) { EqualsBuilder equalsBuilder = new EqualsBuilder(); equalsBuilder.append(fieldName, obj.fieldName); equalsBuilder.append(queries, obj.queries); + if (isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery()) { + equalsBuilder.append(paginationDepth, obj.paginationDepth); + } return equalsBuilder.isEquals(); } @@ -297,6 +326,15 @@ private Collection toQueries(Collection queryBuilders, Quer return queries; } + private static void validatePaginationDepth(Integer paginationDepth) { + if (paginationDepth != null + && (paginationDepth < LOWER_BOUND_OF_PAGINATION_DEPTH || paginationDepth > UPPER_BOUND_OF_PAGINATION_DEPTH)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Pagination depth should lie in the range of 1-1000. Received: %s", paginationDepth) + ); + } + } + /** * visit method to parse the HybridQueryBuilder by a visitor */ diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index f9457f6ca..8a86117a6 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -8,6 +8,7 @@ import lombok.RequiredArgsConstructor; import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.Collector; import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.Weight; @@ -22,6 +23,7 @@ import org.opensearch.common.Nullable; import org.opensearch.common.lucene.search.FilteredCollector; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.neuralsearch.query.HybridQuery; import org.opensearch.neuralsearch.search.HitsThresholdChecker; import org.opensearch.neuralsearch.search.collector.HybridSearchCollector; import org.opensearch.neuralsearch.search.collector.HybridTopFieldDocSortCollector; @@ -80,14 +82,28 @@ public abstract class HybridCollectorManager implements CollectorManager 0) { + searchContext.from(0); + } + Weight filteringWeight = null; // Check for post filter to create weight for filter query and later use that weight in the search workflow if (Objects.nonNull(searchContext.parsedPostFilter()) && Objects.nonNull(searchContext.parsedPostFilter().query())) { @@ -461,6 +477,38 @@ private ReduceableSearchResult reduceSearchResults(final List BooleanClause.Occur.FILTER == clause.getOccur()) .map(BooleanClause::getQuery) .collect(Collectors.toList()); - HybridQuery hybridQueryWithFilter = new HybridQuery(hybridQuery.getSubQueries(), filterQueries); + HybridQuery hybridQueryWithFilter = new HybridQuery( + hybridQuery.getSubQueries(), + filterQueries, + hybridQuery.getPaginationDepth() + ); return hybridQueryWithFilter; } return query; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java index e93c9b9ec..218515a97 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java @@ -134,6 +134,7 @@ public void testSearchResultTypes_whenCompoundDocs_thenDoNormalizationCombinatio ); SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.source().from(0); searchRequest.setBatchedReduceSize(4); AtomicReference onPartialMergeFailure = new AtomicReference<>(); QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer( @@ -179,6 +180,7 @@ public void testSearchResultTypes_whenCompoundDocs_thenDoNormalizationCombinatio } SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); List querySearchResults = queryPhaseResultConsumer.getAtomicArray() @@ -203,6 +205,7 @@ public void testScoreCorrectness_whenCompoundDocs_thenDoNormalizationCombination ); SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.source().from(0); searchRequest.setBatchedReduceSize(4); AtomicReference onPartialMergeFailure = new AtomicReference<>(); QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer( @@ -247,6 +250,7 @@ public void testScoreCorrectness_whenCompoundDocs_thenDoNormalizationCombination SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); when(searchPhaseContext.getNumShards()).thenReturn(1); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); List querySearchResults = queryPhaseResultConsumer.getAtomicArray() @@ -272,7 +276,7 @@ public void testEmptySearchResults_whenEmptySearchResults_thenDoNotExecuteWorkfl SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); normalizationProcessor.process(null, searchPhaseContext); - verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any()); + verify(normalizationProcessorWorkflow, never()).execute(any()); } public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResult_thenDoNotExecuteWorkflow() { @@ -328,7 +332,7 @@ public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResul when(searchPhaseContext.getNumShards()).thenReturn(numberOfShards); normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); - verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any()); + verify(normalizationProcessorWorkflow, never()).execute(any()); } public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormalization() { @@ -344,6 +348,7 @@ public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormaliz ); SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.source().from(0); searchRequest.setBatchedReduceSize(4); AtomicReference onPartialMergeFailure = new AtomicReference<>(); QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer( @@ -408,6 +413,7 @@ public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormaliz queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown); SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); List querySearchResults = queryPhaseResultConsumer.getAtomicArray() @@ -417,7 +423,7 @@ public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormaliz .collect(Collectors.toList()); TestUtils.assertQueryResultScores(querySearchResults); - verify(normalizationProcessorWorkflow).execute(any(), any(), any(), any()); + verify(normalizationProcessorWorkflow).execute(any()); } public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenFail() { @@ -433,6 +439,7 @@ public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenFail() ); SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.source().from(0); searchRequest.setBatchedReduceSize(4); AtomicReference onPartialMergeFailure = new AtomicReference<>(); QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer( @@ -495,6 +502,7 @@ public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenFail() queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown); SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); IllegalStateException exception = expectThrows( IllegalStateException.class, () -> normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java index 59fb51563..e09aea187 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java @@ -12,6 +12,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Optional; @@ -19,8 +20,11 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; import org.opensearch.action.OriginalIndices; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchRequest; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.neuralsearch.processor.dto.NormalizationExecuteDto; import org.opensearch.neuralsearch.util.TestUtils; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; @@ -35,6 +39,7 @@ import org.opensearch.test.OpenSearchTestCase; public class NormalizationProcessorWorkflowTests extends OpenSearchTestCase { + private static final String INDEX_NAME = "normalization-index"; public void testSearchResultTypes_whenResultsOfHybridSearch_thenDoNormalizationCombination() { NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( @@ -72,12 +77,19 @@ public void testSearchResultTypes_whenResultsOfHybridSearch_thenDoNormalizationC querySearchResults.add(querySearchResult); } - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.empty(), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD - ); + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.source().from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + NormalizationExecuteDto normalizationExecuteDTO = NormalizationExecuteDto.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.empty()) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + + normalizationProcessorWorkflow.execute(normalizationExecuteDTO); TestUtils.assertQueryResultScores(querySearchResults); } @@ -114,12 +126,19 @@ public void testSearchResultTypes_whenNoMatches_thenReturnZeroResults() { querySearchResults.add(querySearchResult); } - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.empty(), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD - ); + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.source().from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + NormalizationExecuteDto normalizationExecuteDto = NormalizationExecuteDto.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.empty()) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + + normalizationProcessorWorkflow.execute(normalizationExecuteDto); TestUtils.assertQueryResultScoresWithNoMatches(querySearchResults); } @@ -173,12 +192,18 @@ public void testFetchResults_whenOneShardAndQueryAndFetchResultsPresent_thenDoNo SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10); fetchSearchResult.hits(searchHits); - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.of(fetchSearchResult), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD - ); + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.source().from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + NormalizationExecuteDto normalizationExecuteDto = NormalizationExecuteDto.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.of(fetchSearchResult)) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + normalizationProcessorWorkflow.execute(normalizationExecuteDto); TestUtils.assertQueryResultScores(querySearchResults); TestUtils.assertFetchResultScores(fetchSearchResult, 4); @@ -233,12 +258,18 @@ public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCom SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10); fetchSearchResult.hits(searchHits); - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.of(fetchSearchResult), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD - ); + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.source().from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + NormalizationExecuteDto normalizationExecuteDto = NormalizationExecuteDto.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.of(fetchSearchResult)) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + normalizationProcessorWorkflow.execute(normalizationExecuteDto); TestUtils.assertQueryResultScores(querySearchResults); TestUtils.assertFetchResultScores(fetchSearchResult, 4); @@ -285,15 +316,19 @@ public void testFetchResultsAndNoCache_whenOneShardAndMultipleNodesAndMismatchRe SearchHits searchHits = getSearchHits(); fetchSearchResult.hits(searchHits); - expectThrows( - IllegalStateException.class, - () -> normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.of(fetchSearchResult), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD - ) - ); + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.source().from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + NormalizationExecuteDto normalizationExecuteDto = NormalizationExecuteDto.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.of(fetchSearchResult)) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + + expectThrows(IllegalStateException.class, () -> normalizationProcessorWorkflow.execute(normalizationExecuteDto)); } public void testFetchResultsAndCache_whenOneShardAndMultipleNodesAndMismatchResults_thenSuccessful() { @@ -337,17 +372,87 @@ public void testFetchResultsAndCache_whenOneShardAndMultipleNodesAndMismatchResu SearchHits searchHits = getSearchHits(); fetchSearchResult.hits(searchHits); - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.of(fetchSearchResult), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD - ); + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.source().from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + + NormalizationExecuteDto normalizationExecuteDto = NormalizationExecuteDto.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.of(fetchSearchResult)) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + + normalizationProcessorWorkflow.execute(normalizationExecuteDto); TestUtils.assertQueryResultScores(querySearchResults); TestUtils.assertFetchResultScores(fetchSearchResult, 4); } + public void testNormalization_whenFromIsGreaterThanResultsSize_thenFail() { + NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) + ); + + List querySearchResults = new ArrayList<>(); + for (int shardId = 0; shardId < 4; shardId++) { + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + querySearchResult.topDocs( + new TopDocsAndMaxScore( + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + new ScoreDoc(0, 0.5f), + new ScoreDoc(2, 0.3f), + new ScoreDoc(4, 0.25f), + new ScoreDoc(10, 0.2f), + createStartStopElementForHybridSearchResults(0) } + ), + 0.5f + ), + null + ); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + // requested page is out of bound for the total number of results + querySearchResult.from(17); + querySearchResults.add(querySearchResult); + } + + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.source().from(17); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + + NormalizationExecuteDto normalizationExecuteDto = NormalizationExecuteDto.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.empty()) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> normalizationProcessorWorkflow.execute(normalizationExecuteDto) + ); + + assertEquals( + String.format(Locale.ROOT, "Reached end of search result, increase pagination_depth value to see more results"), + illegalArgumentException.getMessage() + ); + } + private static SearchHits getSearchHits() { SearchHit[] searchHitArray = new SearchHit[] { new SearchHit(-1, "10", Map.of(), Map.of()), diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java index 918f3f45b..cbfacbc88 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java @@ -5,7 +5,7 @@ package org.opensearch.neuralsearch.processor; import java.util.Collections; -import org.opensearch.neuralsearch.processor.combination.CombineScoresDto; +import org.opensearch.neuralsearch.processor.dto.CombineScoresDto; import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; import java.util.List; diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java index 2a6fa49a3..e0bdfee41 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java @@ -52,6 +52,7 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.MatchQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.QueryShardContext; @@ -401,6 +402,167 @@ public void testFromXContent_whenIncorrectFormat_thenFail() { expectThrows(ParsingException.class, () -> HybridQueryBuilder.fromXContent(contentParser2)); } + @SneakyThrows + public void testFromXContent_whenQueriesCountIsGreaterThanFive_thenFail() { + setUpClusterService(); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startArray("queries") + .startObject() + .startObject(NeuralQueryBuilder.NAME) + .startObject(VECTOR_FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(K_FIELD.getPreferredName(), K) + .field(BOOST_FIELD.getPreferredName(), BOOST) + .endObject() + .endObject() + .endObject() + .startObject() + .startObject(TermQueryBuilder.NAME) + .field(TEXT_FIELD_NAME, TERM_QUERY_TEXT) + .endObject() + .endObject() + .startObject() + .startObject(MatchQueryBuilder.NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .endObject() + .endObject() + .startObject() + .startObject(TermQueryBuilder.NAME) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .endObject() + .endObject() + .startObject() + .startObject(MatchAllQueryBuilder.NAME) + .endObject() + .endObject() + .startObject() + .startObject(TermQueryBuilder.NAME) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .endObject() + .endObject() + .endArray() + .endObject(); + + NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry( + List.of( + new NamedXContentRegistry.Entry(QueryBuilder.class, new ParseField(TermQueryBuilder.NAME), TermQueryBuilder::fromXContent), + new NamedXContentRegistry.Entry( + QueryBuilder.class, + new ParseField(NeuralQueryBuilder.NAME), + NeuralQueryBuilder::fromXContent + ), + new NamedXContentRegistry.Entry( + QueryBuilder.class, + new ParseField(HybridQueryBuilder.NAME), + HybridQueryBuilder::fromXContent + ), + new NamedXContentRegistry.Entry( + QueryBuilder.class, + new ParseField(MatchAllQueryBuilder.NAME), + MatchAllQueryBuilder::fromXContent + ), + new NamedXContentRegistry.Entry(QueryBuilder.class, new ParseField(MatchQueryBuilder.NAME), MatchQueryBuilder::fromXContent) + ) + ); + XContentParser contentParser = createParser( + namedXContentRegistry, + xContentBuilder.contentType().xContent(), + BytesReference.bytes(xContentBuilder) + ); + contentParser.nextToken(); + ParsingException exception = expectThrows(ParsingException.class, () -> HybridQueryBuilder.fromXContent(contentParser)); + assertThat(exception.getMessage(), containsString("Number of sub-queries exceeds maximum supported by [hybrid] query")); + } + + @SneakyThrows + public void testFromXContent_whenPaginationDepthIsInvalid_thenFail() { + setUpClusterService(); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field("pagination_depth", -1) + .startArray("queries") + .startObject() + .startObject(NeuralQueryBuilder.NAME) + .startObject(VECTOR_FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(K_FIELD.getPreferredName(), K) + .field(BOOST_FIELD.getPreferredName(), BOOST) + .endObject() + .endObject() + .endObject() + .startObject() + .startObject(TermQueryBuilder.NAME) + .field(TEXT_FIELD_NAME, TERM_QUERY_TEXT) + .endObject() + .endObject() + .endArray() + .endObject(); + + NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry( + List.of( + new NamedXContentRegistry.Entry(QueryBuilder.class, new ParseField(TermQueryBuilder.NAME), TermQueryBuilder::fromXContent), + new NamedXContentRegistry.Entry( + QueryBuilder.class, + new ParseField(NeuralQueryBuilder.NAME), + NeuralQueryBuilder::fromXContent + ), + new NamedXContentRegistry.Entry( + QueryBuilder.class, + new ParseField(HybridQueryBuilder.NAME), + HybridQueryBuilder::fromXContent + ) + ) + ); + XContentParser contentParser = createParser( + namedXContentRegistry, + xContentBuilder.contentType().xContent(), + BytesReference.bytes(xContentBuilder) + ); + contentParser.nextToken(); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> HybridQueryBuilder.fromXContent(contentParser) + ); + assertThat(exception.getMessage(), containsString("Pagination depth should lie in the range of 1-1000. Received: -1")); + + XContentBuilder xContentBuilder1 = XContentFactory.jsonBuilder() + .startObject() + .field("pagination_depth", 10001) + .startArray("queries") + .startObject() + .startObject(NeuralQueryBuilder.NAME) + .startObject(VECTOR_FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(K_FIELD.getPreferredName(), K) + .field(BOOST_FIELD.getPreferredName(), BOOST) + .endObject() + .endObject() + .endObject() + .startObject() + .startObject(TermQueryBuilder.NAME) + .field(TEXT_FIELD_NAME, TERM_QUERY_TEXT) + .endObject() + .endObject() + .endArray() + .endObject(); + + XContentParser contentParser1 = createParser( + namedXContentRegistry, + xContentBuilder1.contentType().xContent(), + BytesReference.bytes(xContentBuilder1) + ); + contentParser1.nextToken(); + IllegalArgumentException exception1 = expectThrows( + IllegalArgumentException.class, + () -> HybridQueryBuilder.fromXContent(contentParser1) + ); + assertThat(exception1.getMessage(), containsString("Pagination depth should lie in the range of 1-1000. Received: 10001")); + } + @SneakyThrows public void testToXContent_whenIncomingJsonIsCorrect_thenSuccessful() { HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); @@ -524,6 +686,7 @@ public void testHashAndEquals_whenSameOrIdenticalObject_thenReturnEqual() { } public void testHashAndEquals_whenSubQueriesDifferent_thenReturnNotEqual() { + setUpClusterService(); String modelId = "testModelId"; String fieldName = "fieldTwo"; String queryText = "query text"; @@ -612,6 +775,7 @@ public void testHashAndEquals_whenSubQueriesDifferent_thenReturnNotEqual() { @SneakyThrows public void testRewrite_whenMultipleSubQueries_thenReturnBuilderForEachSubQuery() { + setUpClusterService(); HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME) .queryText(QUERY_TEXT) diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 610e08dd0..16728cfd7 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -28,6 +28,7 @@ import org.junit.Before; import org.opensearch.client.ResponseException; import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.MatchQueryBuilder; import org.opensearch.index.query.NestedQueryBuilder; import org.opensearch.index.query.QueryBuilder; @@ -320,6 +321,7 @@ public void testNoMatchResults_whenOnlyTermSubQueryWithoutMatch_thenEmptyResult( HybridQueryBuilder hybridQueryBuilderOnlyTerm = new HybridQueryBuilder(); hybridQueryBuilderOnlyTerm.add(termQueryBuilder); hybridQueryBuilderOnlyTerm.add(termQuery2Builder); + hybridQueryBuilderOnlyTerm.paginationDepth(10); Map searchResponseAsMap = search( TEST_MULTI_DOC_WITH_NESTED_FIELDS_INDEX_NAME, @@ -637,6 +639,7 @@ public void testWrappedQueryWithFilter_whenIndexAliasHasFilterAndIndexWithNested HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); hybridQueryBuilder.add(QueryBuilders.existsQuery(TEST_TEXT_FIELD_NAME_1)); + // hybridQueryBuilder.paginationDepth(10); Map searchResponseAsMap = search( alias, @@ -793,46 +796,221 @@ public void testConcurrentSearchWithMultipleSlices_whenMultipleShardsIndex_thenS } } - // TODO remove this test after following issue https://github.com/opensearch-project/neural-search/issues/280 gets resolved. @SneakyThrows - public void testHybridQuery_whenFromIsSetInSearchRequest_thenFail() { + public void testPaginationDepth_whenSubqueriesCountIsGreaterThanFive_thenFail() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT); + TermQueryBuilder termQuery2Builder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT2); + MatchAllQueryBuilder matchAllQueryBuilder = QueryBuilders.matchAllQuery(); + MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT); + MatchQueryBuilder matchQueryBuilder1 = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT2); + MatchQueryBuilder matchQueryBuilder2 = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(termQueryBuilder); + hybridQueryBuilder.add(termQuery2Builder); + hybridQueryBuilder.add(matchAllQueryBuilder); + hybridQueryBuilder.add(matchQueryBuilder); + hybridQueryBuilder.add(matchQueryBuilder1); + hybridQueryBuilder.add(matchQueryBuilder2); + hybridQueryBuilder.paginationDepth(10); + + ResponseException responseException = assertThrows( + ResponseException.class, + () -> search( + TEST_MULTI_DOC_INDEX_NAME, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + null, + false, + null, + 0 + ) + ); + + org.hamcrest.MatcherAssert.assertThat( + responseException.getMessage(), + allOf(containsString("Number of sub-queries exceeds maximum supported by [hybrid] query")) + ); + + } + + @SneakyThrows + public void testPaginationOnSingleShard_whenConcurrentSearchEnabled_thenSuccessful() { try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); - MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); - HybridQueryBuilder hybridQueryBuilderOnlyTerm = new HybridQueryBuilder(); - hybridQueryBuilderOnlyTerm.add(matchQueryBuilder); + testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE); + } + } - ResponseException exceptionNoNestedTypes = expectThrows( - ResponseException.class, - () -> search( - TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, - hybridQueryBuilderOnlyTerm, - null, - 10, - Map.of("search_pipeline", SEARCH_PIPELINE), - null, - null, - null, - false, - null, - 10 - ) + @SneakyThrows + public void testPaginationOnSingleShard_whenConcurrentSearchDisabled_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE); + } + } - ); + @SneakyThrows + public void testPaginationOnMultipleShard_whenConcurrentSearchEnabled_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME); + testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME); + testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail(TEST_MULTI_DOC_INDEX_NAME); + testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail(TEST_MULTI_DOC_INDEX_NAME); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, SEARCH_PIPELINE); + } + } - org.hamcrest.MatcherAssert.assertThat( - exceptionNoNestedTypes.getMessage(), - allOf( - containsString("In the current OpenSearch version pagination is not supported with hybrid query"), - containsString("illegal_argument_exception") - ) - ); + @SneakyThrows + public void testPaginationOnMultipleShard_whenConcurrentSearchDisabled_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME); + testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME); + testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail(TEST_MULTI_DOC_INDEX_NAME); + testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail(TEST_MULTI_DOC_INDEX_NAME); } finally { - wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE); + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, SEARCH_PIPELINE); } } + @SneakyThrows + public void testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(String indexName) { + HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder(); + hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder()); + hybridQueryBuilderOnlyMatchAll.paginationDepth(10); + + Map searchResponseAsMap = search( + indexName, + hybridQueryBuilderOnlyMatchAll, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + null, + false, + null, + 2 + ); + + assertEquals(2, getHitCount(searchResponseAsMap)); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(4, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } + + @SneakyThrows + public void testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful(String indexName) { + HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder(); + hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder()); + + Map searchResponseAsMap = search( + indexName, + hybridQueryBuilderOnlyMatchAll, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + null, + false, + null, + 2 + ); + + assertEquals(2, getHitCount(searchResponseAsMap)); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(4, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } + + @SneakyThrows + public void testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail(String indexName) { + HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder(); + hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder()); + + ResponseException responseException = assertThrows( + ResponseException.class, + () -> search( + indexName, + hybridQueryBuilderOnlyMatchAll, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + null, + false, + null, + 5 + ) + ); + + org.hamcrest.MatcherAssert.assertThat( + responseException.getMessage(), + allOf(containsString("Reached end of search result, increase pagination_depth value to see more results")) + ); + } + + @SneakyThrows + public void testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail(String indexName) { + HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder(); + hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder()); + hybridQueryBuilderOnlyMatchAll.paginationDepth(100001); + + ResponseException responseException = assertThrows( + ResponseException.class, + () -> search( + indexName, + hybridQueryBuilderOnlyMatchAll, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + null, + false, + null, + 0 + ) + ); + + org.hamcrest.MatcherAssert.assertThat( + responseException.getMessage(), + allOf(containsString("Pagination depth should lie in the range of 1-1000. Received: 100001")) + ); + } + @SneakyThrows private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) { diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java index f821e7ddf..43b609545 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java @@ -96,16 +96,19 @@ public void testQueryBasics_whenMultipleDifferentQueries_thenSuccessful() { when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); HybridQuery query1 = new HybridQuery( - List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), + null ); HybridQuery query2 = new HybridQuery( - List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), + null ); HybridQuery query3 = new HybridQuery( List.of( QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_ANOTHER_QUERY_TEXT).toQuery(mockQueryShardContext) - ) + ), + 5 ); QueryUtils.check(query1); QueryUtils.checkEqual(query1, query2); @@ -120,6 +123,7 @@ public void testQueryBasics_whenMultipleDifferentQueries_thenSuccessful() { countOfQueries++; } assertEquals(2, countOfQueries); + assertEquals(5, (int) query3.getPaginationDepth()); } @SneakyThrows @@ -142,7 +146,8 @@ public void testRewrite_whenRewriteQuery_thenSuccessful() { IndexReader reader = DirectoryReader.open(w); HybridQuery hybridQueryWithTerm = new HybridQuery( - List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), + null ); Query rewritten = hybridQueryWithTerm.rewrite(reader); // term query is the same after we rewrite it @@ -161,11 +166,11 @@ public void testRewrite_whenRewriteQuery_thenSuccessful() { KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(VECTOR_FIELD_NAME, VECTOR_QUERY, K); Query knnQuery = knnQueryBuilder.toQuery(mockQueryShardContext); - HybridQuery hybridQueryWithKnn = new HybridQuery(List.of(knnQuery)); + HybridQuery hybridQueryWithKnn = new HybridQuery(List.of(knnQuery), null); rewritten = hybridQueryWithKnn.rewrite(reader); assertSame(hybridQueryWithKnn, rewritten); - IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new HybridQuery(List.of())); + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new HybridQuery(List.of(), null)); assertThat(exception.getMessage(), containsString("collection of queries must not be empty")); w.close(); @@ -198,7 +203,8 @@ public void testWithRandomDocuments_whenMultipleTermSubQueriesWithMatch_thenRetu IndexSearcher searcher = newSearcher(reader); HybridQuery query = new HybridQuery( - List.of(new TermQuery(new Term(TEXT_FIELD_NAME, field1Value)), new TermQuery(new Term(TEXT_FIELD_NAME, field2Value))) + List.of(new TermQuery(new Term(TEXT_FIELD_NAME, field1Value)), new TermQuery(new Term(TEXT_FIELD_NAME, field2Value))), + null ); // executing search query, getting up to 3 docs in result TopDocs hybridQueryResult = searcher.search(query, 3); @@ -244,7 +250,7 @@ public void testWithRandomDocuments_whenOneTermSubQueryWithoutMatch_thenReturnSu DirectoryReader reader = DirectoryReader.open(w); IndexSearcher searcher = newSearcher(reader); - HybridQuery query = new HybridQuery(List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT)))); + HybridQuery query = new HybridQuery(List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT))), null); // executing search query, getting up to 3 docs in result TopDocs hybridQueryResult = searcher.search(query, 3); @@ -280,7 +286,8 @@ public void testWithRandomDocuments_whenMultipleTermSubQueriesWithoutMatch_thenR IndexSearcher searcher = newSearcher(reader); HybridQuery query = new HybridQuery( - List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT)), new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT))) + List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT)), new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT))), + null ); // executing search query, getting up to 3 docs in result TopDocs hybridQueryResult = searcher.search(query, 3); @@ -294,10 +301,22 @@ public void testWithRandomDocuments_whenMultipleTermSubQueriesWithoutMatch_thenR @SneakyThrows public void testWithRandomDocuments_whenNoSubQueries_thenFail() { - IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new HybridQuery(List.of())); + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new HybridQuery(List.of(), null)); assertThat(exception.getMessage(), containsString("collection of queries must not be empty")); } + @SneakyThrows + public void testWithRandomDocuments_whenPaginationDepthIsZero_thenFail() { + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> new HybridQuery( + List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT)), new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT))), + 0 + ) + ); + assertThat(exception.getMessage(), containsString("pagination depth must not be zero")); + } + @SneakyThrows public void testToString_whenCallQueryToString_thenSuccessful() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); @@ -311,7 +330,8 @@ public void testToString_whenCallQueryToString_thenSuccessful() { new BoolQueryBuilder().should(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT)) .should(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_ANOTHER_QUERY_TEXT)) .toQuery(mockQueryShardContext) - ) + ), + null ); String queryString = query.toString(TEXT_FIELD_NAME); @@ -331,7 +351,8 @@ public void testFilter_whenSubQueriesWithFilterPassed_thenSuccessful() { QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_ANOTHER_QUERY_TEXT).toQuery(mockQueryShardContext) ), - List.of(filter) + List.of(filter), + null ); QueryUtils.check(hybridQuery); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java index 10d480475..dcb910c55 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java @@ -60,7 +60,8 @@ public void testScorerIterator_whenExecuteQuery_thenScorerIteratorSuccessful() { IndexReader reader = DirectoryReader.open(w); HybridQuery hybridQueryWithTerm = new HybridQuery( - List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), + null ); IndexSearcher searcher = newSearcher(reader); Weight weight = hybridQueryWithTerm.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0f); @@ -116,7 +117,8 @@ public void testSubQueries_whenMultipleEqualSubQueries_thenSuccessful() { .rewrite(mockQueryShardContext) .toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext) - ) + ), + null ); IndexSearcher searcher = newSearcher(reader); Weight weight = hybridQueryWithTerm.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0f); @@ -163,7 +165,8 @@ public void testExplain_whenCallExplain_thenFail() { IndexReader reader = DirectoryReader.open(w); HybridQuery hybridQueryWithTerm = new HybridQuery( - List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), + null ); IndexSearcher searcher = newSearcher(reader); Weight weight = searcher.createWeight(hybridQueryWithTerm, ScoreMode.COMPLETE, 1.0f); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java index f44e762f0..4e2dfbd89 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java @@ -69,7 +69,7 @@ public void testCollectorManager_whenHybridQueryAndNotConcurrentSearch_thenSucce TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), null); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -129,7 +129,7 @@ public void testCollectorManager_whenHybridQueryAndConcurrentSearch_thenSuccessf TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), null); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java index 1d3bc29e9..f23f99714 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -52,12 +52,14 @@ import org.opensearch.neuralsearch.search.query.exception.HybridSearchRescoreQueryException; import org.opensearch.search.DocValueFormat; import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.ScrollContext; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.query.ReduceableSearchResult; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import static org.mockito.ArgumentMatchers.any; @@ -91,7 +93,7 @@ public void testNewCollector_whenNotConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), null); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -122,7 +124,7 @@ public void testNewCollector_whenConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), null); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -153,7 +155,7 @@ public void testPostFilter_whenNotConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), null); QueryBuilder postFilterQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, "world"); ParsedQuery parsedQuery = new ParsedQuery(postFilterQuery.toQuery(mockQueryShardContext)); @@ -197,7 +199,7 @@ public void testPostFilter_whenConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), null); QueryBuilder postFilterQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, "world"); Query pfQuery = postFilterQuery.toQuery(mockQueryShardContext); @@ -242,7 +244,8 @@ public void testReduce_whenMatchedDocs_thenSuccessful() { when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); HybridQuery hybridQueryWithTerm = new HybridQuery( - List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext)) + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext)), + null ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -343,7 +346,7 @@ public void testNewCollector_whenNotConcurrentSearchAndSortingIsApplied_thenSucc TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), null); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -380,7 +383,7 @@ public void testNewCollector_whenNotConcurrentSearchAndSortingAndSearchAfterAreA TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), null); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -411,7 +414,7 @@ public void testReduce_whenMatchedDocsAndSortingIsApplied_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); - HybridQuery hybridQueryWithMatchAll = new HybridQuery(List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext))); + HybridQuery hybridQueryWithMatchAll = new HybridQuery(List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)), null); when(searchContext.query()).thenReturn(hybridQueryWithMatchAll); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); @@ -508,7 +511,8 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD List.of( QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext) - ) + ), + null ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -630,7 +634,7 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); - HybridQuery hybridQueryWithTerm = new HybridQuery(List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext))); + HybridQuery hybridQueryWithTerm = new HybridQuery(List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)), null); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); @@ -759,7 +763,8 @@ public void testReduceAndRescore_whenMatchedDocsAndRescoreContextPresent_thenSuc List.of( QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext) - ) + ), + null ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -877,7 +882,8 @@ public void testRescoreWithConcurrentSegmentSearch_whenMatchedDocsAndRescore_the QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY3).toQuery(mockQueryShardContext) - ) + ), + null ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -1020,7 +1026,8 @@ public void testReduceAndRescore_whenRescorerThrowsException_thenFail() { List.of( QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext) - ) + ), + null ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -1078,4 +1085,99 @@ public void testReduceAndRescore_whenRescorerThrowsException_thenFail() { reader.close(); directory.close(); } + + @SneakyThrows + public void testCreateCollectorManager_whenFromAreEqualToZeroAndPaginationDepthInRange_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); + // pagination_depth=10 + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 10); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + assertNotNull(hybridCollectorManager); + assertTrue(hybridCollectorManager instanceof HybridCollectorManager.HybridCollectorNonConcurrentManager); + + Collector collector = hybridCollectorManager.newCollector(); + assertNotNull(collector); + assertTrue(collector instanceof HybridTopScoreDocCollector); + + Collector secondCollector = hybridCollectorManager.newCollector(); + assertSame(collector, secondCollector); + } + + @SneakyThrows + public void testCreateCollectorManager_whenPaginationDepthIsEqualToNullAndFromIsGreaterThanZero_thenFail() { + SearchContext searchContext = mock(SearchContext.class); + // From >0 + when(searchContext.from()).thenReturn(5); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); + // if pagination_depth ==0 then internally by default it will pick 10 as the depth + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), null); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> HybridCollectorManager.createHybridCollectorManager(searchContext) + ); + assertEquals( + String.format(Locale.ROOT, "pagination_depth is missing in the search request"), + illegalArgumentException.getMessage() + ); + } + + @SneakyThrows + public void testScrollWithHybridQuery_thenFail() { + SearchContext searchContext = mock(SearchContext.class); + ScrollContext scrollContext = new ScrollContext(); + when(searchContext.scrollContext()).thenReturn(scrollContext); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 10); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> HybridCollectorManager.createHybridCollectorManager(searchContext) + ); + assertEquals( + String.format(Locale.ROOT, "Scroll operation is not supported in hybrid query"), + illegalArgumentException.getMessage() + ); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java b/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java index be9dbc2cc..0fc0980c4 100644 --- a/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java +++ b/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java @@ -45,7 +45,8 @@ public void testIsHybridQueryCheck_whenQueryIsHybridQueryInstance_thenSuccess() .rewrite(mockQueryShardContext) .toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext) - ) + ), + null ); SearchContext searchContext = mock(SearchContext.class); diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index afc545447..a8aa0683b 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -596,7 +596,6 @@ protected Map search( if (requestParams != null && !requestParams.isEmpty()) { requestParams.forEach(request::addParameter); } - logger.info("Sorting request " + builder.toString()); request.setJsonEntity(builder.toString()); Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));