diff --git a/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java index 4cfaf9837..33e971080 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java @@ -10,7 +10,7 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.neuralsearch.processor.explain.CombinedExplainDetails; -import org.opensearch.neuralsearch.processor.explain.ExplanationResponse; +import org.opensearch.neuralsearch.processor.explain.ExplanationPayload; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.SearchShardTarget; @@ -24,7 +24,7 @@ import java.util.Objects; import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY; -import static org.opensearch.neuralsearch.processor.explain.ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR; +import static org.opensearch.neuralsearch.processor.explain.ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR; /** * Processor to add explanation details to search response @@ -40,19 +40,21 @@ public class ExplanationResponseProcessor implements SearchResponseProcessor { private final boolean ignoreFailure; @Override - public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception { + public SearchResponse processResponse(SearchRequest request, SearchResponse response) { return processResponse(request, response, null); } @Override public SearchResponse processResponse(SearchRequest request, SearchResponse response, PipelineProcessingContext requestContext) { - if (Objects.isNull(requestContext) || (Objects.isNull(requestContext.getAttribute(EXPLAIN_RESPONSE_KEY)))) { + if (Objects.isNull(requestContext) + || (Objects.isNull(requestContext.getAttribute(EXPLAIN_RESPONSE_KEY))) + || requestContext.getAttribute(EXPLAIN_RESPONSE_KEY) instanceof ExplanationPayload == false) { return response; } - ExplanationResponse explanationResponse = (ExplanationResponse) requestContext.getAttribute(EXPLAIN_RESPONSE_KEY); - Map explainPayload = explanationResponse.getExplainPayload(); + ExplanationPayload explanationPayload = (ExplanationPayload) requestContext.getAttribute(EXPLAIN_RESPONSE_KEY); + Map explainPayload = explanationPayload.getExplainPayload(); if (explainPayload.containsKey(NORMALIZATION_PROCESSOR)) { - Explanation processorExplanation = explanationResponse.getExplanation(); + Explanation processorExplanation = explanationPayload.getExplanation(); if (Objects.isNull(processorExplanation)) { return response; } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java index 7f0314ef7..d2008ae97 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java @@ -66,9 +66,9 @@ public void process( */ @Override public void process( - final SearchPhaseResults searchPhaseResult, - final SearchPhaseContext searchPhaseContext, - final PipelineProcessingContext requestContext + final SearchPhaseResults searchPhaseResult, + final SearchPhaseContext searchPhaseContext, + final PipelineProcessingContext requestContext ) { prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext)); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index 118d0a25c..f5fb794a7 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -26,9 +26,9 @@ import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; import org.opensearch.neuralsearch.processor.explain.CombinedExplainDetails; import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; -import org.opensearch.neuralsearch.processor.explain.ExplainDetails; +import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; -import org.opensearch.neuralsearch.processor.explain.ExplanationResponse; +import org.opensearch.neuralsearch.processor.explain.ExplanationPayload; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; import org.opensearch.search.SearchHit; @@ -42,7 +42,7 @@ import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY; import static org.opensearch.neuralsearch.processor.combination.ScoreCombiner.MAX_SCORE_WHEN_NO_HITS_FOUND; -import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.topLevelExpalantionForCombinedScore; +import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.topLevelExpalantionForCombinedScore; import static org.opensearch.neuralsearch.search.util.HybridSearchSortUtil.evaluateSortCriteria; /** @@ -123,11 +123,11 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List< Sort sortForQuery = evaluateSortCriteria(request.getQuerySearchResults(), queryTopDocs); - Map normalizationExplain = scoreNormalizer.explain( + Map normalizationExplain = scoreNormalizer.explain( queryTopDocs, (ExplainableTechnique) request.getNormalizationTechnique() ); - Map> combinationExplain = scoreCombiner.explain( + Map> combinationExplain = scoreCombiner.explain( queryTopDocs, request.getCombinationTechnique(), sortForQuery @@ -135,24 +135,24 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List< Map> combinedExplain = new HashMap<>(); combinationExplain.forEach((searchShard, explainDetails) -> { - for (ExplainDetails explainDetail : explainDetails) { + for (ExplanationDetails explainDetail : explainDetails) { DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(explainDetail.docId(), searchShard); - ExplainDetails normalizedExplainDetails = normalizationExplain.get(docIdAtSearchShard); + ExplanationDetails normalizedExplanationDetails = normalizationExplain.get(docIdAtSearchShard); CombinedExplainDetails combinedExplainDetails = CombinedExplainDetails.builder() - .normalizationExplain(normalizedExplainDetails) + .normalizationExplain(normalizedExplanationDetails) .combinationExplain(explainDetail) .build(); combinedExplain.computeIfAbsent(searchShard, k -> new ArrayList<>()).add(combinedExplainDetails); } }); - ExplanationResponse explanationResponse = ExplanationResponse.builder() + ExplanationPayload explanationPayload = ExplanationPayload.builder() .explanation(topLevelExplanationForTechniques) - .explainPayload(Map.of(ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR, combinedExplain)) + .explainPayload(Map.of(ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, combinedExplain)) .build(); // store explain object to pipeline context PipelineProcessingContext pipelineProcessingContext = request.getPipelineProcessingContext(); - pipelineProcessingContext.setAttribute(EXPLAIN_RESPONSE_KEY, explanationResponse); + pipelineProcessingContext.setAttribute(EXPLAIN_RESPONSE_KEY, explanationPayload); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java index 1d31b4c31..4055d0377 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java @@ -12,7 +12,7 @@ import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; -import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.describeCombinationTechnique; +import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.describeCombinationTechnique; /** * Abstracts combination of scores based on arithmetic mean method diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java index 0dcd5c39c..7de4e0499 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java @@ -12,7 +12,7 @@ import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; -import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.describeCombinationTechnique; +import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.describeCombinationTechnique; /** * Abstracts combination of scores based on geometrical mean method diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java index 4fd112bc5..f6c68bc7e 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java @@ -12,7 +12,7 @@ import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; -import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.describeCombinationTechnique; +import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.describeCombinationTechnique; /** * Abstracts combination of scores based on harmonic mean method diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java index 8194ecf74..7b5a1da5f 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java @@ -27,9 +27,9 @@ import lombok.extern.log4j.Log4j2; import org.opensearch.neuralsearch.processor.SearchShard; -import org.opensearch.neuralsearch.processor.explain.ExplainDetails; +import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; -import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.getScoreCombinationExplainDetailsForDocument; +import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.getScoreCombinationExplainDetailsForDocument; /** * Abstracts combination of scores in query search results. @@ -318,40 +318,47 @@ private TotalHits getTotalHits(final List topDocsPerSubQuery, final lon return new TotalHits(maxHits, totalHits); } - public Map> explain( + /** + * Explain the score combination technique for each document in the given queryTopDocs. + * @param queryTopDocs + * @param combinationTechnique + * @param sort + * @return a map of SearchShard and List of ExplainationDetails for each document + */ + public Map> explain( final List queryTopDocs, final ScoreCombinationTechnique combinationTechnique, final Sort sort ) { // In case of duplicate keys, keep the first value - HashMap> explanations = new HashMap<>(); + Map> explanations = new HashMap<>(); for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { - for (Map.Entry> docIdAtSearchShardExplainDetailsEntry : explainByShard( - combinationTechnique, - compoundQueryTopDocs, - sort - ).entrySet()) { - explanations.putIfAbsent(docIdAtSearchShardExplainDetailsEntry.getKey(), docIdAtSearchShardExplainDetailsEntry.getValue()); - } + explanations.putIfAbsent( + compoundQueryTopDocs.getSearchShard(), + explainByShard(combinationTechnique, compoundQueryTopDocs, sort) + ); } return explanations; } - private Map> explainByShard( + private List explainByShard( final ScoreCombinationTechnique scoreCombinationTechnique, final CompoundTopDocs compoundQueryTopDocs, - Sort sort + final Sort sort ) { if (Objects.isNull(compoundQueryTopDocs) || compoundQueryTopDocs.getTotalHits().value == 0) { - return Map.of(); + return List.of(); } - // - create map of normalized scores results returned from the single shard + // create map of normalized scores results returned from the single shard Map normalizedScoresPerDoc = getNormalizedScoresPerDocument(compoundQueryTopDocs.getTopDocs()); + // combine scores Map combinedNormalizedScoresByDocId = normalizedScoresPerDoc.entrySet() .stream() .collect(Collectors.toMap(Map.Entry::getKey, entry -> scoreCombinationTechnique.combine(entry.getValue()))); + // sort combined scores as per sorting criteria - either score desc or field sorting Collection sortedDocsIds = getSortedDocsIds(compoundQueryTopDocs, sort, combinedNormalizedScoresByDocId); - List listOfExplainsForShard = sortedDocsIds.stream() + + List listOfExplanations = sortedDocsIds.stream() .map( docId -> getScoreCombinationExplainDetailsForDocument( docId, @@ -360,7 +367,7 @@ private Map> explainByShard( ) ) .toList(); - return Map.of(compoundQueryTopDocs.getSearchShard(), listOfExplainsForShard); + return listOfExplanations; } private Collection getSortedDocsIds( diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplainDetails.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplainDetails.java index 4a9793fd4..dd8a2a9c3 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplainDetails.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplainDetails.java @@ -15,6 +15,6 @@ @Builder @Getter public class CombinedExplainDetails { - private ExplainDetails normalizationExplain; - private ExplainDetails combinationExplain; + private ExplanationDetails normalizationExplain; + private ExplanationDetails combinationExplain; } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainableTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainableTechnique.java index 6f8dfcf1e..cc2fab6c6 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainableTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainableTechnique.java @@ -28,7 +28,7 @@ default String describe() { * @param queryTopDocs collection of CompoundTopDocs for each shard result * @return map of document per shard and corresponding explanation object */ - default Map explain(final List queryTopDocs) { + default Map explain(final List queryTopDocs) { return Map.of(); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainationUtils.java similarity index 90% rename from src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java rename to src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainationUtils.java index 1eca4232f..9e9fd4c3a 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainationUtils.java @@ -17,7 +17,7 @@ /** * Utility class for explain functionality */ -public class ExplainUtils { +public class ExplainationUtils { /** * Creates map of DocIdAtQueryPhase to String containing source and normalized scores @@ -25,16 +25,16 @@ public class ExplainUtils { * @param sourceScores map of DocIdAtQueryPhase to source scores * @return map of DocIdAtQueryPhase to String containing source and normalized scores */ - public static Map getDocIdAtQueryForNormalization( + public static Map getDocIdAtQueryForNormalization( final Map> normalizedScores, final Map> sourceScores ) { - Map explain = sourceScores.entrySet() + Map explain = sourceScores.entrySet() .stream() .collect(Collectors.toMap(Map.Entry::getKey, entry -> { List srcScores = entry.getValue(); List normScores = normalizedScores.get(entry.getKey()); - return new ExplainDetails( + return new ExplanationDetails( normScores.stream().reduce(0.0f, Float::max), String.format(Locale.ROOT, "source scores: %s normalized to scores: %s", srcScores, normScores) ); @@ -49,13 +49,13 @@ public static Map getDocIdAtQueryForNormaliz * @param normalizedScoresPerDoc * @return */ - public static ExplainDetails getScoreCombinationExplainDetailsForDocument( + public static ExplanationDetails getScoreCombinationExplainDetailsForDocument( final Integer docId, final Map combinedNormalizedScoresByDocId, final float[] normalizedScoresPerDoc ) { float combinedScore = combinedNormalizedScoresByDocId.get(docId); - return new ExplainDetails( + return new ExplanationDetails( docId, combinedScore, String.format( diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainDetails.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java similarity index 74% rename from src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainDetails.java rename to src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java index ca83cc436..594bc4299 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainDetails.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java @@ -10,8 +10,8 @@ * @param value * @param description */ -public record ExplainDetails(int docId, float value, String description) { - public ExplainDetails(float value, String description) { +public record ExplanationDetails(int docId, float value, String description) { + public ExplanationDetails(float value, String description) { this(-1, value, description); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationResponse.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationPayload.java similarity index 72% rename from src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationResponse.java rename to src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationPayload.java index f14050214..a1206a1a1 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationResponse.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationPayload.java @@ -17,11 +17,11 @@ @AllArgsConstructor @Builder @Getter -public class ExplanationResponse { - Explanation explanation; - Map explainPayload; +public class ExplanationPayload { + private final Explanation explanation; + private final Map explainPayload; - public enum ExplanationType { + public enum PayloadType { NORMALIZATION_PROCESSOR } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java index 5c5436564..ca68bf563 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java @@ -17,10 +17,10 @@ import lombok.ToString; import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; -import org.opensearch.neuralsearch.processor.explain.ExplainDetails; +import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; -import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.getDocIdAtQueryForNormalization; +import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.getDocIdAtQueryForNormalization; /** * Abstracts normalization of scores based on L2 method @@ -64,7 +64,7 @@ public String describe() { } @Override - public Map explain(List queryTopDocs) { + public Map explain(List queryTopDocs) { Map> normalizedScores = new HashMap<>(); Map> sourceScores = new HashMap<>(); List normsPerSubquery = getL2Norm(queryTopDocs); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java index 63efb4332..e3487cbcb 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java @@ -20,10 +20,10 @@ import lombok.ToString; import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; -import org.opensearch.neuralsearch.processor.explain.ExplainDetails; +import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; -import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.getDocIdAtQueryForNormalization; +import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.getDocIdAtQueryForNormalization; /** * Abstracts normalization of scores based on min-max method @@ -44,19 +44,7 @@ public class MinMaxScoreNormalizationTechnique implements ScoreNormalizationTech */ @Override public void normalize(final List queryTopDocs) { - int numOfSubqueries = queryTopDocs.stream() - .filter(Objects::nonNull) - .filter(topDocs -> topDocs.getTopDocs().size() > 0) - .findAny() - .get() - .getTopDocs() - .size(); - // get min scores for each sub query - float[] minScoresPerSubquery = getMinScores(queryTopDocs, numOfSubqueries); - - // get max scores for each sub query - float[] maxScoresPerSubquery = getMaxScores(queryTopDocs, numOfSubqueries); - + MinMaxScores minMaxScores = getMinMaxScoresResult(queryTopDocs); // do normalization using actual score and min and max scores for corresponding sub query for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { if (Objects.isNull(compoundQueryTopDocs)) { @@ -66,35 +54,36 @@ public void normalize(final List queryTopDocs) { for (int j = 0; j < topDocsPerSubQuery.size(); j++) { TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j); for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) { - scoreDoc.score = normalizeSingleScore(scoreDoc.score, minScoresPerSubquery[j], maxScoresPerSubquery[j]); + scoreDoc.score = normalizeSingleScore( + scoreDoc.score, + minMaxScores.minScoresPerSubquery()[j], + minMaxScores.maxScoresPerSubquery()[j] + ); } } } } + private MinMaxScores getMinMaxScoresResult(final List queryTopDocs) { + int numOfSubqueries = getNumOfSubqueries(queryTopDocs); + // get min scores for each sub query + float[] minScoresPerSubquery = getMinScores(queryTopDocs, numOfSubqueries); + // get max scores for each sub query + float[] maxScoresPerSubquery = getMaxScores(queryTopDocs, numOfSubqueries); + return new MinMaxScores(minScoresPerSubquery, maxScoresPerSubquery); + } + @Override public String describe() { return String.format(Locale.ROOT, "normalization [%s]", TECHNIQUE_NAME); } @Override - public Map explain(final List queryTopDocs) { + public Map explain(final List queryTopDocs) { + MinMaxScores minMaxScores = getMinMaxScoresResult(queryTopDocs); + Map> normalizedScores = new HashMap<>(); Map> sourceScores = new HashMap<>(); - - int numOfSubqueries = queryTopDocs.stream() - .filter(Objects::nonNull) - .filter(topDocs -> !topDocs.getTopDocs().isEmpty()) - .findAny() - .get() - .getTopDocs() - .size(); - // get min scores for each sub query - float[] minScoresPerSubquery = getMinScores(queryTopDocs, numOfSubqueries); - - // get max scores for each sub query - float[] maxScoresPerSubquery = getMaxScores(queryTopDocs, numOfSubqueries); - for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { if (Objects.isNull(compoundQueryTopDocs)) { continue; @@ -104,17 +93,30 @@ public Map explain(final List new ArrayList<>()).add(normalizedScore); sourceScores.computeIfAbsent(docIdAtSearchShard, k -> new ArrayList<>()).add(scoreDoc.score); scoreDoc.score = normalizedScore; } } } - return getDocIdAtQueryForNormalization(normalizedScores, sourceScores); } + private int getNumOfSubqueries(final List queryTopDocs) { + return queryTopDocs.stream() + .filter(Objects::nonNull) + .filter(topDocs -> !topDocs.getTopDocs().isEmpty()) + .findAny() + .get() + .getTopDocs() + .size(); + } + private float[] getMaxScores(final List queryTopDocs, final int numOfSubqueries) { float[] maxScores = new float[numOfSubqueries]; Arrays.fill(maxScores, Float.MIN_VALUE); @@ -165,4 +167,10 @@ private float normalizeSingleScore(final float score, final float minScore, fina float normalizedScore = (score - minScore) / (maxScore - minScore); return normalizedScore == 0.0f ? MIN_SCORE : normalizedScore; } + + /** + * Result class to hold min and max scores for each sub query + */ + private record MinMaxScores(float[] minScoresPerSubquery, float[] maxScoresPerSubquery) { + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java index 2dcf5f768..67a17fda2 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java @@ -10,7 +10,7 @@ import org.opensearch.neuralsearch.processor.CompoundTopDocs; import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; -import org.opensearch.neuralsearch.processor.explain.ExplainDetails; +import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; public class ScoreNormalizer { @@ -30,7 +30,14 @@ private boolean canQueryResultsBeNormalized(final List queryTop return queryTopDocs.stream().filter(Objects::nonNull).anyMatch(topDocs -> topDocs.getTopDocs().size() > 0); } - public Map explain( + /** + * Explain normalized scores based on input normalization technique. Does not mutate input object. + * @param queryTopDocs original query results from multiple shards and multiple sub-queries + * @param queryTopDocs + * @param scoreNormalizationTechnique + * @return map of doc id to explanation details + */ + public Map explain( final List queryTopDocs, final ExplainableTechnique scoreNormalizationTechnique ) { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ExplanationPayloadProcessorTests.java similarity index 90% rename from src/test/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessorTests.java rename to src/test/java/org/opensearch/neuralsearch/processor/ExplanationPayloadProcessorTests.java index ce2df0b13..fe0099c87 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ExplanationPayloadProcessorTests.java @@ -16,8 +16,8 @@ import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; import org.opensearch.neuralsearch.processor.explain.CombinedExplainDetails; -import org.opensearch.neuralsearch.processor.explain.ExplainDetails; -import org.opensearch.neuralsearch.processor.explain.ExplanationResponse; +import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; +import org.opensearch.neuralsearch.processor.explain.ExplanationPayload; import org.opensearch.search.DocValueFormat; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; @@ -38,7 +38,7 @@ import static org.mockito.Mockito.mock; import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; -public class ExplanationResponseProcessorTests extends OpenSearchTestCase { +public class ExplanationPayloadProcessorTests extends OpenSearchTestCase { private static final String PROCESSOR_TAG = "mockTag"; private static final String DESCRIPTION = "mockDescription"; @@ -130,27 +130,27 @@ public void testParsingOfExplanations_whenResponseHasExplanations_thenSuccessful SearchShard.createSearchShard(searchHitArray[0].getShard()), List.of( CombinedExplainDetails.builder() - .normalizationExplain(new ExplainDetails(1.0f, "source scores: [1.0] normalized to scores: [0.5]")) - .combinationExplain(new ExplainDetails(0.5f, "normalized scores: [0.5] combined to a final score: 0.5")) + .normalizationExplain(new ExplanationDetails(1.0f, "source scores: [1.0] normalized to scores: [0.5]")) + .combinationExplain(new ExplanationDetails(0.5f, "normalized scores: [0.5] combined to a final score: 0.5")) .build() ), SearchShard.createSearchShard(searchHitArray[1].getShard()), List.of( CombinedExplainDetails.builder() - .normalizationExplain(new ExplainDetails(0.5f, "source scores: [0.5] normalized to scores: [0.25]")) - .combinationExplain(new ExplainDetails(0.25f, "normalized scores: [0.25] combined to a final score: 0.25")) + .normalizationExplain(new ExplanationDetails(0.5f, "source scores: [0.5] normalized to scores: [0.25]")) + .combinationExplain(new ExplanationDetails(0.25f, "normalized scores: [0.25] combined to a final score: 0.25")) .build() ) ); - Map explainPayload = Map.of( - ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR, + Map explainPayload = Map.of( + ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, combinedExplainDetails ); - ExplanationResponse explanationResponse = ExplanationResponse.builder() + ExplanationPayload explanationPayload = ExplanationPayload.builder() .explanation(generalExplanation) .explainPayload(explainPayload) .build(); - pipelineProcessingContext.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY, explanationResponse); + pipelineProcessingContext.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY, explanationPayload); // Act SearchResponse processedResponse = explanationResponseProcessor.processResponse( @@ -216,27 +216,27 @@ public void testParsingOfExplanations_whenFieldSortingAndExplanations_thenSucces SearchShard.createSearchShard(searchHitArray[0].getShard()), List.of( CombinedExplainDetails.builder() - .normalizationExplain(new ExplainDetails(1.0f, "source scores: [1.0] normalized to scores: [0.5]")) - .combinationExplain(new ExplainDetails(0.5f, "normalized scores: [0.5] combined to a final score: 0.5")) + .normalizationExplain(new ExplanationDetails(1.0f, "source scores: [1.0] normalized to scores: [0.5]")) + .combinationExplain(new ExplanationDetails(0.5f, "normalized scores: [0.5] combined to a final score: 0.5")) .build() ), SearchShard.createSearchShard(searchHitArray[1].getShard()), List.of( CombinedExplainDetails.builder() - .normalizationExplain(new ExplainDetails(0.5f, "source scores: [0.5] normalized to scores: [0.25]")) - .combinationExplain(new ExplainDetails(0.25f, "normalized scores: [0.25] combined to a final score: 0.25")) + .normalizationExplain(new ExplanationDetails(0.5f, "source scores: [0.5] normalized to scores: [0.25]")) + .combinationExplain(new ExplanationDetails(0.25f, "normalized scores: [0.25] combined to a final score: 0.25")) .build() ) ); - Map explainPayload = Map.of( - ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR, + Map explainPayload = Map.of( + ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, combinedExplainDetails ); - ExplanationResponse explanationResponse = ExplanationResponse.builder() + ExplanationPayload explanationPayload = ExplanationPayload.builder() .explanation(generalExplanation) .explainPayload(explainPayload) .build(); - pipelineProcessingContext.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY, explanationResponse); + pipelineProcessingContext.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY, explanationPayload); // Act SearchResponse processedResponse = explanationResponseProcessor.processResponse( @@ -300,27 +300,27 @@ public void testParsingOfExplanations_whenScoreSortingAndExplanations_thenSucces SearchShard.createSearchShard(searchHitArray[0].getShard()), List.of( CombinedExplainDetails.builder() - .normalizationExplain(new ExplainDetails(1.0f, "source scores: [1.0] normalized to scores: [0.5]")) - .combinationExplain(new ExplainDetails(0.5f, "normalized scores: [0.5] combined to a final score: 0.5")) + .normalizationExplain(new ExplanationDetails(1.0f, "source scores: [1.0] normalized to scores: [0.5]")) + .combinationExplain(new ExplanationDetails(0.5f, "normalized scores: [0.5] combined to a final score: 0.5")) .build() ), SearchShard.createSearchShard(searchHitArray[1].getShard()), List.of( CombinedExplainDetails.builder() - .normalizationExplain(new ExplainDetails(0.5f, "source scores: [0.5] normalized to scores: [0.25]")) - .combinationExplain(new ExplainDetails(0.25f, "normalized scores: [0.25] combined to a final score: 0.25")) + .normalizationExplain(new ExplanationDetails(0.5f, "source scores: [0.5] normalized to scores: [0.25]")) + .combinationExplain(new ExplanationDetails(0.25f, "normalized scores: [0.25] combined to a final score: 0.25")) .build() ) ); - Map explainPayload = Map.of( - ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR, + Map explainPayload = Map.of( + ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, combinedExplainDetails ); - ExplanationResponse explanationResponse = ExplanationResponse.builder() + ExplanationPayload explanationPayload = ExplanationPayload.builder() .explanation(generalExplanation) .explainPayload(explainPayload) .build(); - pipelineProcessingContext.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY, explanationResponse); + pipelineProcessingContext.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY, explanationPayload); // Act SearchResponse processedResponse = explanationResponseProcessor.processResponse(