Skip to content

Commit

Permalink
Use weights instead of count for arithmetic mean
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Jul 27, 2023
1 parent 376af08 commit 0c36569
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,27 +41,27 @@ private List<Float> getWeights(final Map<String, Object> params) {

/**
* Arithmetic mean method for combining scores.
* cscore = (score1 + score2 +...+ scoreN)/N
* score = (weight1*score1 + weight2*score2 +...+ weightN*scoreN)/(weight1 + weight2 + ... + weightN)
*
* Zero (0.0) scores are excluded from number of scores N
*/
@Override
public float combine(final float[] scores) {
float combinedScore = 0.0f;
int count = 0;
float weights = 0;
for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) {
float score = scores[indexOfSubQuery];
if (score >= 0.0) {
float weight = getWeightForSubQuery(indexOfSubQuery);
score = score * weight;
combinedScore += score;
count++;
weights += weight;
}
}
if (count == 0) {
if (weights == 0.0f) {
return ZERO_SCORE;
}
return combinedScore / count;
return combinedScore / weights;
}

private void validateParams(final Map<String, Object> params) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ public void testResultProcessor_whenMultipleShardsAndPartialMatches_thenSuccessf
}

@SneakyThrows
public void testCombinationParams_whenWeightsParamSet_thenSuccessful() {
public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() {
initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME);
// check case when number of weights and sub-queries are same
createSearchPipeline(
Expand All @@ -275,7 +275,7 @@ public void testCombinationParams_whenWeightsParamSet_thenSuccessful() {
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertWeightedScores(searchResponseWithWeights1AsMap, 0.6, 0.5, 0.001);
assertWeightedScores(searchResponseWithWeights1AsMap, 1.0, 1.0, 0.001);

// delete existing pipeline and create a new one with another set of weights
deleteSearchPipeline(SEARCH_PIPELINE);
Expand All @@ -294,7 +294,7 @@ public void testCombinationParams_whenWeightsParamSet_thenSuccessful() {
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertWeightedScores(searchResponseWithWeights2AsMap, 2.0, 0.8, 0.001);
assertWeightedScores(searchResponseWithWeights2AsMap, 1.0, 1.0, 0.001);

// check case when number of weights is less than number of sub-queries
// delete existing pipeline and create a new one with another set of weights
Expand All @@ -314,7 +314,7 @@ public void testCombinationParams_whenWeightsParamSet_thenSuccessful() {
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertWeightedScores(searchResponseWithWeights3AsMap, 1.0, 0.8, 0.001);
assertWeightedScores(searchResponseWithWeights3AsMap, 1.0, 1.0, 0.001);

// check case when number of weights is more than number of sub-queries
// delete existing pipeline and create a new one with another set of weights
Expand All @@ -334,7 +334,7 @@ public void testCombinationParams_whenWeightsParamSet_thenSuccessful() {
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertWeightedScores(searchResponseWithWeights4AsMap, 0.6, 0.5, 0.001);
assertWeightedScores(searchResponseWithWeights4AsMap, 1.0, 1.0, 0.001);
}

private void initializeIndexIfNotExist(String indexName) throws IOException {
Expand Down

0 comments on commit 0c36569

Please sign in to comment.