Skip to content

Commit

Permalink
Adding weights param for combination technique
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Jul 27, 2023
1 parent 2224f1f commit 96d0b77
Show file tree
Hide file tree
Showing 6 changed files with 382 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,34 @@

package org.opensearch.neuralsearch.processor.combination;

import lombok.NoArgsConstructor;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

import org.opensearch.OpenSearchParseException;

/**
* Abstracts combination of scores based on arithmetic mean method
*/
@NoArgsConstructor
public class ArithmeticMeanScoreCombinationTechnique implements ScoreCombinationTechnique {

public static final String TECHNIQUE_NAME = "arithmetic_mean";
public static final String PARAM_NAME_WEIGHTS = "weights";
private static final Float ZERO_SCORE = 0.0f;
private final List<Double> weights;

public ArithmeticMeanScoreCombinationTechnique(final Map<String, Object> params) {
validateParams(params);
if (Objects.isNull(params) || params.isEmpty()) {
weights = List.of();
return;
}
// get weights, we don't need to check for instance as it's done during validation
weights = (List<Double>) params.getOrDefault(PARAM_NAME_WEIGHTS, new ArrayList<>());
}

/**
* Arithmetic mean method for combining scores.
Expand All @@ -26,8 +44,13 @@ public class ArithmeticMeanScoreCombinationTechnique implements ScoreCombination
public float combine(final float[] scores) {
float combinedScore = 0.0f;
int count = 0;
for (float score : scores) {
for (int i = 0; i < scores.length; i++) {
float score = scores[i];
if (score >= 0.0) {
// apply weight for this sub-query if it's set for particular sub-query
if (i < weights.size()) {
score = (float) (score * weights.get(i));

Check warning on line 52 in src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java#L52

Added line #L52 was not covered by tests
}
combinedScore += score;
count++;
}
Expand All @@ -37,4 +60,26 @@ public float combine(final float[] scores) {
}
return combinedScore / count;
}

private void validateParams(final Map<String, Object> params) {
if (Objects.isNull(params) || params.isEmpty()) {
return;
}
// check if only supported params are passed
Set<String> supportedParams = Set.of(PARAM_NAME_WEIGHTS);
Optional<String> optionalNotSupportedParam = params.keySet()
.stream()
.filter(paramName -> !supportedParams.contains(paramName))
.findFirst();
if (optionalNotSupportedParam.isPresent()) {
throw new OpenSearchParseException("provided parameter for combination technique is not supported");

Check warning on line 75 in src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java#L75

Added line #L75 was not covered by tests
}

// check param types
if (params.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) {
if (!(params.get(PARAM_NAME_WEIGHTS) instanceof List)) {
throw new OpenSearchParseException("parameter {} must be a collection of numbers", PARAM_NAME_WEIGHTS);

Check warning on line 81 in src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java#L81

Added line #L81 was not covered by tests
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import java.util.Map;
import java.util.Optional;
import java.util.function.Function;

import org.opensearch.OpenSearchParseException;

Expand All @@ -15,11 +16,11 @@
*/
public class ScoreCombinationFactory {

public static final ScoreCombinationTechnique DEFAULT_METHOD = new ArithmeticMeanScoreCombinationTechnique();
public static final ScoreCombinationTechnique DEFAULT_METHOD = new ArithmeticMeanScoreCombinationTechnique(Map.of());

private final Map<String, ScoreCombinationTechnique> scoreCombinationMethodsMap = Map.of(
private final Map<String, Function<Map<String, Object>, ScoreCombinationTechnique>> scoreCombinationMethodsMap = Map.of(
ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME,
new ArithmeticMeanScoreCombinationTechnique()
ArithmeticMeanScoreCombinationTechnique::new
);

/**
Expand All @@ -28,7 +29,17 @@ public class ScoreCombinationFactory {
* @return instance of ScoreCombinationTechnique for technique name
*/
public ScoreCombinationTechnique createCombination(final String technique) {
return createCombination(technique, Map.of());
}

/**
* Get score combination method by technique name
* @param technique name of technique
* @return instance of ScoreCombinationTechnique for technique name
*/
public ScoreCombinationTechnique createCombination(final String technique, final Map<String, Object> params) {
return Optional.ofNullable(scoreCombinationMethodsMap.get(technique))
.orElseThrow(() -> new OpenSearchParseException("provided combination technique is not supported"));
.orElseThrow(() -> new OpenSearchParseException("provided combination technique is not supported"))
.apply(params);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.neuralsearch.processor.factory;

import static org.opensearch.ingest.ConfigurationUtils.readOptionalMap;
import static org.opensearch.ingest.ConfigurationUtils.readOptionalStringProperty;

import java.util.Map;
import java.util.Objects;
Expand All @@ -29,6 +30,7 @@ public class NormalizationProcessorFactory implements Processor.Factory<SearchPh
public static final String NORMALIZATION_CLAUSE = "normalization";
public static final String COMBINATION_CLAUSE = "combination";
public static final String TECHNIQUE = "technique";
public static final String PARAMETERS = "parameters";

private final NormalizationProcessorWorkflow normalizationProcessorWorkflow;
private ScoreNormalizationFactory scoreNormalizationFactory;
Expand All @@ -53,9 +55,12 @@ public SearchPhaseResultsProcessor create(
Map<String, Object> combinationClause = readOptionalMap(NormalizationProcessor.TYPE, tag, config, COMBINATION_CLAUSE);

ScoreCombinationTechnique scoreCombinationTechnique = ScoreCombinationFactory.DEFAULT_METHOD;
Map<String, Object> combinationParams;
if (Objects.nonNull(combinationClause)) {
String combinationTechnique = (String) combinationClause.getOrDefault(TECHNIQUE, "");
scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique);
String combinationTechnique = readOptionalStringProperty(NormalizationProcessor.TYPE, tag, combinationClause, TECHNIQUE);
// check for optional combination params
combinationParams = readOptionalMap(NormalizationProcessor.TYPE, tag, combinationClause, PARAMETERS);
scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique, combinationParams);
}

return new NormalizationProcessor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Predicate;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -57,8 +58,9 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase {

private static final int DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND = 1000;
private static final String DEFAULT_USER_AGENT = "Kibana";
private static final String NORMALIZATION_METHOD = "min_max";
private static final String COMBINATION_METHOD = "arithmetic_mean";
protected static final String NORMALIZATION_METHOD = "min_max";
protected static final String COMBINATION_METHOD = "arithmetic_mean";
protected static final String PARAM_NAME_WEIGHTS = "weights";

protected final ClassLoader classLoader = this.getClass().getClassLoader();

Expand Down Expand Up @@ -566,30 +568,31 @@ protected void createSearchPipeline(
final String pipelineId,
final String normalizationMethod,
String combinationMethod,
final Map<String, Object> combinationParams
final Map<String, String> combinationParams
) {
StringBuilder stringBuilderForContentBody = new StringBuilder();
stringBuilderForContentBody.append("{\"description\": \"Post processor pipeline\",")
.append("\"phase_results_processors\": [{ ")
.append("\"normalization-processor\": {")
.append("\"normalization\": {")
.append("\"technique\": \"%s\"")
.append("},")
.append("\"combination\": {")
.append("\"technique\": \"%s\"");
if (Objects.nonNull(combinationParams) && !combinationParams.isEmpty()) {
stringBuilderForContentBody.append(", \"parameters\": {");
if (combinationParams.containsKey(PARAM_NAME_WEIGHTS)) {
stringBuilderForContentBody.append("\"weights\": ").append(combinationParams.get(PARAM_NAME_WEIGHTS));
}
stringBuilderForContentBody.append(" }");
}
stringBuilderForContentBody.append("}").append("}}]}");
makeRequest(
client(),
"PUT",
String.format(LOCALE, "/_search/pipeline/%s", pipelineId),
null,
toHttpEntity(
String.format(
LOCALE,
"{\"description\": \"Post processor pipeline\","
+ "\"phase_results_processors\": [{ "
+ "\"normalization-processor\": {"
+ "\"normalization\": {"
+ "\"technique\": \"%s\""
+ "},"
+ "\"combination\": {"
+ "\"technique\": \"%s\""
+ "}"
+ "}}]}",
normalizationMethod,
combinationMethod
)
),
toHttpEntity(String.format(LOCALE, stringBuilderForContentBody.toString(), normalizationMethod, combinationMethod)),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -250,6 +251,92 @@ public void testResultProcessor_whenMultipleShardsAndPartialMatches_thenSuccessf
assertQueryResults(searchResponseAsMap, 4, true);
}

@SneakyThrows
public void testCombinationParams_whenWeightsParamSet_thenSuccessful() {
initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME);
// check case when number of weights and sub-queries are same
createSearchPipeline(
SEARCH_PIPELINE,
NORMALIZATION_METHOD,
COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f }))
);

HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));
hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4));
hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT7));

