Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding weights param for combination technique #235

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,34 @@

package org.opensearch.neuralsearch.processor.combination;

import lombok.NoArgsConstructor;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

import org.opensearch.OpenSearchParseException;

/**
* Abstracts combination of scores based on arithmetic mean method
*/
@NoArgsConstructor
public class ArithmeticMeanScoreCombinationTechnique implements ScoreCombinationTechnique {

public static final String TECHNIQUE_NAME = "arithmetic_mean";
public static final String PARAM_NAME_WEIGHTS = "weights";
private static final Float ZERO_SCORE = 0.0f;
private final List<Double> weights;
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved

public ArithmeticMeanScoreCombinationTechnique(final Map<String, Object> params) {
validateParams(params);
if (Objects.isNull(params) || params.isEmpty()) {
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
weights = List.of();
return;
}
// get weights, we don't need to check for instance as it's done during validation
weights = (List<Double>) params.getOrDefault(PARAM_NAME_WEIGHTS, new ArrayList<>());
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* Arithmetic mean method for combining scores.
Expand All @@ -26,8 +44,13 @@
public float combine(final float[] scores) {
float combinedScore = 0.0f;
int count = 0;
for (float score : scores) {
for (int i = 0; i < scores.length; i++) {
float score = scores[i];
if (score >= 0.0) {
// apply weight for this sub-query if it's set for particular sub-query
if (i < weights.size()) {
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
score = (float) (score * weights.get(i));

Check warning on line 52 in src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java#L52

Added line #L52 was not covered by tests
}
combinedScore += score;
count++;
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
}
Expand All @@ -37,4 +60,26 @@
}
return combinedScore / count;
}

private void validateParams(final Map<String, Object> params) {
if (Objects.isNull(params) || params.isEmpty()) {
return;
}
// check if only supported params are passed
Set<String> supportedParams = Set.of(PARAM_NAME_WEIGHTS);
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
Optional<String> optionalNotSupportedParam = params.keySet()
.stream()
.filter(paramName -> !supportedParams.contains(paramName))
.findFirst();
if (optionalNotSupportedParam.isPresent()) {
throw new OpenSearchParseException("provided parameter for combination technique is not supported");

Check warning on line 75 in src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java#L75

Added line #L75 was not covered by tests
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
}

// check param types
if (params.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) {
if (!(params.get(PARAM_NAME_WEIGHTS) instanceof List)) {
throw new OpenSearchParseException("parameter {} must be a collection of numbers", PARAM_NAME_WEIGHTS);

Check warning on line 81 in src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java#L81

Added line #L81 was not covered by tests
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import java.util.Map;
import java.util.Optional;
import java.util.function.Function;

import org.opensearch.OpenSearchParseException;

Expand All @@ -15,11 +16,11 @@
*/
public class ScoreCombinationFactory {

public static final ScoreCombinationTechnique DEFAULT_METHOD = new ArithmeticMeanScoreCombinationTechnique();
public static final ScoreCombinationTechnique DEFAULT_METHOD = new ArithmeticMeanScoreCombinationTechnique(Map.of());

private final Map<String, ScoreCombinationTechnique> scoreCombinationMethodsMap = Map.of(
private final Map<String, Function<Map<String, Object>, ScoreCombinationTechnique>> scoreCombinationMethodsMap = Map.of(
ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME,
new ArithmeticMeanScoreCombinationTechnique()
ArithmeticMeanScoreCombinationTechnique::new
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
);

/**
Expand All @@ -28,7 +29,17 @@ public class ScoreCombinationFactory {
* @return instance of ScoreCombinationTechnique for technique name
*/
public ScoreCombinationTechnique createCombination(final String technique) {
return createCombination(technique, Map.of());
}

/**
* Get score combination method by technique name
* @param technique name of technique
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
* @return instance of ScoreCombinationTechnique for technique name
*/
public ScoreCombinationTechnique createCombination(final String technique, final Map<String, Object> params) {
return Optional.ofNullable(scoreCombinationMethodsMap.get(technique))
.orElseThrow(() -> new OpenSearchParseException("provided combination technique is not supported"));
.orElseThrow(() -> new OpenSearchParseException("provided combination technique is not supported"))
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
.apply(params);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.neuralsearch.processor.factory;

import static org.opensearch.ingest.ConfigurationUtils.readOptionalMap;
import static org.opensearch.ingest.ConfigurationUtils.readOptionalStringProperty;

import java.util.Map;
import java.util.Objects;
Expand All @@ -29,6 +30,7 @@ public class NormalizationProcessorFactory implements Processor.Factory<SearchPh
public static final String NORMALIZATION_CLAUSE = "normalization";
public static final String COMBINATION_CLAUSE = "combination";
public static final String TECHNIQUE = "technique";
public static final String PARAMETERS = "parameters";

private final NormalizationProcessorWorkflow normalizationProcessorWorkflow;
private ScoreNormalizationFactory scoreNormalizationFactory;
Expand All @@ -53,9 +55,12 @@ public SearchPhaseResultsProcessor create(
Map<String, Object> combinationClause = readOptionalMap(NormalizationProcessor.TYPE, tag, config, COMBINATION_CLAUSE);

ScoreCombinationTechnique scoreCombinationTechnique = ScoreCombinationFactory.DEFAULT_METHOD;
Map<String, Object> combinationParams;
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
if (Objects.nonNull(combinationClause)) {
String combinationTechnique = (String) combinationClause.getOrDefault(TECHNIQUE, "");
scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique);
String combinationTechnique = readOptionalStringProperty(NormalizationProcessor.TYPE, tag, combinationClause, TECHNIQUE);
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
// check for optional combination params
combinationParams = readOptionalMap(NormalizationProcessor.TYPE, tag, combinationClause, PARAMETERS);
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique, combinationParams);
}

return new NormalizationProcessor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Predicate;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -57,8 +58,9 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase {

private static final int DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND = 1000;
private static final String DEFAULT_USER_AGENT = "Kibana";
private static final String NORMALIZATION_METHOD = "min_max";
private static final String COMBINATION_METHOD = "arithmetic_mean";
protected static final String NORMALIZATION_METHOD = "min_max";
protected static final String COMBINATION_METHOD = "arithmetic_mean";
protected static final String PARAM_NAME_WEIGHTS = "weights";

protected final ClassLoader classLoader = this.getClass().getClassLoader();

Expand Down Expand Up @@ -566,30 +568,31 @@ protected void createSearchPipeline(
final String pipelineId,
final String normalizationMethod,
String combinationMethod,
final Map<String, Object> combinationParams
final Map<String, String> combinationParams
) {
StringBuilder stringBuilderForContentBody = new StringBuilder();
stringBuilderForContentBody.append("{\"description\": \"Post processor pipeline\",")
.append("\"phase_results_processors\": [{ ")
.append("\"normalization-processor\": {")
.append("\"normalization\": {")
.append("\"technique\": \"%s\"")
.append("},")
.append("\"combination\": {")
.append("\"technique\": \"%s\"");
if (Objects.nonNull(combinationParams) && !combinationParams.isEmpty()) {
stringBuilderForContentBody.append(", \"parameters\": {");
if (combinationParams.containsKey(PARAM_NAME_WEIGHTS)) {
stringBuilderForContentBody.append("\"weights\": ").append(combinationParams.get(PARAM_NAME_WEIGHTS));
}
stringBuilderForContentBody.append(" }");
}
stringBuilderForContentBody.append("}").append("}}]}");
makeRequest(
client(),
"PUT",
String.format(LOCALE, "/_search/pipeline/%s", pipelineId),
null,
toHttpEntity(
String.format(
LOCALE,
"{\"description\": \"Post processor pipeline\","
+ "\"phase_results_processors\": [{ "
+ "\"normalization-processor\": {"
+ "\"normalization\": {"
+ "\"technique\": \"%s\""
+ "},"
+ "\"combination\": {"
+ "\"technique\": \"%s\""
+ "}"
+ "}}]}",
normalizationMethod,
combinationMethod
)
),
toHttpEntity(String.format(LOCALE, stringBuilderForContentBody.toString(), normalizationMethod, combinationMethod)),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -250,6 +251,92 @@ public void testResultProcessor_whenMultipleShardsAndPartialMatches_thenSuccessf
assertQueryResults(searchResponseAsMap, 4, true);
}

@SneakyThrows
public void testCombinationParams_whenWeightsParamSet_thenSuccessful() {
initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME);
// check case when number of weights and sub-queries are same
createSearchPipeline(
SEARCH_PIPELINE,
NORMALIZATION_METHOD,
COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f }))
);

HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));
hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4));
hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT7));

Map<String, Object> searchResponseWithWeights1AsMap = search(
TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME,
hybridQueryBuilder,
null,
5,
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertWeightedScores(searchResponseWithWeights1AsMap, 0.6, 0.5, 0.001);

// delete existing pipeline and create a new one with another set of weights
deleteSearchPipeline(SEARCH_PIPELINE);
createSearchPipeline(
SEARCH_PIPELINE,
NORMALIZATION_METHOD,
COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 2.0f, 0.5f }))
);

