diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java new file mode 100644 index 000000000..3fff2db2b --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java @@ -0,0 +1,104 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Abstracts combination of scores based on arithmetic mean method + */ +public class HarmonicMeanScoreCombinationTechnique implements ScoreCombinationTechnique { + + public static final String TECHNIQUE_NAME = "arithmetic_mean"; + public static final String PARAM_NAME_WEIGHTS = "weights"; + private static final Set SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS); + private static final Float ZERO_SCORE = 0.0f; + private final List weights; + + public HarmonicMeanScoreCombinationTechnique(final Map params) { + validateParams(params); + weights = getWeights(params); + } + + private List getWeights(final Map params) { + if (Objects.isNull(params) || params.isEmpty()) { + return List.of(); + } + // get weights, we don't need to check for instance as it's done during validation + return ((List) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream() + .map(Double::floatValue) + .collect(Collectors.toUnmodifiableList()); + } + + /** + * Arithmetic mean method for combining scores. + * score = (weight1*score1 + weight2*score2 +...+ weightN*scoreN)/(weight1 + weight2 + ... + weightN) + * + * Zero (0.0) scores are excluded from number of scores N + */ + @Override + public float combine(final float[] scores) { + float combinedScore = 0.0f; + float weights = 0; + for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) { + float score = scores[indexOfSubQuery]; + if (score >= 0.0) { + float weight = getWeightForSubQuery(indexOfSubQuery); + score = score * weight; + combinedScore += score; + weights += weight; + } + } + if (weights == 0.0f) { + return ZERO_SCORE; + } + return combinedScore / weights; + } + + private void validateParams(final Map params) { + if (Objects.isNull(params) || params.isEmpty()) { + return; + } + // check if only supported params are passed + Optional optionalNotSupportedParam = params.keySet() + .stream() + .filter(paramName -> !SUPPORTED_PARAMS.contains(paramName)) + .findFirst(); + if (optionalNotSupportedParam.isPresent()) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "provided parameter for combination technique is not supported. supported parameters are [%s]", + SUPPORTED_PARAMS.stream().collect(Collectors.joining(",")) + ) + ); + } + + // check param types + if (params.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) { + if (!(params.get(PARAM_NAME_WEIGHTS) instanceof List)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "parameter [%s] must be a collection of numbers", PARAM_NAME_WEIGHTS) + ); + } + } + } + + /** + * Get weight for sub-query based on its index in the hybrid search query. Use user provided weight or 1.0 otherwise + * @param indexOfSubQuery 0-based index of sub-query in the Hybrid Search query + * @return weight for sub-query, use one that is set in processor/pipeline definition or 1.0 as default + */ + private float getWeightForSubQuery(int indexOfSubQuery) { + return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery) : 1.0f; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java new file mode 100644 index 000000000..0007a3ef3 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java @@ -0,0 +1,88 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.normalization; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.opensearch.neuralsearch.search.CompoundTopDocs; + +/** + * Abstracts normalization of scores based on L2 method + */ +public class L2ScoreNormalizationTechnique implements ScoreNormalizationTechnique { + + public static final String TECHNIQUE_NAME = "l2"; + private static final float MIN_SCORE = 0.001f; + + /** + * L2 normalization method. + * n_score_i = score_i/sqrt(score1^2 + score2^2 + ... + scoren^2) + * Main algorithm steps: + * - calculate sum of squares of all scores + * - iterate over each result and update score as per formula above where "score" is raw score returned by Hybrid query + */ + @Override + public void normalize(final List queryTopDocs) { + // get l2 norms for each sub-query + List normsPerSubquery = getL2Norm(queryTopDocs); + + // do normalization using actual score and l2 norm + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + if (Objects.isNull(compoundQueryTopDocs)) { + continue; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs(); + for (int j = 0; j < topDocsPerSubQuery.size(); j++) { + TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j); + for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) { + scoreDoc.score = normalizeSingleScore(scoreDoc.score, normsPerSubquery.get(j)); + } + } + } + } + + private List getL2Norm(final List queryTopDocs) { + // find any non-empty compound top docs, it's either empty if shard does not have any results for all of sub-queries, + // or it has results for all the sub-queries. In edge case of shard having results only for one sub-query, there will be TopDocs for + // rest of sub-queries with zero total hits + int numOfSubqueries = queryTopDocs.stream() + .filter(Objects::nonNull) + .filter(topDocs -> topDocs.getCompoundTopDocs().size() > 0) + .findAny() + .get() + .getCompoundTopDocs() + .size(); + float[] l2Norms = new float[numOfSubqueries]; + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + if (Objects.isNull(compoundQueryTopDocs)) { + continue; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs(); + int bound = topDocsPerSubQuery.size(); + for (int index = 0; index < bound; index++) { + for (ScoreDoc scoreDocs : topDocsPerSubQuery.get(index).scoreDocs) { + l2Norms[index] += scoreDocs.score * scoreDocs.score; + } + } + } + for (int index = 0; index < l2Norms.length; index++) { + l2Norms[index] = (float) Math.sqrt(l2Norms[index]); + } + List l2NormList = new ArrayList<>(); + for (int index = 0; index < numOfSubqueries; index++) { + l2NormList.add(l2Norms[index]); + } + return l2NormList; + } + + private float normalizeSingleScore(final float score, final float l2Norm) { + return l2Norm == 0 ? MIN_SCORE : score / l2Norm; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java index b469e241b..667c237c7 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java @@ -17,7 +17,9 @@ public class ScoreNormalizationFactory { private final Map scoreNormalizationMethodsMap = Map.of( MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME, - new MinMaxScoreNormalizationTechnique() + new MinMaxScoreNormalizationTechnique(), + L2ScoreNormalizationTechnique.TECHNIQUE_NAME, + new L2ScoreNormalizationTechnique() ); /** diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index 48c694fb0..6aa9b5a5a 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -107,7 +107,7 @@ protected String uploadModel(String requestBody) throws Exception { ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) ); Map uploadResJson = XContentHelper.convertToMap( - XContentFactory.xContent(XContentType.JSON), + XContentType.JSON.xContent(), EntityUtils.toString(uploadResponse.getEntity()), false ); @@ -136,7 +136,7 @@ protected void loadModel(String modelId) throws Exception { ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) ); Map uploadResJson = XContentHelper.convertToMap( - XContentFactory.xContent(XContentType.JSON), + XContentType.JSON.xContent(), EntityUtils.toString(uploadResponse.getEntity()), false ); @@ -185,7 +185,7 @@ protected float[] runInference(String modelId, String queryText) { ); Map inferenceResJson = XContentHelper.convertToMap( - XContentFactory.xContent(XContentType.JSON), + XContentType.JSON.xContent(), EntityUtils.toString(inferenceResponse.getEntity()), false ); @@ -215,7 +215,7 @@ protected void createIndexWithConfiguration(String indexName, String indexConfig ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) ); Map node = XContentHelper.convertToMap( - XContentFactory.xContent(XContentType.JSON), + XContentType.JSON.xContent(), EntityUtils.toString(response.getEntity()), false ); @@ -239,7 +239,7 @@ protected void createPipelineProcessor(String modelId, String pipelineName) thro ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) ); Map node = XContentHelper.convertToMap( - XContentFactory.xContent(XContentType.JSON), + XContentType.JSON.xContent(), EntityUtils.toString(pipelineCreateResponse.getEntity()), false ); @@ -329,7 +329,7 @@ protected Map search( String responseBody = EntityUtils.toString(response.getEntity()); - return XContentHelper.convertToMap(XContentFactory.xContent(XContentType.JSON), responseBody, false); + return XContentHelper.convertToMap(XContentType.JSON.xContent(), responseBody, false); } /** @@ -445,11 +445,7 @@ protected Map getTaskQueryResponse(String taskId) throws Excepti toHttpEntity(""), ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) ); - return XContentHelper.convertToMap( - XContentFactory.xContent(XContentType.JSON), - EntityUtils.toString(taskQueryResponse.getEntity()), - false - ); + return XContentHelper.convertToMap(XContentType.JSON.xContent(), EntityUtils.toString(taskQueryResponse.getEntity()), false); } protected boolean checkComplete(Map node) { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java index 126275a47..273dbd522 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java @@ -27,6 +27,7 @@ import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.index.SpaceType; import org.opensearch.neuralsearch.common.BaseNeuralSearchIT; +import org.opensearch.neuralsearch.processor.normalization.L2ScoreNormalizationTechnique; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; @@ -93,12 +94,7 @@ protected boolean preserveClusterUponCompletion() { * "technique": "min-max" * }, * "combination": { - * "technique": "sum", - * "parameters": { - * "weights": [ - * 0.4, 0.7 - * ] - * } + * "technique": "arithmetic_mean" * } * } * } @@ -251,6 +247,29 @@ public void testResultProcessor_whenMultipleShardsAndPartialMatches_thenSuccessf assertQueryResults(searchResponseAsMap, 4, true); } + /** + * Using search pipelines with result processor configs like below: + * { + * "description": "Post processor for hybrid search", + * "phase_results_processors": [ + * { + * "normalization-processor": { + * "normalization": { + * "technique": "min-max" + * }, + * "combination": { + * "technique": "arithmetic_mean", + * "parameters": { + * "weights": [ + * 0.4, 0.7 + * ] + * } + * } + * } + * } + * ] + * } + */ @SneakyThrows public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME); @@ -337,6 +356,74 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { assertWeightedScores(searchResponseWithWeights4AsMap, 1.0, 1.0, 0.001); } + /** + * Using search pipelines with config for l2 norm: + * { + * "description": "Post processor for hybrid search", + * "phase_results_processors": [ + * { + * "normalization-processor": { + * "normalization": { + * "technique": "l2" + * }, + * "combination": { + * "technique": "arithmetic_mean" + * } + * } + * } + * ] + * } + */ + @SneakyThrows + public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + L2ScoreNormalizationTechnique.TECHNIQUE_NAME, + COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + ); + + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null); + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(neuralQueryBuilder); + hybridQueryBuilder.add(termQueryBuilder); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + int totalExpectedDocQty = 5; + assertNotNull(searchResponseAsMap); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(totalExpectedDocQty, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + assertTrue(getMaxScore(searchResponseAsMap).isPresent()); + assertTrue(Range.between(.6f, 1.0f).contains(getMaxScore(searchResponseAsMap).get())); + + List> hitsNestedList = getNestedHits(searchResponseAsMap); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + // verify scores order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify the scores are normalized. for l2 scores max score will not be 1.0 so we're checking on a range + assertTrue(Range.between(.6f, 1.0f).contains((float) scores.stream().map(Double::floatValue).max(Double::compare).get())); + + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + } + private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME.equalsIgnoreCase(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)) { prepareKnnIndex( diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java index 226a3878d..7cea20a41 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java @@ -13,7 +13,6 @@ import org.apache.hc.core5.http.io.entity.EntityUtils; import org.apache.hc.core5.http.message.BasicHeader; import org.opensearch.client.Response; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; import org.opensearch.neuralsearch.common.BaseNeuralSearchIT; @@ -71,7 +70,7 @@ private void ingestDocument() throws Exception { ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) ); Map map = XContentHelper.convertToMap( - XContentFactory.xContent(XContentType.JSON), + XContentType.JSON.xContent(), EntityUtils.toString(response.getEntity()), false ); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java new file mode 100644 index 000000000..c5f8c4860 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java @@ -0,0 +1,225 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.normalization; + +import java.util.Arrays; +import java.util.List; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; +import org.opensearch.neuralsearch.search.CompoundTopDocs; + +/** + * Abstracts normalization of scores based on min-max method + */ +public class L2ScoreNormalizationTechniqueTests extends OpenSearchQueryTestCase { + private static final float DELTA_FOR_ASSERTION = 0.0001f; + + public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() { + L2ScoreNormalizationTechnique normalizationTechnique = new L2ScoreNormalizationTechnique(); + Float[] scores = { 0.5f, 0.2f }; + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, scores[0]), new ScoreDoc(4, scores[1]) } + ) + ) + ) + ); + normalizationTechnique.normalize(compoundTopDocs); + + CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(2, l2Norm(scores[0], Arrays.asList(scores))), + new ScoreDoc(4, l2Norm(scores[1], Arrays.asList(scores))) } + ) + ) + ); + assertNotNull(compoundTopDocs); + assertEquals(1, compoundTopDocs.size()); + assertNotNull(compoundTopDocs.get(0).getCompoundTopDocs()); + assertCompoundTopDocs(expectedCompoundDocs, compoundTopDocs.get(0).getCompoundTopDocs().get(0)); + } + + public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSuccessful() { + L2ScoreNormalizationTechnique normalizationTechnique = new L2ScoreNormalizationTechnique(); + Float[] scoresQuery1 = { 0.5f, 0.2f }; + Float[] scoresQuery2 = { 0.9f, 0.7f, 0.1f }; + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, scoresQuery1[0]), new ScoreDoc(4, scoresQuery1[1]) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(3, scoresQuery2[0]), + new ScoreDoc(4, scoresQuery2[1]), + new ScoreDoc(2, scoresQuery2[2]) } + ) + ) + ) + ); + normalizationTechnique.normalize(compoundTopDocs); + + CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(2, l2Norm(scoresQuery1[0], Arrays.asList(scoresQuery1))), + new ScoreDoc(4, l2Norm(scoresQuery1[1], Arrays.asList(scoresQuery1))) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(3, l2Norm(scoresQuery2[0], Arrays.asList(scoresQuery2))), + new ScoreDoc(4, l2Norm(scoresQuery2[1], Arrays.asList(scoresQuery2))), + new ScoreDoc(2, l2Norm(scoresQuery2[2], Arrays.asList(scoresQuery2))) } + ) + ) + ); + assertNotNull(compoundTopDocs); + assertEquals(1, compoundTopDocs.size()); + assertNotNull(compoundTopDocs.get(0).getCompoundTopDocs()); + for (int i = 0; i < expectedCompoundDocs.getCompoundTopDocs().size(); i++) { + assertCompoundTopDocs(expectedCompoundDocs.getCompoundTopDocs().get(i), compoundTopDocs.get(0).getCompoundTopDocs().get(i)); + } + } + + public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_thenSuccessful() { + L2ScoreNormalizationTechnique normalizationTechnique = new L2ScoreNormalizationTechnique(); + Float[] scoresShard1Query1 = { 0.5f, 0.2f }; + Float[] scoresShard1and2Query3 = { 0.9f, 0.7f, 0.1f, 0.8f, 0.7f, 0.6f, 0.5f }; + Float[] scoresShard2Query2 = { 2.9f, 0.7f }; + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, scoresShard1Query1[0]), new ScoreDoc(4, scoresShard1Query1[1]) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(3, scoresShard1and2Query3[0]), + new ScoreDoc(4, scoresShard1and2Query3[1]), + new ScoreDoc(2, scoresShard1and2Query3[2]) } + ) + ) + ), + new CompoundTopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(7, scoresShard2Query2[0]), new ScoreDoc(9, scoresShard2Query2[1]) } + ), + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(3, scoresShard1and2Query3[3]), + new ScoreDoc(9, scoresShard1and2Query3[4]), + new ScoreDoc(10, scoresShard1and2Query3[5]), + new ScoreDoc(15, scoresShard1and2Query3[6]) } + ) + ) + ) + ); + normalizationTechnique.normalize(compoundTopDocs); + + CompoundTopDocs expectedCompoundDocsShard1 = new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(2, l2Norm(scoresShard1Query1[0], Arrays.asList(scoresShard1Query1))), + new ScoreDoc(4, l2Norm(scoresShard1Query1[1], Arrays.asList(scoresShard1Query1))) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(3, l2Norm(scoresShard1and2Query3[0], Arrays.asList(scoresShard1and2Query3))), + new ScoreDoc(4, l2Norm(scoresShard1and2Query3[1], Arrays.asList(scoresShard1and2Query3))), + new ScoreDoc(2, l2Norm(scoresShard1and2Query3[2], Arrays.asList(scoresShard1and2Query3))) } + ) + ) + ); + + CompoundTopDocs expectedCompoundDocsShard2 = new CompoundTopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(7, l2Norm(scoresShard2Query2[0], Arrays.asList(scoresShard2Query2))), + new ScoreDoc(9, l2Norm(scoresShard2Query2[1], Arrays.asList(scoresShard2Query2))) } + ), + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(3, l2Norm(scoresShard1and2Query3[3], Arrays.asList(scoresShard1and2Query3))), + new ScoreDoc(9, l2Norm(scoresShard1and2Query3[4], Arrays.asList(scoresShard1and2Query3))), + new ScoreDoc(10, l2Norm(scoresShard1and2Query3[5], Arrays.asList(scoresShard1and2Query3))), + new ScoreDoc(15, l2Norm(scoresShard1and2Query3[6], Arrays.asList(scoresShard1and2Query3))) } + ) + ) + ); + + assertNotNull(compoundTopDocs); + assertEquals(2, compoundTopDocs.size()); + assertNotNull(compoundTopDocs.get(0).getCompoundTopDocs()); + for (int i = 0; i < expectedCompoundDocsShard1.getCompoundTopDocs().size(); i++) { + assertCompoundTopDocs( + expectedCompoundDocsShard1.getCompoundTopDocs().get(i), + compoundTopDocs.get(0).getCompoundTopDocs().get(i) + ); + } + assertNotNull(compoundTopDocs.get(1).getCompoundTopDocs()); + for (int i = 0; i < expectedCompoundDocsShard2.getCompoundTopDocs().size(); i++) { + assertCompoundTopDocs( + expectedCompoundDocsShard2.getCompoundTopDocs().get(i), + compoundTopDocs.get(1).getCompoundTopDocs().get(i) + ); + } + } + + private float l2Norm(float score, List scores) { + return score / (float) Math.sqrt(scores.stream().map(Float::doubleValue).map(s -> s * s).mapToDouble(Double::doubleValue).sum()); + } + + private void assertCompoundTopDocs(TopDocs expected, TopDocs actual) { + assertEquals(expected.totalHits.value, actual.totalHits.value); + assertEquals(expected.totalHits.relation, actual.totalHits.relation); + assertEquals(expected.scoreDocs.length, actual.scoreDocs.length); + for (int i = 0; i < expected.scoreDocs.length; i++) { + assertEquals(expected.scoreDocs[i].score, actual.scoreDocs[i].score, DELTA_FOR_ASSERTION); + assertEquals(expected.scoreDocs[i].doc, actual.scoreDocs[i].doc); + assertEquals(expected.scoreDocs[i].shardIndex, actual.scoreDocs[i].shardIndex); + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java new file mode 100644 index 000000000..bf1489fe3 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.normalization; + +import static org.hamcrest.Matchers.containsString; + +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; + +public class ScoreNormalizationFactoryTests extends OpenSearchQueryTestCase { + + public void testMinMaxNorm_whenCreatingByName_thenReturnCorrectInstance() { + ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); + ScoreNormalizationTechnique scoreNormalizationTechnique = scoreNormalizationFactory.createNormalization("min_max"); + + assertNotNull(scoreNormalizationTechnique); + assertTrue(scoreNormalizationTechnique instanceof MinMaxScoreNormalizationTechnique); + } + + public void testL2Norm_whenCreatingByName_thenReturnCorrectInstance() { + ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); + ScoreNormalizationTechnique scoreNormalizationTechnique = scoreNormalizationFactory.createNormalization("l2"); + + assertNotNull(scoreNormalizationTechnique); + assertTrue(scoreNormalizationTechnique instanceof L2ScoreNormalizationTechnique); + } + + public void testUnsupportedTechnique_whenPassingInvalidName_thenFail() { + ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); + IllegalArgumentException illegalArgumentException = expectThrows( + IllegalArgumentException.class, + () -> scoreNormalizationFactory.createNormalization("randomname") + ); + assertThat(illegalArgumentException.getMessage(), containsString("provided normalization technique is not supported")); + } +}