Skip to content

Commit

Permalink
Add geometric mean normalization for scores
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Aug 1, 2023
1 parent 1f67b94 commit 6e3b996
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor.combination;

import java.util.List;
import java.util.Map;
import java.util.Set;

/**
* Abstracts combination of scores based on geometrical mean method
*/
public class GeometricMeanScoreCombinationTechnique implements ScoreCombinationTechnique {

public static final String TECHNIQUE_NAME = "geometric_mean";
public static final String PARAM_NAME_WEIGHTS = "weights";
private static final Set<String> SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS);
private static final Float ZERO_SCORE = 0.0f;
private final List<Float> weights;
private final ScoreCombinationUtil scoreCombinationUtil;

public GeometricMeanScoreCombinationTechnique(final Map<String, Object> params, final ScoreCombinationUtil combinationUtil) {
scoreCombinationUtil = combinationUtil;
scoreCombinationUtil.validateParams(params, SUPPORTED_PARAMS);
weights = scoreCombinationUtil.getWeights(params);
}

/**
* Weighted geometric mean method for combining scores.
*
* We use formula below to calculate mean. It's based on fact that logarithm of geometric mean is the
* weighted arithmetic mean of the logarithms of individual scores.
*
* geometric_mean = exp(sum(weight_1*ln(score_1) + .... + weight_n*ln(score_n))/sum(weight_1 + ... + weight_n))
*/
@Override
public float combine(final float[] scores) {
float weightedLnSum = 0;
float sumOfWeights = 0;
for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) {
float score = scores[indexOfSubQuery];
if (score <= 0) {
// scores 0.0 need to be skipped, ln() of 0 is not defined
continue;
}
float weight = scoreCombinationUtil.getWeightForSubQuery(weights, indexOfSubQuery);
sumOfWeights += weight;
weightedLnSum += weight * Math.log(score);
}
return sumOfWeights == 0 ? ZERO_SCORE : (float) Math.exp(weightedLnSum / sumOfWeights);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ public class ScoreCombinationFactory {
ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME,
params -> new ArithmeticMeanScoreCombinationTechnique(params, scoreCombinationUtil),
HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME,
params -> new HarmonicMeanScoreCombinationTechnique(params, scoreCombinationUtil)
params -> new HarmonicMeanScoreCombinationTechnique(params, scoreCombinationUtil),
GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME,
params -> new GeometricMeanScoreCombinationTechnique(params, scoreCombinationUtil)
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,64 @@ public void testL2NormHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSu
assertHybridSearchResults(searchResponseAsMap, totalExpectedDocQty, minMaxExpectedScoresRange);
}

@SneakyThrows
public void testMinMaxNormGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccessful() {
initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME);
createSearchPipeline(
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
GEOMETRIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
);

NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null);
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);

HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
hybridQueryBuilder.add(neuralQueryBuilder);
hybridQueryBuilder.add(termQueryBuilder);

Map<String, Object> searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME,
hybridQueryBuilder,
null,
5,
Map.of("search_pipeline", SEARCH_PIPELINE)
);
int totalExpectedDocQty = 5;
float[] minMaxExpectedScoresRange = { 0.6f, 1.0f };
assertHybridSearchResults(searchResponseAsMap, totalExpectedDocQty, minMaxExpectedScoresRange);
}

@SneakyThrows
public void testL2NormGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccessful() {
initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME);
createSearchPipeline(
SEARCH_PIPELINE,
L2_NORMALIZATION_METHOD,
GEOMETRIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
);

NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null);
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);

HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
hybridQueryBuilder.add(neuralQueryBuilder);
hybridQueryBuilder.add(termQueryBuilder);

Map<String, Object> searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME,
hybridQueryBuilder,
null,
5,
Map.of("search_pipeline", SEARCH_PIPELINE)
);
int totalExpectedDocQty = 5;
float[] minMaxExpectedScoresRange = { 0.5f, 1.0f };
assertHybridSearchResults(searchResponseAsMap, totalExpectedDocQty, minMaxExpectedScoresRange);
}

