diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java index 8ac7c63be..4eb9564e6 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java @@ -41,27 +41,27 @@ private List getWeights(final Map params) { /** * Arithmetic mean method for combining scores. - * cscore = (score1 + score2 +...+ scoreN)/N + * score = (weight1*score1 + weight2*score2 +...+ weightN*scoreN)/(weight1 + weight2 + ... + weightN) * * Zero (0.0) scores are excluded from number of scores N */ @Override public float combine(final float[] scores) { float combinedScore = 0.0f; - int count = 0; + float weights = 0; for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) { float score = scores[indexOfSubQuery]; if (score >= 0.0) { float weight = getWeightForSubQuery(indexOfSubQuery); score = score * weight; combinedScore += score; - count++; + weights += weight; } } - if (count == 0) { + if (weights == 0.0f) { return ZERO_SCORE; } - return combinedScore / count; + return combinedScore / weights; } private void validateParams(final Map params) { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java index 795ed1e5f..126275a47 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java @@ -252,7 +252,7 @@ public void testResultProcessor_whenMultipleShardsAndPartialMatches_thenSuccessf } @SneakyThrows - public void testCombinationParams_whenWeightsParamSet_thenSuccessful() { + public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME); // check case when number of weights and sub-queries are same createSearchPipeline( @@ -275,7 +275,7 @@ public void testCombinationParams_whenWeightsParamSet_thenSuccessful() { Map.of("search_pipeline", SEARCH_PIPELINE) ); - assertWeightedScores(searchResponseWithWeights1AsMap, 0.6, 0.5, 0.001); + assertWeightedScores(searchResponseWithWeights1AsMap, 1.0, 1.0, 0.001); // delete existing pipeline and create a new one with another set of weights deleteSearchPipeline(SEARCH_PIPELINE); @@ -294,7 +294,7 @@ public void testCombinationParams_whenWeightsParamSet_thenSuccessful() { Map.of("search_pipeline", SEARCH_PIPELINE) ); - assertWeightedScores(searchResponseWithWeights2AsMap, 2.0, 0.8, 0.001); + assertWeightedScores(searchResponseWithWeights2AsMap, 1.0, 1.0, 0.001); // check case when number of weights is less than number of sub-queries // delete existing pipeline and create a new one with another set of weights @@ -314,7 +314,7 @@ public void testCombinationParams_whenWeightsParamSet_thenSuccessful() { Map.of("search_pipeline", SEARCH_PIPELINE) ); - assertWeightedScores(searchResponseWithWeights3AsMap, 1.0, 0.8, 0.001); + assertWeightedScores(searchResponseWithWeights3AsMap, 1.0, 1.0, 0.001); // check case when number of weights is more than number of sub-queries // delete existing pipeline and create a new one with another set of weights @@ -334,7 +334,7 @@ public void testCombinationParams_whenWeightsParamSet_thenSuccessful() { Map.of("search_pipeline", SEARCH_PIPELINE) ); - assertWeightedScores(searchResponseWithWeights4AsMap, 0.6, 0.5, 0.001); + assertWeightedScores(searchResponseWithWeights4AsMap, 1.0, 1.0, 0.001); } private void initializeIndexIfNotExist(String indexName) throws IOException {