Skip to content

Commit

Permalink
Added integ test, adjust some calculations
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Feb 18, 2025
1 parent 9540799 commit b63f34f
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -215,38 +215,6 @@ private float normalizeSingleScore(final float score, final float minScore, fina
return lowerBound.getMode().normalize(score, minScore, maxScore, lowerBound.getMinScore());
}

private boolean shouldIgnoreLowerBound(LowerBound lowerBound) {
return !lowerBound.isEnabled() || lowerBound.getMode() == Mode.IGNORE;
}

private float normalizeWithoutLowerBound(float score, float minScore, float maxScore) {
float normalizedScore = (score - minScore) / (maxScore - minScore);
return normalizedScore == 0.0f ? MIN_SCORE : normalizedScore;
}

private float normalizeWithLowerBound(float score, float minScore, float maxScore, LowerBound lowerBound) {
if (lowerBound.getMode() == Mode.APPLY) {
return normalizeWithApplyMode(score, maxScore, lowerBound);
} else if (lowerBound.getMode() == Mode.CLIP) {
return normalizeWithClipMode(score, minScore, maxScore, lowerBound);
}
return (score - minScore) / (maxScore - minScore);
}

private float normalizeWithApplyMode(float score, float maxScore, LowerBound lowerBound) {
if (score < lowerBound.getMinScore()) {
return score / (maxScore - score);
}
return (score - lowerBound.getMinScore()) / (maxScore - lowerBound.getMinScore());
}

private float normalizeWithClipMode(float score, float minScore, float maxScore, LowerBound lowerBound) {
if (score < minScore) {
return lowerBound.getMinScore() / (maxScore - lowerBound.getMinScore());
}
return (score - lowerBound.getMinScore()) / (maxScore - lowerBound.getMinScore());
}

/**
* Result class to hold min and max scores for each sub query
*/
Expand Down Expand Up @@ -302,7 +270,7 @@ private List<Pair<Mode, Float>> getLowerBounds(final Map<String, Object> params)
* Result class to hold lower bound for each sub query
*/
@Getter
private class LowerBound {
private static class LowerBound {
static final float MIN_LOWER_BOUND_SCORE = -10_000f;
static final float MAX_LOWER_BOUND_SCORE = 10_000f;
static final float DEFAULT_LOWER_BOUND_SCORE = 0.0f;
Expand All @@ -326,7 +294,9 @@ protected enum Mode {
APPLY {
@Override
public float normalize(float score, float minScore, float maxScore, float lowerBoundScore) {
if (score < lowerBoundScore) {
if (maxScore < lowerBoundScore) {
return (score - minScore) / (maxScore - minScore);
} else if (score < lowerBoundScore) {
return score / (maxScore - score);
}
return (score - lowerBoundScore) / (maxScore - lowerBoundScore);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
*/
package org.opensearch.neuralsearch.processor;

import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD;
import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_NORMALIZATION_METHOD;
import static org.opensearch.neuralsearch.util.TestUtils.RELATION_EQUAL_TO;
import static org.opensearch.neuralsearch.util.TestUtils.TEST_DIMENSION;
import static org.opensearch.neuralsearch.util.TestUtils.TEST_SPACE_TYPE;
Expand Down Expand Up @@ -48,6 +50,8 @@ public class NormalizationProcessorIT extends BaseNeuralSearchIT {
private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1";
private static final String TEST_TEXT_FIELD_NAME_2 = "test-text-field-2";
private static final String SEARCH_PIPELINE = "phase-results-normalization-processor-pipeline";
private static final String SEARCH_PIPELINE_LOWER_BOUNDS_2_QUERIES = "normalization-processor-with-lower-bounds-two-queries";
private static final String SEARCH_PIPELINE_LOWER_BOUNDS_3_QUERIES = "normalization-processor-with-lower-bounds-three-queries";
private final float[] testVector1 = createRandomVector(TEST_DIMENSION);
private final float[] testVector2 = createRandomVector(TEST_DIMENSION);
private final float[] testVector3 = createRandomVector(TEST_DIMENSION);
Expand Down Expand Up @@ -239,6 +243,123 @@ public void testQueryMatches_whenMultipleShards_thenSuccessful() {
assertQueryResults(searchResponseAsMapNoMatches, 0, true);
}

@SneakyThrows
public void testMinMaxLowerBounds_whenMultipleShards_thenSuccessful() {
String modelId = null;
initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME);
modelId = prepareModel();
createSearchPipeline(
SEARCH_PIPELINE_LOWER_BOUNDS_2_QUERIES,
DEFAULT_NORMALIZATION_METHOD,
Map.of(
"lower_bounds",
List.of(
Map.of("mode", "apply", "min_score", Float.toString(0.01f)),
Map.of("mode", "clip", "min_score", Float.toString(0.0f))
)
),
DEFAULT_COMBINATION_METHOD,
Map.of(),
false
);
int totalExpectedDocQty = 6;

NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder()
.fieldName(TEST_KNN_VECTOR_FIELD_NAME_1)
.queryText(TEST_DOC_TEXT1)
.modelId(modelId)
.k(6)
.build();

TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);

HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
hybridQueryBuilder.add(neuralQueryBuilder);
hybridQueryBuilder.add(termQueryBuilder);

Map<String, Object> searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME,
hybridQueryBuilder,
null,
6,
Map.of("search_pipeline", SEARCH_PIPELINE_LOWER_BOUNDS_2_QUERIES)
);

assertNotNull(searchResponseAsMap);
Map<String, Object> total = getTotalHits(searchResponseAsMap);
assertNotNull(total.get("value"));
assertEquals(totalExpectedDocQty, total.get("value"));
assertNotNull(total.get("relation"));
assertEquals(RELATION_EQUAL_TO, total.get("relation"));
assertTrue(getMaxScore(searchResponseAsMap).isPresent());
assertTrue(Range.between(.5f, 1.0f).contains(getMaxScore(searchResponseAsMap).get()));
List<Map<String, Object>> hitsNestedList = getNestedHits(searchResponseAsMap);
List<String> ids = new ArrayList<>();
List<Double> scores = new ArrayList<>();
for (Map<String, Object> oneHit : hitsNestedList) {
ids.add((String) oneHit.get("_id"));
scores.add((Double) oneHit.get("_score"));
}
// verify scores order
assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1)));

// verify the scores are normalized. we need special assert logic because combined score may vary as neural search query
// based on random vectors and return results for every doc. In some cases that may affect 1.0 score from term query and make it
// lower.
float highestScore = scores.stream().max(Double::compare).get().floatValue();
assertTrue(Range.between(.5f, 1.0f).contains(highestScore));
float lowestScore = scores.stream().min(Double::compare).get().floatValue();
assertTrue(Range.between(.0f, .5f).contains(lowestScore));

// verify that all ids are unique
assertEquals(Set.copyOf(ids).size(), ids.size());

createSearchPipeline(
SEARCH_PIPELINE_LOWER_BOUNDS_3_QUERIES,
DEFAULT_NORMALIZATION_METHOD,
Map.of(
"lower_bounds",
List.of(
Map.of("mode", "apply", "min_score", Float.toString(0.01f)),
Map.of("mode", "clip", "min_score", Float.toString(0.0f)),
Map.of("mode", "ignore", "min_score", Float.toString(0.0f))
)
),
DEFAULT_COMBINATION_METHOD,
Map.of(),
false
);

// verify case when there are partial match
HybridQueryBuilder hybridQueryBuilderPartialMatch = new HybridQueryBuilder();
hybridQueryBuilderPartialMatch.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));
hybridQueryBuilderPartialMatch.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4));
hybridQueryBuilderPartialMatch.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT7));

