diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java new file mode 100644 index 000000000..c17d641cc --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java @@ -0,0 +1,54 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Abstracts combination of scores based on geometrical mean method + */ +public class GeometricMeanScoreCombinationTechnique implements ScoreCombinationTechnique { + + public static final String TECHNIQUE_NAME = "geometric_mean"; + public static final String PARAM_NAME_WEIGHTS = "weights"; + private static final Set SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS); + private static final Float ZERO_SCORE = 0.0f; + private final List weights; + private final ScoreCombinationUtil scoreCombinationUtil; + + public GeometricMeanScoreCombinationTechnique(final Map params, final ScoreCombinationUtil combinationUtil) { + scoreCombinationUtil = combinationUtil; + scoreCombinationUtil.validateParams(params, SUPPORTED_PARAMS); + weights = scoreCombinationUtil.getWeights(params); + } + + /** + * Weighted geometric mean method for combining scores. + * + * We use formula below to calculate mean. It's based on fact that logarithm of geometric mean is the + * weighted arithmetic mean of the logarithms of individual scores. + * + * geometric_mean = exp(sum(weight_1*ln(score_1) + .... + weight_n*ln(score_n))/sum(weight_1 + ... + weight_n)) + */ + @Override + public float combine(final float[] scores) { + float weightedLnSum = 0; + float sumOfWeights = 0; + for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) { + float score = scores[indexOfSubQuery]; + if (score <= 0) { + // scores 0.0 need to be skipped, ln() of 0 is not defined + continue; + } + float weight = scoreCombinationUtil.getWeightForSubQuery(weights, indexOfSubQuery); + sumOfWeights += weight; + weightedLnSum += weight * Math.log(score); + } + return sumOfWeights == 0 ? ZERO_SCORE : (float) Math.exp(weightedLnSum / sumOfWeights); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java index d034ede16..f05d24823 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java @@ -24,7 +24,9 @@ public class ScoreCombinationFactory { ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME, params -> new ArithmeticMeanScoreCombinationTechnique(params, scoreCombinationUtil), HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME, - params -> new HarmonicMeanScoreCombinationTechnique(params, scoreCombinationUtil) + params -> new HarmonicMeanScoreCombinationTechnique(params, scoreCombinationUtil), + GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME, + params -> new GeometricMeanScoreCombinationTechnique(params, scoreCombinationUtil) ); /** diff --git a/src/test/java/org/opensearch/neuralsearch/TestUtils.java b/src/test/java/org/opensearch/neuralsearch/TestUtils.java index 97a289bcd..ff221bf20 100644 --- a/src/test/java/org/opensearch/neuralsearch/TestUtils.java +++ b/src/test/java/org/opensearch/neuralsearch/TestUtils.java @@ -7,12 +7,18 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; import static org.opensearch.test.OpenSearchTestCase.randomFloat; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.IntStream; +import org.apache.commons.lang3.Range; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.XContentBuilder; @@ -20,6 +26,8 @@ public class TestUtils { + private final static String RELATION_EQUAL_TO = "eq"; + /** * Convert an xContentBuilder to a map * @param xContentBuilder to produce map from @@ -58,7 +66,7 @@ public static float[] createRandomVector(int dimension) { } /** - * Assert results of hyrdir query after score normalization and combination + * Assert results of hyrdid query after score normalization and combination * @param querySearchResults collection of query search results after they processed by normalization processor */ public static void assertQueryResultScores(List querySearchResults) { @@ -94,4 +102,92 @@ public static void assertQueryResultScores(List querySearchRe .orElse(Float.MAX_VALUE); assertEquals(0.001f, minScoreScoreFromScoreDocs, 0.0f); } + + /** + * Assert results of hybrid query after score normalization and combination + * @param searchResponseWithWeightsAsMap collection of query search results after they processed by normalization processor + * @param expectedMaxScore expected maximum score + * @param expectedMaxMinusOneScore second highest expected score + * @param expectedMinScore expected minimal score + */ + public static void assertWeightedScores( + Map searchResponseWithWeightsAsMap, + double expectedMaxScore, + double expectedMaxMinusOneScore, + double expectedMinScore + ) { + assertNotNull(searchResponseWithWeightsAsMap); + Map totalWeights = getTotalHits(searchResponseWithWeightsAsMap); + assertNotNull(totalWeights.get("value")); + assertEquals(4, totalWeights.get("value")); + assertNotNull(totalWeights.get("relation")); + assertEquals(RELATION_EQUAL_TO, totalWeights.get("relation")); + assertTrue(getMaxScore(searchResponseWithWeightsAsMap).isPresent()); + assertEquals(expectedMaxScore, getMaxScore(searchResponseWithWeightsAsMap).get(), 0.001f); + + List scoresWeights = new ArrayList<>(); + for (Map oneHit : getNestedHits(searchResponseWithWeightsAsMap)) { + scoresWeights.add((Double) oneHit.get("_score")); + } + // verify scores order + assertTrue(IntStream.range(0, scoresWeights.size() - 1).noneMatch(idx -> scoresWeights.get(idx) < scoresWeights.get(idx + 1))); + // verify the scores are normalized with inclusion of weights + assertEquals(expectedMaxScore, scoresWeights.get(0), 0.001); + assertEquals(expectedMaxMinusOneScore, scoresWeights.get(1), 0.001); + assertEquals(expectedMinScore, scoresWeights.get(scoresWeights.size() - 1), 0.001); + } + + /** + * Assert results of hybrid query after score normalization and combination + * @param searchResponseAsMap collection of query search results after they processed by normalization processor + * @param totalExpectedDocQty expected total document quantity + * @param minMaxScoreRange range of scores from min to max inclusive + */ + public static void assertHybridSearchResults( + Map searchResponseAsMap, + int totalExpectedDocQty, + float[] minMaxScoreRange + ) { + assertNotNull(searchResponseAsMap); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(totalExpectedDocQty, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + assertTrue(getMaxScore(searchResponseAsMap).isPresent()); + assertTrue(Range.between(minMaxScoreRange[0], minMaxScoreRange[1]).contains(getMaxScore(searchResponseAsMap).get())); + + List> hitsNestedList = getNestedHits(searchResponseAsMap); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + // verify scores order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify the scores are normalized. for l2 scores max score will not be 1.0 so we're checking on a range + assertTrue( + Range.between(minMaxScoreRange[0], minMaxScoreRange[1]) + .contains(scores.stream().map(Double::floatValue).max(Double::compare).get()) + ); + + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + } + + private static List> getNestedHits(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return (List>) hitsMap.get("hits"); + } + + private static Map getTotalHits(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return (Map) hitsMap.get("total"); + } + + private static Optional getMaxScore(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return hitsMap.get("max_score") == null ? Optional.empty() : Optional.of(((Double) hitsMap.get("max_score")).floatValue()); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java similarity index 59% rename from src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java rename to src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java index 54271d042..5938a04a6 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java @@ -9,7 +9,6 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; @@ -32,12 +31,11 @@ import com.google.common.primitives.Floats; -public class ScoreNormalizationCombinationIT extends BaseNeuralSearchIT { +public class NormalizationProcessorIT extends BaseNeuralSearchIT { private static final String TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME = "test-neural-multi-doc-one-shard-index"; private static final String TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME = "test-neural-multi-doc-three-shards-index"; private static final String TEST_QUERY_TEXT3 = "hello"; private static final String TEST_QUERY_TEXT4 = "place"; - private static final String TEST_QUERY_TEXT5 = "welcome"; private static final String TEST_QUERY_TEXT6 = "notexistingword"; private static final String TEST_QUERY_TEXT7 = "notexistingwordtwo"; private static final String TEST_DOC_TEXT1 = "Hello world"; @@ -56,11 +54,6 @@ public class ScoreNormalizationCombinationIT extends BaseNeuralSearchIT { private final float[] testVector3 = createRandomVector(TEST_DIMENSION); private final float[] testVector4 = createRandomVector(TEST_DIMENSION); private final static String RELATION_EQUAL_TO = "eq"; - private final static String RELATION_GREATER_OR_EQUAL_TO = "gte"; - - private static final String L2_NORMALIZATION_METHOD = "l2"; - private static final String HARMONIC_MEAN_COMBINATION_METHOD = "harmonic_mean"; - private static final String GEOMETRIC_MEAN_COMBINATION_METHOD = "geometric_mean"; @Before public void setUp() throws Exception { @@ -250,220 +243,6 @@ public void testResultProcessor_whenMultipleShardsAndPartialMatches_thenSuccessf assertQueryResults(searchResponseAsMap, 4, true); } - /** - * Using search pipelines with result processor configs like below: - * { - * "description": "Post processor for hybrid search", - * "phase_results_processors": [ - * { - * "normalization-processor": { - * "normalization": { - * "technique": "min-max" - * }, - * "combination": { - * "technique": "arithmetic_mean", - * "parameters": { - * "weights": [ - * 0.4, 0.7 - * ] - * } - * } - * } - * } - * ] - * } - */ - @SneakyThrows - public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { - initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME); - // check case when number of weights and sub-queries are same - createSearchPipeline( - SEARCH_PIPELINE, - DEFAULT_NORMALIZATION_METHOD, - DEFAULT_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f })) - ); - - HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); - hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); - hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4)); - hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT7)); - - Map searchResponseWithWeights1AsMap = search( - TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, - hybridQueryBuilder, - null, - 5, - Map.of("search_pipeline", SEARCH_PIPELINE) - ); - - assertWeightedScores(searchResponseWithWeights1AsMap, 1.0, 1.0, 0.001); - - // delete existing pipeline and create a new one with another set of weights - deleteSearchPipeline(SEARCH_PIPELINE); - createSearchPipeline( - SEARCH_PIPELINE, - DEFAULT_NORMALIZATION_METHOD, - DEFAULT_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 2.0f, 0.5f })) - ); - - Map searchResponseWithWeights2AsMap = search( - TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, - hybridQueryBuilder, - null, - 5, - Map.of("search_pipeline", SEARCH_PIPELINE) - ); - - assertWeightedScores(searchResponseWithWeights2AsMap, 1.0, 1.0, 0.001); - - // check case when number of weights is less than number of sub-queries - // delete existing pipeline and create a new one with another set of weights - deleteSearchPipeline(SEARCH_PIPELINE); - createSearchPipeline( - SEARCH_PIPELINE, - DEFAULT_NORMALIZATION_METHOD, - DEFAULT_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f })) - ); - - Map searchResponseWithWeights3AsMap = search( - TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, - hybridQueryBuilder, - null, - 5, - Map.of("search_pipeline", SEARCH_PIPELINE) - ); - - assertWeightedScores(searchResponseWithWeights3AsMap, 1.0, 1.0, 0.001); - - // check case when number of weights is more than number of sub-queries - // delete existing pipeline and create a new one with another set of weights - deleteSearchPipeline(SEARCH_PIPELINE); - createSearchPipeline( - SEARCH_PIPELINE, - DEFAULT_NORMALIZATION_METHOD, - DEFAULT_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f, 1.5f })) - ); - - Map searchResponseWithWeights4AsMap = search( - TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, - hybridQueryBuilder, - null, - 5, - Map.of("search_pipeline", SEARCH_PIPELINE) - ); - - assertWeightedScores(searchResponseWithWeights4AsMap, 1.0, 1.0, 0.001); - } - - /** - * Using search pipelines with config for l2 norm: - * { - * "description": "Post processor for hybrid search", - * "phase_results_processors": [ - * { - * "normalization-processor": { - * "normalization": { - * "technique": "l2" - * }, - * "combination": { - * "technique": "arithmetic_mean" - * } - * } - * } - * ] - * } - */ - @SneakyThrows - public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { - initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); - createSearchPipeline( - SEARCH_PIPELINE, - L2_NORMALIZATION_METHOD, - DEFAULT_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) - ); - - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null); - TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); - - HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); - hybridQueryBuilder.add(neuralQueryBuilder); - hybridQueryBuilder.add(termQueryBuilder); - - Map searchResponseAsMap = search( - TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, - hybridQueryBuilder, - null, - 5, - Map.of("search_pipeline", SEARCH_PIPELINE) - ); - int totalExpectedDocQty = 5; - float[] minMaxExpectedScoresRange = { 0.6f, 1.0f }; - assertHybridSearchResults(searchResponseAsMap, totalExpectedDocQty, minMaxExpectedScoresRange); - } - - @SneakyThrows - public void testMinMaxNormHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessful() { - initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); - createSearchPipeline( - SEARCH_PIPELINE, - DEFAULT_NORMALIZATION_METHOD, - HARMONIC_MEAN_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) - ); - - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null); - TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); - - HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); - hybridQueryBuilder.add(neuralQueryBuilder); - hybridQueryBuilder.add(termQueryBuilder); - - Map searchResponseAsMap = search( - TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, - hybridQueryBuilder, - null, - 5, - Map.of("search_pipeline", SEARCH_PIPELINE) - ); - int totalExpectedDocQty = 5; - float[] minMaxExpectedScoresRange = { 0.6f, 1.0f }; - assertHybridSearchResults(searchResponseAsMap, totalExpectedDocQty, minMaxExpectedScoresRange); - } - - @SneakyThrows - public void testL2NormHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessful() { - initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); - createSearchPipeline( - SEARCH_PIPELINE, - L2_NORMALIZATION_METHOD, - HARMONIC_MEAN_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) - ); - - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null); - TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); - - HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); - hybridQueryBuilder.add(neuralQueryBuilder); - hybridQueryBuilder.add(termQueryBuilder); - - Map searchResponseAsMap = search( - TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, - hybridQueryBuilder, - null, - 5, - Map.of("search_pipeline", SEARCH_PIPELINE) - ); - int totalExpectedDocQty = 5; - float[] minMaxExpectedScoresRange = { 0.5f, 1.0f }; - assertHybridSearchResults(searchResponseAsMap, totalExpectedDocQty, minMaxExpectedScoresRange); - } - private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME.equalsIgnoreCase(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)) { prepareKnnIndex( @@ -616,60 +395,4 @@ private void assertQueryResults(Map searchResponseAsMap, int tot // verify that all ids are unique assertEquals(Set.copyOf(ids).size(), ids.size()); } - - private void assertWeightedScores( - Map searchResponseWithWeightsAsMap, - double expectedMaxScore, - double expectedMaxMinusOneScore, - double expectedMinScore - ) { - assertNotNull(searchResponseWithWeightsAsMap); - Map totalWeights = getTotalHits(searchResponseWithWeightsAsMap); - assertNotNull(totalWeights.get("value")); - assertEquals(4, totalWeights.get("value")); - assertNotNull(totalWeights.get("relation")); - assertEquals(RELATION_EQUAL_TO, totalWeights.get("relation")); - assertTrue(getMaxScore(searchResponseWithWeightsAsMap).isPresent()); - assertEquals(expectedMaxScore, getMaxScore(searchResponseWithWeightsAsMap).get(), 0.001f); - - List scoresWeights = new ArrayList<>(); - for (Map oneHit : getNestedHits(searchResponseWithWeightsAsMap)) { - scoresWeights.add((Double) oneHit.get("_score")); - } - // verify scores order - assertTrue(IntStream.range(0, scoresWeights.size() - 1).noneMatch(idx -> scoresWeights.get(idx) < scoresWeights.get(idx + 1))); - // verify the scores are normalized with inclusion of weights - assertEquals(expectedMaxScore, scoresWeights.get(0), 0.001); - assertEquals(expectedMaxMinusOneScore, scoresWeights.get(1), 0.001); - assertEquals(expectedMinScore, scoresWeights.get(scoresWeights.size() - 1), 0.001); - } - - private void assertHybridSearchResults(Map searchResponseAsMap, int totalExpectedDocQty, float[] minMaxScoreRange) { - assertNotNull(searchResponseAsMap); - Map total = getTotalHits(searchResponseAsMap); - assertNotNull(total.get("value")); - assertEquals(totalExpectedDocQty, total.get("value")); - assertNotNull(total.get("relation")); - assertEquals(RELATION_EQUAL_TO, total.get("relation")); - assertTrue(getMaxScore(searchResponseAsMap).isPresent()); - assertTrue(Range.between(minMaxScoreRange[0], minMaxScoreRange[1]).contains(getMaxScore(searchResponseAsMap).get())); - - List> hitsNestedList = getNestedHits(searchResponseAsMap); - List ids = new ArrayList<>(); - List scores = new ArrayList<>(); - for (Map oneHit : hitsNestedList) { - ids.add((String) oneHit.get("_id")); - scores.add((Double) oneHit.get("_score")); - } - // verify scores order - assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); - // verify the scores are normalized. for l2 scores max score will not be 1.0 so we're checking on a range - assertTrue( - Range.between(minMaxScoreRange[0], minMaxScoreRange[1]) - .contains(scores.stream().map(Double::floatValue).max(Double::compare).get()) - ); - - // verify that all ids are unique - assertEquals(Set.copyOf(ids).size(), ids.size()); - } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java new file mode 100644 index 000000000..23b5daa90 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java @@ -0,0 +1,423 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import static org.opensearch.neuralsearch.TestUtils.assertHybridSearchResults; +import static org.opensearch.neuralsearch.TestUtils.assertWeightedScores; +import static org.opensearch.neuralsearch.TestUtils.createRandomVector; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; + +import lombok.SneakyThrows; + +import org.junit.After; +import org.junit.Before; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.neuralsearch.common.BaseNeuralSearchIT; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; + +import com.google.common.primitives.Floats; + +public class ScoreCombinationIT extends BaseNeuralSearchIT { + private static final String TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME = "test-neural-multi-doc-one-shard-index"; + private static final String TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME = "test-neural-multi-doc-three-shards-index"; + private static final String TEST_QUERY_TEXT3 = "hello"; + private static final String TEST_QUERY_TEXT4 = "place"; + private static final String TEST_QUERY_TEXT7 = "notexistingwordtwo"; + private static final String TEST_DOC_TEXT1 = "Hello world"; + private static final String TEST_DOC_TEXT2 = "Hi to this place"; + private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; + private static final String TEST_DOC_TEXT4 = "Hello, I'm glad to you see you pal"; + private static final String TEST_DOC_TEXT5 = "Say hello and enter my friend"; + private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1"; + private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1"; + private static final int TEST_DIMENSION = 768; + private static final SpaceType TEST_SPACE_TYPE = SpaceType.L2; + private static final AtomicReference modelId = new AtomicReference<>(); + private static final String SEARCH_PIPELINE = "phase-results-pipeline"; + private final float[] testVector1 = createRandomVector(TEST_DIMENSION); + private final float[] testVector2 = createRandomVector(TEST_DIMENSION); + private final float[] testVector3 = createRandomVector(TEST_DIMENSION); + private final float[] testVector4 = createRandomVector(TEST_DIMENSION); + private final static String RELATION_EQUAL_TO = "eq"; + + private static final String L2_NORMALIZATION_METHOD = "l2"; + private static final String HARMONIC_MEAN_COMBINATION_METHOD = "harmonic_mean"; + private static final String GEOMETRIC_MEAN_COMBINATION_METHOD = "geometric_mean"; + + @Before + public void setUp() throws Exception { + super.setUp(); + updateClusterSettings(); + modelId.compareAndSet(null, prepareModel()); + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + deleteSearchPipeline(SEARCH_PIPELINE); + } + + @Override + public boolean isUpdateClusterSettings() { + return false; + } + + @Override + protected boolean preserveClusterUponCompletion() { + return true; + } + + /** + * Using search pipelines with result processor configs like below: + * { + * "description": "Post processor for hybrid search", + * "phase_results_processors": [ + * { + * "normalization-processor": { + * "normalization": { + * "technique": "min-max" + * }, + * "combination": { + * "technique": "arithmetic_mean", + * "parameters": { + * "weights": [ + * 0.4, 0.7 + * ] + * } + * } + * } + * } + * ] + * } + */ + @SneakyThrows + public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME); + // check case when number of weights and sub-queries are same + createSearchPipeline( + SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + DEFAULT_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f })) + ); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4)); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT7)); + + Map searchResponseWithWeights1AsMap = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertWeightedScores(searchResponseWithWeights1AsMap, 1.0, 1.0, 0.001); + + // delete existing pipeline and create a new one with another set of weights + deleteSearchPipeline(SEARCH_PIPELINE); + createSearchPipeline( + SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + DEFAULT_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 2.0f, 0.5f })) + ); + + Map searchResponseWithWeights2AsMap = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertWeightedScores(searchResponseWithWeights2AsMap, 1.0, 1.0, 0.001); + + // check case when number of weights is less than number of sub-queries + // delete existing pipeline and create a new one with another set of weights + deleteSearchPipeline(SEARCH_PIPELINE); + createSearchPipeline( + SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + DEFAULT_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f })) + ); + + Map searchResponseWithWeights3AsMap = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertWeightedScores(searchResponseWithWeights3AsMap, 1.0, 1.0, 0.001); + + // check case when number of weights is more than number of sub-queries + // delete existing pipeline and create a new one with another set of weights + deleteSearchPipeline(SEARCH_PIPELINE); + createSearchPipeline( + SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + DEFAULT_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f, 1.5f })) + ); + + Map searchResponseWithWeights4AsMap = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertWeightedScores(searchResponseWithWeights4AsMap, 1.0, 1.0, 0.001); + } + + /** + * Using search pipelines with config for harmonic mean: + * { + * "description": "Post processor for hybrid search", + * "phase_results_processors": [ + * { + * "normalization-processor": { + * "normalization": { + * "technique": "l2" + * }, + * "combination": { + * "technique": "harmonic_mean" + * } + * } + * } + * ] + * } + */ + @SneakyThrows + public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + HARMONIC_MEAN_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + + HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder(); + hybridQueryBuilderDefaultNorm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null)); + hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapDefaultNorm = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderDefaultNorm, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertHybridSearchResults(searchResponseAsMapDefaultNorm, 5, new float[] { 0.5f, 1.0f }); + + deleteSearchPipeline(SEARCH_PIPELINE); + + createSearchPipeline( + SEARCH_PIPELINE, + L2_NORMALIZATION_METHOD, + HARMONIC_MEAN_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + + HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); + hybridQueryBuilderL2Norm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null)); + hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapL2Norm = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderL2Norm, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertHybridSearchResults(searchResponseAsMapL2Norm, 5, new float[] { 0.5f, 1.0f }); + } + + /** + * Using search pipelines with config for geometric mean: + * { + * "description": "Post processor for hybrid search", + * "phase_results_processors": [ + * { + * "normalization-processor": { + * "normalization": { + * "technique": "l2" + * }, + * "combination": { + * "technique": "geometric_mean" + * } + * } + * } + * ] + * } + */ + @SneakyThrows + public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + GEOMETRIC_MEAN_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + + HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder(); + hybridQueryBuilderDefaultNorm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null)); + hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapDefaultNorm = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderDefaultNorm, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertHybridSearchResults(searchResponseAsMapDefaultNorm, 5, new float[] { 0.5f, 1.0f }); + + deleteSearchPipeline(SEARCH_PIPELINE); + + createSearchPipeline( + SEARCH_PIPELINE, + L2_NORMALIZATION_METHOD, + GEOMETRIC_MEAN_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + + HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); + hybridQueryBuilderL2Norm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null)); + hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapL2Norm = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderL2Norm, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertHybridSearchResults(searchResponseAsMapL2Norm, 5, new float[] { 0.5f, 1.0f }); + } + + private void initializeIndexIfNotExist(String indexName) throws IOException { + if (TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME.equalsIgnoreCase(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)) { + prepareKnnIndex( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), + 1 + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "1", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector1).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "2", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector2).toArray()) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "3", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector3).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT2) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "4", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "5", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT4) + ); + assertEquals(5, getDocCount(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)); + } + + if (TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME.equalsIgnoreCase(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME)) { + prepareKnnIndex( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), + 3 + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "1", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector1).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "2", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector2).toArray()) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "3", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector3).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT2) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "4", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "5", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT4) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "6", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT5) + ); + assertEquals(6, getDocCount(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME)); + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java new file mode 100644 index 000000000..e9324aa9b --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java @@ -0,0 +1,356 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import static org.opensearch.neuralsearch.TestUtils.assertHybridSearchResults; +import static org.opensearch.neuralsearch.TestUtils.createRandomVector; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; + +import lombok.SneakyThrows; + +import org.junit.After; +import org.junit.Before; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.neuralsearch.common.BaseNeuralSearchIT; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; + +import com.google.common.primitives.Floats; + +public class ScoreNormalizationIT extends BaseNeuralSearchIT { + private static final String TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME = "test-neural-multi-doc-one-shard-index"; + private static final String TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME = "test-neural-multi-doc-three-shards-index"; + private static final String TEST_QUERY_TEXT3 = "hello"; + private static final String TEST_DOC_TEXT1 = "Hello world"; + private static final String TEST_DOC_TEXT2 = "Hi to this place"; + private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; + private static final String TEST_DOC_TEXT4 = "Hello, I'm glad to you see you pal"; + private static final String TEST_DOC_TEXT5 = "Say hello and enter my friend"; + private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1"; + private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1"; + private static final int TEST_DIMENSION = 768; + private static final SpaceType TEST_SPACE_TYPE = SpaceType.L2; + private static final AtomicReference modelId = new AtomicReference<>(); + private static final String SEARCH_PIPELINE = "phase-results-pipeline"; + private final float[] testVector1 = createRandomVector(TEST_DIMENSION); + private final float[] testVector2 = createRandomVector(TEST_DIMENSION); + private final float[] testVector3 = createRandomVector(TEST_DIMENSION); + private final float[] testVector4 = createRandomVector(TEST_DIMENSION); + + private static final String L2_NORMALIZATION_METHOD = "l2"; + private static final String HARMONIC_MEAN_COMBINATION_METHOD = "harmonic_mean"; + private static final String GEOMETRIC_MEAN_COMBINATION_METHOD = "geometric_mean"; + + @Before + public void setUp() throws Exception { + super.setUp(); + updateClusterSettings(); + modelId.compareAndSet(null, prepareModel()); + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + deleteSearchPipeline(SEARCH_PIPELINE); + } + + @Override + public boolean isUpdateClusterSettings() { + return false; + } + + @Override + protected boolean preserveClusterUponCompletion() { + return true; + } + + /** + * Using search pipelines with config for l2 norm: + * { + * "description": "Post processor for hybrid search", + * "phase_results_processors": [ + * { + * "normalization-processor": { + * "normalization": { + * "technique": "l2" + * }, + * "combination": { + * "technique": "arithmetic_mean" + * } + * } + * } + * ] + * } + */ + @SneakyThrows + public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + L2_NORMALIZATION_METHOD, + DEFAULT_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + + HybridQueryBuilder hybridQueryBuilderArithmeticMean = new HybridQueryBuilder(); + hybridQueryBuilderArithmeticMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null)); + hybridQueryBuilderArithmeticMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapArithmeticMean = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderArithmeticMean, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertHybridSearchResults(searchResponseAsMapArithmeticMean, 5, new float[] { 0.6f, 1.0f }); + + deleteSearchPipeline(SEARCH_PIPELINE); + + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + L2_NORMALIZATION_METHOD, + HARMONIC_MEAN_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + + HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder(); + hybridQueryBuilderHarmonicMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null)); + hybridQueryBuilderHarmonicMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapHarmonicMean = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderHarmonicMean, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertHybridSearchResults(searchResponseAsMapHarmonicMean, 5, new float[] { 0.5f, 1.0f }); + + deleteSearchPipeline(SEARCH_PIPELINE); + + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + L2_NORMALIZATION_METHOD, + GEOMETRIC_MEAN_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + + HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder(); + hybridQueryBuilderGeometricMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null)); + hybridQueryBuilderGeometricMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapGeometricMean = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderGeometricMean, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertHybridSearchResults(searchResponseAsMapGeometricMean, 5, new float[] { 0.5f, 1.0f }); + } + + /** + * Using search pipelines with config for min-max norm: + * { + * "description": "Post processor for hybrid search", + * "phase_results_processors": [ + * { + * "normalization-processor": { + * "normalization": { + * "technique": "l2" + * }, + * "combination": { + * "technique": "arithmetic_mean" + * } + * } + * } + * ] + * } + */ + @SneakyThrows + public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + DEFAULT_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + + HybridQueryBuilder hybridQueryBuilderArithmeticMean = new HybridQueryBuilder(); + hybridQueryBuilderArithmeticMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null)); + hybridQueryBuilderArithmeticMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapArithmeticMean = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderArithmeticMean, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertHybridSearchResults(searchResponseAsMapArithmeticMean, 5, new float[] { 1.0f, 1.0f }); + + deleteSearchPipeline(SEARCH_PIPELINE); + + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + HARMONIC_MEAN_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + + HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder(); + hybridQueryBuilderHarmonicMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null)); + hybridQueryBuilderHarmonicMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapHarmonicMean = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderHarmonicMean, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertHybridSearchResults(searchResponseAsMapHarmonicMean, 5, new float[] { 0.6f, 1.0f }); + + deleteSearchPipeline(SEARCH_PIPELINE); + + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + GEOMETRIC_MEAN_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + + HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder(); + hybridQueryBuilderGeometricMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null)); + hybridQueryBuilderGeometricMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + Map searchResponseAsMapGeometricMean = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilderGeometricMean, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertHybridSearchResults(searchResponseAsMapGeometricMean, 5, new float[] { 0.6f, 1.0f }); + } + + private void initializeIndexIfNotExist(String indexName) throws IOException { + if (TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME.equalsIgnoreCase(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)) { + prepareKnnIndex( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), + 1 + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "1", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector1).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "2", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector2).toArray()) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "3", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector3).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT2) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "4", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "5", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT4) + ); + assertEquals(5, getDocCount(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)); + } + + if (TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME.equalsIgnoreCase(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME)) { + prepareKnnIndex( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), + 3 + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "1", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector1).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "2", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector2).toArray()) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "3", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector3).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT2) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "4", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "5", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT4) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "6", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT5) + ); + assertEquals(6, getDocCount(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME)); + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java index 7cea20a41..d256f9974 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java @@ -26,6 +26,7 @@ public class TextEmbeddingProcessorIT extends BaseNeuralSearchIT { private static final String PIPELINE_NAME = "pipeline-hybrid"; public void testTextEmbeddingProcessor() throws Exception { + updateClusterSettings(); String modelId = uploadTextEmbeddingModel(); loadModel(modelId); createPipelineProcessor(modelId, PIPELINE_NAME); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java index 3c3ca3776..9ee3a9bbf 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java @@ -12,17 +12,19 @@ public class ArithmeticMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { + private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); + public ArithmeticMeanScoreCombinationTechniqueTests() { this.expectedScoreFunction = this::arithmeticMean; } public void testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores() { - ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of(), new ScoreCombinationUtil()); + ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of(), scoreCombinationUtil); testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores(technique); } public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { - ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of(), new ScoreCombinationUtil()); + ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of(), scoreCombinationUtil); testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(technique); } @@ -30,7 +32,7 @@ public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() { List weights = List.of(0.9, 0.2, 0.7); ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), - new ScoreCombinationUtil() + scoreCombinationUtil ); testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); } @@ -39,7 +41,7 @@ public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { List weights = List.of(0.9, 0.2, 0.7); ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), - new ScoreCombinationUtil() + scoreCombinationUtil ); testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java new file mode 100644 index 000000000..b0d6509bf --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import static org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique.PARAM_NAME_WEIGHTS; + +import java.util.List; +import java.util.Map; + +public class GeometricMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { + + private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); + + public GeometricMeanScoreCombinationTechniqueTests() { + this.expectedScoreFunction = this::geometricMean; + } + + public void testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores() { + ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique(Map.of(), scoreCombinationUtil); + testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores(technique); + } + + public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { + ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique(Map.of(), scoreCombinationUtil); + testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(technique); + } + + public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() { + List weights = List.of(0.9, 0.2, 0.7); + ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique( + Map.of(PARAM_NAME_WEIGHTS, weights), + scoreCombinationUtil + ); + testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); + } + + public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { + List weights = List.of(0.9, 0.2, 0.7); + ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique( + Map.of(PARAM_NAME_WEIGHTS, weights), + scoreCombinationUtil + ); + testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); + } + + private float geometricMean(List scores, List weights) { + assertEquals(scores.size(), weights.size()); + float sumOfWeights = 0; + float weightedSumOfLn = 0; + for (int i = 0; i < scores.size(); i++) { + float score = scores.get(i), weight = weights.get(i).floatValue(); + if (score > 0) { + sumOfWeights += weight; + weightedSumOfLn += weight * Math.log(score); + } + } + return sumOfWeights == 0 ? 0f : (float) Math.exp(weightedSumOfLn / sumOfWeights); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java index 02a8084ef..3c45b7877 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java @@ -12,17 +12,19 @@ public class HarmonicMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { + private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); + public HarmonicMeanScoreCombinationTechniqueTests() { this.expectedScoreFunction = (scores, weights) -> harmonicMean(scores, weights); } public void testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores() { - ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of(), new ScoreCombinationUtil()); + ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of(), scoreCombinationUtil); testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores(technique); } public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { - ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of(), new ScoreCombinationUtil()); + ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of(), scoreCombinationUtil); testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(technique); } @@ -30,7 +32,7 @@ public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() { List weights = List.of(0.9, 0.2, 0.7); ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), - new ScoreCombinationUtil() + scoreCombinationUtil ); testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); } @@ -39,7 +41,7 @@ public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { List weights = List.of(0.9, 0.2, 0.7); ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), - new ScoreCombinationUtil() + scoreCombinationUtil ); testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java index 9f164b3d3..ee46bf0a1 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java @@ -27,6 +27,14 @@ public void testHarmonicWeightedMean_whenCreatingByName_thenReturnCorrectInstanc assertTrue(scoreCombinationTechnique instanceof HarmonicMeanScoreCombinationTechnique); } + public void testGeometricWeightedMean_whenCreatingByName_thenReturnCorrectInstance() { + ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); + ScoreCombinationTechnique scoreCombinationTechnique = scoreCombinationFactory.createCombination("geometric_mean"); + + assertNotNull(scoreCombinationTechnique); + assertTrue(scoreCombinationTechnique instanceof GeometricMeanScoreCombinationTechnique); + } + public void testUnsupportedTechnique_whenPassingInvalidName_thenFail() { ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); IllegalArgumentException illegalArgumentException = expectThrows(