Skip to content

Commit

Permalink
refactoring code
Browse files Browse the repository at this point in the history
Signed-off-by: Varun Jain <[email protected]>
  • Loading branch information
vibrantvarun committed Nov 4, 2024
1 parent 085710a commit e0ed9cb
Show file tree
Hide file tree
Showing 19 changed files with 427 additions and 273 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public final class MinClusterVersionUtil {

private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0;
private static final Version MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH = Version.V_2_14_0;
private static final Version MINIMAL_SUPPORTED_VERSION_PAGINATION_IN_HYBRID_QUERY = Version.V_2_19_0;

// Note this minimal version will act as a override
private static final Map<String, Version> MINIMAL_VERSION_NEURAL = ImmutableMap.<String, Version>builder()
Expand All @@ -38,6 +39,10 @@ public static boolean isClusterOnOrAfterMinReqVersionForRadialSearch() {
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH);
}

public static boolean isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery() {
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_PAGINATION_IN_HYBRID_QUERY);
}

public static boolean isClusterOnOrAfterMinReqVersion(String key) {
Version version;
if (MINIMAL_VERSION_NEURAL.containsKey(key)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.action.search.SearchPhaseName;
import org.opensearch.action.search.SearchPhaseResults;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.dto.NormalizationExecuteDto;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.fetch.FetchSearchResult;
Expand Down Expand Up @@ -58,21 +59,16 @@ public <Result extends SearchPhaseResult> void process(
}
List<QuerySearchResult> querySearchResults = getQueryPhaseSearchResults(searchPhaseResult);
Optional<FetchSearchResult> fetchSearchResult = getFetchSearchResults(searchPhaseResult);
int fromValueForSingleShard = 0;
boolean isSingleShard = false;
if (searchPhaseContext.getNumShards() == 1 && fetchSearchResult.isPresent()) {
isSingleShard = true;
fromValueForSingleShard = searchPhaseContext.getRequest().source().from();
}

normalizationWorkflow.execute(
querySearchResults,
fetchSearchResult,
normalizationTechnique,
combinationTechnique,
fromValueForSingleShard,
isSingleShard
);
// Builds data transfer object to pass into execute
NormalizationExecuteDto normalizationExecuteDto = NormalizationExecuteDto.builder()
.querySearchResults(querySearchResults)
.fetchSearchResultOptional(fetchSearchResult)
.normalizationTechnique(normalizationTechnique)
.combinationTechnique(combinationTechnique)
.searchPhaseContext(searchPhaseContext)
.build();

normalizationWorkflow.execute(normalizationExecuteDto);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopFieldDocs;
import org.apache.lucene.search.FieldDoc;
import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.processor.combination.CombineScoresDto;
import org.opensearch.neuralsearch.processor.dto.CombineScoresDto;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.dto.NormalizationExecuteDto;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.search.SearchHit;
Expand All @@ -47,18 +49,17 @@ public class NormalizationProcessorWorkflow {

/**
* Start execution of this workflow
* @param querySearchResults input data with QuerySearchResult from multiple shards
* @param normalizationTechnique technique for score normalization
* @param combinationTechnique technique for score combination
* @param normalizationExecuteDto contains querySearchResults input data with QuerySearchResult
* from multiple shards, fetchSearchResultOptional, normalizationTechnique technique for score normalization
* combinationTechnique technique for score combination, searchPhaseContext.
*/
public void execute(
final List<QuerySearchResult> querySearchResults,
final Optional<FetchSearchResult> fetchSearchResultOptional,
final ScoreNormalizationTechnique normalizationTechnique,
final ScoreCombinationTechnique combinationTechnique,
final int fromValueForSingleShard,
final boolean isSingleShard
) {
public void execute(final NormalizationExecuteDto normalizationExecuteDto) {
final List<QuerySearchResult> querySearchResults = normalizationExecuteDto.getQuerySearchResults();
final Optional<FetchSearchResult> fetchSearchResultOptional = normalizationExecuteDto.getFetchSearchResultOptional();
final ScoreNormalizationTechnique normalizationTechnique = normalizationExecuteDto.getNormalizationTechnique();
final ScoreCombinationTechnique combinationTechnique = normalizationExecuteDto.getCombinationTechnique();
final SearchPhaseContext searchPhaseContext = normalizationExecuteDto.getSearchPhaseContext();

// save original state
List<Integer> unprocessedDocIds = unprocessedDocIds(querySearchResults);

Expand All @@ -75,8 +76,8 @@ public void execute(
.scoreCombinationTechnique(combinationTechnique)
.querySearchResults(querySearchResults)
.sort(evaluateSortCriteria(querySearchResults, queryTopDocs))
.fromValueForSingleShard(fromValueForSingleShard)
.isSingleShard(isSingleShard)
.fromValueForSingleShard(searchPhaseContext.getRequest().source().from())
.isFetchResultsPresent(fetchSearchResultOptional.isPresent())
.build();

// combine
Expand All @@ -86,7 +87,12 @@ public void execute(
// post-process data
log.debug("Post-process query results after score normalization and combination");
updateOriginalQueryResults(combineScoresDTO);
updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional, unprocessedDocIds, fromValueForSingleShard);
updateOriginalFetchResults(
querySearchResults,
fetchSearchResultOptional,
unprocessedDocIds,
combineScoresDTO.getFromValueForSingleShard()
);
}

/**
Expand Down Expand Up @@ -117,7 +123,6 @@ private void updateOriginalQueryResults(final CombineScoresDto combineScoresDTO)
final List<QuerySearchResult> querySearchResults = combineScoresDTO.getQuerySearchResults();
final List<CompoundTopDocs> queryTopDocs = getCompoundTopDocs(combineScoresDTO, querySearchResults);
final Sort sort = combineScoresDTO.getSort();
final int from = querySearchResults.get(0).from();
int totalScoreDocsCount = 0;
for (int index = 0; index < querySearchResults.size(); index++) {
QuerySearchResult querySearchResult = querySearchResults.get(index);
Expand All @@ -127,14 +132,16 @@ private void updateOriginalQueryResults(final CombineScoresDto combineScoresDTO)
buildTopDocs(updatedTopDocs, sort),
maxScoreForShard(updatedTopDocs, sort != null)
);
if (combineScoresDTO.isSingleShard()) {
// Fetch Phase had ran before the normalization phase, therefore update the from value in result of each shard.
// This will ensure the trimming of the results.
if (combineScoresDTO.isFetchResultsPresent()) {
querySearchResult.from(combineScoresDTO.getFromValueForSingleShard());
}
querySearchResult.topDocs(updatedTopDocsAndMaxScore, querySearchResult.sortValueFormats());
}

if ((from > 0 || combineScoresDTO.getFromValueForSingleShard() > 0)
&& (from > totalScoreDocsCount || combineScoresDTO.getFromValueForSingleShard() > totalScoreDocsCount)) {
final int from = querySearchResults.get(0).from();
if (from > 0 && from > totalScoreDocsCount) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Reached end of search result, increase pagination_depth value to see more results")
);
Expand Down Expand Up @@ -231,6 +238,9 @@ private void updateOriginalFetchResults(
QuerySearchResult querySearchResult = querySearchResults.get(0);
TopDocs topDocs = querySearchResult.topDocs().topDocs;

// When normalization process will execute before the fetch phase, then from =0 is applicable.
// When normalization process runs after fetch phase, then search hits already fetched. Therefore, use the from value sent in the
// search request.
// iterate over the normalized/combined scores, that solves (1) and (3)
SearchHit[] updatedSearchHitArray = new SearchHit[topDocs.scoreDocs.length - fromValueForSingleShard];
for (int i = fromValueForSingleShard; i < topDocs.scoreDocs.length; i++) {
Expand All @@ -242,14 +252,6 @@ private void updateOriginalFetchResults(
updatedSearchHitArray[i - fromValueForSingleShard] = searchHit;
}

// iterate over the normalized/combined scores, that solves (1) and (3)
// SearchHit[] updatedSearchHitArray = Arrays.stream(topDocs.scoreDocs).map(scoreDoc -> {
// // get fetched hit content by doc_id
// SearchHit searchHit = docIdToSearchHit.get(scoreDoc.doc);
// // update score to normalized/combined value (3)
// searchHit.score(scoreDoc.score);
// return searchHit;
// }).toArray(SearchHit[]::new);
SearchHits updatedSearchHits = new SearchHits(
updatedSearchHitArray,
querySearchResult.getTotalHits(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.opensearch.neuralsearch.processor.CompoundTopDocs;

import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.processor.dto.CombineScoresDto;

/**
* Abstracts combination of scores in query search results.
Expand Down Expand Up @@ -69,7 +70,6 @@ public void combineScores(final CombineScoresDto combineScoresDTO) {
Sort sort = combineScoresDTO.getSort();
combineScoresDTO.getQueryTopDocs()
.forEach(compoundQueryTopDocs -> combineShardScores(scoreCombinationTechnique, compoundQueryTopDocs, sort));

}

private void combineShardScores(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor.combination;
package org.opensearch.neuralsearch.processor.dto;

import java.util.List;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import org.apache.lucene.search.Sort;
import org.opensearch.common.Nullable;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.search.query.QuerySearchResult;

/**
Expand All @@ -30,5 +32,5 @@ public class CombineScoresDto {
@Nullable
private Sort sort;
private int fromValueForSingleShard;
private boolean isSingleShard;
private boolean isFetchResultsPresent;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor.dto;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.query.QuerySearchResult;

import java.util.List;
import java.util.Optional;

/**
* DTO object to hold data in NormalizationProcessorWorkflow class
* in NormalizationProcessorWorkflow.
*/
@AllArgsConstructor
@Builder
@Getter
public class NormalizationExecuteDto {
@NonNull
private List<QuerySearchResult> querySearchResults;
@NonNull
private Optional<FetchSearchResult> fetchSearchResultOptional;
@NonNull
private ScoreNormalizationTechnique normalizationTechnique;
@NonNull
private ScoreCombinationTechnique combinationTechnique;
@NonNull
private SearchPhaseContext searchPhaseContext;
}
15 changes: 8 additions & 7 deletions src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import java.util.Objects;
import java.util.concurrent.Callable;

import lombok.Getter;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
Expand All @@ -31,21 +32,25 @@
* Implementation of Query interface for type "hybrid". It allows execution of multiple sub-queries and collect individual
* scores for each sub-query.
*/
@Getter
public final class HybridQuery extends Query implements Iterable<Query> {

private final List<Query> subQueries;
private int paginationDepth;
private Integer paginationDepth;

/**
* Create new instance of hybrid query object based on collection of sub queries and filter query
* @param subQueries collection of queries that are executed individually and contribute to a final list of combined scores
* @param filterQueries list of filters that will be applied to each sub query. Each filter from the list is added as bool "filter" clause. If this is null sub queries will be executed as is
*/
public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQueries, int paginationDepth) {
public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQueries, Integer paginationDepth) {
Objects.requireNonNull(subQueries, "collection of queries must not be null");
if (subQueries.isEmpty()) {
throw new IllegalArgumentException("collection of queries must not be empty");
}
if (paginationDepth != null && paginationDepth == 0) {
throw new IllegalArgumentException("pagination depth must not be zero");
}
if (Objects.isNull(filterQueries) || filterQueries.isEmpty()) {
this.subQueries = new ArrayList<>(subQueries);
} else {
Expand All @@ -61,7 +66,7 @@ public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQ
this.paginationDepth = paginationDepth;
}

public HybridQuery(final Collection<Query> subQueries, final int paginationDepth) {
public HybridQuery(final Collection<Query> subQueries, final Integer paginationDepth) {
this(subQueries, List.of(), paginationDepth);
}

Expand Down Expand Up @@ -192,10 +197,6 @@ public Collection<Query> getSubQueries() {
return Collections.unmodifiableCollection(subQueries);
}

public int getPaginationDepth() {
return paginationDepth;
}

/**
* Create the Weight used to score this query
*
Expand Down
Loading

0 comments on commit e0ed9cb

Please sign in to comment.