diff --git a/CHANGELOG.md b/CHANGELOG.md index 5fc16369c..180989f05 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 3.0](https://github.com/opensearch-project/neural-search/compare/2.x...HEAD) ### Features +- Lower bound for min-max normalization technique in hybrid query ([#1195](https://github.com/opensearch-project/neural-search/pull/1195)) ### Enhancements - Set neural-search plugin 3.0.0 baseline JDK version to JDK-21 ([#838](https://github.com/opensearch-project/neural-search/pull/838)) ### Bug Fixes diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java index d9a009a87..6051f71e0 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java @@ -11,6 +11,7 @@ import java.util.Locale; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.stream.Collectors; import lombok.AllArgsConstructor; @@ -42,14 +43,14 @@ public class MinMaxScoreNormalizationTechnique implements ScoreNormalizationTech public static final String TECHNIQUE_NAME = "min_max"; protected static final float MIN_SCORE = 0.001f; private static final float SINGLE_RESULT_SCORE = 1.0f; - private final List> lowerBounds; + private final Optional>> lowerBoundsOptional; public MinMaxScoreNormalizationTechnique() { this(Map.of()); } public MinMaxScoreNormalizationTechnique(final Map params) { - lowerBounds = getLowerBounds(params); + lowerBoundsOptional = getLowerBounds(params); } /** @@ -69,8 +70,14 @@ public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { continue; } List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); - if (Objects.nonNull(lowerBounds) && !lowerBounds.isEmpty() && lowerBounds.size() != topDocsPerSubQuery.size()) { - throw new IllegalArgumentException("lower bounds size should be same as number of sub queries"); + if (isLowerBoundsAndSubQueriesCountMismatched(topDocsPerSubQuery)) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "expected lower bounds array to contain %d elements matching the number of sub-queries, but found a mismatch", + topDocsPerSubQuery.size() + ) + ); } for (int j = 0; j < topDocsPerSubQuery.size(); j++) { TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j); @@ -87,14 +94,14 @@ public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { } } - private LowerBound getLowerBound(int j) { - LowerBound lowerBound; - if (Objects.isNull(lowerBounds) || lowerBounds.isEmpty()) { - lowerBound = new LowerBound(); - } else { - lowerBound = new LowerBound(true, lowerBounds.get(j).getLeft(), lowerBounds.get(j).getRight()); - } - return lowerBound; + private boolean isLowerBoundsAndSubQueriesCountMismatched(List topDocsPerSubQuery) { + return lowerBoundsOptional.isPresent() && lowerBoundsOptional.get().size() != topDocsPerSubQuery.size(); + } + + private LowerBound getLowerBound(int subQueryIndex) { + return lowerBoundsOptional.map( + pairs -> new LowerBound(true, pairs.get(subQueryIndex).getLeft(), pairs.get(subQueryIndex).getRight()) + ).orElseGet(LowerBound::new); } private MinMaxScores getMinMaxScoresResult(final List queryTopDocs) { @@ -108,7 +115,12 @@ private MinMaxScores getMinMaxScoresResult(final List queryTopD @Override public String describe() { - return String.format(Locale.ROOT, "%s", TECHNIQUE_NAME); + return lowerBoundsOptional.map(lb -> { + String lowerBounds = lb.stream() + .map(pair -> String.format(Locale.ROOT, "(%s, %s)", pair.getLeft(), pair.getRight())) + .collect(Collectors.joining(", ", "[", "]")); + return String.format(Locale.ROOT, "%s, lower bounds %s", TECHNIQUE_NAME, lowerBounds); + }).orElse(String.format(Locale.ROOT, "%s", TECHNIQUE_NAME)); } @Override @@ -226,14 +238,14 @@ private class MinMaxScores { float[] maxScoresPerSubquery; } - private List> getLowerBounds(final Map params) { - List> lowerBounds = new ArrayList<>(); - + private Optional>> getLowerBounds(final Map params) { // Early return if params is null or doesn't contain lower_bounds if (Objects.isNull(params) || !params.containsKey("lower_bounds")) { - return lowerBounds; + return Optional.empty(); } + List> lowerBounds = new ArrayList<>(); + Object lowerBoundsObj = params.get("lower_bounds"); if (!(lowerBoundsObj instanceof List lowerBoundsParams)) { throw new IllegalArgumentException("lower_bounds must be a List"); @@ -259,8 +271,13 @@ private List> getLowerBounds(final Map params) Map lowerBound = (Map) boundObj; try { - Mode mode = Mode.fromString(lowerBound.get("mode").toString()); - float minScore = Float.parseFloat(String.valueOf(lowerBound.get("min_score"))); + Mode mode = Mode.fromString(Objects.isNull(lowerBound.get("mode")) ? "" : lowerBound.get("mode").toString()); + float minScore; + if (Objects.isNull(lowerBound.get("min_score"))) { + minScore = LowerBound.DEFAULT_LOWER_BOUND_SCORE; + } else { + minScore = Float.parseFloat(String.valueOf(lowerBound.get("min_score"))); + } Validate.isTrue( minScore >= LowerBound.MIN_LOWER_BOUND_SCORE && minScore <= LowerBound.MAX_LOWER_BOUND_SCORE, @@ -275,7 +292,7 @@ private List> getLowerBounds(final Map params) } } - return lowerBounds; + return Optional.of(lowerBounds); } /** @@ -337,10 +354,12 @@ public float normalize(float score, float minScore, float maxScore, float lowerB .collect(Collectors.joining(", ")); public static Mode fromString(String value) { - if (value == null || value.trim().isEmpty()) { + if (Objects.isNull(value)) { throw new IllegalArgumentException("mode value cannot be null or empty"); } - + if (value.trim().isEmpty()) { + return DEFAULT; + } try { return valueOf(value.toUpperCase(Locale.ROOT)); } catch (IllegalArgumentException e) { @@ -351,5 +370,10 @@ public static Mode fromString(String value) { } public abstract float normalize(float score, float minScore, float maxScore, float lowerBoundScore); + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java index 840c19394..1f86b84ef 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java @@ -28,6 +28,7 @@ import static org.opensearch.neuralsearch.processor.normalization.MinMaxScoreNormalizationTechnique.MIN_SCORE; import static org.opensearch.neuralsearch.query.HybridQueryBuilder.MAX_NUMBER_OF_SUB_QUERIES; import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; +import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_LOWER_BOUNDS; /** * Abstracts normalization of scores based on min-max method @@ -290,18 +291,12 @@ public void testMode_fromString_invalidValues() { assertEquals("invalid mode: invalid, valid values are: apply, clip, ignore", exception.getMessage()); } - public void testMode_fromString_nullOrEmpty() { + public void testLowerBoundsModeFromString_whenNullOrEmpty_thenFail() { IllegalArgumentException nullException = expectThrows( IllegalArgumentException.class, () -> MinMaxScoreNormalizationTechnique.Mode.fromString(null) ); assertEquals("mode value cannot be null or empty", nullException.getMessage()); - - IllegalArgumentException emptyException = expectThrows( - IllegalArgumentException.class, - () -> MinMaxScoreNormalizationTechnique.Mode.fromString("") - ); - assertEquals("mode value cannot be null or empty", emptyException.getMessage()); } public void testMode_normalize_apply() { @@ -349,11 +344,11 @@ public void testMode_normalize_ignore() { assertEquals(MIN_SCORE, normalizedLowScore, DELTA_FOR_SCORE_ASSERTION); } - public void testMode_defaultValue() { + public void testLowerBoundsMode_whenDefaultValue_thenSuccessful() { assertEquals(MinMaxScoreNormalizationTechnique.Mode.APPLY, MinMaxScoreNormalizationTechnique.Mode.DEFAULT); } - public void testLowerBoundsExceedsMaxSubQueries() { + public void testLowerBounds_whenExceedsMaxSubQueries_thenFail() { List> lowerBounds = new ArrayList<>(); for (int i = 0; i <= 100; i++) { @@ -389,6 +384,105 @@ public void testLowerBoundsExceedsMaxSubQueries() { ); } + public void testDescribe_whenLowerBoundsArePresent_thenSuccessful() { + Map parameters = new HashMap<>(); + List> lowerBounds = Arrays.asList( + Map.of("mode", "apply", "min_score", 0.2), + + Map.of("mode", "clip", "min_score", 0.1) + ); + parameters.put("lower_bounds", lowerBounds); + MinMaxScoreNormalizationTechnique techniqueWithBounds = new MinMaxScoreNormalizationTechnique(parameters); + assertEquals("min_max, lower bounds [(apply, 0.2), (clip, 0.1)]", techniqueWithBounds.describe()); + + // Test case 2: without lower bounds + Map emptyParameters = new HashMap<>(); + MinMaxScoreNormalizationTechnique techniqueWithoutBounds = new MinMaxScoreNormalizationTechnique(emptyParameters); + assertEquals("min_max", techniqueWithoutBounds.describe()); + + Map parametersMissingMode = new HashMap<>(); + List> lowerBoundsMissingMode = Arrays.asList( + Map.of("min_score", 0.2), + Map.of("mode", "clip", "min_score", 0.1) + ); + parametersMissingMode.put("lower_bounds", lowerBoundsMissingMode); + MinMaxScoreNormalizationTechnique techniqueMissingMode = new MinMaxScoreNormalizationTechnique(parametersMissingMode); + assertEquals("min_max, lower bounds [(apply, 0.2), (clip, 0.1)]", techniqueMissingMode.describe()); + + Map parametersMissingScore = new HashMap<>(); + List> lowerBoundsMissingScore = Arrays.asList( + Map.of("mode", "apply"), + Map.of("mode", "clip", "min_score", 0.1) + ); + parametersMissingScore.put("lower_bounds", lowerBoundsMissingScore); + MinMaxScoreNormalizationTechnique techniqueMissingScore = new MinMaxScoreNormalizationTechnique(parametersMissingScore); + assertEquals("min_max, lower bounds [(apply, 0.0), (clip, 0.1)]", techniqueMissingScore.describe()); + } + + public void testLowerBounds_whenInvalidInput_thenFail() { + // Test case 1: Invalid mode value + Map parametersInvalidMode = new HashMap<>(); + List> lowerBoundsInvalidMode = Arrays.asList( + Map.of("mode", "invalid_mode", "min_score", 0.2), + Map.of("mode", "clip", "min_score", 0.1) + ); + parametersInvalidMode.put("lower_bounds", lowerBoundsInvalidMode); + IllegalArgumentException invalidModeException = expectThrows( + IllegalArgumentException.class, + () -> new MinMaxScoreNormalizationTechnique(parametersInvalidMode) + ); + assertEquals("invalid mode: invalid_mode, valid values are: apply, clip, ignore", invalidModeException.getMessage()); + + // Test case 4: Invalid min_score type + Map parametersInvalidScore = new HashMap<>(); + List> lowerBoundsInvalidScore = Arrays.asList( + Map.of("mode", "apply", "min_score", "not_a_number"), + Map.of("mode", "clip", "min_score", 0.1) + ); + parametersInvalidScore.put("lower_bounds", lowerBoundsInvalidScore); + IllegalArgumentException invalidScoreException = expectThrows( + IllegalArgumentException.class, + () -> new MinMaxScoreNormalizationTechnique(parametersInvalidScore) + ); + assertEquals("Invalid format for min_score: must be a valid float value", invalidScoreException.getMessage()); + } + + public void testThrowsException_whenLowerBoundsAndSubQueriesCountMismatch_thenFail() { + Map parameters = new HashMap<>(); + List> lowerBounds = Arrays.asList(Map.of("mode", "clip", "min_score", 0.1)); + parameters.put(PARAM_NAME_LOWER_BOUNDS, lowerBounds); + + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) } + ), + new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(3, 0.1f) }) + ), + false, + SEARCH_SHARD + ) + ); + ScoreNormalizationTechnique minMaxTechnique = new MinMaxScoreNormalizationTechnique(parameters); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(minMaxTechnique) + .build(); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> minMaxTechnique.normalize(normalizeScoresDTO) + ); + + assertEquals( + "expected lower bounds array to contain 2 elements matching the number of sub-queries, but found a mismatch", + exception.getMessage() + ); + } + private void assertCompoundTopDocs(TopDocs expected, TopDocs actual) { assertEquals(expected.totalHits.value(), actual.totalHits.value()); assertEquals(expected.totalHits.relation(), actual.totalHits.relation());