Skip to content

Commit

Permalink
Add workflow, refactor techniques, added s/n methods
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Jul 21, 2023
1 parent 328620f commit 319d4bb
Show file tree
Hide file tree
Showing 22 changed files with 675 additions and 393 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@

package org.opensearch.neuralsearch.processor;

import java.util.Arrays;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

import lombok.AccessLevel;
import lombok.AllArgsConstructor;
Expand All @@ -20,7 +19,8 @@
import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.action.search.SearchPhaseName;
import org.opensearch.action.search.SearchPhaseResults;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.search.CompoundTopDocs;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.internal.SearchContext;
Expand Down Expand Up @@ -48,6 +48,7 @@ public class NormalizationProcessor implements SearchPhaseResultsProcessor {
final ScoreNormalizationTechnique normalizationTechnique;
@Getter(AccessLevel.PACKAGE)
final ScoreCombinationTechnique combinationTechnique;
final NormalizationProcessorWorkflow normalizationWorkflow;

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
Expand All @@ -64,17 +65,8 @@ public <Result extends SearchPhaseResult> void process(
if (shouldSearchResultsBeIgnored(searchPhaseResult)) {
return;
}

TopDocsAndMaxScore[] topDocsAndMaxScores = getCompoundQueryTopDocsFromSearchPhaseResult(searchPhaseResult);
List<CompoundTopDocs> queryTopDocs = Arrays.stream(topDocsAndMaxScores)
.map(td -> td != null ? (CompoundTopDocs) td.topDocs : null)
.collect(Collectors.toList());

ScoreNormalizer.normalizeScores(queryTopDocs, normalizationTechnique);

List<Float> combinedMaxScores = ScoreCombiner.combineScores(queryTopDocs, combinationTechnique);

updateOriginalQueryResults(searchPhaseResult, queryTopDocs, topDocsAndMaxScores, combinedMaxScores);
List<QuerySearchResult> querySearchResults = getQuerySearchResults(searchPhaseResult);
normalizationWorkflow.execute(querySearchResults, normalizationTechnique, combinationTechnique);
}

@Override
Expand Down Expand Up @@ -128,46 +120,16 @@ private boolean isNotHybridQuery(final Optional<SearchPhaseResult> maybeResult)
|| !(maybeResult.get().queryResult().topDocs().topDocs instanceof CompoundTopDocs);
}

