Skip to content

Commit

Permalink
Refactor classes and methods
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Nov 4, 2024
1 parent a19de09 commit 4f50c18
Show file tree
Hide file tree
Showing 16 changed files with 148 additions and 124 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.neuralsearch.processor.explain.CombinedExplainDetails;
import org.opensearch.neuralsearch.processor.explain.ExplanationResponse;
import org.opensearch.neuralsearch.processor.explain.ExplanationPayload;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.SearchShardTarget;
Expand All @@ -24,7 +24,7 @@
import java.util.Objects;

import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY;
import static org.opensearch.neuralsearch.processor.explain.ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR;
import static org.opensearch.neuralsearch.processor.explain.ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR;

/**
* Processor to add explanation details to search response
Expand All @@ -40,19 +40,21 @@ public class ExplanationResponseProcessor implements SearchResponseProcessor {
private final boolean ignoreFailure;

@Override
public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception {
public SearchResponse processResponse(SearchRequest request, SearchResponse response) {
return processResponse(request, response, null);
}

@Override
public SearchResponse processResponse(SearchRequest request, SearchResponse response, PipelineProcessingContext requestContext) {
if (Objects.isNull(requestContext) || (Objects.isNull(requestContext.getAttribute(EXPLAIN_RESPONSE_KEY)))) {
if (Objects.isNull(requestContext)
|| (Objects.isNull(requestContext.getAttribute(EXPLAIN_RESPONSE_KEY)))
|| requestContext.getAttribute(EXPLAIN_RESPONSE_KEY) instanceof ExplanationPayload == false) {
return response;
}
ExplanationResponse explanationResponse = (ExplanationResponse) requestContext.getAttribute(EXPLAIN_RESPONSE_KEY);
Map<ExplanationResponse.ExplanationType, Object> explainPayload = explanationResponse.getExplainPayload();
ExplanationPayload explanationPayload = (ExplanationPayload) requestContext.getAttribute(EXPLAIN_RESPONSE_KEY);
Map<ExplanationPayload.PayloadType, Object> explainPayload = explanationPayload.getExplainPayload();
if (explainPayload.containsKey(NORMALIZATION_PROCESSOR)) {
Explanation processorExplanation = explanationResponse.getExplanation();
Explanation processorExplanation = explanationPayload.getExplanation();
if (Objects.isNull(processorExplanation)) {
return response;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ public <Result extends SearchPhaseResult> void process(
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext,
final PipelineProcessingContext requestContext
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext,
final PipelineProcessingContext requestContext
) {
prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.explain.CombinedExplainDetails;
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplainDetails;
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
import org.opensearch.neuralsearch.processor.explain.ExplanationResponse;
import org.opensearch.neuralsearch.processor.explain.ExplanationPayload;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.search.SearchHit;
Expand All @@ -42,7 +42,7 @@

import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY;
import static org.opensearch.neuralsearch.processor.combination.ScoreCombiner.MAX_SCORE_WHEN_NO_HITS_FOUND;
import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.topLevelExpalantionForCombinedScore;
import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.topLevelExpalantionForCombinedScore;
import static org.opensearch.neuralsearch.search.util.HybridSearchSortUtil.evaluateSortCriteria;

/**
Expand Down Expand Up @@ -123,36 +123,36 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List<

Sort sortForQuery = evaluateSortCriteria(request.getQuerySearchResults(), queryTopDocs);

Map<DocIdAtSearchShard, ExplainDetails> normalizationExplain = scoreNormalizer.explain(
Map<DocIdAtSearchShard, ExplanationDetails> normalizationExplain = scoreNormalizer.explain(
queryTopDocs,
(ExplainableTechnique) request.getNormalizationTechnique()
);
Map<SearchShard, List<ExplainDetails>> combinationExplain = scoreCombiner.explain(
Map<SearchShard, List<ExplanationDetails>> combinationExplain = scoreCombiner.explain(
queryTopDocs,
request.getCombinationTechnique(),
sortForQuery
);
Map<SearchShard, List<CombinedExplainDetails>> combinedExplain = new HashMap<>();

combinationExplain.forEach((searchShard, explainDetails) -> {
for (ExplainDetails explainDetail : explainDetails) {
for (ExplanationDetails explainDetail : explainDetails) {
DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(explainDetail.docId(), searchShard);
ExplainDetails normalizedExplainDetails = normalizationExplain.get(docIdAtSearchShard);
ExplanationDetails normalizedExplanationDetails = normalizationExplain.get(docIdAtSearchShard);
CombinedExplainDetails combinedExplainDetails = CombinedExplainDetails.builder()
.normalizationExplain(normalizedExplainDetails)
.normalizationExplain(normalizedExplanationDetails)
.combinationExplain(explainDetail)
.build();
combinedExplain.computeIfAbsent(searchShard, k -> new ArrayList<>()).add(combinedExplainDetails);
}
});

ExplanationResponse explanationResponse = ExplanationResponse.builder()
ExplanationPayload explanationPayload = ExplanationPayload.builder()
.explanation(topLevelExplanationForTechniques)
.explainPayload(Map.of(ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR, combinedExplain))
.explainPayload(Map.of(ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, combinedExplain))
.build();
// store explain object to pipeline context
PipelineProcessingContext pipelineProcessingContext = request.getPipelineProcessingContext();
pipelineProcessingContext.setAttribute(EXPLAIN_RESPONSE_KEY, explanationResponse);
pipelineProcessingContext.setAttribute(EXPLAIN_RESPONSE_KEY, explanationPayload);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;

import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS;
import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.describeCombinationTechnique;
import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.describeCombinationTechnique;

/**
* Abstracts combination of scores based on arithmetic mean method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;

import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS;
import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.describeCombinationTechnique;
import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.describeCombinationTechnique;

/**
* Abstracts combination of scores based on geometrical mean method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;

import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS;
import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.describeCombinationTechnique;
import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.describeCombinationTechnique;

/**
* Abstracts combination of scores based on harmonic mean method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@

import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.processor.SearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplainDetails;
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;

import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.getScoreCombinationExplainDetailsForDocument;
import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.getScoreCombinationExplainDetailsForDocument;

/**
* Abstracts combination of scores in query search results.
Expand Down Expand Up @@ -318,40 +318,47 @@ private TotalHits getTotalHits(final List<TopDocs> topDocsPerSubQuery, final lon
return new TotalHits(maxHits, totalHits);
}

public Map<SearchShard, List<ExplainDetails>> explain(
/**
* Explain the score combination technique for each document in the given queryTopDocs.
* @param queryTopDocs
* @param combinationTechnique
* @param sort
* @return a map of SearchShard and List of ExplainationDetails for each document
*/
public Map<SearchShard, List<ExplanationDetails>> explain(
final List<CompoundTopDocs> queryTopDocs,
final ScoreCombinationTechnique combinationTechnique,
final Sort sort
) {
// In case of duplicate keys, keep the first value
HashMap<SearchShard, List<ExplainDetails>> explanations = new HashMap<>();
Map<SearchShard, List<ExplanationDetails>> explanations = new HashMap<>();
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
for (Map.Entry<SearchShard, List<ExplainDetails>> docIdAtSearchShardExplainDetailsEntry : explainByShard(
combinationTechnique,
compoundQueryTopDocs,
sort
).entrySet()) {
explanations.putIfAbsent(docIdAtSearchShardExplainDetailsEntry.getKey(), docIdAtSearchShardExplainDetailsEntry.getValue());
}
explanations.putIfAbsent(
compoundQueryTopDocs.getSearchShard(),
explainByShard(combinationTechnique, compoundQueryTopDocs, sort)
);
}
return explanations;
}

