Skip to content

Commit

Permalink
Add harmonic mean combination (#238)
Browse files Browse the repository at this point in the history
* Add harmonic mean combination

Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski authored Aug 1, 2023
1 parent 6ad641a commit 1f67b94
Show file tree
Hide file tree
Showing 12 changed files with 412 additions and 191 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,8 @@
package org.opensearch.neuralsearch.processor.combination;

import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/**
* Abstracts combination of scores based on arithmetic mean method
Expand All @@ -23,20 +19,12 @@ public class ArithmeticMeanScoreCombinationTechnique implements ScoreCombination
private static final Set<String> SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS);
private static final Float ZERO_SCORE = 0.0f;
private final List<Float> weights;
private final ScoreCombinationUtil scoreCombinationUtil;

public ArithmeticMeanScoreCombinationTechnique(final Map<String, Object> params) {
validateParams(params);
weights = getWeights(params);
}

private List<Float> getWeights(final Map<String, Object> params) {
if (Objects.isNull(params) || params.isEmpty()) {
return List.of();
}
// get weights, we don't need to check for instance as it's done during validation
return ((List<Double>) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream()
.map(Double::floatValue)
.collect(Collectors.toUnmodifiableList());
public ArithmeticMeanScoreCombinationTechnique(final Map<String, Object> params, final ScoreCombinationUtil combinationUtil) {
scoreCombinationUtil = combinationUtil;
scoreCombinationUtil.validateParams(params, SUPPORTED_PARAMS);
weights = scoreCombinationUtil.getWeights(params);
}

/**
Expand All @@ -48,57 +36,19 @@ private List<Float> getWeights(final Map<String, Object> params) {
@Override
public float combine(final float[] scores) {
float combinedScore = 0.0f;
float weights = 0;
float sumOfWeights = 0;
for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) {
float score = scores[indexOfSubQuery];
if (score >= 0.0) {
float weight = getWeightForSubQuery(indexOfSubQuery);
float weight = scoreCombinationUtil.getWeightForSubQuery(weights, indexOfSubQuery);
score = score * weight;
combinedScore += score;
weights += weight;
sumOfWeights += weight;
}
}
if (weights == 0.0f) {
if (sumOfWeights == 0.0f) {
return ZERO_SCORE;
}
return combinedScore / weights;
}

private void validateParams(final Map<String, Object> params) {
if (Objects.isNull(params) || params.isEmpty()) {
return;
}
// check if only supported params are passed
Optional<String> optionalNotSupportedParam = params.keySet()
.stream()
.filter(paramName -> !SUPPORTED_PARAMS.contains(paramName))
.findFirst();
if (optionalNotSupportedParam.isPresent()) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"provided parameter for combination technique is not supported. supported parameters are [%s]",
SUPPORTED_PARAMS.stream().collect(Collectors.joining(","))
)
);
}

// check param types
if (params.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) {
if (!(params.get(PARAM_NAME_WEIGHTS) instanceof List)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "parameter [%s] must be a collection of numbers", PARAM_NAME_WEIGHTS)
);
}
}
}

/**
* Get weight for sub-query based on its index in the hybrid search query. Use user provided weight or 1.0 otherwise
* @param indexOfSubQuery 0-based index of sub-query in the Hybrid Search query
* @return weight for sub-query, use one that is set in processor/pipeline definition or 1.0 as default
*/
private float getWeightForSubQuery(int indexOfSubQuery) {
return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery) : 1.0f;
return combinedScore / sumOfWeights;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,99 +6,46 @@
package org.opensearch.neuralsearch.processor.combination;

import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/**
* Abstracts combination of scores based on arithmetic mean method
* Abstracts combination of scores based on harmonic mean method
*/
public class HarmonicMeanScoreCombinationTechnique implements ScoreCombinationTechnique {

public static final String TECHNIQUE_NAME = "arithmetic_mean";
public static final String TECHNIQUE_NAME = "harmonic_mean";
public static final String PARAM_NAME_WEIGHTS = "weights";
private static final Set<String> SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS);
private static final Float ZERO_SCORE = 0.0f;
private final List<Float> weights;
private final ScoreCombinationUtil scoreCombinationUtil;

public HarmonicMeanScoreCombinationTechnique(final Map<String, Object> params) {
validateParams(params);
weights = getWeights(params);
}

private List<Float> getWeights(final Map<String, Object> params) {
if (Objects.isNull(params) || params.isEmpty()) {
return List.of();
}
// get weights, we don't need to check for instance as it's done during validation
return ((List<Double>) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream()
.map(Double::floatValue)
.collect(Collectors.toUnmodifiableList());
public HarmonicMeanScoreCombinationTechnique(final Map<String, Object> params, final ScoreCombinationUtil combinationUtil) {
scoreCombinationUtil = combinationUtil;
scoreCombinationUtil.validateParams(params, SUPPORTED_PARAMS);
weights = scoreCombinationUtil.getWeights(params);
}

/**
* Arithmetic mean method for combining scores.
* score = (weight1*score1 + weight2*score2 +...+ weightN*scoreN)/(weight1 + weight2 + ... + weightN)
* Weighted harmonic mean method for combining scores.
* score = sum(weight_1 + .... + weight_n)/sum(weight_1/score_1 + ... + weight_n/score_n)
*
* Zero (0.0) scores are excluded from number of scores N
*/
@Override
public float combine(final float[] scores) {
float combinedScore = 0.0f;
float weights = 0;
float sumOfWeights = 0;
float sumOfHarmonics = 0;
for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) {
float score = scores[indexOfSubQuery];
if (score >= 0.0) {
float weight = getWeightForSubQuery(indexOfSubQuery);
score = score * weight;
combinedScore += score;
weights += weight;
if (score <= 0) {
continue;
}
float weightOfSubQuery = scoreCombinationUtil.getWeightForSubQuery(weights, indexOfSubQuery);
sumOfWeights += weightOfSubQuery;
sumOfHarmonics += weightOfSubQuery / score;
}
if (weights == 0.0f) {
return ZERO_SCORE;
}
return combinedScore / weights;
}

private void validateParams(final Map<String, Object> params) {
if (Objects.isNull(params) || params.isEmpty()) {
return;
}
// check if only supported params are passed
Optional<String> optionalNotSupportedParam = params.keySet()
.stream()
.filter(paramName -> !SUPPORTED_PARAMS.contains(paramName))
.findFirst();
if (optionalNotSupportedParam.isPresent()) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"provided parameter for combination technique is not supported. supported parameters are [%s]",
SUPPORTED_PARAMS.stream().collect(Collectors.joining(","))
)
);
}

// check param types
if (params.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) {
if (!(params.get(PARAM_NAME_WEIGHTS) instanceof List)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "parameter [%s] must be a collection of numbers", PARAM_NAME_WEIGHTS)
);
}
}
}