private <Result extends SearchPhaseResult> TopDocsAndMaxScore[] getCompoundQueryTopDocsFromSearchPhaseResult(
final SearchPhaseResults<Result> results
) {
List<Result> preShardResultList = results.getAtomicArray().asList();
TopDocsAndMaxScore[] result = new TopDocsAndMaxScore[preShardResultList.size()];
for (int idx = 0; idx < preShardResultList.size(); idx++) {
Result shardResult = preShardResultList.get(idx);
private <Result extends SearchPhaseResult> List<QuerySearchResult> getQuerySearchResults(final SearchPhaseResults<Result> results) {
List<Result> resultsPerShard = results.getAtomicArray().asList();
List<QuerySearchResult> querySearchResults = new ArrayList<>();
for (Result shardResult : resultsPerShard) {
if (shardResult == null) {
querySearchResults.add(null);
continue;
}
QuerySearchResult querySearchResult = shardResult.queryResult();
TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs();
if (!(topDocsAndMaxScore.topDocs instanceof CompoundTopDocs)) {
continue;
}
result[idx] = topDocsAndMaxScore;
}
return result;
}

@VisibleForTesting
protected <Result extends SearchPhaseResult> void updateOriginalQueryResults(
final SearchPhaseResults<Result> results,
final List<CompoundTopDocs> queryTopDocs,
TopDocsAndMaxScore[] topDocsAndMaxScores,
List<Float> combinedMaxScores
) {
List<Result> preShardResultList = results.getAtomicArray().asList();
for (int i = 0; i < preShardResultList.size(); i++) {
CompoundTopDocs updatedTopDocs = queryTopDocs.get(i);
if (Objects.isNull(updatedTopDocs)) {
continue;
}
float maxScore = updatedTopDocs.totalHits.value > 0 ? updatedTopDocs.scoreDocs[0].score : 0.0f;
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(updatedTopDocs, maxScore);
QuerySearchResult querySearchResult = preShardResultList.get(i).queryResult();
querySearchResult.topDocs(topDocsAndMaxScore, null);
if (topDocsAndMaxScores[i] != null) {
topDocsAndMaxScores[i].maxScore = combinedMaxScores.get(i);
}
querySearchResults.add(shardResult.queryResult());
}
return querySearchResults;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor;

import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

import lombok.AccessLevel;
import lombok.NoArgsConstructor;

import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.search.CompoundTopDocs;
import org.opensearch.search.query.QuerySearchResult;

import com.google.common.annotations.VisibleForTesting;

/**
* Class abstracts steps required for score normalization and combination, this includes pre-processing of income data
* and post-processing for final results
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public class NormalizationProcessorWorkflow {

/**
* Return instance of workflow class. Making default constructor private for now
* as we may use singleton pattern here and share one instance among processors
* @return instance of NormalizationProcessorWorkflow
*/
public static NormalizationProcessorWorkflow create() {
return new NormalizationProcessorWorkflow();
}

/**
* Start execution of this workflow
* @param querySearchResults input data with QuerySearchResult from multiple shards
* @param normalizationTechnique technique for score normalization
* @param combinationTechnique technique for score combination
*/
public void execute(
final List<QuerySearchResult> querySearchResults,
final ScoreNormalizationTechnique normalizationTechnique,
final ScoreCombinationTechnique combinationTechnique
) {
// pre-process data
List<CompoundTopDocs> queryTopDocs = getQueryTopDocs(querySearchResults);

// normalize
ScoreNormalizer scoreNormalizer = new ScoreNormalizer();
scoreNormalizer.normalizeScores(queryTopDocs, normalizationTechnique);

// combine
ScoreCombiner scoreCombiner = new ScoreCombiner();
List<Float> combinedMaxScores = scoreCombiner.combineScores(queryTopDocs, combinationTechnique);

// post-process data
updateOriginalQueryResults(querySearchResults, queryTopDocs, combinedMaxScores);
}

private List<CompoundTopDocs> getQueryTopDocs(List<QuerySearchResult> querySearchResults) {
List<CompoundTopDocs> queryTopDocs = querySearchResults.stream()
.filter(searchResult -> searchResult.topDocs().topDocs instanceof CompoundTopDocs)
.map(searchResult -> (CompoundTopDocs) searchResult.topDocs().topDocs)
.collect(Collectors.toList());
return queryTopDocs;
}

@VisibleForTesting
protected void updateOriginalQueryResults(
List<QuerySearchResult> querySearchResults,
final List<CompoundTopDocs> queryTopDocs,
List<Float> combinedMaxScores
) {
TopDocsAndMaxScore[] topDocsAndMaxScores = new TopDocsAndMaxScore[querySearchResults.size()];
for (int idx = 0; idx < querySearchResults.size(); idx++) {
QuerySearchResult querySearchResult = querySearchResults.get(idx);
TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs();
if (!(topDocsAndMaxScore.topDocs instanceof CompoundTopDocs)) {
continue;
}
topDocsAndMaxScores[idx] = topDocsAndMaxScore;
}
for (int i = 0; i < querySearchResults.size(); i++) {
QuerySearchResult querySearchResult = querySearchResults.get(i);
CompoundTopDocs updatedTopDocs = queryTopDocs.get(i);
if (Objects.isNull(updatedTopDocs)) {
continue;
}
float maxScore = updatedTopDocs.totalHits.value > 0 ? updatedTopDocs.scoreDocs[0].score : 0.0f;
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(updatedTopDocs, maxScore);
if (querySearchResult == null) {
continue;
}
querySearchResult.topDocs(topDocsAndMaxScore, null);
if (topDocsAndMaxScores[i] != null) {
topDocsAndMaxScores[i].maxScore = combinedMaxScores.get(i);
}
}
}
}

This file was deleted.

This file was deleted.

This file was deleted.

Loading

0 comments on commit 319d4bb

Please sign in to comment.