-
Notifications
You must be signed in to change notification settings - Fork 66
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add harmonic mean combination (#238)
* Add harmonic mean combination Signed-off-by: Martin Gaievski <[email protected]>
- Loading branch information
1 parent
6ad641a
commit 1f67b94
Showing
12 changed files
with
412 additions
and
191 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
80 changes: 80 additions & 0 deletions
80
src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.neuralsearch.processor.combination; | ||
|
||
import java.util.List; | ||
import java.util.Locale; | ||
import java.util.Map; | ||
import java.util.Objects; | ||
import java.util.Optional; | ||
import java.util.Set; | ||
import java.util.stream.Collectors; | ||
|
||
/** | ||
* Collection of utility methods for score combination technique classes | ||
*/ | ||
class ScoreCombinationUtil { | ||
private static final String PARAM_NAME_WEIGHTS = "weights"; | ||
|
||
/** | ||
* Get collection of weights based on user provided config | ||
* @param params map of named parameters and their values | ||
* @return collection of weights | ||
*/ | ||
public List<Float> getWeights(final Map<String, Object> params) { | ||
if (Objects.isNull(params) || params.isEmpty()) { | ||
return List.of(); | ||
} | ||
// get weights, we don't need to check for instance as it's done during validation | ||
return ((List<Double>) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream() | ||
.map(Double::floatValue) | ||
.collect(Collectors.toUnmodifiableList()); | ||
} | ||
|
||
/** | ||
* Validate config parameters for this technique | ||
* @param actualParams map of parameters in form of name-value | ||
* @param supportedParams collection of parameters that we should validate against, typically that's what is supported by exact technique | ||
*/ | ||
public void validateParams(final Map<String, Object> actualParams, final Set<String> supportedParams) { | ||
if (Objects.isNull(actualParams) || actualParams.isEmpty()) { | ||
return; | ||
} | ||
// check if only supported params are passed | ||
Optional<String> optionalNotSupportedParam = actualParams.keySet() | ||
.stream() | ||
.filter(paramName -> !supportedParams.contains(paramName)) | ||
.findFirst(); | ||
if (optionalNotSupportedParam.isPresent()) { | ||
throw new IllegalArgumentException( | ||
String.format( | ||
Locale.ROOT, | ||
"provided parameter for combination technique is not supported. supported parameters are [%s]", | ||
supportedParams.stream().collect(Collectors.joining(",")) | ||
) | ||
); | ||
} | ||
|
||
// check param types | ||
if (actualParams.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) { | ||
if (!(actualParams.get(PARAM_NAME_WEIGHTS) instanceof List)) { | ||
throw new IllegalArgumentException( | ||
String.format(Locale.ROOT, "parameter [%s] must be a collection of numbers", PARAM_NAME_WEIGHTS) | ||
); | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* Get weight for sub-query based on its index in the hybrid search query. Use user provided weight or 1.0 otherwise | ||
* @param weights collection of weights for sub-queries | ||
* @param indexOfSubQuery 0-based index of sub-query in the Hybrid Search query | ||
* @return weight for sub-query, use one that is set in processor/pipeline definition or 1.0 as default | ||
*/ | ||
public float getWeightForSubQuery(final List<Float> weights, final int indexOfSubQuery) { | ||
return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery) : 1.0f; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.