Map<String, Object> searchResponseWithWeights2AsMap = search(
TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME,
hybridQueryBuilder,
null,
5,
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertWeightedScores(searchResponseWithWeights2AsMap, 2.0, 0.8, 0.001);

// check case when number of weights is less than number of sub-queries
// delete existing pipeline and create a new one with another set of weights
deleteSearchPipeline(SEARCH_PIPELINE);
createSearchPipeline(
SEARCH_PIPELINE,
NORMALIZATION_METHOD,
COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f }))
);

Map<String, Object> searchResponseWithWeights3AsMap = search(
TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME,
hybridQueryBuilder,
null,
5,
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertWeightedScores(searchResponseWithWeights3AsMap, 1.0, 0.8, 0.001);

// check case when number of weights is more than number of sub-queries
// delete existing pipeline and create a new one with another set of weights
deleteSearchPipeline(SEARCH_PIPELINE);
createSearchPipeline(
SEARCH_PIPELINE,
NORMALIZATION_METHOD,
COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f, 1.5f }))
);

Map<String, Object> searchResponseWithWeights4AsMap = search(
TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME,
hybridQueryBuilder,
null,
5,
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertWeightedScores(searchResponseWithWeights4AsMap, 0.6, 0.5, 0.001);
}

