Skip to content

Commit

Permalink
Use defaults in case technique namer not provided
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Jul 27, 2023
1 parent e320ed3 commit 376af08
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,6 @@ private void validateParams(final Map<String, Object> params) {
* @return weight for sub-query, use one that is set in processor/pipeline definition or 1.0 as default
*/
private float getWeightForSubQuery(int indexOfSubQuery) {
return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery).floatValue() : 1.0f;
return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery) : 1.0f;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,19 @@
package org.opensearch.neuralsearch.processor.factory;

import static org.opensearch.ingest.ConfigurationUtils.readOptionalMap;
import static org.opensearch.ingest.ConfigurationUtils.readOptionalStringProperty;
import static org.opensearch.ingest.ConfigurationUtils.readStringProperty;

import java.util.Map;
import java.util.Objects;

import lombok.AllArgsConstructor;

import org.opensearch.OpenSearchParseException;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.MinMaxScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.pipeline.Processor;
Expand Down Expand Up @@ -49,22 +50,30 @@ public SearchPhaseResultsProcessor create(
Map<String, Object> normalizationClause = readOptionalMap(NormalizationProcessor.TYPE, tag, config, NORMALIZATION_CLAUSE);
ScoreNormalizationTechnique normalizationTechnique = ScoreNormalizationFactory.DEFAULT_METHOD;
if (Objects.nonNull(normalizationClause)) {
String normalizationTechniqueName = (String) normalizationClause.getOrDefault(TECHNIQUE, "");
String normalizationTechniqueName = readStringProperty(
NormalizationProcessor.TYPE,
tag,
normalizationClause,
TECHNIQUE,
MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME
);
normalizationTechnique = scoreNormalizationFactory.createNormalization(normalizationTechniqueName);
}

Map<String, Object> combinationClause = readOptionalMap(NormalizationProcessor.TYPE, tag, config, COMBINATION_CLAUSE);

ScoreCombinationTechnique scoreCombinationTechnique = ScoreCombinationFactory.DEFAULT_METHOD;
if (Objects.nonNull(combinationClause)) {
String combinationTechnique = readOptionalStringProperty(NormalizationProcessor.TYPE, tag, combinationClause, TECHNIQUE);
String combinationTechnique = readStringProperty(
NormalizationProcessor.TYPE,
tag,
combinationClause,
TECHNIQUE,
ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME
);
// check for optional combination params
Map<String, Object> combinationParams = readOptionalMap(NormalizationProcessor.TYPE, tag, combinationClause, PARAMETERS);
try {
scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique, combinationParams);
} catch (IllegalArgumentException illegalArgumentException) {
throw new OpenSearchParseException(illegalArgumentException.getMessage(), illegalArgumentException);
}
scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique, combinationParams);
}

