From de773ff2fafc047ca3cf85dbe7de76f440ee3957 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Wed, 26 Jul 2023 17:20:34 -0700 Subject: [PATCH] Adding weights param for combination technique Signed-off-by: Martin Gaievski --- ...ithmeticMeanScoreCombinationTechnique.java | 51 +++++- .../combination/ScoreCombinationFactory.java | 19 ++- .../NormalizationProcessorFactory.java | 9 +- .../common/BaseNeuralSearchIT.java | 43 ++--- .../ScoreNormalizationCombinationIT.java | 114 +++++++++++++ .../NormalizationProcessorFactoryTests.java | 160 ++++++++++++++++-- 6 files changed, 356 insertions(+), 40 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java index 10c5533f5..4ff3670b2 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java @@ -5,16 +5,34 @@ package org.opensearch.neuralsearch.processor.combination; -import lombok.NoArgsConstructor; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; + +import org.opensearch.OpenSearchParseException; /** * Abstracts combination of scores based on arithmetic mean method */ -@NoArgsConstructor public class ArithmeticMeanScoreCombinationTechnique implements ScoreCombinationTechnique { public static final String TECHNIQUE_NAME = "arithmetic_mean"; + public static final String PARAM_NAME_WEIGHTS = "weights"; private static final Float ZERO_SCORE = 0.0f; + private final List weights; + + public ArithmeticMeanScoreCombinationTechnique(final Map params) { + validateParams(params); + if (Objects.isNull(params) || params.isEmpty()) { + weights = List.of(); + return; + } + // get weights, we don't need to check for instance as it's done during validation + weights = (List) params.getOrDefault(PARAM_NAME_WEIGHTS, new ArrayList<>()); + } /** * Arithmetic mean method for combining scores. @@ -26,8 +44,13 @@ public class ArithmeticMeanScoreCombinationTechnique implements ScoreCombination public float combine(final float[] scores) { float combinedScore = 0.0f; int count = 0; - for (float score : scores) { + for (int i = 0; i < scores.length; i++) { + float score = scores[i]; if (score >= 0.0) { + // apply weight for this sub-query if it's set for particular sub-query + if (i < weights.size()) { + score = (float) (score * weights.get(i)); + } combinedScore += score; count++; } @@ -37,4 +60,26 @@ public float combine(final float[] scores) { } return combinedScore / count; } + + private void validateParams(final Map params) { + if (Objects.isNull(params) || params.isEmpty()) { + return; + } + // check if only supported params are passed + Set supportedParams = Set.of(PARAM_NAME_WEIGHTS); + Optional optionalNotSupportedParam = params.keySet() + .stream() + .filter(paramName -> !supportedParams.contains(paramName)) + .findFirst(); + if (optionalNotSupportedParam.isPresent()) { + throw new OpenSearchParseException("provided parameter for combination technique is not supported"); + } + + // check param types + if (params.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) { + if (!(params.get(PARAM_NAME_WEIGHTS) instanceof List)) { + throw new OpenSearchParseException("parameter {} must be a collection of numbers", PARAM_NAME_WEIGHTS); + } + } + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java index bf55a8cc5..db4721ddd 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java @@ -7,6 +7,7 @@ import java.util.Map; import java.util.Optional; +import java.util.function.Function; import org.opensearch.OpenSearchParseException; @@ -15,11 +16,11 @@ */ public class ScoreCombinationFactory { - public static final ScoreCombinationTechnique DEFAULT_METHOD = new ArithmeticMeanScoreCombinationTechnique(); + public static final ScoreCombinationTechnique DEFAULT_METHOD = new ArithmeticMeanScoreCombinationTechnique(Map.of()); - private final Map scoreCombinationMethodsMap = Map.of( + private final Map, ScoreCombinationTechnique>> scoreCombinationMethodsMap = Map.of( ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME, - new ArithmeticMeanScoreCombinationTechnique() + ArithmeticMeanScoreCombinationTechnique::new ); /** @@ -28,7 +29,17 @@ public class ScoreCombinationFactory { * @return instance of ScoreCombinationTechnique for technique name */ public ScoreCombinationTechnique createCombination(final String technique) { + return createCombination(technique, Map.of()); + } + + /** + * Get score combination method by technique name + * @param technique name of technique + * @return instance of ScoreCombinationTechnique for technique name + */ + public ScoreCombinationTechnique createCombination(final String technique, final Map params) { return Optional.ofNullable(scoreCombinationMethodsMap.get(technique)) - .orElseThrow(() -> new OpenSearchParseException("provided combination technique is not supported")); + .orElseThrow(() -> new OpenSearchParseException("provided combination technique is not supported")) + .apply(params); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactory.java index f31a5c6bc..524c1b105 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactory.java @@ -6,6 +6,7 @@ package org.opensearch.neuralsearch.processor.factory; import static org.opensearch.ingest.ConfigurationUtils.readOptionalMap; +import static org.opensearch.ingest.ConfigurationUtils.readOptionalStringProperty; import java.util.Map; import java.util.Objects; @@ -29,6 +30,7 @@ public class NormalizationProcessorFactory implements Processor.Factory combinationClause = readOptionalMap(NormalizationProcessor.TYPE, tag, config, COMBINATION_CLAUSE); ScoreCombinationTechnique scoreCombinationTechnique = ScoreCombinationFactory.DEFAULT_METHOD; + Map combinationParams; if (Objects.nonNull(combinationClause)) { - String combinationTechnique = (String) combinationClause.getOrDefault(TECHNIQUE, ""); - scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique); + String combinationTechnique = readOptionalStringProperty(NormalizationProcessor.TYPE, tag, combinationClause, TECHNIQUE); + // check for optional combination params + combinationParams = readOptionalMap(NormalizationProcessor.TYPE, tag, combinationClause, PARAMETERS); + scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique, combinationParams); } return new NormalizationProcessor( diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index bf56ab92d..48c694fb0 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -14,6 +14,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.function.Predicate; import java.util.stream.Collectors; @@ -57,8 +58,9 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { private static final int DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND = 1000; private static final String DEFAULT_USER_AGENT = "Kibana"; - private static final String NORMALIZATION_METHOD = "min_max"; - private static final String COMBINATION_METHOD = "arithmetic_mean"; + protected static final String NORMALIZATION_METHOD = "min_max"; + protected static final String COMBINATION_METHOD = "arithmetic_mean"; + protected static final String PARAM_NAME_WEIGHTS = "weights"; protected final ClassLoader classLoader = this.getClass().getClassLoader(); @@ -566,30 +568,31 @@ protected void createSearchPipeline( final String pipelineId, final String normalizationMethod, String combinationMethod, - final Map combinationParams + final Map combinationParams ) { + StringBuilder stringBuilderForContentBody = new StringBuilder(); + stringBuilderForContentBody.append("{\"description\": \"Post processor pipeline\",") + .append("\"phase_results_processors\": [{ ") + .append("\"normalization-processor\": {") + .append("\"normalization\": {") + .append("\"technique\": \"%s\"") + .append("},") + .append("\"combination\": {") + .append("\"technique\": \"%s\""); + if (Objects.nonNull(combinationParams) && !combinationParams.isEmpty()) { + stringBuilderForContentBody.append(", \"parameters\": {"); + if (combinationParams.containsKey(PARAM_NAME_WEIGHTS)) { + stringBuilderForContentBody.append("\"weights\": ").append(combinationParams.get(PARAM_NAME_WEIGHTS)); + } + stringBuilderForContentBody.append(" }"); + } + stringBuilderForContentBody.append("}").append("}}]}"); makeRequest( client(), "PUT", String.format(LOCALE, "/_search/pipeline/%s", pipelineId), null, - toHttpEntity( - String.format( - LOCALE, - "{\"description\": \"Post processor pipeline\"," - + "\"phase_results_processors\": [{ " - + "\"normalization-processor\": {" - + "\"normalization\": {" - + "\"technique\": \"%s\"" - + "}," - + "\"combination\": {" - + "\"technique\": \"%s\"" - + "}" - + "}}]}", - normalizationMethod, - combinationMethod - ) - ), + toHttpEntity(String.format(LOCALE, stringBuilderForContentBody.toString(), normalizationMethod, combinationMethod)), ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) ); } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java index d72ee6f7f..795ed1e5f 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java @@ -9,6 +9,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; @@ -250,6 +251,92 @@ public void testResultProcessor_whenMultipleShardsAndPartialMatches_thenSuccessf assertQueryResults(searchResponseAsMap, 4, true); } + @SneakyThrows + public void testCombinationParams_whenWeightsParamSet_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME); + // check case when number of weights and sub-queries are same + createSearchPipeline( + SEARCH_PIPELINE, + NORMALIZATION_METHOD, + COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f })) + ); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4)); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT7)); + + Map searchResponseWithWeights1AsMap = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertWeightedScores(searchResponseWithWeights1AsMap, 0.6, 0.5, 0.001); + + // delete existing pipeline and create a new one with another set of weights + deleteSearchPipeline(SEARCH_PIPELINE); + createSearchPipeline( + SEARCH_PIPELINE, + NORMALIZATION_METHOD, + COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 2.0f, 0.5f })) + ); + + Map searchResponseWithWeights2AsMap = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertWeightedScores(searchResponseWithWeights2AsMap, 2.0, 0.8, 0.001); + + // check case when number of weights is less than number of sub-queries + // delete existing pipeline and create a new one with another set of weights + deleteSearchPipeline(SEARCH_PIPELINE); + createSearchPipeline( + SEARCH_PIPELINE, + NORMALIZATION_METHOD, + COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f })) + ); + + Map searchResponseWithWeights3AsMap = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertWeightedScores(searchResponseWithWeights3AsMap, 1.0, 0.8, 0.001); + + // check case when number of weights is more than number of sub-queries + // delete existing pipeline and create a new one with another set of weights + deleteSearchPipeline(SEARCH_PIPELINE); + createSearchPipeline( + SEARCH_PIPELINE, + NORMALIZATION_METHOD, + COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f, 1.5f })) + ); + + Map searchResponseWithWeights4AsMap = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertWeightedScores(searchResponseWithWeights4AsMap, 0.6, 0.5, 0.001); + } + private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME.equalsIgnoreCase(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)) { prepareKnnIndex( @@ -402,4 +489,31 @@ private void assertQueryResults(Map searchResponseAsMap, int tot // verify that all ids are unique assertEquals(Set.copyOf(ids).size(), ids.size()); } + + private void assertWeightedScores( + Map searchResponseWithWeightsAsMap, + double expectedMaxScore, + double expectedMaxMinusOneScore, + double expectedMinScore + ) { + assertNotNull(searchResponseWithWeightsAsMap); + Map totalWeights = getTotalHits(searchResponseWithWeightsAsMap); + assertNotNull(totalWeights.get("value")); + assertEquals(4, totalWeights.get("value")); + assertNotNull(totalWeights.get("relation")); + assertEquals(RELATION_EQUAL_TO, totalWeights.get("relation")); + assertTrue(getMaxScore(searchResponseWithWeightsAsMap).isPresent()); + assertEquals(expectedMaxScore, getMaxScore(searchResponseWithWeightsAsMap).get(), 0.001f); + + List scoresWeights = new ArrayList<>(); + for (Map oneHit : getNestedHits(searchResponseWithWeightsAsMap)) { + scoresWeights.add((Double) oneHit.get("_score")); + } + // verify scores order + assertTrue(IntStream.range(0, scoresWeights.size() - 1).noneMatch(idx -> scoresWeights.get(idx) < scoresWeights.get(idx + 1))); + // verify the scores are normalized with inclusion of weights + assertEquals(expectedMaxScore, scoresWeights.get(0), 0.001); + assertEquals(expectedMaxMinusOneScore, scoresWeights.get(1), 0.001); + assertEquals(expectedMinScore, scoresWeights.get(scoresWeights.size() - 1), 0.001); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java index babeed214..0e1aa4cf1 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java @@ -6,7 +6,12 @@ package org.opensearch.neuralsearch.processor.factory; import static org.mockito.Mockito.mock; +import static org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory.COMBINATION_CLAUSE; +import static org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory.NORMALIZATION_CLAUSE; +import static org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory.PARAMETERS; +import static org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory.TECHNIQUE; +import java.util.Arrays; import java.util.HashMap; import java.util.Map; @@ -24,6 +29,8 @@ import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; import org.opensearch.test.OpenSearchTestCase; +import com.carrotsearch.randomizedtesting.RandomizedTest; + public class NormalizationProcessorFactoryTests extends OpenSearchTestCase { private static final String NORMALIZATION_METHOD = "min_max"; @@ -68,8 +75,47 @@ public void testNormalizationProcessor_whenWithParams_thenSuccessful() { String description = "description"; boolean ignoreFailure = false; Map config = new HashMap<>(); - config.put("normalization", Map.of("technique", "min_max")); - config.put("combination", Map.of("technique", "arithmetic_mean")); + config.put("normalization", new HashMap<>(Map.of("technique", "min_max"))); + config.put("combination", new HashMap<>(Map.of("technique", "arithmetic_mean"))); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + SearchPhaseResultsProcessor searchPhaseResultsProcessor = normalizationProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + assertNotNull(searchPhaseResultsProcessor); + assertTrue(searchPhaseResultsProcessor instanceof NormalizationProcessor); + NormalizationProcessor normalizationProcessor = (NormalizationProcessor) searchPhaseResultsProcessor; + assertEquals("normalization-processor", normalizationProcessor.getType()); + } + + @SneakyThrows + public void testNormalizationProcessor_whenWithCombinationParams_thenSuccessful() { + NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + config.put(NORMALIZATION_CLAUSE, new HashMap<>(Map.of("technique", "min_max"))); + config.put( + COMBINATION_CLAUSE, + new HashMap<>( + Map.of( + TECHNIQUE, + "arithmetic_mean", + PARAMETERS, + new HashMap<>(Map.of("weights", Arrays.asList(RandomizedTest.randomFloat(), RandomizedTest.randomFloat()))) + ) + ) + ); Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); SearchPhaseResultsProcessor searchPhaseResultsProcessor = normalizationProcessorFactory.create( processorFactories, @@ -85,7 +131,58 @@ public void testNormalizationProcessor_whenWithParams_thenSuccessful() { assertEquals("normalization-processor", normalizationProcessor.getType()); } - public void testInputValidation_whenInvalidParameters_thenFail() { + public void testInputValidation_whenInvalidNormalizationClause_thenFail() { + NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + + expectThrows( + OpenSearchParseException.class, + () -> normalizationProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + new HashMap<>( + Map.of( + NormalizationProcessorFactory.NORMALIZATION_CLAUSE, + Map.of(TECHNIQUE, ""), + NormalizationProcessorFactory.COMBINATION_CLAUSE, + Map.of(TECHNIQUE, ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME) + ) + ), + pipelineContext + ) + ); + + expectThrows( + OpenSearchParseException.class, + () -> normalizationProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + new HashMap<>( + Map.of( + NormalizationProcessorFactory.NORMALIZATION_CLAUSE, + new HashMap(Map.of(TECHNIQUE, "random_name_for_normalization")), + NormalizationProcessorFactory.COMBINATION_CLAUSE, + new HashMap(Map.of(TECHNIQUE, ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME)) + ) + ), + pipelineContext + ) + ); + } + + public void testInputValidation_whenInvalidCombinationClause_thenFail() { NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory( new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), new ScoreNormalizationFactory(), @@ -107,14 +204,46 @@ public void testInputValidation_whenInvalidParameters_thenFail() { new HashMap<>( Map.of( NormalizationProcessorFactory.NORMALIZATION_CLAUSE, - Map.of(NormalizationProcessorFactory.TECHNIQUE, ""), + new HashMap(Map.of(TECHNIQUE, NORMALIZATION_METHOD)), + NormalizationProcessorFactory.COMBINATION_CLAUSE, + new HashMap(Map.of(TECHNIQUE, "")) + ) + ), + pipelineContext + ) + ); + + expectThrows( + OpenSearchParseException.class, + () -> normalizationProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + new HashMap<>( + Map.of( + NormalizationProcessorFactory.NORMALIZATION_CLAUSE, + new HashMap(Map.of(TECHNIQUE, NORMALIZATION_METHOD)), NormalizationProcessorFactory.COMBINATION_CLAUSE, - Map.of(NormalizationProcessorFactory.TECHNIQUE, ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME) + new HashMap(Map.of(TECHNIQUE, "random_name_for_combination")) ) ), pipelineContext ) ); + } + + public void testInputValidation_whenInvalidCombinationParams_thenFail() { + NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); expectThrows( OpenSearchParseException.class, @@ -126,9 +255,16 @@ public void testInputValidation_whenInvalidParameters_thenFail() { new HashMap<>( Map.of( NormalizationProcessorFactory.NORMALIZATION_CLAUSE, - Map.of(NormalizationProcessorFactory.TECHNIQUE, NORMALIZATION_METHOD), + new HashMap(Map.of(TECHNIQUE, NORMALIZATION_METHOD)), NormalizationProcessorFactory.COMBINATION_CLAUSE, - Map.of(NormalizationProcessorFactory.TECHNIQUE, "") + new HashMap( + Map.of( + TECHNIQUE, + "", + NormalizationProcessorFactory.PARAMETERS, + new HashMap<>(Map.of("weights", "random_string")) + ) + ) ) ), pipelineContext @@ -145,9 +281,9 @@ public void testInputValidation_whenInvalidParameters_thenFail() { new HashMap<>( Map.of( NormalizationProcessorFactory.NORMALIZATION_CLAUSE, - Map.of(NormalizationProcessorFactory.TECHNIQUE, "random_name_for_normalization"), + new HashMap(Map.of(TECHNIQUE, NORMALIZATION_METHOD)), NormalizationProcessorFactory.COMBINATION_CLAUSE, - Map.of(NormalizationProcessorFactory.TECHNIQUE, ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME) + new HashMap(Map.of(TECHNIQUE, "", NormalizationProcessorFactory.PARAMETERS, new HashMap<>(Map.of("weights", 5.0)))) ) ), pipelineContext @@ -164,9 +300,11 @@ public void testInputValidation_whenInvalidParameters_thenFail() { new HashMap<>( Map.of( NormalizationProcessorFactory.NORMALIZATION_CLAUSE, - Map.of(NormalizationProcessorFactory.TECHNIQUE, NORMALIZATION_METHOD), + new HashMap(Map.of(TECHNIQUE, NORMALIZATION_METHOD)), NormalizationProcessorFactory.COMBINATION_CLAUSE, - Map.of(NormalizationProcessorFactory.TECHNIQUE, "random_name_for_combination") + new HashMap( + Map.of(TECHNIQUE, "", NormalizationProcessorFactory.PARAMETERS, new HashMap<>(Map.of("random_param", "value"))) + ) ) ), pipelineContext