Skip to content

Commit

Permalink
Adding weights param for combination technique (#235)
Browse files Browse the repository at this point in the history
* Adding weights param for combination technique

Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Aug 3, 2023
1 parent fdec5fa commit 65b6c8c
Show file tree
Hide file tree
Showing 10 changed files with 704 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,100 @@

package org.opensearch.neuralsearch.processor.combination;

import lombok.NoArgsConstructor;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/**
* Abstracts combination of scores based on arithmetic mean method
*/
@NoArgsConstructor
public class ArithmeticMeanScoreCombinationTechnique implements ScoreCombinationTechnique {

public static final String TECHNIQUE_NAME = "arithmetic_mean";
public static final String PARAM_NAME_WEIGHTS = "weights";
private static final Set<String> SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS);
private static final Float ZERO_SCORE = 0.0f;
private final List<Float> weights;

public ArithmeticMeanScoreCombinationTechnique(final Map<String, Object> params) {
validateParams(params);
weights = getWeights(params);
}

private List<Float> getWeights(final Map<String, Object> params) {
if (Objects.isNull(params) || params.isEmpty()) {
return List.of();
}
// get weights, we don't need to check for instance as it's done during validation
return ((List<Double>) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream()
.map(Double::floatValue)
.collect(Collectors.toUnmodifiableList());
}

/**
* Arithmetic mean method for combining scores.
* cscore = (score1 + score2 +...+ scoreN)/N
* score = (weight1*score1 + weight2*score2 +...+ weightN*scoreN)/(weight1 + weight2 + ... + weightN)
*
* Zero (0.0) scores are excluded from number of scores N
*/
@Override
public float combine(final float[] scores) {
float combinedScore = 0.0f;
int count = 0;
for (float score : scores) {
float weights = 0;
for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) {
float score = scores[indexOfSubQuery];
if (score >= 0.0) {
float weight = getWeightForSubQuery(indexOfSubQuery);
score = score * weight;
combinedScore += score;
count++;
weights += weight;
}
}
if (count == 0) {
if (weights == 0.0f) {
return ZERO_SCORE;
}
return combinedScore / count;
return combinedScore / weights;
}

private void validateParams(final Map<String, Object> params) {
if (Objects.isNull(params) || params.isEmpty()) {
return;
}
// check if only supported params are passed
Optional<String> optionalNotSupportedParam = params.keySet()
.stream()
.filter(paramName -> !SUPPORTED_PARAMS.contains(paramName))
.findFirst();
if (optionalNotSupportedParam.isPresent()) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"provided parameter for combination technique is not supported. supported parameters are [%s]",
SUPPORTED_PARAMS.stream().collect(Collectors.joining(","))
)
);
}

// check param types
if (params.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) {
if (!(params.get(PARAM_NAME_WEIGHTS) instanceof List)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "parameter [%s] must be a collection of numbers", PARAM_NAME_WEIGHTS)
);
}
}
}

/**
* Get weight for sub-query based on its index in the hybrid search query. Use user provided weight or 1.0 otherwise
* @param indexOfSubQuery 0-based index of sub-query in the Hybrid Search query
* @return weight for sub-query, use one that is set in processor/pipeline definition or 1.0 as default
*/
private float getWeightForSubQuery(int indexOfSubQuery) {
return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery) : 1.0f;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,18 @@

import java.util.Map;
import java.util.Optional;

import org.opensearch.OpenSearchParseException;
import java.util.function.Function;