return new NormalizationProcessor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
*/
public class MinMaxScoreNormalizationTechnique implements ScoreNormalizationTechnique {

protected static final String TECHNIQUE_NAME = "min_max";
public static final String TECHNIQUE_NAME = "min_max";
private static final float MIN_SCORE = 0.001f;
private static final float SINGLE_RESULT_SCORE = 1.0f;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import java.util.Map;
import java.util.Optional;

import org.opensearch.OpenSearchParseException;

/**
* Abstracts creation of exact score normalization method based on technique name
*/
Expand All @@ -29,6 +27,6 @@ public class ScoreNormalizationFactory {
*/
public ScoreNormalizationTechnique createNormalization(final String technique) {
return Optional.ofNullable(scoreNormalizationMethodsMap.get(technique))
.orElseThrow(() -> new OpenSearchParseException("provided normalization technique is not supported"));
.orElseThrow(() -> new IllegalArgumentException("provided normalization technique is not supported"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import lombok.SneakyThrows;

import org.opensearch.OpenSearchParseException;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique;
Expand Down Expand Up @@ -64,6 +63,35 @@ public void testNormalizationProcessor_whenNoParams_thenSuccessful() {
assertEquals("normalization-processor", normalizationProcessor.getType());
}

@SneakyThrows
public void testNormalizationProcessor_whenTechniqueNamesNotSet_thenSuccessful() {
NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()),
new ScoreNormalizationFactory(),
new ScoreCombinationFactory()
);
final Map<String, Processor.Factory<SearchPhaseResultsProcessor>> processorFactories = new HashMap<>();
String tag = "tag";
String description = "description";
boolean ignoreFailure = false;
Map<String, Object> config = new HashMap<>();
config.put("normalization", new HashMap<>(Map.of()));
config.put("combination", new HashMap<>(Map.of()));
Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class);
SearchPhaseResultsProcessor searchPhaseResultsProcessor = normalizationProcessorFactory.create(
processorFactories,
tag,
description,
ignoreFailure,
config,
pipelineContext
);
assertNotNull(searchPhaseResultsProcessor);
assertTrue(searchPhaseResultsProcessor instanceof NormalizationProcessor);
NormalizationProcessor normalizationProcessor = (NormalizationProcessor) searchPhaseResultsProcessor;
assertEquals("normalization-processor", normalizationProcessor.getType());
}

@SneakyThrows
public void testNormalizationProcessor_whenWithParams_thenSuccessful() {
NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory(
Expand Down Expand Up @@ -113,7 +141,7 @@ public void testNormalizationProcessor_whenWithCombinationParams_thenSuccessful(
TECHNIQUE,
"arithmetic_mean",
PARAMETERS,
new HashMap<>(Map.of("weights", Arrays.asList(RandomizedTest.randomFloat(), RandomizedTest.randomFloat())))
new HashMap<>(Map.of("weights", Arrays.asList(RandomizedTest.randomDouble(), RandomizedTest.randomDouble())))
)
)
);
Expand Down Expand Up @@ -145,7 +173,7 @@ public void testInputValidation_whenInvalidNormalizationClause_thenFail() {
Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class);

expectThrows(
OpenSearchParseException.class,
IllegalArgumentException.class,
() -> normalizationProcessorFactory.create(
processorFactories,
tag,
Expand All @@ -154,17 +182,17 @@ public void testInputValidation_whenInvalidNormalizationClause_thenFail() {
new HashMap<>(
Map.of(
NormalizationProcessorFactory.NORMALIZATION_CLAUSE,
Map.of(TECHNIQUE, ""),
new HashMap(Map.of(TECHNIQUE, "")),
NormalizationProcessorFactory.COMBINATION_CLAUSE,
Map.of(TECHNIQUE, ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME)
new HashMap(Map.of(TECHNIQUE, ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME))
)
),
pipelineContext
)
);

expectThrows(
OpenSearchParseException.class,
IllegalArgumentException.class,
() -> normalizationProcessorFactory.create(
processorFactories,
tag,
Expand Down Expand Up @@ -196,7 +224,7 @@ public void testInputValidation_whenInvalidCombinationClause_thenFail() {
Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class);

expectThrows(
OpenSearchParseException.class,
IllegalArgumentException.class,
() -> normalizationProcessorFactory.create(
processorFactories,
tag,
Expand All @@ -215,7 +243,7 @@ public void testInputValidation_whenInvalidCombinationClause_thenFail() {
);

expectThrows(
OpenSearchParseException.class,
IllegalArgumentException.class,
() -> normalizationProcessorFactory.create(
processorFactories,
tag,
Expand Down Expand Up @@ -246,8 +274,8 @@ public void testInputValidation_whenInvalidCombinationParams_thenFail() {
boolean ignoreFailure = false;
Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class);

OpenSearchParseException exceptionBadTechnique = expectThrows(
OpenSearchParseException.class,
IllegalArgumentException exceptionBadTechnique = expectThrows(
IllegalArgumentException.class,
() -> normalizationProcessorFactory.create(
processorFactories,
tag,
Expand All @@ -273,8 +301,8 @@ public void testInputValidation_whenInvalidCombinationParams_thenFail() {
);
assertThat(exceptionBadTechnique.getMessage(), containsString("provided combination technique is not supported"));

OpenSearchParseException exceptionInvalidWeights = expectThrows(
OpenSearchParseException.class,
IllegalArgumentException exceptionInvalidWeights = expectThrows(
IllegalArgumentException.class,
() -> normalizationProcessorFactory.create(
processorFactories,
tag,
Expand All @@ -300,8 +328,8 @@ public void testInputValidation_whenInvalidCombinationParams_thenFail() {
);
assertThat(exceptionInvalidWeights.getMessage(), containsString("parameter [weights] must be a collection of numbers"));

OpenSearchParseException exceptionInvalidWeights2 = expectThrows(
OpenSearchParseException.class,
IllegalArgumentException exceptionInvalidWeights2 = expectThrows(
IllegalArgumentException.class,
() -> normalizationProcessorFactory.create(
processorFactories,
tag,
Expand All @@ -327,8 +355,8 @@ public void testInputValidation_whenInvalidCombinationParams_thenFail() {
);
assertThat(exceptionInvalidWeights2.getMessage(), containsString("parameter [weights] must be a collection of numbers"));

OpenSearchParseException exceptionInvalidParam = expectThrows(
OpenSearchParseException.class,
IllegalArgumentException exceptionInvalidParam = expectThrows(
IllegalArgumentException.class,
() -> normalizationProcessorFactory.create(
processorFactories,
tag,
Expand Down

0 comments on commit 376af08

Please sign in to comment.