private void initializeIndexIfNotExist(String indexName) throws IOException {
if (TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME.equalsIgnoreCase(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)) {
prepareKnnIndex(
Expand Down Expand Up @@ -402,4 +489,31 @@ private void assertQueryResults(Map<String, Object> searchResponseAsMap, int tot
// verify that all ids are unique
assertEquals(Set.copyOf(ids).size(), ids.size());
}

private void assertWeightedScores(
Map<String, Object> searchResponseWithWeightsAsMap,
double expectedMaxScore,
double expectedMaxMinusOneScore,
double expectedMinScore
) {
assertNotNull(searchResponseWithWeightsAsMap);
Map<String, Object> totalWeights = getTotalHits(searchResponseWithWeightsAsMap);
assertNotNull(totalWeights.get("value"));
assertEquals(4, totalWeights.get("value"));
assertNotNull(totalWeights.get("relation"));
assertEquals(RELATION_EQUAL_TO, totalWeights.get("relation"));
assertTrue(getMaxScore(searchResponseWithWeightsAsMap).isPresent());
assertEquals(expectedMaxScore, getMaxScore(searchResponseWithWeightsAsMap).get(), 0.001f);

List<Double> scoresWeights = new ArrayList<>();
for (Map<String, Object> oneHit : getNestedHits(searchResponseWithWeightsAsMap)) {
scoresWeights.add((Double) oneHit.get("_score"));
}
// verify scores order
assertTrue(IntStream.range(0, scoresWeights.size() - 1).noneMatch(idx -> scoresWeights.get(idx) < scoresWeights.get(idx + 1)));
// verify the scores are normalized with inclusion of weights
assertEquals(expectedMaxScore, scoresWeights.get(0), 0.001);
assertEquals(expectedMaxMinusOneScore, scoresWeights.get(1), 0.001);
assertEquals(expectedMinScore, scoresWeights.get(scoresWeights.size() - 1), 0.001);
}
}
Loading
Loading