Skip to content

Commit

Permalink
Added more validations and unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Feb 25, 2025
1 parent a983cfd commit 89999f9
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 31 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 3.0](https://github.com/opensearch-project/neural-search/compare/2.x...HEAD)
### Features
- Lower bound for min-max normalization technique in hybrid query ([#1195](https://github.com/opensearch-project/neural-search/pull/1195))
### Enhancements
- Set neural-search plugin 3.0.0 baseline JDK version to JDK-21 ([#838](https://github.com/opensearch-project/neural-search/pull/838))
### Bug Fixes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

import lombok.AllArgsConstructor;
Expand Down Expand Up @@ -42,14 +43,14 @@ public class MinMaxScoreNormalizationTechnique implements ScoreNormalizationTech
public static final String TECHNIQUE_NAME = "min_max";
protected static final float MIN_SCORE = 0.001f;
private static final float SINGLE_RESULT_SCORE = 1.0f;
private final List<Pair<Mode, Float>> lowerBounds;
private final Optional<List<Pair<Mode, Float>>> lowerBoundsOptional;

public MinMaxScoreNormalizationTechnique() {
this(Map.of());
}

public MinMaxScoreNormalizationTechnique(final Map<String, Object> params) {
lowerBounds = getLowerBounds(params);
lowerBoundsOptional = getLowerBounds(params);
}

/**
Expand All @@ -69,8 +70,14 @@ public void normalize(final NormalizeScoresDTO normalizeScoresDTO) {
continue;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
if (Objects.nonNull(lowerBounds) && !lowerBounds.isEmpty() && lowerBounds.size() != topDocsPerSubQuery.size()) {
throw new IllegalArgumentException("lower bounds size should be same as number of sub queries");
if (isLowerBoundsAndSubQueriesCountMismatched(topDocsPerSubQuery)) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"expected lower bounds array to contain %d elements matching the number of sub-queries, but found a mismatch",
topDocsPerSubQuery.size()
)
);
}
for (int j = 0; j < topDocsPerSubQuery.size(); j++) {
TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j);
Expand All @@ -87,14 +94,14 @@ public void normalize(final NormalizeScoresDTO normalizeScoresDTO) {
}
}

private LowerBound getLowerBound(int j) {
LowerBound lowerBound;
if (Objects.isNull(lowerBounds) || lowerBounds.isEmpty()) {
lowerBound = new LowerBound();
} else {
lowerBound = new LowerBound(true, lowerBounds.get(j).getLeft(), lowerBounds.get(j).getRight());
}
return lowerBound;
private boolean isLowerBoundsAndSubQueriesCountMismatched(List<TopDocs> topDocsPerSubQuery) {
return lowerBoundsOptional.isPresent() && lowerBoundsOptional.get().size() != topDocsPerSubQuery.size();
}

private LowerBound getLowerBound(int subQueryIndex) {
return lowerBoundsOptional.map(
pairs -> new LowerBound(true, pairs.get(subQueryIndex).getLeft(), pairs.get(subQueryIndex).getRight())
).orElseGet(LowerBound::new);
}

private MinMaxScores getMinMaxScoresResult(final List<CompoundTopDocs> queryTopDocs) {
Expand All @@ -108,7 +115,12 @@ private MinMaxScores getMinMaxScoresResult(final List<CompoundTopDocs> queryTopD

@Override
public String describe() {
return String.format(Locale.ROOT, "%s", TECHNIQUE_NAME);
return lowerBoundsOptional.map(lb -> {
String lowerBounds = lb.stream()
.map(pair -> String.format(Locale.ROOT, "(%s, %s)", pair.getLeft(), pair.getRight()))
.collect(Collectors.joining(", ", "[", "]"));
return String.format(Locale.ROOT, "%s, lower bounds %s", TECHNIQUE_NAME, lowerBounds);
}).orElse(String.format(Locale.ROOT, "%s", TECHNIQUE_NAME));
}

@Override
Expand Down Expand Up @@ -226,14 +238,14 @@ private class MinMaxScores {
float[] maxScoresPerSubquery;
}

private List<Pair<Mode, Float>> getLowerBounds(final Map<String, Object> params) {
List<Pair<Mode, Float>> lowerBounds = new ArrayList<>();

private Optional<List<Pair<Mode, Float>>> getLowerBounds(final Map<String, Object> params) {
// Early return if params is null or doesn't contain lower_bounds
if (Objects.isNull(params) || !params.containsKey("lower_bounds")) {
return lowerBounds;
return Optional.empty();
}

List<Pair<Mode, Float>> lowerBounds = new ArrayList<>();

Object lowerBoundsObj = params.get("lower_bounds");
if (!(lowerBoundsObj instanceof List<?> lowerBoundsParams)) {
throw new IllegalArgumentException("lower_bounds must be a List");
Expand All @@ -259,8 +271,13 @@ private List<Pair<Mode, Float>> getLowerBounds(final Map<String, Object> params)
Map<String, Object> lowerBound = (Map<String, Object>) boundObj;

try {
Mode mode = Mode.fromString(lowerBound.get("mode").toString());
float minScore = Float.parseFloat(String.valueOf(lowerBound.get("min_score")));
Mode mode = Mode.fromString(Objects.isNull(lowerBound.get("mode")) ? "" : lowerBound.get("mode").toString());
float minScore;
if (Objects.isNull(lowerBound.get("min_score"))) {
minScore = LowerBound.DEFAULT_LOWER_BOUND_SCORE;
} else {
minScore = Float.parseFloat(String.valueOf(lowerBound.get("min_score")));
}

Validate.isTrue(
minScore >= LowerBound.MIN_LOWER_BOUND_SCORE && minScore <= LowerBound.MAX_LOWER_BOUND_SCORE,
Expand All @@ -275,7 +292,7 @@ private List<Pair<Mode, Float>> getLowerBounds(final Map<String, Object> params)
}
}

return lowerBounds;
return Optional.of(lowerBounds);
}

/**
Expand Down Expand Up @@ -337,10 +354,12 @@ public float normalize(float score, float minScore, float maxScore, float lowerB
.collect(Collectors.joining(", "));

public static Mode fromString(String value) {
if (value == null || value.trim().isEmpty()) {
if (Objects.isNull(value)) {
throw new IllegalArgumentException("mode value cannot be null or empty");
}

if (value.trim().isEmpty()) {
return DEFAULT;
}
try {
return valueOf(value.toUpperCase(Locale.ROOT));
} catch (IllegalArgumentException e) {
Expand All @@ -351,5 +370,10 @@ public static Mode fromString(String value) {
}

public abstract float normalize(float score, float minScore, float maxScore, float lowerBoundScore);

@Override
public String toString() {
return name().toLowerCase(Locale.ROOT);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import static org.opensearch.neuralsearch.processor.normalization.MinMaxScoreNormalizationTechnique.MIN_SCORE;
import static org.opensearch.neuralsearch.query.HybridQueryBuilder.MAX_NUMBER_OF_SUB_QUERIES;
import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION;
import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_LOWER_BOUNDS;

/**
* Abstracts normalization of scores based on min-max method
Expand Down Expand Up @@ -290,18 +291,12 @@ public void testMode_fromString_invalidValues() {
assertEquals("invalid mode: invalid, valid values are: apply, clip, ignore", exception.getMessage());
}

public void testMode_fromString_nullOrEmpty() {
public void testLowerBoundsModeFromString_whenNullOrEmpty_thenFail() {
IllegalArgumentException nullException = expectThrows(
IllegalArgumentException.class,
() -> MinMaxScoreNormalizationTechnique.Mode.fromString(null)
);
assertEquals("mode value cannot be null or empty", nullException.getMessage());

IllegalArgumentException emptyException = expectThrows(
IllegalArgumentException.class,
() -> MinMaxScoreNormalizationTechnique.Mode.fromString("")
);
assertEquals("mode value cannot be null or empty", emptyException.getMessage());
}

public void testMode_normalize_apply() {
Expand Down Expand Up @@ -349,11 +344,11 @@ public void testMode_normalize_ignore() {
assertEquals(MIN_SCORE, normalizedLowScore, DELTA_FOR_SCORE_ASSERTION);
}

public void testMode_defaultValue() {
public void testLowerBoundsMode_whenDefaultValue_thenSuccessful() {
assertEquals(MinMaxScoreNormalizationTechnique.Mode.APPLY, MinMaxScoreNormalizationTechnique.Mode.DEFAULT);
}

public void testLowerBoundsExceedsMaxSubQueries() {
public void testLowerBounds_whenExceedsMaxSubQueries_thenFail() {
List<Map<String, Object>> lowerBounds = new ArrayList<>();

for (int i = 0; i <= 100; i++) {
Expand Down Expand Up @@ -389,6 +384,105 @@ public void testLowerBoundsExceedsMaxSubQueries() {
);
}

public void testDescribe_whenLowerBoundsArePresent_thenSuccessful() {
Map<String, Object> parameters = new HashMap<>();
List<Map<String, Object>> lowerBounds = Arrays.asList(
Map.of("mode", "apply", "min_score", 0.2),

Map.of("mode", "clip", "min_score", 0.1)
);
parameters.put("lower_bounds", lowerBounds);
MinMaxScoreNormalizationTechnique techniqueWithBounds = new MinMaxScoreNormalizationTechnique(parameters);
assertEquals("min_max, lower bounds [(apply, 0.2), (clip, 0.1)]", techniqueWithBounds.describe());

// Test case 2: without lower bounds
Map<String, Object> emptyParameters = new HashMap<>();
MinMaxScoreNormalizationTechnique techniqueWithoutBounds = new MinMaxScoreNormalizationTechnique(emptyParameters);
assertEquals("min_max", techniqueWithoutBounds.describe());

Map<String, Object> parametersMissingMode = new HashMap<>();
List<Map<String, Object>> lowerBoundsMissingMode = Arrays.asList(
Map.of("min_score", 0.2),
Map.of("mode", "clip", "min_score", 0.1)
);
parametersMissingMode.put("lower_bounds", lowerBoundsMissingMode);
MinMaxScoreNormalizationTechnique techniqueMissingMode = new MinMaxScoreNormalizationTechnique(parametersMissingMode);
assertEquals("min_max, lower bounds [(apply, 0.2), (clip, 0.1)]", techniqueMissingMode.describe());

Map<String, Object> parametersMissingScore = new HashMap<>();
List<Map<String, Object>> lowerBoundsMissingScore = Arrays.asList(
Map.of("mode", "apply"),
Map.of("mode", "clip", "min_score", 0.1)
);
parametersMissingScore.put("lower_bounds", lowerBoundsMissingScore);
MinMaxScoreNormalizationTechnique techniqueMissingScore = new MinMaxScoreNormalizationTechnique(parametersMissingScore);
assertEquals("min_max, lower bounds [(apply, 0.0), (clip, 0.1)]", techniqueMissingScore.describe());
}

public void testLowerBounds_whenInvalidInput_thenFail() {
// Test case 1: Invalid mode value
Map<String, Object> parametersInvalidMode = new HashMap<>();
List<Map<String, Object>> lowerBoundsInvalidMode = Arrays.asList(
Map.of("mode", "invalid_mode", "min_score", 0.2),
Map.of("mode", "clip", "min_score", 0.1)
);
parametersInvalidMode.put("lower_bounds", lowerBoundsInvalidMode);
IllegalArgumentException invalidModeException = expectThrows(
IllegalArgumentException.class,
() -> new MinMaxScoreNormalizationTechnique(parametersInvalidMode)
);
assertEquals("invalid mode: invalid_mode, valid values are: apply, clip, ignore", invalidModeException.getMessage());

// Test case 4: Invalid min_score type
Map<String, Object> parametersInvalidScore = new HashMap<>();
List<Map<String, Object>> lowerBoundsInvalidScore = Arrays.asList(
Map.of("mode", "apply", "min_score", "not_a_number"),
Map.of("mode", "clip", "min_score", 0.1)
);
parametersInvalidScore.put("lower_bounds", lowerBoundsInvalidScore);
IllegalArgumentException invalidScoreException = expectThrows(
IllegalArgumentException.class,
() -> new MinMaxScoreNormalizationTechnique(parametersInvalidScore)
);
assertEquals("Invalid format for min_score: must be a valid float value", invalidScoreException.getMessage());
}

public void testThrowsException_whenLowerBoundsAndSubQueriesCountMismatch_thenFail() {
Map<String, Object> parameters = new HashMap<>();
List<Map<String, Object>> lowerBounds = Arrays.asList(Map.of("mode", "clip", "min_score", 0.1));
parameters.put(PARAM_NAME_LOWER_BOUNDS, lowerBounds);

List<CompoundTopDocs> compoundTopDocs = List.of(
new CompoundTopDocs(
new TotalHits(2, TotalHits.Relation.EQUAL_TO),
List.of(
new TopDocs(
new TotalHits(2, TotalHits.Relation.EQUAL_TO),
new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) }
),
new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(3, 0.1f) })
),
false,
SEARCH_SHARD
)
);
ScoreNormalizationTechnique minMaxTechnique = new MinMaxScoreNormalizationTechnique(parameters);
NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder()
.queryTopDocs(compoundTopDocs)
.normalizationTechnique(minMaxTechnique)
.build();

IllegalArgumentException exception = expectThrows(
IllegalArgumentException.class,
() -> minMaxTechnique.normalize(normalizeScoresDTO)
);

assertEquals(
"expected lower bounds array to contain 2 elements matching the number of sub-queries, but found a mismatch",
exception.getMessage()
);
}

private void assertCompoundTopDocs(TopDocs expected, TopDocs actual) {
assertEquals(expected.totalHits.value(), actual.totalHits.value());
assertEquals(expected.totalHits.relation(), actual.totalHits.relation());
Expand Down

0 comments on commit 89999f9

Please sign in to comment.