Skip to content

Commit

Permalink
Adding search processor for score normalization and combination
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Jul 19, 2023
1 parent c4d0a0c commit 3ebe720
Show file tree
Hide file tree
Showing 15 changed files with 1,722 additions and 5 deletions.
6 changes: 5 additions & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,12 @@ testClusters.integTest {
// Increase heap size from default of 512mb to 1gb. When heap size is 512mb, our integ tests sporadically fail due
// to ml-commons memory circuit breaker exception
jvmArgs("-Xms1g", "-Xmx1g")
// enable hybrid search for testing

// enable features for testing
// hybrid search
systemProperty('neural_search_hybrid_search_enabled', 'true')
// search pipelines
systemProperty('opensearch.experimental.feature.search_pipeline.enabled', 'true')
}

// Remote Integration Tests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import java.util.Optional;
import java.util.function.Supplier;

import lombok.extern.log4j.Log4j2;

import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.service.ClusterService;
Expand All @@ -24,7 +26,9 @@
import org.opensearch.ingest.Processor;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
Expand All @@ -33,9 +37,11 @@
import org.opensearch.plugins.ExtensiblePlugin;
import org.opensearch.plugins.IngestPlugin;
import org.opensearch.plugins.Plugin;
import org.opensearch.plugins.SearchPipelinePlugin;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.script.ScriptService;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.query.QueryPhaseSearcher;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;
Expand All @@ -45,7 +51,8 @@
/**
* Neural Search plugin class
*/
public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, IngestPlugin, ExtensiblePlugin {
@Log4j2
public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, IngestPlugin, ExtensiblePlugin, SearchPipelinePlugin {
/**
* Gates the functionality of hybrid search
* Currently query phase searcher added with hybrid search will conflict with concurrent search in core.
Expand Down Expand Up @@ -90,9 +97,18 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet
@Override
public Optional<QueryPhaseSearcher> getQueryPhaseSearcher() {
if (FeatureFlags.isEnabled(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED)) {
log.info("Registering hybrid query phase searcher");
return Optional.of(new HybridQueryPhaseSearcher());
}
log.info("Not registering hybrid query phase searcher because feature flag is disabled");
// we want feature be disabled by default due to risk of colliding and breaking concurrent search in core
return Optional.empty();
}

@Override
public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchPhaseResultsProcessor>> getSearchPhaseResultsProcessors(
Parameters parameters
) {
return Map.of(NormalizationProcessor.TYPE, new NormalizationProcessorFactory());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor;

import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

import joptsimple.internal.Strings;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.extern.log4j.Log4j2;

import org.apache.commons.lang3.EnumUtils;
import org.opensearch.action.search.QueryPhaseResultConsumer;
import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.action.search.SearchPhaseName;
import org.opensearch.action.search.SearchPhaseResults;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.search.CompoundTopDocs;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.query.QuerySearchResult;

import com.google.common.annotations.VisibleForTesting;

/**
* Processor for score normalization and combination on post query search results. Updates query results with
* normalized and combined scores for next phase (typically it's FETCH)
*/
@Log4j2
public class NormalizationProcessor implements SearchPhaseResultsProcessor {
public static final String TYPE = "normalization-processor";
public static final String NORMALIZATION_CLAUSE = "normalization";
public static final String COMBINATION_CLAUSE = "combination";
public static final String TECHNIQUE = "technique";

private final String tag;
private final String description;
@VisibleForTesting
@Getter(AccessLevel.PACKAGE)
final ScoreNormalizationTechnique normalizationTechnique;
@Getter(AccessLevel.PACKAGE)
final ScoreCombinationTechnique combinationTechnique;

/**
* Need all args constructor to validate parameters and fail fast
* @param tag
* @param description
* @param normalizationTechnique
* @param combinationTechnique
*/
public NormalizationProcessor(
final String tag,
final String description,
final String normalizationTechnique,
final String combinationTechnique
) {
this.tag = tag;
this.description = description;
validateParameters(normalizationTechnique, combinationTechnique);
this.normalizationTechnique = ScoreNormalizationTechnique.valueOf(normalizationTechnique);
this.combinationTechnique = ScoreCombinationTechnique.valueOf(combinationTechnique);
}

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
* @param <Result>
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext
) {
if (searchPhaseResult instanceof QueryPhaseResultConsumer) {
QueryPhaseResultConsumer queryPhaseResultConsumer = (QueryPhaseResultConsumer) searchPhaseResult;
Optional<SearchPhaseResult> maybeResult = queryPhaseResultConsumer.getAtomicArray()
.asList()
.stream()
.filter(Objects::nonNull)
.findFirst();
if (isNotHybridQuery(maybeResult)) {
return;
}

TopDocsAndMaxScore[] topDocsAndMaxScores = getCompoundQueryTopDocsForResult(searchPhaseResult);
CompoundTopDocs[] queryTopDocs = Arrays.stream(topDocsAndMaxScores)
.map(td -> td != null ? (CompoundTopDocs) td.topDocs : null)
.collect(Collectors.toList())
.toArray(CompoundTopDocs[]::new);

ScoreNormalizer scoreNormalizer = new ScoreNormalizer();
scoreNormalizer.normalizeScores(queryTopDocs, normalizationTechnique);

ScoreCombiner scoreCombinator = new ScoreCombiner();
List<Float> combinedMaxScores = scoreCombinator.combineScores(queryTopDocs, combinationTechnique);

updateOriginalQueryResults(searchPhaseResult, queryTopDocs, topDocsAndMaxScores, combinedMaxScores);
}
}

@Override
public SearchPhaseName getBeforePhase() {
return SearchPhaseName.QUERY;
}

@Override
public SearchPhaseName getAfterPhase() {
return SearchPhaseName.FETCH;
}

@Override
public String getType() {
return TYPE;
}

@Override
public String getTag() {
return tag;
}

@Override
public String getDescription() {
return description;
}

@Override
public boolean isIgnoreFailure() {
return true;
}

protected void validateParameters(final String normalizationTechniqueName, final String combinationTechniqueName) {
if (Strings.isNullOrEmpty(normalizationTechniqueName)) {
throw new IllegalArgumentException("normalization technique cannot be empty");
}
if (Strings.isNullOrEmpty(combinationTechniqueName)) {
throw new IllegalArgumentException("combination technique cannot be empty");
}
if (!EnumUtils.isValidEnum(ScoreNormalizationTechnique.class, normalizationTechniqueName)) {
log.error(String.format(Locale.ROOT, "provided normalization technique [%s] is not supported", normalizationTechniqueName));
throw new IllegalArgumentException("provided normalization technique is not supported");
}
if (!EnumUtils.isValidEnum(ScoreCombinationTechnique.class, combinationTechniqueName)) {
log.error(String.format(Locale.ROOT, "provided combination technique [%s] is not supported", combinationTechniqueName));
throw new IllegalArgumentException("provided combination technique is not supported");
}
}

private boolean isNotHybridQuery(final Optional<SearchPhaseResult> maybeResult) {
return maybeResult.isEmpty()
|| Objects.isNull(maybeResult.get().queryResult())
|| !(maybeResult.get().queryResult().topDocs().topDocs instanceof CompoundTopDocs);
}

private <Result extends SearchPhaseResult> TopDocsAndMaxScore[] getCompoundQueryTopDocsForResult(
final SearchPhaseResults<Result> results
) {
List<Result> preShardResultList = results.getAtomicArray().asList();
TopDocsAndMaxScore[] result = new TopDocsAndMaxScore[results.getAtomicArray().length()];
int idx = 0;
for (Result shardResult : preShardResultList) {
if (shardResult == null) {
idx++;
continue;
}
QuerySearchResult querySearchResult = shardResult.queryResult();
TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs();
if (!(topDocsAndMaxScore.topDocs instanceof CompoundTopDocs)) {
idx++;
continue;
}
result[idx++] = topDocsAndMaxScore;
}
return result;
}

@VisibleForTesting
protected <Result extends SearchPhaseResult> void updateOriginalQueryResults(
final SearchPhaseResults<Result> results,
final CompoundTopDocs[] queryTopDocs,
TopDocsAndMaxScore[] topDocsAndMaxScores,
List<Float> combinedMaxScores
) {
List<Result> preShardResultList = results.getAtomicArray().asList();
for (int i = 0; i < preShardResultList.size(); i++) {
QuerySearchResult querySearchResult = preShardResultList.get(i).queryResult();
CompoundTopDocs updatedTopDocs = queryTopDocs[i];
if (updatedTopDocs == null) {
continue;
}
float maxScore = updatedTopDocs.totalHits.value > 0 ? updatedTopDocs.scoreDocs[0].score : 0.0f;
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(updatedTopDocs, maxScore);
querySearchResult.topDocs(topDocsAndMaxScore, null);
if (topDocsAndMaxScores[i] != null) {
topDocsAndMaxScores[i].maxScore = combinedMaxScores.get(i);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor;

/**
* Collection of techniques for score combination
*/
public enum ScoreCombinationTechnique {

/**
* Arithmetic mean method for combining scores.
* cscore = (score1 + score2 +...+ scoreN)/N
*
* Zero (0.0) scores are excluded from number of scores N
*/
ARITHMETIC_MEAN {

@Override
public float combine(float[] scores) {
float combinedScore = 0.0f;
int count = 0;
for (float score : scores) {
if (score >= 0.0) {
combinedScore += score;
count++;
}
}
return combinedScore / count;
}
};

public static final ScoreCombinationTechnique DEFAULT = ARITHMETIC_MEAN;

/**
* Defines combination function specific to this technique
* @param scores array of collected original scores
* @return combined score
*/
abstract float combine(float[] scores);
}
Loading

0 comments on commit 3ebe720

Please sign in to comment.