Map<String, Object> searchResponseWithWeights1AsMap = search(
TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME,
hybridQueryBuilder,
null,
5,
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertWeightedScores(searchResponseWithWeights1AsMap, 0.6, 0.5, 0.001);

// delete existing pipeline and create a new one with another set of weights
deleteSearchPipeline(SEARCH_PIPELINE);
createSearchPipeline(
SEARCH_PIPELINE,
NORMALIZATION_METHOD,
COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 2.0f, 0.5f }))
);

Map<String, Object> searchResponseWithWeights2AsMap = search(
TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME,
hybridQueryBuilder,
null,
5,
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertWeightedScores(searchResponseWithWeights2AsMap, 2.0, 0.8, 0.001);

// check case when number of weights is less than number of sub-queries
// delete existing pipeline and create a new one with another set of weights
deleteSearchPipeline(SEARCH_PIPELINE);
createSearchPipeline(
SEARCH_PIPELINE,
NORMALIZATION_METHOD,
COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f }))
);

Map<String, Object> searchResponseWithWeights3AsMap = search(
TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME,
hybridQueryBuilder,
null,
5,
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertWeightedScores(searchResponseWithWeights3AsMap, 1.0, 0.8, 0.001);

// check case when number of weights is more than number of sub-queries
// delete existing pipeline and create a new one with another set of weights
deleteSearchPipeline(SEARCH_PIPELINE);
createSearchPipeline(
SEARCH_PIPELINE,
NORMALIZATION_METHOD,
COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f, 1.5f }))
);

