Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Pagination in hybrid query #963

Open
wants to merge 13 commits into
base: Pagination_in_hybridQuery
Choose a base branch
from
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x)
### Features
- Pagination in Hybrid query ([#963](https://github.com/opensearch-project/neural-search/pull/963))
### Enhancements
### Bug Fixes
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,21 @@ public <Result extends SearchPhaseResult> void process(
}
List<QuerySearchResult> querySearchResults = getQueryPhaseSearchResults(searchPhaseResult);
Optional<FetchSearchResult> fetchSearchResult = getFetchSearchResults(searchPhaseResult);
normalizationWorkflow.execute(querySearchResults, fetchSearchResult, normalizationTechnique, combinationTechnique);
int fromValueForSingleShard = 0;
boolean isSingleShard = false;
if (searchPhaseContext.getNumShards() == 1 && fetchSearchResult.isPresent()) {
isSingleShard = true;
fromValueForSingleShard = searchPhaseContext.getRequest().source().from();
}

normalizationWorkflow.execute(
querySearchResults,
fetchSearchResult,
normalizationTechnique,
combinationTechnique,
fromValueForSingleShard,
isSingleShard
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ public void execute(
final List<QuerySearchResult> querySearchResults,
final Optional<FetchSearchResult> fetchSearchResultOptional,
final ScoreNormalizationTechnique normalizationTechnique,
final ScoreCombinationTechnique combinationTechnique
final ScoreCombinationTechnique combinationTechnique,
final int fromValueForSingleShard,
final boolean isSingleShard
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
) {
// save original state
List<Integer> unprocessedDocIds = unprocessedDocIds(querySearchResults);
Expand All @@ -73,6 +75,8 @@ public void execute(
.scoreCombinationTechnique(combinationTechnique)
.querySearchResults(querySearchResults)
.sort(evaluateSortCriteria(querySearchResults, queryTopDocs))
.fromValueForSingleShard(fromValueForSingleShard)
.isSingleShard(isSingleShard)
.build();

// combine
Expand All @@ -82,7 +86,7 @@ public void execute(
// post-process data
log.debug("Post-process query results after score normalization and combination");
updateOriginalQueryResults(combineScoresDTO);
updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional, unprocessedDocIds);
updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional, unprocessedDocIds, fromValueForSingleShard);
}

/**
Expand Down Expand Up @@ -113,15 +117,28 @@ private void updateOriginalQueryResults(final CombineScoresDto combineScoresDTO)
final List<QuerySearchResult> querySearchResults = combineScoresDTO.getQuerySearchResults();
final List<CompoundTopDocs> queryTopDocs = getCompoundTopDocs(combineScoresDTO, querySearchResults);
final Sort sort = combineScoresDTO.getSort();
final int from = querySearchResults.get(0).from();
int totalScoreDocsCount = 0;
for (int index = 0; index < querySearchResults.size(); index++) {
QuerySearchResult querySearchResult = querySearchResults.get(index);
CompoundTopDocs updatedTopDocs = queryTopDocs.get(index);
totalScoreDocsCount += updatedTopDocs.getScoreDocs().size();
TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore(
buildTopDocs(updatedTopDocs, sort),
maxScoreForShard(updatedTopDocs, sort != null)
);
if (combineScoresDTO.isSingleShard()) {
querySearchResult.from(combineScoresDTO.getFromValueForSingleShard());
}
querySearchResult.topDocs(updatedTopDocsAndMaxScore, querySearchResult.sortValueFormats());
}

if ((from > 0 || combineScoresDTO.getFromValueForSingleShard() > 0)
&& (from > totalScoreDocsCount || combineScoresDTO.getFromValueForSingleShard() > totalScoreDocsCount)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Reached end of search result, increase pagination_depth value to see more results")
);
}
}

private List<CompoundTopDocs> getCompoundTopDocs(CombineScoresDto combineScoresDTO, List<QuerySearchResult> querySearchResults) {
Expand Down Expand Up @@ -180,7 +197,8 @@ private TopDocs buildTopDocs(CompoundTopDocs updatedTopDocs, Sort sort) {
private void updateOriginalFetchResults(
final List<QuerySearchResult> querySearchResults,
final Optional<FetchSearchResult> fetchSearchResultOptional,
final List<Integer> docIds
final List<Integer> docIds,
final int fromValueForSingleShard
) {
if (fetchSearchResultOptional.isEmpty()) {
return;
Expand Down Expand Up @@ -212,14 +230,26 @@ private void updateOriginalFetchResults(

QuerySearchResult querySearchResult = querySearchResults.get(0);
TopDocs topDocs = querySearchResult.topDocs().topDocs;

// iterate over the normalized/combined scores, that solves (1) and (3)
SearchHit[] updatedSearchHitArray = Arrays.stream(topDocs.scoreDocs).map(scoreDoc -> {
SearchHit[] updatedSearchHitArray = new SearchHit[topDocs.scoreDocs.length - fromValueForSingleShard];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please pull topDocs.scoreDocs.length - fromValueForSingleShard expression to a variable and give it meaningful name

for (int i = fromValueForSingleShard; i < topDocs.scoreDocs.length; i++) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please change the semantic here, start from 0 and do (i + offset) when you're reading from topDocs

ScoreDoc scoreDoc = topDocs.scoreDocs[i];
// get fetched hit content by doc_id
SearchHit searchHit = docIdToSearchHit.get(scoreDoc.doc);
// update score to normalized/combined value (3)
searchHit.score(scoreDoc.score);
return searchHit;
}).toArray(SearchHit[]::new);
updatedSearchHitArray[i - fromValueForSingleShard] = searchHit;
}

// iterate over the normalized/combined scores, that solves (1) and (3)
// SearchHit[] updatedSearchHitArray = Arrays.stream(topDocs.scoreDocs).map(scoreDoc -> {
// // get fetched hit content by doc_id
// SearchHit searchHit = docIdToSearchHit.get(scoreDoc.doc);
// // update score to normalized/combined value (3)
// searchHit.score(scoreDoc.score);
// return searchHit;
// }).toArray(SearchHit[]::new);
SearchHits updatedSearchHits = new SearchHits(
updatedSearchHitArray,
querySearchResult.getTotalHits(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,6 @@ public class CombineScoresDto {
private List<QuerySearchResult> querySearchResults;
@Nullable
private Sort sort;
private int fromValueForSingleShard;
private boolean isSingleShard;
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,11 @@ public class ScoreCombiner {
public void combineScores(final CombineScoresDto combineScoresDTO) {
// iterate over results from each shard. Every CompoundTopDocs object has results from
// multiple sub queries, doc ids may repeat for each sub query results
ScoreCombinationTechnique scoreCombinationTechnique = combineScoresDTO.getScoreCombinationTechnique();
Sort sort = combineScoresDTO.getSort();
combineScoresDTO.getQueryTopDocs()
.forEach(
compoundQueryTopDocs -> combineShardScores(
combineScoresDTO.getScoreCombinationTechnique(),
compoundQueryTopDocs,
combineScoresDTO.getSort()
)
);
.forEach(compoundQueryTopDocs -> combineShardScores(scoreCombinationTechnique, compoundQueryTopDocs, sort));

}

private void combineShardScores(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,14 @@
public final class HybridQuery extends Query implements Iterable<Query> {

private final List<Query> subQueries;
private int paginationDepth;
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved

/**
* Create new instance of hybrid query object based on collection of sub queries and filter query
* @param subQueries collection of queries that are executed individually and contribute to a final list of combined scores
* @param filterQueries list of filters that will be applied to each sub query. Each filter from the list is added as bool "filter" clause. If this is null sub queries will be executed as is
*/
public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQueries) {
public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQueries, int paginationDepth) {
Objects.requireNonNull(subQueries, "collection of queries must not be null");
if (subQueries.isEmpty()) {
throw new IllegalArgumentException("collection of queries must not be empty");
Expand All @@ -57,10 +58,11 @@ public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQ
}
this.subQueries = modifiedSubQueries;
}
this.paginationDepth = paginationDepth;
}

public HybridQuery(final Collection<Query> subQueries) {
this(subQueries, List.of());
public HybridQuery(final Collection<Query> subQueries, final int paginationDepth) {
this(subQueries, List.of(), paginationDepth);
}

/**
Expand Down Expand Up @@ -128,7 +130,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
return super.rewrite(indexSearcher);
}
final List<Query> rewrittenSubQueries = manager.getQueriesAfterRewrite(collectors);
return new HybridQuery(rewrittenSubQueries);
return new HybridQuery(rewrittenSubQueries, paginationDepth);
}

private Void rewriteQuery(Query query, HybridQueryExecutorCollector<IndexSearcher, Map.Entry<Query, Boolean>> collector) {
Expand Down Expand Up @@ -190,6 +192,10 @@ public Collection<Query> getSubQueries() {
return Collections.unmodifiableCollection(subQueries);
}

public int getPaginationDepth() {
return paginationDepth;
}

/**
* Create the Weight used to score this query
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,19 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder<HybridQueryBu
public static final String NAME = "hybrid";

private static final ParseField QUERIES_FIELD = new ParseField("queries");
private static final ParseField DEPTH_FIELD = new ParseField("pagination_depth");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
private static final ParseField DEPTH_FIELD = new ParseField("pagination_depth");
private static final ParseField PAGINATION_DEPTH_FIELD = new ParseField("pagination_depth");


private final List<QueryBuilder> queries = new ArrayList<>();

private String fieldName;
private int paginationDepth;

static final int MAX_NUMBER_OF_SUB_QUERIES = 5;

public HybridQueryBuilder(StreamInput in) throws IOException {
super(in);
queries.addAll(readQueries(in));
paginationDepth = in.readInt();
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
}

/**
Expand All @@ -68,6 +71,7 @@ public HybridQueryBuilder(StreamInput in) throws IOException {
@Override
protected void doWriteTo(StreamOutput out) throws IOException {
writeQueries(out, queries);
out.writeInt(paginationDepth);
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
}

/**
Expand Down Expand Up @@ -97,6 +101,7 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep
queryBuilder.toXContent(builder, params);
}
builder.endArray();
builder.field(DEPTH_FIELD.getPreferredName(), paginationDepth);
printBoostAndQueryName(builder);
builder.endObject();
}
Expand All @@ -113,7 +118,7 @@ protected Query doToQuery(QueryShardContext queryShardContext) throws IOExceptio
if (queryCollection.isEmpty()) {
return Queries.newMatchNoDocsQuery(String.format(Locale.ROOT, "no clauses for %s query", NAME));
}
return new HybridQuery(queryCollection);
return new HybridQuery(queryCollection, paginationDepth);
}

/**
Expand Down Expand Up @@ -149,6 +154,7 @@ protected Query doToQuery(QueryShardContext queryShardContext) throws IOExceptio
public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOException {
float boost = AbstractQueryBuilder.DEFAULT_BOOST;

int paginationDepth = 0;
final List<QueryBuilder> queries = new ArrayList<>();
String queryName = null;

Expand Down Expand Up @@ -196,6 +202,8 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx
}
} else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
queryName = parser.text();
} else if (DEPTH_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
paginationDepth = parser.intValue();
} else {
log.error(String.format(Locale.ROOT, "[%s] query does not support [%s]", NAME, currentFieldName));
throw new ParsingException(
Expand All @@ -216,6 +224,7 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx
HybridQueryBuilder compoundQueryBuilder = new HybridQueryBuilder();
compoundQueryBuilder.queryName(queryName);
compoundQueryBuilder.boost(boost);
compoundQueryBuilder.paginationDepth(paginationDepth);
for (QueryBuilder query : queries) {
compoundQueryBuilder.add(query);
}
Expand All @@ -235,6 +244,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryShardContext) throws I
if (changed) {
newBuilder.queryName(queryName);
newBuilder.boost(boost);
newBuilder.paginationDepth(paginationDepth);
return newBuilder;
} else {
return this;
Expand All @@ -257,6 +267,7 @@ protected boolean doEquals(HybridQueryBuilder obj) {
EqualsBuilder equalsBuilder = new EqualsBuilder();
equalsBuilder.append(fieldName, obj.fieldName);
equalsBuilder.append(queries, obj.queries);
equalsBuilder.append(paginationDepth, obj.paginationDepth);
return equalsBuilder.isEquals();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.Weight;
Expand All @@ -22,6 +23,7 @@
import org.opensearch.common.Nullable;
import org.opensearch.common.lucene.search.FilteredCollector;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.query.HybridQuery;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;
import org.opensearch.neuralsearch.search.collector.HybridSearchCollector;
import org.opensearch.neuralsearch.search.collector.HybridTopFieldDocSortCollector;
Expand Down Expand Up @@ -72,6 +74,7 @@ public abstract class HybridCollectorManager implements CollectorManager<Collect
@Nullable
private final FieldDoc after;
private final SearchContext searchContext;
private static final int DEFAULT_PAGINATION_DEPTH = 10;
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved

/**
* Create new instance of HybridCollectorManager depending on the concurrent search beeing enabled or disabled.
Expand All @@ -80,13 +83,20 @@ public abstract class HybridCollectorManager implements CollectorManager<Collect
* @throws IOException
*/
public static CollectorManager createHybridCollectorManager(final SearchContext searchContext) throws IOException {
if (searchContext.scrollContext() != null) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "Scroll operation is not supported in hybrid query"));
}
final IndexReader reader = searchContext.searcher().getIndexReader();
final int totalNumDocs = Math.max(0, reader.numDocs());
int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs);
int numDocs = Math.min(getSubqueryResultsRetrievalSize(searchContext), totalNumDocs);
int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo();
if (searchContext.sort() != null) {
validateSortCriteria(searchContext, searchContext.trackScores());
}
boolean isSingleShard = searchContext.numberOfShards() == 1;
if (isSingleShard && searchContext.from() > 0) {
searchContext.from(0);
}