Map<String, Object> searchResponseAsMapPartialMatch = search(
TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME,
hybridQueryBuilderPartialMatch,
null,
5,
Map.of("search_pipeline", SEARCH_PIPELINE_LOWER_BOUNDS_3_QUERIES)
);
assertQueryResults(searchResponseAsMapPartialMatch, 4, false, Range.between(0.33f, 1.0f));

// verify case when query doesn't have a match
HybridQueryBuilder hybridQueryBuilderNoMatches = new HybridQueryBuilder();
hybridQueryBuilderNoMatches.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT6));
hybridQueryBuilderNoMatches.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT7));

Map<String, Object> searchResponseAsMapNoMatches = search(
TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME,
hybridQueryBuilderNoMatches,
null,
5,
Map.of("search_pipeline", SEARCH_PIPELINE_LOWER_BOUNDS_2_QUERIES)
);
assertQueryResults(searchResponseAsMapNoMatches, 0, true);
}

private void initializeIndexIfNotExist(String indexName) throws IOException {
if (TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME.equalsIgnoreCase(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)) {
prepareKnnIndex(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,14 @@ protected boolean preserveClusterUponCompletion() {
public void testExplain_whenMultipleSubqueriesAndOneShard_thenSuccessful() {
initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME);
// create search pipeline with both normalization processor and explain response processor
createSearchPipeline(NORMALIZATION_SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true);
createSearchPipeline(
NORMALIZATION_SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
Map.of(),
DEFAULT_COMBINATION_METHOD,
Map.of(),
true
);

TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4);
Expand Down Expand Up @@ -195,6 +202,7 @@ public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful()
createSearchPipeline(
NORMALIZATION_SEARCH_PIPELINE,
NORMALIZATION_TECHNIQUE_L2,
Map.of(),
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.3f, 0.7f })),
true
Expand Down Expand Up @@ -324,7 +332,14 @@ public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful()
public void testExplanationResponseProcessor_whenProcessorIsNotConfigured_thenResponseHasQueryExplanations() {
initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME);
// create search pipeline with normalization processor, no explanation response processor
createSearchPipeline(NORMALIZATION_SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), false);
createSearchPipeline(
NORMALIZATION_SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
Map.of(),
DEFAULT_COMBINATION_METHOD,
Map.of(),
false
);

TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4);
Expand Down Expand Up @@ -472,7 +487,14 @@ public void testExplanationResponseProcessor_whenProcessorIsNotConfigured_thenRe
public void testExplain_whenLargeNumberOfDocuments_thenSuccessful() {
initializeIndexIfNotExist(TEST_LARGE_DOCS_INDEX_NAME);
// create search pipeline with both normalization processor and explain response processor
createSearchPipeline(NORMALIZATION_SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true);
createSearchPipeline(
NORMALIZATION_SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
Map.of(),
DEFAULT_COMBINATION_METHOD,
Map.of(),
true
);

TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
Expand Down Expand Up @@ -526,7 +548,14 @@ public void testExplain_whenLargeNumberOfDocuments_thenSuccessful() {
public void testSpecificQueryTypes_whenMultiMatchAndKnn_thenSuccessful() {
initializeIndexIfNotExist(TEST_LARGE_DOCS_INDEX_NAME);
// create search pipeline with both normalization processor and explain response processor
createSearchPipeline(NORMALIZATION_SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true);
createSearchPipeline(
NORMALIZATION_SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
Map.of(),
DEFAULT_COMBINATION_METHOD,
Map.of(),
true
);

HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
hybridQueryBuilder.add(QueryBuilders.multiMatchQuery(TEST_QUERY_TEXT3, TEST_TEXT_FIELD_NAME_1, TEST_TEXT_FIELD_NAME_2));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ public void testExplainAndSort_whenIndexWithMultipleShards_thenSuccessful() {
updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false);

initializeIndexIfNotExists(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SHARDS_COUNT_IN_MULTI_NODE_CLUSTER);
createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true);
createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, Map.of(), DEFAULT_COMBINATION_METHOD, Map.of(), true);
// Assert
// scores for search hits
HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
import static org.opensearch.neuralsearch.util.TestUtils.ML_PLUGIN_SYSTEM_INDEX_PREFIX;
import static org.opensearch.neuralsearch.util.TestUtils.OPENDISTRO_SECURITY;
import static org.opensearch.neuralsearch.util.TestUtils.OPENSEARCH_SYSTEM_INDEX_PREFIX;
import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_LOWER_BOUNDS;
import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS;
import static org.opensearch.neuralsearch.util.TestUtils.MAX_RETRY;
import static org.opensearch.neuralsearch.util.TestUtils.MAX_TIME_OUT_INTERVAL;
Expand Down Expand Up @@ -1217,13 +1218,14 @@ protected void createSearchPipeline(
String combinationMethod,
final Map<String, String> combinationParams
) {
createSearchPipeline(pipelineId, normalizationMethod, combinationMethod, combinationParams, false);
createSearchPipeline(pipelineId, normalizationMethod, Map.of(), combinationMethod, combinationParams, false);
}

@SneakyThrows
protected void createSearchPipeline(
final String pipelineId,
final String normalizationMethod,
final Map<String, Object> normalizationParams,
final String combinationMethod,
final Map<String, String> combinationParams,
boolean addExplainResponseProcessor
Expand All @@ -1235,10 +1237,32 @@ protected void createSearchPipeline(
.append(NormalizationProcessor.TYPE)
.append("\": {")
.append("\"normalization\": {")
.append("\"technique\": \"%s\"")
.append("},")
.append("\"combination\": {")
.append("\"technique\": \"%s\"");
if (Objects.nonNull(normalizationParams) && !normalizationParams.isEmpty()) {
stringBuilderForContentBody.append(", \"parameters\": {");
if (normalizationParams.containsKey(PARAM_NAME_LOWER_BOUNDS)) {
stringBuilderForContentBody.append("\"lower_bounds\": [");
List<Map> lowerBounds = (List) normalizationParams.get(PARAM_NAME_LOWER_BOUNDS);
for (int i = 0; i < lowerBounds.size(); i++) {
Map<String, String> lowerBound = lowerBounds.get(i);
stringBuilderForContentBody.append("{ ")
.append("\"mode\"")
.append(": \"")
.append(lowerBound.get("mode"))
.append("\",")
.append("\"min_score\"")
.append(": ")
.append(lowerBound.get("min_score"))
.append(" }");
if (i < lowerBounds.size() - 1) {
stringBuilderForContentBody.append(", ");
}
}
stringBuilderForContentBody.append("]");
}
stringBuilderForContentBody.append(" }");
}
stringBuilderForContentBody.append("},").append("\"combination\": {").append("\"technique\": \"%s\"");
if (Objects.nonNull(combinationParams) && !combinationParams.isEmpty()) {
stringBuilderForContentBody.append(", \"parameters\": {");
if (combinationParams.containsKey(PARAM_NAME_WEIGHTS)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ public class TestUtils {
public static final String DEFAULT_NORMALIZATION_METHOD = "min_max";
public static final String DEFAULT_COMBINATION_METHOD = "arithmetic_mean";
public static final String PARAM_NAME_WEIGHTS = "weights";
public static final String PARAM_NAME_LOWER_BOUNDS = "lower_bounds";
public static final String SPARSE_ENCODING_PROCESSOR = "sparse_encoding";
public static final int MAX_TIME_OUT_INTERVAL = 3000;
public static final int MAX_RETRY = 5;
Expand Down

0 comments on commit b63f34f

Please sign in to comment.