/**
* Get weight for sub-query based on its index in the hybrid search query. Use user provided weight or 1.0 otherwise
* @param indexOfSubQuery 0-based index of sub-query in the Hybrid Search query
* @return weight for sub-query, use one that is set in processor/pipeline definition or 1.0 as default
*/
private float getWeightForSubQuery(int indexOfSubQuery) {
return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery) : 1.0f;
return sumOfHarmonics > 0 ? sumOfWeights / sumOfHarmonics : ZERO_SCORE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,18 @@
* Abstracts creation of exact score combination method based on technique name
*/
public class ScoreCombinationFactory {
private static final ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil();

public static final ScoreCombinationTechnique DEFAULT_METHOD = new ArithmeticMeanScoreCombinationTechnique(Map.of());
public static final ScoreCombinationTechnique DEFAULT_METHOD = new ArithmeticMeanScoreCombinationTechnique(
Map.of(),
scoreCombinationUtil
);

private final Map<String, Function<Map<String, Object>, ScoreCombinationTechnique>> scoreCombinationMethodsMap = Map.of(
ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME,
ArithmeticMeanScoreCombinationTechnique::new
params -> new ArithmeticMeanScoreCombinationTechnique(params, scoreCombinationUtil),
HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME,
params -> new HarmonicMeanScoreCombinationTechnique(params, scoreCombinationUtil)
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.neuralsearch.processor.combination;

public interface ScoreCombinationTechnique {

/**
* Defines combination function specific to this technique
* @param scores array of collected original scores
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor.combination;

import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/**
* Collection of utility methods for score combination technique classes
*/
class ScoreCombinationUtil {
private static final String PARAM_NAME_WEIGHTS = "weights";

/**
* Get collection of weights based on user provided config
* @param params map of named parameters and their values
* @return collection of weights
*/
public List<Float> getWeights(final Map<String, Object> params) {
if (Objects.isNull(params) || params.isEmpty()) {
return List.of();
}
// get weights, we don't need to check for instance as it's done during validation
return ((List<Double>) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream()
.map(Double::floatValue)
.collect(Collectors.toUnmodifiableList());
}

/**
* Validate config parameters for this technique
* @param actualParams map of parameters in form of name-value
* @param supportedParams collection of parameters that we should validate against, typically that's what is supported by exact technique
*/
public void validateParams(final Map<String, Object> actualParams, final Set<String> supportedParams) {
if (Objects.isNull(actualParams) || actualParams.isEmpty()) {
return;
}
// check if only supported params are passed
Optional<String> optionalNotSupportedParam = actualParams.keySet()
.stream()
.filter(paramName -> !supportedParams.contains(paramName))
.findFirst();
if (optionalNotSupportedParam.isPresent()) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"provided parameter for combination technique is not supported. supported parameters are [%s]",
supportedParams.stream().collect(Collectors.joining(","))
)
);
}

// check param types
if (actualParams.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) {
if (!(actualParams.get(PARAM_NAME_WEIGHTS) instanceof List)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "parameter [%s] must be a collection of numbers", PARAM_NAME_WEIGHTS)
);
}
}
}