Map<String, Object> searchResponseWithWeights4AsMap = search(
TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME,
hybridQueryBuilder,
null,
5,
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertWeightedScores(searchResponseWithWeights4AsMap, 0.6, 0.5, 0.001);
}

private void initializeIndexIfNotExist(String indexName) throws IOException {
if (TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME.equalsIgnoreCase(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)) {
prepareKnnIndex(
Expand Down Expand Up @@ -402,4 +489,31 @@ private void assertQueryResults(Map<String, Object> searchResponseAsMap, int tot
// verify that all ids are unique
assertEquals(Set.copyOf(ids).size(), ids.size());
}

private void assertWeightedScores(
Map<String, Object> searchResponseWithWeightsAsMap,
double expectedMaxScore,
double expectedMaxMinusOneScore,
double expectedMinScore
) {
assertNotNull(searchResponseWithWeightsAsMap);
Map<String, Object> totalWeights = getTotalHits(searchResponseWithWeightsAsMap);
assertNotNull(totalWeights.get("value"));
assertEquals(4, totalWeights.get("value"));
assertNotNull(totalWeights.get("relation"));
assertEquals(RELATION_EQUAL_TO, totalWeights.get("relation"));
assertTrue(getMaxScore(searchResponseWithWeightsAsMap).isPresent());
assertEquals(expectedMaxScore, getMaxScore(searchResponseWithWeightsAsMap).get(), 0.001f);

List<Double> scoresWeights = new ArrayList<>();
for (Map<String, Object> oneHit : getNestedHits(searchResponseWithWeightsAsMap)) {
scoresWeights.add((Double) oneHit.get("_score"));
}
// verify scores order
assertTrue(IntStream.range(0, scoresWeights.size() - 1).noneMatch(idx -> scoresWeights.get(idx) < scoresWeights.get(idx + 1)));
// verify the scores are normalized with inclusion of weights
assertEquals(expectedMaxScore, scoresWeights.get(0), 0.001);
assertEquals(expectedMaxMinusOneScore, scoresWeights.get(1), 0.001);
assertEquals(expectedMinScore, scoresWeights.get(scoresWeights.size() - 1), 0.001);
}
}
Loading

0 comments on commit 96d0b77

Please sign in to comment.