private Map<SearchShard, List<ExplainDetails>> explainByShard(
private List<ExplanationDetails> explainByShard(
final ScoreCombinationTechnique scoreCombinationTechnique,
final CompoundTopDocs compoundQueryTopDocs,
Sort sort
final Sort sort
) {
if (Objects.isNull(compoundQueryTopDocs) || compoundQueryTopDocs.getTotalHits().value == 0) {
return Map.of();
return List.of();
}
// - create map of normalized scores results returned from the single shard
// create map of normalized scores results returned from the single shard
Map<Integer, float[]> normalizedScoresPerDoc = getNormalizedScoresPerDocument(compoundQueryTopDocs.getTopDocs());
// combine scores
Map<Integer, Float> combinedNormalizedScoresByDocId = normalizedScoresPerDoc.entrySet()
.stream()
.collect(Collectors.toMap(Map.Entry::getKey, entry -> scoreCombinationTechnique.combine(entry.getValue())));
// sort combined scores as per sorting criteria - either score desc or field sorting
Collection<Integer> sortedDocsIds = getSortedDocsIds(compoundQueryTopDocs, sort, combinedNormalizedScoresByDocId);
List<ExplainDetails> listOfExplainsForShard = sortedDocsIds.stream()

List<ExplanationDetails> listOfExplanations = sortedDocsIds.stream()
.map(
docId -> getScoreCombinationExplainDetailsForDocument(
docId,
Expand All @@ -360,13 +367,13 @@ private Map<SearchShard, List<ExplainDetails>> explainByShard(
)
)
.toList();
return Map.of(compoundQueryTopDocs.getSearchShard(), listOfExplainsForShard);
return listOfExplanations;
}

private Collection<Integer> getSortedDocsIds(
CompoundTopDocs compoundQueryTopDocs,
Sort sort,
Map<Integer, Float> combinedNormalizedScoresByDocId
final CompoundTopDocs compoundQueryTopDocs,
final Sort sort,
final Map<Integer, Float> combinedNormalizedScoresByDocId
) {
Collection<Integer> sortedDocsIds;
if (sort != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@
@Builder
@Getter
public class CombinedExplainDetails {
private ExplainDetails normalizationExplain;
private ExplainDetails combinationExplain;
private ExplanationDetails normalizationExplain;
private ExplanationDetails combinationExplain;
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ default String describe() {
* @param queryTopDocs collection of CompoundTopDocs for each shard result
* @return map of document per shard and corresponding explanation object
*/
default Map<DocIdAtSearchShard, ExplainDetails> explain(final List<CompoundTopDocs> queryTopDocs) {
default Map<DocIdAtSearchShard, ExplanationDetails> explain(final List<CompoundTopDocs> queryTopDocs) {
return Map.of();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,24 @@
/**
* Utility class for explain functionality
*/
public class ExplainUtils {
public class ExplainationUtils {

/**
* Creates map of DocIdAtQueryPhase to String containing source and normalized scores
* @param normalizedScores map of DocIdAtQueryPhase to normalized scores
* @param sourceScores map of DocIdAtQueryPhase to source scores
* @return map of DocIdAtQueryPhase to String containing source and normalized scores
*/
public static Map<DocIdAtSearchShard, ExplainDetails> getDocIdAtQueryForNormalization(
public static Map<DocIdAtSearchShard, ExplanationDetails> getDocIdAtQueryForNormalization(
final Map<DocIdAtSearchShard, List<Float>> normalizedScores,
final Map<DocIdAtSearchShard, List<Float>> sourceScores
) {
Map<DocIdAtSearchShard, ExplainDetails> explain = sourceScores.entrySet()
Map<DocIdAtSearchShard, ExplanationDetails> explain = sourceScores.entrySet()
.stream()
.collect(Collectors.toMap(Map.Entry::getKey, entry -> {
List<Float> srcScores = entry.getValue();
List<Float> normScores = normalizedScores.get(entry.getKey());
return new ExplainDetails(
return new ExplanationDetails(
normScores.stream().reduce(0.0f, Float::max),
String.format(Locale.ROOT, "source scores: %s normalized to scores: %s", srcScores, normScores)
);
Expand All @@ -49,13 +49,13 @@ public static Map<DocIdAtSearchShard, ExplainDetails> getDocIdAtQueryForNormaliz
* @param normalizedScoresPerDoc
* @return
*/
public static ExplainDetails getScoreCombinationExplainDetailsForDocument(
public static ExplanationDetails getScoreCombinationExplainDetailsForDocument(
final Integer docId,
final Map<Integer, Float> combinedNormalizedScoresByDocId,
final float[] normalizedScoresPerDoc
) {
float combinedScore = combinedNormalizedScoresByDocId.get(docId);
return new ExplainDetails(
return new ExplanationDetails(
docId,
combinedScore,
String.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
* @param value
* @param description
*/
public record ExplainDetails(int docId, float value, String description) {
public ExplainDetails(float value, String description) {
public record ExplanationDetails(int docId, float value, String description) {
public ExplanationDetails(float value, String description) {
this(-1, value, description);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
@AllArgsConstructor
@Builder
@Getter
public class ExplanationResponse {
Explanation explanation;
Map<ExplanationType, Object> explainPayload;
public class ExplanationPayload {
private final Explanation explanation;
private final Map<PayloadType, Object> explainPayload;

public enum ExplanationType {
public enum PayloadType {
NORMALIZATION_PROCESSOR
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

import lombok.ToString;
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplainDetails;
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;

import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.getDocIdAtQueryForNormalization;
import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.getDocIdAtQueryForNormalization;

/**
* Abstracts normalization of scores based on L2 method
Expand Down Expand Up @@ -64,7 +64,7 @@ public String describe() {
}

@Override
public Map<DocIdAtSearchShard, ExplainDetails> explain(List<CompoundTopDocs> queryTopDocs) {
public Map<DocIdAtSearchShard, ExplanationDetails> explain(List<CompoundTopDocs> queryTopDocs) {
Map<DocIdAtSearchShard, List<Float>> normalizedScores = new HashMap<>();
Map<DocIdAtSearchShard, List<Float>> sourceScores = new HashMap<>();
List<Float> normsPerSubquery = getL2Norm(queryTopDocs);
Expand Down
Loading

0 comments on commit 4f50c18

Please sign in to comment.