/**
* Get weight for sub-query based on its index in the hybrid search query. Use user provided weight or 1.0 otherwise
* @param weights collection of weights for sub-queries
* @param indexOfSubQuery 0-based index of sub-query in the Hybrid Search query
* @return weight for sub-query, use one that is set in processor/pipeline definition or 1.0 as default
*/
public float getWeightForSubQuery(final List<Float> weights, final int indexOfSubQuery) {
return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery) : 1.0f;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase {

private static final int DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND = 1000;
private static final String DEFAULT_USER_AGENT = "Kibana";
protected static final String NORMALIZATION_METHOD = "min_max";
protected static final String COMBINATION_METHOD = "arithmetic_mean";
protected static final String DEFAULT_NORMALIZATION_METHOD = "min_max";
protected static final String DEFAULT_COMBINATION_METHOD = "arithmetic_mean";
protected static final String PARAM_NAME_WEIGHTS = "weights";

protected final ClassLoader classLoader = this.getClass().getClassLoader();
Expand Down Expand Up @@ -556,7 +556,7 @@ public boolean isUpdateClusterSettings() {

@SneakyThrows
protected void createSearchPipelineWithResultsPostProcessor(final String pipelineId) {
createSearchPipeline(pipelineId, NORMALIZATION_METHOD, COMBINATION_METHOD, Map.of());
createSearchPipeline(pipelineId, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of());
}

@SneakyThrows
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
import org.opensearch.action.search.SearchPhaseName;
import org.opensearch.action.search.SearchProgressListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.common.breaker.CircuitBreaker;
import org.opensearch.common.breaker.NoopCircuitBreaker;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.common.util.BigArrays;
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.common.util.concurrent.OpenSearchThreadPoolExecutor;
import org.opensearch.core.common.breaker.CircuitBreaker;
import org.opensearch.core.common.breaker.NoopCircuitBreaker;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.neuralsearch.TestUtils;
import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique;
Expand Down
Loading

0 comments on commit 1f67b94

Please sign in to comment.