Weight filteringWeight = null;
// Check for post filter to create weight for filter query and later use that weight in the search workflow
Expand Down Expand Up @@ -461,6 +471,42 @@ private ReduceableSearchResult reduceSearchResults(final List<ReduceableSearchRe
};
}

/**
* Get maximum subquery results count to be collected from each shard.
* @param searchContext search context that contains pagination depth
* @return results size to collected
*/
private static int getSubqueryResultsRetrievalSize(final SearchContext searchContext) {
int paginationDepth;
HybridQuery hybridQuery;
Query query = searchContext.query();
if (query instanceof BooleanQuery) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when such scenario can happen? we should not allow hybrid if it's not the top clause

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added the comment on top why it is written like that. Basically in case of nested fields and alias filter, hybrid query gets wrapped under bool query.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simple case with nested fields is taken care of in QueryPhaseSearcher class. This logic doesn't belong here, unless this is a special scenario that we missed in that extract query method I mentioned above.

BooleanQuery booleanQuery = (BooleanQuery) query;
hybridQuery = (HybridQuery) booleanQuery.clauses().get(0).getQuery();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks hacky, can we avoid this logic?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added the comment on top why it is written like that. Basically in case of nested fields and alias filter, hybrid query gets wrapped under bool query.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see my previous comment, this should be handled in here

paginationDepth = hybridQuery.getPaginationDepth();
} else {
hybridQuery = (HybridQuery) query;
paginationDepth = hybridQuery.getPaginationDepth();
}

if (paginationDepth != 0) {
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
validatePaginationDepth(paginationDepth);
return paginationDepth;
} else if (searchContext.from() > 0 && paginationDepth == 0) {
return DEFAULT_PAGINATION_DEPTH;
} else {
return searchContext.from() + searchContext.size();
}
}

private static void validatePaginationDepth(int depth) {
if (depth < 0 || depth > 10000) {
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Pagination depth should lie in the range of 1-1000. Received: %s", depth)
);
}
}

/**
* Implementation of the HybridCollector that reuses instance of collector on each even call. This allows caller to
* use saved state of collector
Expand Down
Loading
Loading