Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding search processor for score normalization and combination #227

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
Expand Down Expand Up @@ -65,6 +66,7 @@ public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin,
public static final String NEURAL_SEARCH_HYBRID_SEARCH_ENABLED = "neural_search_hybrid_search_enabled";
private MLCommonsClientAccessor clientAccessor;
private NormalizationProcessorWorkflow normalizationProcessorWorkflow;
private ScoreCombinationFactory scoreCombinationFactory;

@Override
public Collection<Object> createComponents(
Expand All @@ -82,6 +84,7 @@ public Collection<Object> createComponents(
) {
NeuralQueryBuilder.initialize(clientAccessor);
normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner());
scoreCombinationFactory = new ScoreCombinationFactory();
return List.of(clientAccessor);
}

Expand Down Expand Up @@ -114,6 +117,9 @@ public Optional<QueryPhaseSearcher> getQueryPhaseSearcher() {
public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchPhaseResultsProcessor>> getSearchPhaseResultsProcessors(
Parameters parameters
) {
return Map.of(NormalizationProcessor.TYPE, new NormalizationProcessorFactory(normalizationProcessorWorkflow));
return Map.of(
NormalizationProcessor.TYPE,
new NormalizationProcessorFactory(normalizationProcessorWorkflow, scoreCombinationFactory)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ public class NormalizationProcessorWorkflow {
* Start execution of this workflow
* @param querySearchResults input data with QuerySearchResult from multiple shards
* @param normalizationTechnique technique for score normalization
* @param combinationTechnique technique for score combination
* @param combinationMethod technique for score combination
*/
public void execute(
final List<QuerySearchResult> querySearchResults,
final ScoreNormalizationTechnique normalizationTechnique,
final ScoreCombinationTechnique combinationTechnique
final ScoreCombinationTechnique combinationMethod
) {
// pre-process data
List<CompoundTopDocs> queryTopDocs = getQueryTopDocs(querySearchResults);
Expand All @@ -49,7 +49,7 @@ public void execute(
scoreNormalizer.normalizeScores(queryTopDocs, normalizationTechnique);

// combine
List<Float> combinedMaxScores = scoreCombiner.combineScores(queryTopDocs, combinationTechnique);
List<Float> combinedMaxScores = scoreCombiner.combineScores(queryTopDocs, combinationMethod);

// post-process data
updateOriginalQueryResults(querySearchResults, queryTopDocs, combinedMaxScores);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
* Abstracts combination of scores based on arithmetic mean method
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public class ArithmeticMeanScoreCombinationMethod implements ScoreCombinationMethod {
public class ArithmeticMeanScoreCombinationTechnique implements ScoreCombinationTechnique {

private static final ArithmeticMeanScoreCombinationMethod INSTANCE = new ArithmeticMeanScoreCombinationMethod();
private static final ArithmeticMeanScoreCombinationTechnique INSTANCE = new ArithmeticMeanScoreCombinationTechnique();
public static final String TECHNIQUE_NAME = "arithmetic_mean";
private static final Float ZERO_SCORE = 0.0f;

public static ArithmeticMeanScoreCombinationMethod getInstance() {
public static ArithmeticMeanScoreCombinationTechnique getInstance() {
return INSTANCE;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor.combination;

import java.util.Map;
import java.util.Optional;

import org.opensearch.OpenSearchParseException;

/**
* Abstracts creation of exact score combination method based on technique name
*/
public class ScoreCombinationFactory {

private static final ScoreCombinationTechnique DEFAULT_COMBINATION_METHOD = ArithmeticMeanScoreCombinationTechnique.getInstance();

private final Map<String, ScoreCombinationTechnique> scoreCombinationMethodsMap = Map.of(
ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME,
ArithmeticMeanScoreCombinationTechnique.getInstance()
);

/**
* Get score combination method by technique name
* @param technique name of technique
* @return
*/
public ScoreCombinationTechnique createCombination(final String technique) {
return Optional.ofNullable(scoreCombinationMethodsMap.get(technique))
.orElseThrow(() -> new OpenSearchParseException("provided combination technique is not supported"));
}

/**
* Default combination method
* @return combination method that is used in case user did not provide combination technique name
*/
public ScoreCombinationTechnique defaultCombination() {
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
return DEFAULT_COMBINATION_METHOD;
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,11 @@

package org.opensearch.neuralsearch.processor.combination;

import lombok.AllArgsConstructor;

/**
* Collection of techniques for score combination
*/
@AllArgsConstructor
public enum ScoreCombinationTechnique {

public interface ScoreCombinationTechnique {
/**
* Arithmetic mean method for combining scores.
* Defines combination function specific to this technique
* @param scores array of collected original scores
* @return combined score
*/
ARITHMETIC_MEAN(ArithmeticMeanScoreCombinationMethod.getInstance());

public static final ScoreCombinationTechnique DEFAULT = ARITHMETIC_MEAN;
private final ScoreCombinationMethod method;

public float combine(final float[] scores) {
return method.combine(scores);
}
float combine(final float[] scores);
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public class ScoreCombiner {
* Different score combination techniques are different in step 2, where we create map of "doc id" - "combined score",
* other steps are same for all techniques.
* @param queryTopDocs query results that need to be normalized, mutated by method execution
* @param scoreCombinationTechnique exact combination technique that should be applied
* @param scoreCombinationTechnique exact combination method that should be applied
* @return list of max combined scores for each shard
*/
public List<Float> combineScores(final List<CompoundTopDocs> queryTopDocs, final ScoreCombinationTechnique scoreCombinationTechnique) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.apache.commons.lang3.StringUtils;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.pipeline.Processor;
Expand All @@ -28,6 +29,7 @@
@AllArgsConstructor
public class NormalizationProcessorFactory implements Processor.Factory<SearchPhaseResultsProcessor> {
private final NormalizationProcessorWorkflow normalizationProcessorWorkflow;
private ScoreCombinationFactory scoreCombinationFactory;

@Override
public SearchPhaseResultsProcessor create(
Expand All @@ -54,22 +56,25 @@ public SearchPhaseResultsProcessor create(
config,
NormalizationProcessor.COMBINATION_CLAUSE
);
String combinationTechnique = Objects.isNull(combinationClause)
? ScoreCombinationTechnique.DEFAULT.name()
: (String) combinationClause.getOrDefault(NormalizationProcessor.TECHNIQUE, "");

validateParameters(normalizationTechnique, combinationTechnique, tag);
ScoreCombinationTechnique scoreCombinationTechnique = scoreCombinationFactory.defaultCombination();
if (Objects.nonNull(combinationClause)) {
String combinationTechnique = (String) combinationClause.getOrDefault(NormalizationProcessor.TECHNIQUE, "");
scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique);
}

validateParameters(normalizationTechnique, tag);

return new NormalizationProcessor(
tag,
description,
ScoreNormalizationTechnique.valueOf(normalizationTechnique),
ScoreCombinationTechnique.valueOf(combinationTechnique),
scoreCombinationTechnique,
normalizationProcessorWorkflow
);
}

protected void validateParameters(final String normalizationTechniqueName, final String combinationTechniqueName, final String tag) {
protected void validateParameters(final String normalizationTechniqueName, final String tag) {
if (StringUtils.isEmpty(normalizationTechniqueName)) {
throw newConfigurationException(
NormalizationProcessor.TYPE,
Expand All @@ -78,14 +83,6 @@ protected void validateParameters(final String normalizationTechniqueName, final
"normalization technique cannot be empty"
);
}
if (StringUtils.isEmpty(combinationTechniqueName)) {
throw newConfigurationException(
NormalizationProcessor.TYPE,
tag,
NormalizationProcessor.TECHNIQUE,
"combination technique cannot be empty"
);
}
if (!EnumUtils.isValidEnum(ScoreNormalizationTechnique.class, normalizationTechniqueName)) {
throw newConfigurationException(
NormalizationProcessor.TYPE,
Expand All @@ -94,13 +91,5 @@ protected void validateParameters(final String normalizationTechniqueName, final
"provided normalization technique is not supported"
);
}
if (!EnumUtils.isValidEnum(ScoreCombinationTechnique.class, combinationTechniqueName)) {
throw newConfigurationException(
NormalizationProcessor.TYPE,
tag,
NormalizationProcessor.TECHNIQUE,
"provided combination technique is not supported"
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.neuralsearch.OpenSearchSecureRestTestCase;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;

import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -561,7 +561,7 @@ protected void createSearchPipelineWithResultsPostProcessor(final String pipelin
createSearchPipeline(
pipelineId,
ScoreNormalizationTechnique.MIN_MAX.name(),
ScoreCombinationTechnique.ARITHMETIC_MEAN.name(),
ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME,
Map.of()
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.common.util.concurrent.OpenSearchThreadPoolExecutor;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
Expand Down Expand Up @@ -99,7 +100,7 @@ public void testClassFields_whenCreateNewObject_thenAllFieldsPresent() {
PROCESSOR_TAG,
DESCRIPTION,
ScoreNormalizationTechnique.MIN_MAX,
ScoreCombinationTechnique.ARITHMETIC_MEAN,
new ScoreCombinationFactory().createCombination(ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME),
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner())
);

Expand All @@ -118,7 +119,7 @@ public void testSearchResultTypes_whenCompoundDocs_thenDoNormalizationCombinatio
PROCESSOR_TAG,
DESCRIPTION,
ScoreNormalizationTechnique.MIN_MAX,
ScoreCombinationTechnique.ARITHMETIC_MEAN,
new ScoreCombinationFactory().createCombination(ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME),
normalizationProcessorWorkflow
);

Expand Down Expand Up @@ -177,7 +178,7 @@ public void testEmptySearchResults_whenEmptySearchResults_thenDoNotExecuteWorkfl
PROCESSOR_TAG,
DESCRIPTION,
ScoreNormalizationTechnique.MIN_MAX,
ScoreCombinationTechnique.ARITHMETIC_MEAN,
new ScoreCombinationFactory().createCombination(ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME),
normalizationProcessorWorkflow
);
SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class);
Expand All @@ -194,7 +195,7 @@ public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResul
PROCESSOR_TAG,
DESCRIPTION,
ScoreNormalizationTechnique.MIN_MAX,
ScoreCombinationTechnique.ARITHMETIC_MEAN,
new ScoreCombinationFactory().createCombination(ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME),
normalizationProcessorWorkflow
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import org.opensearch.action.OriginalIndices;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
Expand Down Expand Up @@ -60,7 +60,11 @@ public void testSearchResultTypes_whenCompoundDocs_thenDoNormalizationCombinatio
querySearchResults.add(querySearchResult);
}

normalizationProcessorWorkflow.execute(querySearchResults, ScoreNormalizationTechnique.DEFAULT, ScoreCombinationTechnique.DEFAULT);
normalizationProcessorWorkflow.execute(
querySearchResults,
ScoreNormalizationTechnique.DEFAULT,
new ScoreCombinationFactory().defaultCombination()
);

verify(normalizationProcessorWorkflow, times(1)).updateOriginalQueryResults(any(), any(), any());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.search.CompoundTopDocs;
import org.opensearch.test.OpenSearchTestCase;

public class ScoreCombinationMethodTests extends OpenSearchTestCase {
public class ScoreCombinationTechniqueTests extends OpenSearchTestCase {

public void testEmptyResults_whenEmptyResultsAndDefaultMethod_thenNoProcessing() {
ScoreCombiner scoreCombiner = new ScoreCombiner();
List<Float> maxScores = scoreCombiner.combineScores(List.of(), ScoreCombinationTechnique.ARITHMETIC_MEAN);
List<Float> maxScores = scoreCombiner.combineScores(List.of(), new ScoreCombinationFactory().defaultCombination());
assertNotNull(maxScores);
assertEquals(0, maxScores.size());
}
Expand Down Expand Up @@ -60,7 +60,7 @@ public void testCombination_whenMultipleSubqueriesResultsAndDefaultMethod_thenSc
)
);

List<Float> combinedMaxScores = scoreCombiner.combineScores(queryTopDocs, ScoreCombinationTechnique.ARITHMETIC_MEAN);
List<Float> combinedMaxScores = scoreCombiner.combineScores(queryTopDocs, new ScoreCombinationFactory().defaultCombination());

assertNotNull(queryTopDocs);
assertEquals(3, queryTopDocs.size());
Expand Down
Loading