diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/AbstractScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/AbstractScoreCombinationTechnique.java new file mode 100644 index 000000000..16850f04f --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/AbstractScoreCombinationTechnique.java @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Base class for score normalization technique + */ +public abstract class AbstractScoreCombinationTechnique { + private static final String PARAM_NAME_WEIGHTS = "weights"; + + /** + * Each technique must provide collection of supported parameters + * @return set of supported parameter names + */ + abstract Set getSupportedParams(); + + /** + * Get collection of weights based on user provided config + * @param params map of named parameters and their values + * @return collection of weights + */ + protected List getWeights(final Map params) { + if (Objects.isNull(params) || params.isEmpty()) { + return List.of(); + } + // get weights, we don't need to check for instance as it's done during validation + return ((List) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream() + .map(Double::floatValue) + .collect(Collectors.toUnmodifiableList()); + } + + /** + * Validate config parameters for this technique + * @param params map of parameters in form of name-value + */ + protected void validateParams(final Map params) { + if (Objects.isNull(params) || params.isEmpty()) { + return; + } + // check if only supported params are passed + Optional optionalNotSupportedParam = params.keySet() + .stream() + .filter(paramName -> !getSupportedParams().contains(paramName)) + .findFirst(); + if (optionalNotSupportedParam.isPresent()) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "provided parameter for combination technique is not supported. supported parameters are [%s]", + getSupportedParams().stream().collect(Collectors.joining(",")) + ) + ); + } + + // check param types + if (params.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) { + if (!(params.get(PARAM_NAME_WEIGHTS) instanceof List)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "parameter [%s] must be a collection of numbers", PARAM_NAME_WEIGHTS) + ); + } + } + } + + /** + * Get weight for sub-query based on its index in the hybrid search query. Use user provided weight or 1.0 otherwise + * @param weights collection of weights for sub-queries + * @param indexOfSubQuery 0-based index of sub-query in the Hybrid Search query + * @return weight for sub-query, use one that is set in processor/pipeline definition or 1.0 as default + */ + protected float getWeightForSubQuery(final List weights, final int indexOfSubQuery) { + return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery) : 1.0f; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java index 4eb9564e6..c679608d6 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java @@ -6,17 +6,13 @@ package org.opensearch.neuralsearch.processor.combination; import java.util.List; -import java.util.Locale; import java.util.Map; -import java.util.Objects; -import java.util.Optional; import java.util.Set; -import java.util.stream.Collectors; /** * Abstracts combination of scores based on arithmetic mean method */ -public class ArithmeticMeanScoreCombinationTechnique implements ScoreCombinationTechnique { +public class ArithmeticMeanScoreCombinationTechnique extends AbstractScoreCombinationTechnique implements ScoreCombinationTechnique { public static final String TECHNIQUE_NAME = "arithmetic_mean"; public static final String PARAM_NAME_WEIGHTS = "weights"; @@ -29,16 +25,6 @@ public ArithmeticMeanScoreCombinationTechnique(final Map params) weights = getWeights(params); } - private List getWeights(final Map params) { - if (Objects.isNull(params) || params.isEmpty()) { - return List.of(); - } - // get weights, we don't need to check for instance as it's done during validation - return ((List) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream() - .map(Double::floatValue) - .collect(Collectors.toUnmodifiableList()); - } - /** * Arithmetic mean method for combining scores. * score = (weight1*score1 + weight2*score2 +...+ weightN*scoreN)/(weight1 + weight2 + ... + weightN) @@ -52,7 +38,7 @@ public float combine(final float[] scores) { for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) { float score = scores[indexOfSubQuery]; if (score >= 0.0) { - float weight = getWeightForSubQuery(indexOfSubQuery); + float weight = getWeightForSubQuery(this.weights, indexOfSubQuery); score = score * weight; combinedScore += score; weights += weight; @@ -64,41 +50,8 @@ public float combine(final float[] scores) { return combinedScore / weights; } - private void validateParams(final Map params) { - if (Objects.isNull(params) || params.isEmpty()) { - return; - } - // check if only supported params are passed - Optional optionalNotSupportedParam = params.keySet() - .stream() - .filter(paramName -> !SUPPORTED_PARAMS.contains(paramName)) - .findFirst(); - if (optionalNotSupportedParam.isPresent()) { - throw new IllegalArgumentException( - String.format( - Locale.ROOT, - "provided parameter for combination technique is not supported. supported parameters are [%s]", - SUPPORTED_PARAMS.stream().collect(Collectors.joining(",")) - ) - ); - } - - // check param types - if (params.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) { - if (!(params.get(PARAM_NAME_WEIGHTS) instanceof List)) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "parameter [%s] must be a collection of numbers", PARAM_NAME_WEIGHTS) - ); - } - } - } - - /** - * Get weight for sub-query based on its index in the hybrid search query. Use user provided weight or 1.0 otherwise - * @param indexOfSubQuery 0-based index of sub-query in the Hybrid Search query - * @return weight for sub-query, use one that is set in processor/pipeline definition or 1.0 as default - */ - private float getWeightForSubQuery(int indexOfSubQuery) { - return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery) : 1.0f; + @Override + Set getSupportedParams() { + return SUPPORTED_PARAMS; } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java index 3fff2db2b..69fd12d77 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java @@ -6,19 +6,15 @@ package org.opensearch.neuralsearch.processor.combination; import java.util.List; -import java.util.Locale; import java.util.Map; -import java.util.Objects; -import java.util.Optional; import java.util.Set; -import java.util.stream.Collectors; /** - * Abstracts combination of scores based on arithmetic mean method + * Abstracts combination of scores based on harmonic mean method */ -public class HarmonicMeanScoreCombinationTechnique implements ScoreCombinationTechnique { +public class HarmonicMeanScoreCombinationTechnique extends AbstractScoreCombinationTechnique implements ScoreCombinationTechnique { - public static final String TECHNIQUE_NAME = "arithmetic_mean"; + public static final String TECHNIQUE_NAME = "harmonic_mean"; public static final String PARAM_NAME_WEIGHTS = "weights"; private static final Set SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS); private static final Float ZERO_SCORE = 0.0f; @@ -29,76 +25,30 @@ public HarmonicMeanScoreCombinationTechnique(final Map params) { weights = getWeights(params); } - private List getWeights(final Map params) { - if (Objects.isNull(params) || params.isEmpty()) { - return List.of(); - } - // get weights, we don't need to check for instance as it's done during validation - return ((List) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream() - .map(Double::floatValue) - .collect(Collectors.toUnmodifiableList()); - } - /** - * Arithmetic mean method for combining scores. - * score = (weight1*score1 + weight2*score2 +...+ weightN*scoreN)/(weight1 + weight2 + ... + weightN) + * Weighted harmonic mean method for combining scores. + * score = sum(weight_1 + .... + weight_n)/sum(weight_1/score_1 + ... + weight_n/score_n) * * Zero (0.0) scores are excluded from number of scores N */ @Override public float combine(final float[] scores) { - float combinedScore = 0.0f; - float weights = 0; + float sumOfWeights = 0; + float sumOfHarmonics = 0; for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) { float score = scores[indexOfSubQuery]; - if (score >= 0.0) { - float weight = getWeightForSubQuery(indexOfSubQuery); - score = score * weight; - combinedScore += score; - weights += weight; - } - } - if (weights == 0.0f) { - return ZERO_SCORE; - } - return combinedScore / weights; - } - - private void validateParams(final Map params) { - if (Objects.isNull(params) || params.isEmpty()) { - return; - } - // check if only supported params are passed - Optional optionalNotSupportedParam = params.keySet() - .stream() - .filter(paramName -> !SUPPORTED_PARAMS.contains(paramName)) - .findFirst(); - if (optionalNotSupportedParam.isPresent()) { - throw new IllegalArgumentException( - String.format( - Locale.ROOT, - "provided parameter for combination technique is not supported. supported parameters are [%s]", - SUPPORTED_PARAMS.stream().collect(Collectors.joining(",")) - ) - ); - } - - // check param types - if (params.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) { - if (!(params.get(PARAM_NAME_WEIGHTS) instanceof List)) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "parameter [%s] must be a collection of numbers", PARAM_NAME_WEIGHTS) - ); + if (score <= 0) { + continue; } + float weightOfSubQuery = getWeightForSubQuery(weights, indexOfSubQuery); + sumOfWeights += weightOfSubQuery; + sumOfHarmonics += weightOfSubQuery / score; } + return sumOfHarmonics > 0 ? sumOfWeights / sumOfHarmonics : ZERO_SCORE; } - /** - * Get weight for sub-query based on its index in the hybrid search query. Use user provided weight or 1.0 otherwise - * @param indexOfSubQuery 0-based index of sub-query in the Hybrid Search query - * @return weight for sub-query, use one that is set in processor/pipeline definition or 1.0 as default - */ - private float getWeightForSubQuery(int indexOfSubQuery) { - return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery) : 1.0f; + @Override + Set getSupportedParams() { + return SUPPORTED_PARAMS; } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java index 2f4804eb1..1195b7004 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java @@ -18,7 +18,9 @@ public class ScoreCombinationFactory { private final Map, ScoreCombinationTechnique>> scoreCombinationMethodsMap = Map.of( ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME, - ArithmeticMeanScoreCombinationTechnique::new + ArithmeticMeanScoreCombinationTechnique::new, + HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME, + HarmonicMeanScoreCombinationTechnique::new ); /** diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index 6aa9b5a5a..add0205b0 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -58,8 +58,8 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { private static final int DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND = 1000; private static final String DEFAULT_USER_AGENT = "Kibana"; - protected static final String NORMALIZATION_METHOD = "min_max"; - protected static final String COMBINATION_METHOD = "arithmetic_mean"; + protected static final String DEFAULT_NORMALIZATION_METHOD = "min_max"; + protected static final String DEFAULT_COMBINATION_METHOD = "arithmetic_mean"; protected static final String PARAM_NAME_WEIGHTS = "weights"; protected final ClassLoader classLoader = this.getClass().getClassLoader(); @@ -556,7 +556,7 @@ public boolean isUpdateClusterSettings() { @SneakyThrows protected void createSearchPipelineWithResultsPostProcessor(final String pipelineId) { - createSearchPipeline(pipelineId, NORMALIZATION_METHOD, COMBINATION_METHOD, Map.of()); + createSearchPipeline(pipelineId, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of()); } @SneakyThrows diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java index 273dbd522..54271d042 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java @@ -27,7 +27,6 @@ import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.index.SpaceType; import org.opensearch.neuralsearch.common.BaseNeuralSearchIT; -import org.opensearch.neuralsearch.processor.normalization.L2ScoreNormalizationTechnique; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; @@ -59,6 +58,10 @@ public class ScoreNormalizationCombinationIT extends BaseNeuralSearchIT { private final static String RELATION_EQUAL_TO = "eq"; private final static String RELATION_GREATER_OR_EQUAL_TO = "gte"; + private static final String L2_NORMALIZATION_METHOD = "l2"; + private static final String HARMONIC_MEAN_COMBINATION_METHOD = "harmonic_mean"; + private static final String GEOMETRIC_MEAN_COMBINATION_METHOD = "geometric_mean"; + @Before public void setUp() throws Exception { super.setUp(); @@ -276,8 +279,8 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { // check case when number of weights and sub-queries are same createSearchPipeline( SEARCH_PIPELINE, - NORMALIZATION_METHOD, - COMBINATION_METHOD, + DEFAULT_NORMALIZATION_METHOD, + DEFAULT_COMBINATION_METHOD, Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f })) ); @@ -300,8 +303,8 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { deleteSearchPipeline(SEARCH_PIPELINE); createSearchPipeline( SEARCH_PIPELINE, - NORMALIZATION_METHOD, - COMBINATION_METHOD, + DEFAULT_NORMALIZATION_METHOD, + DEFAULT_COMBINATION_METHOD, Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 2.0f, 0.5f })) ); @@ -320,8 +323,8 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { deleteSearchPipeline(SEARCH_PIPELINE); createSearchPipeline( SEARCH_PIPELINE, - NORMALIZATION_METHOD, - COMBINATION_METHOD, + DEFAULT_NORMALIZATION_METHOD, + DEFAULT_COMBINATION_METHOD, Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f })) ); @@ -340,8 +343,8 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { deleteSearchPipeline(SEARCH_PIPELINE); createSearchPipeline( SEARCH_PIPELINE, - NORMALIZATION_METHOD, - COMBINATION_METHOD, + DEFAULT_NORMALIZATION_METHOD, + DEFAULT_COMBINATION_METHOD, Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f, 1.5f })) ); @@ -379,8 +382,8 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); createSearchPipeline( SEARCH_PIPELINE, - L2ScoreNormalizationTechnique.TECHNIQUE_NAME, - COMBINATION_METHOD, + L2_NORMALIZATION_METHOD, + DEFAULT_COMBINATION_METHOD, Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) ); @@ -399,29 +402,66 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { Map.of("search_pipeline", SEARCH_PIPELINE) ); int totalExpectedDocQty = 5; - assertNotNull(searchResponseAsMap); - Map total = getTotalHits(searchResponseAsMap); - assertNotNull(total.get("value")); - assertEquals(totalExpectedDocQty, total.get("value")); - assertNotNull(total.get("relation")); - assertEquals(RELATION_EQUAL_TO, total.get("relation")); - assertTrue(getMaxScore(searchResponseAsMap).isPresent()); - assertTrue(Range.between(.6f, 1.0f).contains(getMaxScore(searchResponseAsMap).get())); + float[] minMaxExpectedScoresRange = { 0.6f, 1.0f }; + assertHybridSearchResults(searchResponseAsMap, totalExpectedDocQty, minMaxExpectedScoresRange); + } - List> hitsNestedList = getNestedHits(searchResponseAsMap); - List ids = new ArrayList<>(); - List scores = new ArrayList<>(); - for (Map oneHit : hitsNestedList) { - ids.add((String) oneHit.get("_id")); - scores.add((Double) oneHit.get("_score")); - } - // verify scores order - assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); - // verify the scores are normalized. for l2 scores max score will not be 1.0 so we're checking on a range - assertTrue(Range.between(.6f, 1.0f).contains((float) scores.stream().map(Double::floatValue).max(Double::compare).get())); + @SneakyThrows + public void testMinMaxNormHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + HARMONIC_MEAN_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); - // verify that all ids are unique - assertEquals(Set.copyOf(ids).size(), ids.size()); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null); + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(neuralQueryBuilder); + hybridQueryBuilder.add(termQueryBuilder); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + int totalExpectedDocQty = 5; + float[] minMaxExpectedScoresRange = { 0.6f, 1.0f }; + assertHybridSearchResults(searchResponseAsMap, totalExpectedDocQty, minMaxExpectedScoresRange); + } + + @SneakyThrows + public void testL2NormHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + L2_NORMALIZATION_METHOD, + HARMONIC_MEAN_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null); + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(neuralQueryBuilder); + hybridQueryBuilder.add(termQueryBuilder); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + int totalExpectedDocQty = 5; + float[] minMaxExpectedScoresRange = { 0.5f, 1.0f }; + assertHybridSearchResults(searchResponseAsMap, totalExpectedDocQty, minMaxExpectedScoresRange); } private void initializeIndexIfNotExist(String indexName) throws IOException { @@ -603,4 +643,33 @@ private void assertWeightedScores( assertEquals(expectedMaxMinusOneScore, scoresWeights.get(1), 0.001); assertEquals(expectedMinScore, scoresWeights.get(scoresWeights.size() - 1), 0.001); } + + private void assertHybridSearchResults(Map searchResponseAsMap, int totalExpectedDocQty, float[] minMaxScoreRange) { + assertNotNull(searchResponseAsMap); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(totalExpectedDocQty, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + assertTrue(getMaxScore(searchResponseAsMap).isPresent()); + assertTrue(Range.between(minMaxScoreRange[0], minMaxScoreRange[1]).contains(getMaxScore(searchResponseAsMap).get())); + + List> hitsNestedList = getNestedHits(searchResponseAsMap); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + // verify scores order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify the scores are normalized. for l2 scores max score will not be 1.0 so we're checking on a range + assertTrue( + Range.between(minMaxScoreRange[0], minMaxScoreRange[1]) + .contains(scores.stream().map(Double::floatValue).max(Double::compare).get()) + ); + + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java index d9f63d291..79f036fd8 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java @@ -10,41 +10,46 @@ import java.util.List; import java.util.Map; -import org.opensearch.test.OpenSearchTestCase; +public class ArithmeticMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { -public class ArithmeticMeanScoreCombinationTechniqueTests extends OpenSearchTestCase { - - private static final float DELTA_FOR_ASSERTION = 0.0001f; + public ArithmeticMeanScoreCombinationTechniqueTests() { + this.expectedScoreFunction = this::arithmeticMean; + } public void testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores() { - ArithmeticMeanScoreCombinationTechnique combinationTechnique = new ArithmeticMeanScoreCombinationTechnique(Map.of()); - float[] scores = { 1.0f, 0.5f, 0.3f }; - float actualScore = combinationTechnique.combine(scores); - assertEquals(0.6f, actualScore, DELTA_FOR_ASSERTION); + ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of()); + testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores(technique); } public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { - ArithmeticMeanScoreCombinationTechnique combinationTechnique = new ArithmeticMeanScoreCombinationTechnique(Map.of()); - float[] scores = { 1.0f, -1.0f, 0.6f }; - float actualScore = combinationTechnique.combine(scores); - assertEquals(0.8f, actualScore, DELTA_FOR_ASSERTION); + ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of()); + testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(technique); } public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() { - ArithmeticMeanScoreCombinationTechnique combinationTechnique = new ArithmeticMeanScoreCombinationTechnique( - Map.of(PARAM_NAME_WEIGHTS, List.of(0.9, 0.2, 0.7)) - ); - float[] scores = { 1.0f, 0.5f, 0.3f }; - float actualScore = combinationTechnique.combine(scores); - assertEquals(0.6722f, actualScore, DELTA_FOR_ASSERTION); + List weights = List.of(0.9, 0.2, 0.7); + ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of(PARAM_NAME_WEIGHTS, weights)); + testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); } public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { - ArithmeticMeanScoreCombinationTechnique combinationTechnique = new ArithmeticMeanScoreCombinationTechnique( - Map.of(PARAM_NAME_WEIGHTS, List.of(0.9, 0.15, 0.7)) - ); - float[] scores = { 1.0f, -1.0f, 0.6f }; - float actualScore = combinationTechnique.combine(scores); - assertEquals(0.825f, actualScore, DELTA_FOR_ASSERTION); + List weights = List.of(0.9, 0.2, 0.7); + ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of(PARAM_NAME_WEIGHTS, weights)); + testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); + } + + private float arithmeticMean(List scores, List weights) { + assertEquals(scores.size(), weights.size()); + float sumOfWeightedScores = 0; + float sumOfWeights = 0; + for (int i = 0; i < scores.size(); i++) { + float score = scores.get(i); + float weight = weights.get(i).floatValue(); + if (score >= 0) { + sumOfWeightedScores += score * weight; + sumOfWeights += weight; + } + } + return sumOfWeights == 0 ? 0f : sumOfWeightedScores / sumOfWeights; } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/BaseScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/BaseScoreCombinationTechniqueTests.java new file mode 100644 index 000000000..cf9d1080f --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/BaseScoreCombinationTechniqueTests.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import java.util.Arrays; +import java.util.List; +import java.util.function.BiFunction; + +import lombok.NoArgsConstructor; + +import org.apache.commons.lang.ArrayUtils; +import org.opensearch.test.OpenSearchTestCase; + +@NoArgsConstructor +public class BaseScoreCombinationTechniqueTests extends OpenSearchTestCase { + + protected BiFunction, List, Float> expectedScoreFunction; + + private static final float DELTA_FOR_ASSERTION = 0.0001f; + + public void testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores(final ScoreCombinationTechnique technique) { + float[] scores = { 1.0f, 0.5f, 0.3f }; + float actualScore = technique.combine(scores); + float expectedScore = expectedScoreFunction.apply(Arrays.asList(ArrayUtils.toObject(scores)), List.of(1.0, 1.0, 1.0)); + assertEquals(expectedScore, actualScore, DELTA_FOR_ASSERTION); + } + + public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(final ScoreCombinationTechnique technique) { + float[] scores = { 1.0f, -1.0f, 0.6f }; + float actualScore = technique.combine(scores); + float expectedScore = expectedScoreFunction.apply(Arrays.asList(ArrayUtils.toObject(scores)), List.of(1.0, 1.0, 1.0)); + assertEquals(expectedScore, actualScore, DELTA_FOR_ASSERTION); + } + + public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores( + final ScoreCombinationTechnique technique, + List weights + ) { + float[] scores = { 1.0f, 0.5f, 0.3f }; + float actualScore = technique.combine(scores); + float expectedScore = expectedScoreFunction.apply(Arrays.asList(ArrayUtils.toObject(scores)), weights); + assertEquals(expectedScore, actualScore, DELTA_FOR_ASSERTION); + } + + public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores( + final ScoreCombinationTechnique technique, + List weights + ) { + float[] scores = { 1.0f, -1.0f, 0.6f }; + float actualScore = technique.combine(scores); + float expectedScore = expectedScoreFunction.apply(Arrays.asList(ArrayUtils.toObject(scores)), weights); + assertEquals(expectedScore, actualScore, DELTA_FOR_ASSERTION); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java new file mode 100644 index 000000000..425c46443 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import static org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique.PARAM_NAME_WEIGHTS; + +import java.util.List; +import java.util.Map; + +public class HarmonicMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { + + public HarmonicMeanScoreCombinationTechniqueTests() { + this.expectedScoreFunction = (scores, weights) -> harmonicMean(scores, weights); + } + + public void testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores() { + ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of()); + testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores(technique); + } + + public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { + ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of()); + testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(technique); + } + + public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() { + List weights = List.of(0.9, 0.2, 0.7); + ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of(PARAM_NAME_WEIGHTS, weights)); + testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); + } + + public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { + List weights = List.of(0.9, 0.2, 0.7); + ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of(PARAM_NAME_WEIGHTS, weights)); + testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); + } + + private float harmonicMean(List scores, List weights) { + assertEquals(scores.size(), weights.size()); + float w = 0, h = 0; + for (int i = 0; i < scores.size(); i++) { + float score = scores.get(i), weight = weights.get(i).floatValue(); + if (score > 0) { + w += weight; + h += weight / score; + } + } + return h == 0 ? 0f : w / h; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java new file mode 100644 index 000000000..9f164b3d3 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import static org.hamcrest.Matchers.containsString; + +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; + +public class ScoreCombinationFactoryTests extends OpenSearchQueryTestCase { + + public void testArithmeticWeightedMean_whenCreatingByName_thenReturnCorrectInstance() { + ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); + ScoreCombinationTechnique scoreCombinationTechnique = scoreCombinationFactory.createCombination("arithmetic_mean"); + + assertNotNull(scoreCombinationTechnique); + assertTrue(scoreCombinationTechnique instanceof ArithmeticMeanScoreCombinationTechnique); + } + + public void testHarmonicWeightedMean_whenCreatingByName_thenReturnCorrectInstance() { + ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); + ScoreCombinationTechnique scoreCombinationTechnique = scoreCombinationFactory.createCombination("harmonic_mean"); + + assertNotNull(scoreCombinationTechnique); + assertTrue(scoreCombinationTechnique instanceof HarmonicMeanScoreCombinationTechnique); + } + + public void testUnsupportedTechnique_whenPassingInvalidName_thenFail() { + ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); + IllegalArgumentException illegalArgumentException = expectThrows( + IllegalArgumentException.class, + () -> scoreCombinationFactory.createCombination("randomname") + ); + org.hamcrest.MatcherAssert.assertThat( + illegalArgumentException.getMessage(), + containsString("provided combination technique is not supported") + ); + } +}