Skip to content

Commit

Permalink
Address review comments: refactor code, use IllegalArgument exception
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Jul 27, 2023
1 parent 96d0b77 commit e320ed3
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@

package org.opensearch.neuralsearch.processor.combination;

import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

import org.opensearch.OpenSearchParseException;
import java.util.stream.Collectors;

/**
* Abstracts combination of scores based on arithmetic mean method
Expand All @@ -21,17 +20,23 @@ public class ArithmeticMeanScoreCombinationTechnique implements ScoreCombination

public static final String TECHNIQUE_NAME = "arithmetic_mean";
public static final String PARAM_NAME_WEIGHTS = "weights";
private static final Set<String> SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS);
private static final Float ZERO_SCORE = 0.0f;
private final List<Double> weights;
private final List<Float> weights;

public ArithmeticMeanScoreCombinationTechnique(final Map<String, Object> params) {
validateParams(params);
weights = getWeights(params);
}

private List<Float> getWeights(final Map<String, Object> params) {
if (Objects.isNull(params) || params.isEmpty()) {
weights = List.of();
return;
return List.of();
}
// get weights, we don't need to check for instance as it's done during validation
weights = (List<Double>) params.getOrDefault(PARAM_NAME_WEIGHTS, new ArrayList<>());
return ((List<Double>) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream()
.map(Double::floatValue)
.collect(Collectors.toUnmodifiableList());
}

/**
Expand All @@ -44,13 +49,11 @@ public ArithmeticMeanScoreCombinationTechnique(final Map<String, Object> params)
public float combine(final float[] scores) {
float combinedScore = 0.0f;
int count = 0;
for (int i = 0; i < scores.length; i++) {
float score = scores[i];
for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) {
float score = scores[indexOfSubQuery];
if (score >= 0.0) {
// apply weight for this sub-query if it's set for particular sub-query
if (i < weights.size()) {
score = (float) (score * weights.get(i));
}
float weight = getWeightForSubQuery(indexOfSubQuery);
score = score * weight;
combinedScore += score;
count++;
}
Expand All @@ -66,20 +69,36 @@ private void validateParams(final Map<String, Object> params) {
return;
}
// check if only supported params are passed
Set<String> supportedParams = Set.of(PARAM_NAME_WEIGHTS);
Optional<String> optionalNotSupportedParam = params.keySet()
.stream()
.filter(paramName -> !supportedParams.contains(paramName))
.filter(paramName -> !SUPPORTED_PARAMS.contains(paramName))
.findFirst();
if (optionalNotSupportedParam.isPresent()) {
throw new OpenSearchParseException("provided parameter for combination technique is not supported");
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"provided parameter for combination technique is not supported. supported parameters are [%s]",
SUPPORTED_PARAMS.stream().collect(Collectors.joining(","))
)
);
}

// check param types
if (params.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) {
if (!(params.get(PARAM_NAME_WEIGHTS) instanceof List)) {
throw new OpenSearchParseException("parameter {} must be a collection of numbers", PARAM_NAME_WEIGHTS);
throw new IllegalArgumentException(
String.format(Locale.ROOT, "parameter [%s] must be a collection of numbers", PARAM_NAME_WEIGHTS)
);
}
}
}

/**
* Get weight for sub-query based on its index in the hybrid search query. Use user provided weight or 1.0 otherwise
* @param indexOfSubQuery 0-based index of sub-query in the Hybrid Search query
* @return weight for sub-query, use one that is set in processor/pipeline definition or 1.0 as default
*/
private float getWeightForSubQuery(int indexOfSubQuery) {
return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery).floatValue() : 1.0f;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
import java.util.Optional;
import java.util.function.Function;

import org.opensearch.OpenSearchParseException;