/**
* Abstracts creation of exact score combination method based on technique name
*/
public class ScoreCombinationFactory {

public static final ScoreCombinationTechnique DEFAULT_METHOD = new ArithmeticMeanScoreCombinationTechnique();
public static final ScoreCombinationTechnique DEFAULT_METHOD = new ArithmeticMeanScoreCombinationTechnique(Map.of());

private final Map<String, ScoreCombinationTechnique> scoreCombinationMethodsMap = Map.of(
private final Map<String, Function<Map<String, Object>, ScoreCombinationTechnique>> scoreCombinationMethodsMap = Map.of(
ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME,
new ArithmeticMeanScoreCombinationTechnique()
ArithmeticMeanScoreCombinationTechnique::new
);

/**
Expand All @@ -28,7 +27,18 @@ public class ScoreCombinationFactory {
* @return instance of ScoreCombinationTechnique for technique name
*/
public ScoreCombinationTechnique createCombination(final String technique) {
return createCombination(technique, Map.of());
}

/**
* Get score combination method by technique name
* @param technique name of technique
* @param params parameters that combination technique may use
* @return instance of ScoreCombinationTechnique for technique name
*/
public ScoreCombinationTechnique createCombination(final String technique, final Map<String, Object> params) {
return Optional.ofNullable(scoreCombinationMethodsMap.get(technique))
.orElseThrow(() -> new OpenSearchParseException("provided combination technique is not supported"));
.orElseThrow(() -> new IllegalArgumentException("provided combination technique is not supported"))
.apply(params);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.neuralsearch.processor.factory;

import static org.opensearch.ingest.ConfigurationUtils.readOptionalMap;
import static org.opensearch.ingest.ConfigurationUtils.readStringProperty;

import java.util.Map;
import java.util.Objects;
Expand All @@ -14,8 +15,10 @@

import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.MinMaxScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.pipeline.Processor;
Expand All @@ -29,6 +32,7 @@ public class NormalizationProcessorFactory implements Processor.Factory<SearchPh
public static final String NORMALIZATION_CLAUSE = "normalization";
public static final String COMBINATION_CLAUSE = "combination";
public static final String TECHNIQUE = "technique";
public static final String PARAMETERS = "parameters";

private final NormalizationProcessorWorkflow normalizationProcessorWorkflow;
private ScoreNormalizationFactory scoreNormalizationFactory;
Expand All @@ -46,16 +50,30 @@ public SearchPhaseResultsProcessor create(
Map<String, Object> normalizationClause = readOptionalMap(NormalizationProcessor.TYPE, tag, config, NORMALIZATION_CLAUSE);
ScoreNormalizationTechnique normalizationTechnique = ScoreNormalizationFactory.DEFAULT_METHOD;
if (Objects.nonNull(normalizationClause)) {
String normalizationTechniqueName = (String) normalizationClause.getOrDefault(TECHNIQUE, "");
String normalizationTechniqueName = readStringProperty(
NormalizationProcessor.TYPE,
tag,
normalizationClause,
TECHNIQUE,
MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME
);
normalizationTechnique = scoreNormalizationFactory.createNormalization(normalizationTechniqueName);
}

Map<String, Object> combinationClause = readOptionalMap(NormalizationProcessor.TYPE, tag, config, COMBINATION_CLAUSE);

ScoreCombinationTechnique scoreCombinationTechnique = ScoreCombinationFactory.DEFAULT_METHOD;
if (Objects.nonNull(combinationClause)) {
String combinationTechnique = (String) combinationClause.getOrDefault(TECHNIQUE, "");
scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique);
String combinationTechnique = readStringProperty(
NormalizationProcessor.TYPE,
tag,
combinationClause,
TECHNIQUE,
ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME
);
// check for optional combination params
Map<String, Object> combinationParams = readOptionalMap(NormalizationProcessor.TYPE, tag, combinationClause, PARAMETERS);
scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique, combinationParams);
}

return new NormalizationProcessor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
*/
public class MinMaxScoreNormalizationTechnique implements ScoreNormalizationTechnique {

protected static final String TECHNIQUE_NAME = "min_max";
public static final String TECHNIQUE_NAME = "min_max";
private static final float MIN_SCORE = 0.001f;
private static final float SINGLE_RESULT_SCORE = 1.0f;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import java.util.Map;
import java.util.Optional;

import org.opensearch.OpenSearchParseException;

/**
* Abstracts creation of exact score normalization method based on technique name
*/
Expand All @@ -29,6 +27,6 @@ public class ScoreNormalizationFactory {
*/
public ScoreNormalizationTechnique createNormalization(final String technique) {
return Optional.ofNullable(scoreNormalizationMethodsMap.get(technique))
.orElseThrow(() -> new OpenSearchParseException("provided normalization technique is not supported"));
.orElseThrow(() -> new IllegalArgumentException("provided normalization technique is not supported"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Predicate;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -57,8 +58,9 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase {

private static final int DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND = 1000;
private static final String DEFAULT_USER_AGENT = "Kibana";
private static final String NORMALIZATION_METHOD = "min_max";
private static final String COMBINATION_METHOD = "arithmetic_mean";
protected static final String NORMALIZATION_METHOD = "min_max";
protected static final String COMBINATION_METHOD = "arithmetic_mean";
protected static final String PARAM_NAME_WEIGHTS = "weights";

protected final ClassLoader classLoader = this.getClass().getClassLoader();

Expand Down Expand Up @@ -562,30 +564,31 @@ protected void createSearchPipeline(
final String pipelineId,
final String normalizationMethod,
String combinationMethod,
final Map<String, Object> combinationParams
final Map<String, String> combinationParams
) {
StringBuilder stringBuilderForContentBody = new StringBuilder();
stringBuilderForContentBody.append("{\"description\": \"Post processor pipeline\",")
.append("\"phase_results_processors\": [{ ")
.append("\"normalization-processor\": {")
.append("\"normalization\": {")
.append("\"technique\": \"%s\"")
.append("},")
.append("\"combination\": {")
.append("\"technique\": \"%s\"");
if (Objects.nonNull(combinationParams) && !combinationParams.isEmpty()) {
stringBuilderForContentBody.append(", \"parameters\": {");
if (combinationParams.containsKey(PARAM_NAME_WEIGHTS)) {
stringBuilderForContentBody.append("\"weights\": ").append(combinationParams.get(PARAM_NAME_WEIGHTS));
}
stringBuilderForContentBody.append(" }");
}
stringBuilderForContentBody.append("}").append("}}]}");
makeRequest(
client(),
"PUT",
String.format(LOCALE, "/_search/pipeline/%s", pipelineId),
null,
toHttpEntity(
String.format(
LOCALE,
"{\"description\": \"Post processor pipeline\","
+ "\"phase_results_processors\": [{ "
+ "\"normalization-processor\": {"
+ "\"normalization\": {"
+ "\"technique\": \"%s\""
+ "},"
+ "\"combination\": {"
+ "\"technique\": \"%s\""
+ "}"
+ "}}]}",
normalizationMethod,
combinationMethod
)
),
toHttpEntity(String.format(LOCALE, stringBuilderForContentBody.toString(), normalizationMethod, combinationMethod)),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
}
Expand Down
Loading

0 comments on commit 65b6c8c

Please sign in to comment.