private void initializeIndexIfNotExist(String indexName) throws IOException {
if (TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME.equalsIgnoreCase(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)) {
prepareKnnIndex(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,27 @@

public class ArithmeticMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests {

private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil();

public ArithmeticMeanScoreCombinationTechniqueTests() {
this.expectedScoreFunction = this::arithmeticMean;
}

public void testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores() {
ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of(), new ScoreCombinationUtil());
ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of(), scoreCombinationUtil);
testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores(technique);
}

public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() {
ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of(), new ScoreCombinationUtil());
ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of(), scoreCombinationUtil);
testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(technique);
}

public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() {
List<Double> weights = List.of(0.9, 0.2, 0.7);
ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
new ScoreCombinationUtil()
scoreCombinationUtil
);
testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, weights);
}
Expand All @@ -39,7 +41,7 @@ public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() {
List<Double> weights = List.of(0.9, 0.2, 0.7);
ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
new ScoreCombinationUtil()
scoreCombinationUtil
);
testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, weights);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor.combination;

import static org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique.PARAM_NAME_WEIGHTS;

import java.util.List;
import java.util.Map;

public class GeometricMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests {

private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil();

public GeometricMeanScoreCombinationTechniqueTests() {
this.expectedScoreFunction = this::geometricMean;
}

public void testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores() {
ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique(Map.of(), scoreCombinationUtil);
testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores(technique);
}

public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() {
ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique(Map.of(), scoreCombinationUtil);
testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(technique);
}

public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() {
List<Double> weights = List.of(0.9, 0.2, 0.7);
ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
scoreCombinationUtil
);
testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, weights);
}

public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() {
List<Double> weights = List.of(0.9, 0.2, 0.7);
ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
scoreCombinationUtil
);
testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, weights);
}

private float geometricMean(List<Float> scores, List<Double> weights) {
assertEquals(scores.size(), weights.size());
float sumOfWeights = 0;
float weightedSumOfLn = 0;
for (int i = 0; i < scores.size(); i++) {
float score = scores.get(i), weight = weights.get(i).floatValue();
if (score > 0) {
sumOfWeights += weight;
weightedSumOfLn += weight * Math.log(score);
}
}
return sumOfWeights == 0 ? 0f : (float) Math.exp(weightedSumOfLn / sumOfWeights);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,27 @@

public class HarmonicMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests {

private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil();

public HarmonicMeanScoreCombinationTechniqueTests() {
this.expectedScoreFunction = (scores, weights) -> harmonicMean(scores, weights);
}

public void testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores() {
ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of(), new ScoreCombinationUtil());
ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of(), scoreCombinationUtil);
testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores(technique);
}

public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() {
ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of(), new ScoreCombinationUtil());
ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of(), scoreCombinationUtil);
testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(technique);
}

public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() {
List<Double> weights = List.of(0.9, 0.2, 0.7);
ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
new ScoreCombinationUtil()
scoreCombinationUtil
);
testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, weights);
}
Expand All @@ -39,7 +41,7 @@ public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() {
List<Double> weights = List.of(0.9, 0.2, 0.7);
ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
new ScoreCombinationUtil()
scoreCombinationUtil
);
testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, weights);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ public void testHarmonicWeightedMean_whenCreatingByName_thenReturnCorrectInstanc
assertTrue(scoreCombinationTechnique instanceof HarmonicMeanScoreCombinationTechnique);
}

public void testGeometricWeightedMean_whenCreatingByName_thenReturnCorrectInstance() {
ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory();
ScoreCombinationTechnique scoreCombinationTechnique = scoreCombinationFactory.createCombination("geometric_mean");

assertNotNull(scoreCombinationTechnique);
assertTrue(scoreCombinationTechnique instanceof GeometricMeanScoreCombinationTechnique);
}

public void testUnsupportedTechnique_whenPassingInvalidName_thenFail() {
ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory();
IllegalArgumentException illegalArgumentException = expectThrows(
Expand Down

0 comments on commit 6e3b996

Please sign in to comment.