/**
* Abstracts creation of exact score combination method based on technique name
*/
Expand All @@ -35,11 +33,12 @@ public ScoreCombinationTechnique createCombination(final String technique) {
/**
* Get score combination method by technique name
* @param technique name of technique
* @param params parameters that combination technique may use
* @return instance of ScoreCombinationTechnique for technique name
*/
public ScoreCombinationTechnique createCombination(final String technique, final Map<String, Object> params) {
return Optional.ofNullable(scoreCombinationMethodsMap.get(technique))
.orElseThrow(() -> new OpenSearchParseException("provided combination technique is not supported"))
.orElseThrow(() -> new IllegalArgumentException("provided combination technique is not supported"))
.apply(params);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import lombok.AllArgsConstructor;

import org.opensearch.OpenSearchParseException;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
Expand Down Expand Up @@ -55,12 +56,15 @@ public SearchPhaseResultsProcessor create(
Map<String, Object> combinationClause = readOptionalMap(NormalizationProcessor.TYPE, tag, config, COMBINATION_CLAUSE);

ScoreCombinationTechnique scoreCombinationTechnique = ScoreCombinationFactory.DEFAULT_METHOD;
Map<String, Object> combinationParams;
if (Objects.nonNull(combinationClause)) {
String combinationTechnique = readOptionalStringProperty(NormalizationProcessor.TYPE, tag, combinationClause, TECHNIQUE);
// check for optional combination params
combinationParams = readOptionalMap(NormalizationProcessor.TYPE, tag, combinationClause, PARAMETERS);
scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique, combinationParams);
Map<String, Object> combinationParams = readOptionalMap(NormalizationProcessor.TYPE, tag, combinationClause, PARAMETERS);
try {
scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique, combinationParams);
} catch (IllegalArgumentException illegalArgumentException) {
throw new OpenSearchParseException(illegalArgumentException.getMessage(), illegalArgumentException);
}
}

return new NormalizationProcessor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.neuralsearch.processor.factory;

import static org.hamcrest.Matchers.containsString;
import static org.mockito.Mockito.mock;
import static org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory.COMBINATION_CLAUSE;
import static org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory.NORMALIZATION_CLAUSE;
Expand Down Expand Up @@ -245,7 +246,7 @@ public void testInputValidation_whenInvalidCombinationParams_thenFail() {
boolean ignoreFailure = false;
Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class);

expectThrows(
OpenSearchParseException exceptionBadTechnique = expectThrows(
OpenSearchParseException.class,
() -> normalizationProcessorFactory.create(
processorFactories,
Expand All @@ -270,8 +271,9 @@ public void testInputValidation_whenInvalidCombinationParams_thenFail() {
pipelineContext
)
);
assertThat(exceptionBadTechnique.getMessage(), containsString("provided combination technique is not supported"));

expectThrows(
OpenSearchParseException exceptionInvalidWeights = expectThrows(
OpenSearchParseException.class,
() -> normalizationProcessorFactory.create(
processorFactories,
Expand All @@ -283,14 +285,22 @@ public void testInputValidation_whenInvalidCombinationParams_thenFail() {
NormalizationProcessorFactory.NORMALIZATION_CLAUSE,
new HashMap(Map.of(TECHNIQUE, NORMALIZATION_METHOD)),
NormalizationProcessorFactory.COMBINATION_CLAUSE,
new HashMap(Map.of(TECHNIQUE, "", NormalizationProcessorFactory.PARAMETERS, new HashMap<>(Map.of("weights", 5.0))))
new HashMap(
Map.of(
TECHNIQUE,
COMBINATION_METHOD,
NormalizationProcessorFactory.PARAMETERS,
new HashMap<>(Map.of("weights", 5.0))
)
)
)
),
pipelineContext
)
);
assertThat(exceptionInvalidWeights.getMessage(), containsString("parameter [weights] must be a collection of numbers"));

expectThrows(
OpenSearchParseException exceptionInvalidWeights2 = expectThrows(
OpenSearchParseException.class,
() -> normalizationProcessorFactory.create(
processorFactories,
Expand All @@ -305,7 +315,7 @@ public void testInputValidation_whenInvalidCombinationParams_thenFail() {
new HashMap(
Map.of(
TECHNIQUE,
"",
COMBINATION_METHOD,
NormalizationProcessorFactory.PARAMETERS,
new HashMap<>(Map.of("weights", new Boolean[] { true, false }))
)
Expand All @@ -315,8 +325,9 @@ public void testInputValidation_whenInvalidCombinationParams_thenFail() {
pipelineContext
)
);
assertThat(exceptionInvalidWeights2.getMessage(), containsString("parameter [weights] must be a collection of numbers"));

expectThrows(
OpenSearchParseException exceptionInvalidParam = expectThrows(
OpenSearchParseException.class,
() -> normalizationProcessorFactory.create(
processorFactories,
Expand All @@ -329,12 +340,21 @@ public void testInputValidation_whenInvalidCombinationParams_thenFail() {
new HashMap(Map.of(TECHNIQUE, NORMALIZATION_METHOD)),
NormalizationProcessorFactory.COMBINATION_CLAUSE,
new HashMap(
Map.of(TECHNIQUE, "", NormalizationProcessorFactory.PARAMETERS, new HashMap<>(Map.of("random_param", "value")))
Map.of(
TECHNIQUE,
COMBINATION_METHOD,
NormalizationProcessorFactory.PARAMETERS,
new HashMap<>(Map.of("random_param", "value"))
)
)
)
),
pipelineContext
)
);
assertThat(
exceptionInvalidParam.getMessage(),
containsString("provided parameter for combination technique is not supported. supported parameters are [weights]")
);
}
}

0 comments on commit e320ed3

Please sign in to comment.