diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index a1a7a1601..2daa06c1e 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -11,12 +11,12 @@ import java.util.Arrays; import java.util.List; import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; import java.util.function.Consumer; import java.util.stream.Collectors; import org.opensearch.common.CheckedConsumer; -import org.opensearch.common.Nullable; +import org.opensearch.common.cache.Cache; +import org.opensearch.common.cache.CacheBuilder; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.ml.client.MachineLearningNodeClient; @@ -46,46 +46,181 @@ @RequiredArgsConstructor @Log4j2 public class MLCommonsClientAccessor { - private static final List TARGET_RESPONSE_FILTERS = List.of("sentence_embedding"); - private final MachineLearningNodeClient mlClient; - private final Map modelAsymmetryCache = new ConcurrentHashMap<>(); /** - * Wrapper around {@link #inferenceSentences} that expected a single input text and produces a single floating - * point vector as a response. - * - * @param modelId {@link String} - * @param inputText {@link List} of {@link String} on which inference needs to happen - * @param listener {@link ActionListener} which will be called when prediction is completed or errored out + * Inference parameters for calls to the MLCommons client. */ - public void inferenceSentence( - @NonNull final String modelId, - @NonNull final String inputText, - @NonNull final ActionListener> listener - ) { - inferenceSentence(modelId, inputText, null, listener); + public static class InferenceRequest { + + private static final List DEFAULT_TARGET_RESPONSE_FILTERS = List.of("sentence_embedding"); + + private final String modelId; + private final List inputTexts; + private final MLAlgoParams mlAlgoParams; + private final List targetResponseFilters; + private final Map inputObjects; + private final String queryText; + + public InferenceRequest( + @NonNull String modelId, + List inputTexts, + MLAlgoParams mlAlgoParams, + List targetResponseFilters, + Map inputObjects, + String queryText + ) { + this.modelId = modelId; + this.inputTexts = inputTexts; + this.mlAlgoParams = mlAlgoParams; + this.targetResponseFilters = targetResponseFilters == null ? DEFAULT_TARGET_RESPONSE_FILTERS : targetResponseFilters; + this.inputObjects = inputObjects; + this.queryText = queryText; + } + + public String getModelId() { + return modelId; + } + + public List getInputTexts() { + return inputTexts; + } + + public MLAlgoParams getMlAlgoParams() { + return mlAlgoParams; + } + + public List getTargetResponseFilters() { + return targetResponseFilters; + } + + public Map getInputObjects() { + return inputObjects; + } + + public String getQueryText() { + return queryText; + } + + /** + * Builder for {@link InferenceRequest}. Supports fluent construction of the request object. + */ + public static class Builder { + + private final String modelId; + private List inputTexts; + private MLAlgoParams mlAlgoParams; + private List targetResponseFilters; + private Map inputObjects; + private String queryText; + + /** + * @param modelId the model id to use for inference + */ + public Builder(String modelId) { + this.modelId = modelId; + } + + /** + * @param inputTexts a {@link List} of input texts to use for inference + * @return this builder + */ + public Builder inputTexts(List inputTexts) { + this.inputTexts = inputTexts; + return this; + } + + /** + * @param inputText an input text to add to the list of input texts. Repeated calls will add + * more input texts. + * @return this builder + */ + public Builder inputText(String inputText) { + if (this.inputTexts != null) { + this.inputTexts.add(inputText); + } else { + this.inputTexts = new ArrayList<>(); + this.inputTexts.add(inputText); + } + return this; + } + + /** + * @param mlAlgoParams the {@link MLAlgoParams} to use for inference. + * @return this builder + */ + public Builder mlAlgoParams(MLAlgoParams mlAlgoParams) { + this.mlAlgoParams = mlAlgoParams; + return this; + } + + /** + * @param targetResponseFilters a {@link List} of target response filters to use for + * inference + * @return this builder + */ + public Builder targetResponseFilters(List targetResponseFilters) { + this.targetResponseFilters = targetResponseFilters; + return this; + } + + /** + * @param inputObjects {@link Map} of {@link String}, {@link String} on which inference needs + * to happen + * @return this builder + */ + public Builder inputObjects(Map inputObjects) { + this.inputObjects = inputObjects; + return this; + } + + /** + * @param queryText the query text to use for similarity inference + * @return this builder + */ + public Builder queryText(String queryText) { + this.queryText = queryText; + return this; + } + + /** + * @return a new {@link InferenceRequest} object with the parameters set in this builder + */ + public InferenceRequest build() { + return new InferenceRequest(modelId, inputTexts, mlAlgoParams, targetResponseFilters, inputObjects, queryText); + } + + } } + private final MachineLearningNodeClient mlClient; + private final Cache modelAsymmetryCache = CacheBuilder.builder().setMaximumWeight(10_000).build(); + /** - * Wrapper around {@link #inferenceSentences} that expected a single input text and produces a single floating - * point vector as a response. Supports passing {@link MLAlgoParams} to the inference. If the model is - * asymmetric, passing a - * {@link org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters} is - * mandatory. This method will check whether the model being used is asymmetric and correctly handle the - * parameter, so it's okay to always pass the parameter (even if the model is symmetric). + * Wrapper around {@link #inferenceSentencesMap} that expects a single input text and produces a + * single floating point vector as a response. + *

+ * If the model is asymmetric, the {@link InferenceRequest} must contain an + * {@link + * org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters} s + * {@link MLAlgoParams}. This method will check whether the model being used is asymmetric and + * correctly handle the parameter, so it's okay to always pass the parameter (even if the model is + * symmetric). * - * @param modelId {@link String} - * @param inputText {@link List} of {@link String} on which inference needs to happen - * @param mlAlgoParams {@link MLAlgoParams} which will be used to run the inference - * @param listener {@link ActionListener} which will be called when prediction is completed or errored out + * @param inferenceRequest {@link InferenceRequest} containing the modelId and other parameters. + * @param listener {@link ActionListener} which will be called when prediction is + * completed or errored out */ - public void inferenceSentence( - @NonNull final String modelId, - @NonNull final String inputText, - @Nullable final MLAlgoParams mlAlgoParams, - @NonNull final ActionListener> listener - ) { - inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, List.of(inputText), mlAlgoParams, ActionListener.wrap(response -> { + public void inferenceSentence(@NonNull final InferenceRequest inferenceRequest, @NonNull final ActionListener> listener) { + if (inferenceRequest.inputTexts.size() != 1) { + listener.onFailure( + new IllegalArgumentException( + "Unexpected number of input texts. Expected 1 input text, but got [" + inferenceRequest.inputTexts.size() + "]" + ) + ); + return; + } + + inferenceSentences(inferenceRequest, ActionListener.wrap(response -> { if (response.size() != 1) { listener.onFailure( new IllegalStateException( @@ -95,111 +230,57 @@ public void inferenceSentence( return; } - listener.onResponse(response.get(0)); + listener.onResponse(response.getFirst()); }, listener::onFailure)); } /** * Abstraction to call predict function of api of MLClient with default targetResponse filters. It - * uses the custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The + * uses the custom model provided as modelId and runs the {@link FunctionName#TEXT_EMBEDDING}. The * return will be sent using the actionListener which will have a {@link List} of {@link List} of * {@link Float} in the order of inputText. We are not making this function generic enough to take * any function or TaskType as currently we need to run only TextEmbedding tasks only. + *

+ * If the model is asymmetric, the {@link InferenceRequest} must contain an + * {@link + * org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters} as + * {@link MLAlgoParams}. This method will check whether the model being used is asymmetric and + * correctly handle the parameter, so it's okay to always pass the parameter (even if the model is + * symmetric). * - * @param modelId {@link String} - * @param inputText {@link List} of {@link String} on which inference needs to happen - * @param listener {@link ActionListener} which will be called when prediction is completed or - * errored out + * @param inferenceRequest {@link InferenceRequest} containing the modelId and other parameters. + * @param listener {@link ActionListener} which will be called when prediction is + * completed or errored out */ public void inferenceSentences( - @NonNull final String modelId, - @NonNull final List inputText, + @NonNull final InferenceRequest inferenceRequest, @NonNull final ActionListener>> listener ) { - inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, inputText, null, listener); - } - - /** - * Abstraction to call predict function of api of MLClient with default targetResponse filters. It - * uses the custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The - * return will be sent using the actionListener which will have a {@link List} of {@link List} of - * {@link Float} in the order of inputText. We are not making this function generic enough to take - * any function or TaskType as currently we need to run only TextEmbedding tasks only. Supports - * passing {@link MLAlgoParams} to the inference. If the model is asymmetric, passing a - * {@link org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters} - * is mandatory. This method will check whether the model being used is asymmetric and correctly - * handle the parameter, so it's okay to always pass the parameter (even if the model is symmetric). - * - * @param modelId {@link String} - * @param inputText {@link List} of {@link String} on which inference needs to happen - * @param mlAlgoParams {@link MLAlgoParams} which will be used to run the inference - * @param listener {@link ActionListener} which will be called when prediction is completed or - */ - public void inferenceSentences( - @NonNull final String modelId, - @NonNull final List inputText, - @Nullable final MLAlgoParams mlAlgoParams, - @NonNull final ActionListener>> listener - ) { - inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, inputText, mlAlgoParams, listener); - } - - /** - * Abstraction to call predict function of api of MLClient with provided targetResponse filters. - * It uses the custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. - * The return will be sent using the actionListener which will have a {@link List} of {@link List} - * of {@link Float} in the order of inputText. We are not making this function generic enough to - * take any function or TaskType as currently we need to run only TextEmbedding tasks only. - * - * @param targetResponseFilters {@link List} of {@link String} which filters out the responses - * @param modelId {@link String} - * @param inputText {@link List} of {@link String} on which inference needs to happen - * @param listener {@link ActionListener} which will be called when prediction is - * completed or errored out. - */ - public void inferenceSentences( - @NonNull final List targetResponseFilters, - @NonNull final String modelId, - @NonNull final List inputText, - @NonNull final ActionListener>> listener - ) { - inferenceSentences(targetResponseFilters, modelId, inputText, null, listener); - } - - /** - * Abstraction to call predict function of api of MLClient with provided targetResponse filters. - * It uses the custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. - * The return will be sent using the actionListener which will have a {@link List} of {@link List} - * of {@link Float} in the order of inputText. We are not making this function generic enough to - * take any function or TaskType as currently we need to run only TextEmbedding tasks only. Supports - * passing {@link MLAlgoParams} to the inference. If the model is asymmetric, passing a - * {@link org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters} - * is mandatory. This method will check whether the model being used is asymmetric and correctly - * handle the parameter, so it's okay to always pass the parameter (even if the model is symmetric). - * - * @param targetResponseFilters {@link List} of {@link String} which filters out the responses - * @param modelId {@link String} - * @param inputText {@link List} of {@link String} on which inference needs to happen - * @param mlAlgoParams {@link MLAlgoParams} which will be used to run the inference - * @param listener {@link ActionListener} which will be called when prediction is - * completed or errored out. - */ - public void inferenceSentences( - @NonNull final List targetResponseFilters, - @NonNull final String modelId, - @NonNull final List inputText, - @Nullable final MLAlgoParams mlAlgoParams, - @NonNull final ActionListener>> listener - ) { - retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, mlAlgoParams, 0, listener); + if (inferenceRequest.inputTexts.isEmpty()) { + listener.onFailure(new IllegalArgumentException("inputTexts must be provided")); + return; + } + retryableInferenceSentencesWithVectorResult( + inferenceRequest.targetResponseFilters, + inferenceRequest.modelId, + inferenceRequest.inputTexts, + inferenceRequest.mlAlgoParams, + 0, + listener + ); } public void inferenceSentencesWithMapResult( - @NonNull final String modelId, - @NonNull final List inputText, + @NonNull InferenceRequest inferenceRequest, @NonNull final ActionListener>> listener ) { - retryableInferenceSentencesWithMapResult(modelId, inputText, null, 0, listener); + retryableInferenceSentencesWithMapResult( + inferenceRequest.modelId, + inferenceRequest.inputTexts, + inferenceRequest.mlAlgoParams, + 0, + listener + ); } /** @@ -207,45 +288,35 @@ public void inferenceSentencesWithMapResult( * It uses the custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. * The return will be sent using the actionListener which will have a list of floats in the order * of inputText. + *

+ * If the model is asymmetric, the {@link InferenceRequest} must contain an + * {@link + * org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters} as + * {@link MLAlgoParams}. This method will check whether the model being used is asymmetric and + * correctly handle the parameter, so it's okay to always pass the parameter (even if the model is + * symmetric). * - * @param modelId {@link String} - * @param inputObjects {@link Map} of {@link String}, {@link String} on which inference needs to - * happen - * @param listener {@link ActionListener} which will be called when prediction is completed or - * errored out. + * @param inferenceRequest {@link InferenceRequest} containing the modelId and other parameters. + * Must contain inputObjects. + * @param listener {@link ActionListener} which will be called when prediction is + * completed or errored out. */ - public void inferenceSentences( - @NonNull final String modelId, - @NonNull final Map inputObjects, + public void inferenceSentencesMap( + @NonNull final InferenceRequest inferenceRequest, @NonNull final ActionListener> listener ) { - inferenceSentences(modelId, inputObjects, null, listener); - } - - /** - * Abstraction to call predict function of api of MLClient with provided targetResponse filters. - * It uses the custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. - * The return will be sent using the actionListener which will have a list of floats in the order - * of inputText. Supports passing {@link MLAlgoParams} to the inference. If the model is asymmetric, - * passing a - * {@link org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters} - * is mandatory. This method will check whether the model being used is asymmetric and correctly - * handle the parameter, so it's okay to always pass the parameter (even if the model is symmetric). - * - * @param modelId {@link String} - * @param inputObjects {@link Map} of {@link String}, {@link String} on which inference needs to - * happen - * @param mlAlgoParams {@link MLAlgoParams} which will be used to run the inference - * @param listener {@link ActionListener} which will be called when prediction is completed or - * errored out. - */ - public void inferenceSentences( - @NonNull final String modelId, - @NonNull final Map inputObjects, - @Nullable final MLAlgoParams mlAlgoParams, - @NonNull final ActionListener> listener - ) { - retryableInferenceSentencesWithSingleVectorResult(TARGET_RESPONSE_FILTERS, modelId, inputObjects, mlAlgoParams, 0, listener); + if (inferenceRequest.inputObjects == null) { + listener.onFailure(new IllegalArgumentException("inputObjects must be provided")); + return; + } + retryableInferenceSentencesWithSingleVectorResult( + inferenceRequest.targetResponseFilters, + inferenceRequest.modelId, + inferenceRequest.inputObjects, + inferenceRequest.mlAlgoParams, + 0, + listener + ); } /** @@ -254,18 +325,22 @@ public void inferenceSentences( * actionListener as a list of floats representing the similarity scores of the texts w.r.t. the * query text, in the order of the input texts. * - * @param modelId {@link String} ML-Commons Model Id - * @param queryText {@link String} The query to compare all the inputText to - * @param inputText {@link List} of {@link String} The texts to compare to the query - * @param listener {@link ActionListener} receives the result of the inference + * @param inferenceRequest {@link InferenceRequest} containing the modelId and other parameters. + * Must contain queryText. + * @param listener {@link ActionListener} receives the result of the inference */ - public void inferenceSimilarity( - @NonNull final String modelId, - @NonNull final String queryText, - @NonNull final List inputText, - @NonNull final ActionListener> listener - ) { - retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, 0, listener); + public void inferenceSimilarity(@NonNull final InferenceRequest inferenceRequest, @NonNull final ActionListener> listener) { + if (inferenceRequest.queryText == null) { + listener.onFailure(new IllegalArgumentException("queryText must be provided")); + return; + } + retryableInferenceSimilarityWithVectorResult( + inferenceRequest.modelId, + inferenceRequest.queryText, + inferenceRequest.inputTexts, + 0, + listener + ); } private void retryableInferenceSentencesWithMapResult( @@ -329,30 +404,29 @@ private void retryableInferenceSentencesWithVectorResult( } /** - * Check if the model is asymmetric and then run the prediction. Model asymmetry is a concept - * that is specific to TextEmbeddingModelConfig. If the model is not a TextEmbeddingModel, then - * this check is not applicable. - * + * Check if the model is asymmetric and then run the prediction. Model asymmetry is a concept that + * is specific to TextEmbeddingModelConfig. If the model is not a TextEmbeddingModel, then this + * check is not applicable. + *

* The asymmetry of a model is static for a given model. To avoid repeated checks for the same * model, we cache the model asymmetry status. Non-TextEmbeddingModels are cached as false. * - * @param modelId The model id to check - * @param onFailure The action to take if the model cannot be retrieved + * @param modelId The model id to check + * @param onFailure The action to take if the model cannot be retrieved * @param runPrediction The action to take if the model is successfully retrieved */ private void checkModelAsymmetryAndThenPredict(String modelId, Consumer onFailure, Consumer runPrediction) { CheckedConsumer checkModelAsymmetryListener = model -> { MLModelConfig modelConfig = model.getModelConfig(); - if (!(modelConfig instanceof TextEmbeddingModelConfig)) { - modelAsymmetryCache.putIfAbsent(modelId, false); + if (!(modelConfig instanceof TextEmbeddingModelConfig textEmbeddingModelConfig)) { + modelAsymmetryCache.computeIfAbsent(modelId, k -> false); return; } - final TextEmbeddingModelConfig textEmbeddingModelConfig = (TextEmbeddingModelConfig) modelConfig; final boolean isAsymmetricModel = textEmbeddingModelConfig.getPassagePrefix() != null || textEmbeddingModelConfig.getQueryPrefix() != null; - modelAsymmetryCache.putIfAbsent(modelId, isAsymmetricModel); + modelAsymmetryCache.computeIfAbsent(modelId, k -> isAsymmetricModel); }; - if (modelAsymmetryCache.containsKey(modelId)) { + if (modelAsymmetryCache.get(modelId) != null) { runPrediction.accept(modelAsymmetryCache.get(modelId)); } else { mlClient.getModel(modelId, ActionListener.wrap(mlModel -> { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index e01840fbb..682a54126 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -14,6 +14,7 @@ import org.opensearch.env.Environment; import org.opensearch.ingest.IngestDocument; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest; import org.opensearch.neuralsearch.util.TokenWeightUtil; import lombok.extern.log4j.Log4j2; @@ -48,17 +49,19 @@ public void doExecute( List inferenceList, BiConsumer handler ) { - mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { - setVectorFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)); - handler.accept(ingestDocument, null); - }, e -> { handler.accept(null, e); })); + mlCommonsClientAccessor.inferenceSentencesWithMapResult( + new InferenceRequest.Builder(this.modelId).inputTexts(inferenceList).build(), + ActionListener.wrap(resultMaps -> { + setVectorFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)); + handler.accept(ingestDocument, null); + }, e -> { handler.accept(null, e); }) + ); } @Override public void doBatchExecute(List inferenceList, Consumer> handler, Consumer onException) { mlCommonsClientAccessor.inferenceSentencesWithMapResult( - this.modelId, - inferenceList, + new InferenceRequest.Builder(this.modelId).inputTexts(inferenceList).build(), ActionListener.wrap(resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)), onException) ); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index 23dc6af49..16862c27c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -18,10 +18,12 @@ import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest; /** - * This processor is used for user input data text embedding processing, model_id can be used to indicate which model user use, - * and field_map can be used to indicate which fields needs text embedding and the corresponding keys for the text embedding results. + * This processor is used for user input data text embedding processing, model_id can be used to + * indicate which model user use, and field_map can be used to indicate which fields needs text + * embedding and the corresponding keys for the text embedding results. */ @Log4j2 public final class TextEmbeddingProcessor extends InferenceProcessor { @@ -29,6 +31,10 @@ public final class TextEmbeddingProcessor extends InferenceProcessor { public static final String TYPE = "text_embedding"; public static final String LIST_TYPE_NESTED_MAP_KEY = "knn"; + private static final AsymmetricTextEmbeddingParameters PASSAGE_PARAMETERS = AsymmetricTextEmbeddingParameters.builder() + .embeddingContentType(EmbeddingContentType.PASSAGE) + .build(); + public TextEmbeddingProcessor( String tag, String description, @@ -50,9 +56,7 @@ public void doExecute( BiConsumer handler ) { mlCommonsClientAccessor.inferenceSentences( - this.modelId, - inferenceList, - AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(), + new InferenceRequest.Builder(this.modelId).inputTexts(inferenceList).mlAlgoParams(PASSAGE_PARAMETERS).build(), ActionListener.wrap(vectors -> { setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors); handler.accept(ingestDocument, null); @@ -62,6 +66,9 @@ public void doExecute( @Override public void doBatchExecute(List inferenceList, Consumer> handler, Consumer onException) { - mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(handler::accept, onException)); + mlCommonsClientAccessor.inferenceSentences( + new InferenceRequest.Builder(this.modelId).inputTexts(inferenceList).mlAlgoParams(PASSAGE_PARAMETERS).build(), + ActionListener.wrap(handler::accept, onException) + ); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java index e808869f9..f71110cdc 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java @@ -25,6 +25,7 @@ import com.google.common.annotations.VisibleForTesting; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest; import org.opensearch.neuralsearch.util.ProcessorDocumentUtils; /** @@ -113,10 +114,13 @@ public void execute(final IngestDocument ingestDocument, final BiConsumer { - setVectorFieldsToDocument(ingestDocument, vectors); - handler.accept(ingestDocument, null); - }, e -> { handler.accept(null, e); })); + mlCommonsClientAccessor.inferenceSentencesMap( + new InferenceRequest.Builder(this.modelId).inputObjects(inferenceMap).build(), + ActionListener.wrap(vectors -> { + setVectorFieldsToDocument(ingestDocument, vectors); + handler.accept(ingestDocument, null); + }, e -> { handler.accept(null, e); }) + ); } } catch (Exception e) { handler.accept(null, e); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java index d8d9e8ec3..29508bd40 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java @@ -12,6 +12,7 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.core.action.ActionListener; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest; import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory; import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher; import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher; @@ -73,9 +74,9 @@ public void rescoreSearchResponse( List ctxList = (List) ctxObj; List contexts = ctxList.stream().map(str -> (String) str).collect(Collectors.toList()); mlCommonsClientAccessor.inferenceSimilarity( - modelId, - (String) rerankingContext.get(QueryContextSourceFetcher.QUERY_TEXT_FIELD), - contexts, + new InferenceRequest.Builder(modelId).queryText((String) rerankingContext.get(QueryContextSourceFetcher.QUERY_TEXT_FIELD)) + .inputTexts(contexts) + .build(), listener ); } diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java index c8b2a1d4a..2fccfafb0 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java @@ -57,11 +57,12 @@ import lombok.Setter; import lombok.experimental.Accessors; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest; /** - * NeuralQueryBuilder is responsible for producing "neural" query types. A "neural" query type is a wrapper around a - * k-NN vector query. It uses a ML language model to produce a dense vector from a query string that is then used as - * the query vector for the k-NN search. + * NeuralQueryBuilder is responsible for producing "neural" query types. A "neural" query type is a + * wrapper around a k-NN vector query. It uses a ML language model to produce a dense vector from a + * query string that is then used as the query vector for the k-NN search. */ @Log4j2 @@ -86,6 +87,9 @@ public class NeuralQueryBuilder extends AbstractQueryBuilder static final ParseField K_FIELD = new ParseField("k"); private static final int DEFAULT_K = 10; + private static final AsymmetricTextEmbeddingParameters QUERY_PARAMETERS = AsymmetricTextEmbeddingParameters.builder() + .embeddingContentType(EmbeddingContentType.QUERY) + .build(); private static MLCommonsClientAccessor ML_CLIENT; @@ -335,10 +339,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { inferenceInput.put(INPUT_IMAGE, queryImage()); } queryRewriteContext.registerAsyncAction( - ((client, actionListener) -> ML_CLIENT.inferenceSentences( - modelId(), - inferenceInput, - AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.QUERY).build(), + ((client, actionListener) -> ML_CLIENT.inferenceSentencesMap( + new InferenceRequest.Builder(modelId()).inputObjects(inferenceInput).mlAlgoParams(QUERY_PARAMETERS).build(), ActionListener.wrap(floatList -> { vectorSetOnce.set(vectorAsListToArray(floatList)); actionListener.onResponse(null); @@ -368,8 +370,12 @@ protected Query doToQuery(QueryShardContext queryShardContext) { @Override protected boolean doEquals(NeuralQueryBuilder obj) { - if (this == obj) return true; - if (obj == null || getClass() != obj.getClass()) return false; + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } EqualsBuilder equalsBuilder = new EqualsBuilder(); equalsBuilder.append(fieldName, obj.fieldName); equalsBuilder.append(queryText, obj.queryText); diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java index f46997d5e..ef9657759 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java @@ -37,6 +37,7 @@ import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest; import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; import org.opensearch.neuralsearch.util.TokenWeightUtil; @@ -341,8 +342,7 @@ private BiConsumer> getModelInferenceAsync(SetOnce ML_CLIENT.inferenceSentencesWithMapResult( - modelId(), - List.of(queryText), + new InferenceRequest.Builder(modelId()).inputTexts(List.of(queryText)).build(), ActionListener.wrap(mapResultList -> { Map queryTokens = TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList).get(0); if (Objects.nonNull(twoPhaseSharedQueryToken)) { diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index 3c0376909..d45c6f15c 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -34,6 +34,7 @@ import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.neuralsearch.constants.TestCommonConstants; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.transport.NodeNotConnectedException; @@ -66,9 +67,7 @@ public void testInferenceSentence_whenValidInput_thenSuccess() { setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSentence( - TestCommonConstants.MODEL_ID, - TestCommonConstants.SENTENCES_LIST.get(0), - AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(), + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputText(TestCommonConstants.SENTENCES_LIST.get(0)).build(), singleSentenceResultListener ); @@ -101,7 +100,10 @@ public void testInferenceSentences_whenValidInputThenSuccess() { }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); setupMocksForTextEmbeddingModelAsymmetryCheck(false); - accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + accessor.inferenceSentences( + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + resultListener + ); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -119,7 +121,10 @@ public void testInferenceSentences_whenResultFromClient_thenEmptyVectorList() { }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); setupMocksForTextEmbeddingModelAsymmetryCheck(false); - accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + accessor.inferenceSentences( + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + resultListener + ); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -137,9 +142,9 @@ public void testInferenceSentences_whenExceptionFromMLClient_thenFailure() { setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSentences( - TestCommonConstants.TARGET_RESPONSE_FILTERS, - TestCommonConstants.MODEL_ID, - TestCommonConstants.SENTENCES_LIST, + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST) + .targetResponseFilters(TestCommonConstants.TARGET_RESPONSE_FILTERS) + .build(), resultListener ); @@ -163,9 +168,9 @@ public void testInferenceSentences_whenNodeNotConnectedException_thenRetry_3Time setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSentences( - TestCommonConstants.TARGET_RESPONSE_FILTERS, - TestCommonConstants.MODEL_ID, - TestCommonConstants.SENTENCES_LIST, + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST) + .targetResponseFilters(TestCommonConstants.TARGET_RESPONSE_FILTERS) + .build(), resultListener ); @@ -185,9 +190,9 @@ public void testInferenceSentences_whenNotConnectionException_thenNoRetry() { setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSentences( - TestCommonConstants.TARGET_RESPONSE_FILTERS, - TestCommonConstants.MODEL_ID, - TestCommonConstants.SENTENCES_LIST, + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).targetResponseFilters(TestCommonConstants.TARGET_RESPONSE_FILTERS) + .inputTexts(TestCommonConstants.SENTENCES_LIST) + .build(), resultListener ); @@ -206,9 +211,9 @@ public void testInferenceSentences_whenModelAsymmetric_thenSuccess() { setupMocksForTextEmbeddingModelAsymmetryCheck(true); accessor.inferenceSentence( - TestCommonConstants.MODEL_ID, - TestCommonConstants.SENTENCES_LIST.get(0), - AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(), + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputText(TestCommonConstants.SENTENCES_LIST.get(0)) + .mlAlgoParams(AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build()) + .build(), singleSentenceResultListener ); @@ -233,9 +238,9 @@ public void testInferenceSentences_whenGetModelException_thenFailure() { setupMocksForTextEmbeddingModelAsymmetryCheck(exception); accessor.inferenceSentence( - TestCommonConstants.MODEL_ID, - TestCommonConstants.SENTENCES_LIST.get(0), - AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(), + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputText(TestCommonConstants.SENTENCES_LIST.get(0)) + .mlAlgoParams(AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build()) + .build(), singleSentenceResultListener ); @@ -263,7 +268,10 @@ public void testInferenceSentencesWithMapResult_whenValidInput_thenSuccess() { setupMocksForTextEmbeddingModelAsymmetryCheck(false); - accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + accessor.inferenceSentencesWithMapResult( + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + resultListener + ); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -282,7 +290,10 @@ public void testInferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenEx setupMocksForTextEmbeddingModelAsymmetryCheck(false); - accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + accessor.inferenceSentencesWithMapResult( + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + resultListener + ); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -309,7 +320,10 @@ public void testInferenceSentencesWithMapResult_whenModelTensorListEmpty_thenExc setupMocksForTextEmbeddingModelAsymmetryCheck(false); - accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + accessor.inferenceSentencesWithMapResult( + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + resultListener + ); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -339,7 +353,10 @@ public void testInferenceSentencesWithMapResult_whenModelTensorListSizeBiggerTha setupMocksForTextEmbeddingModelAsymmetryCheck(false); - accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + accessor.inferenceSentencesWithMapResult( + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + resultListener + ); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -361,7 +378,10 @@ public void testInferenceSentencesWithMapResult_whenRetryableException_retry3Tim setupMocksForTextEmbeddingModelAsymmetryCheck(false); - accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + accessor.inferenceSentencesWithMapResult( + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + resultListener + ); Mockito.verify(client, times(4)) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -379,7 +399,10 @@ public void testInferenceSentencesWithMapResult_whenNotRetryableException_thenFa setupMocksForTextEmbeddingModelAsymmetryCheck(false); - accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + accessor.inferenceSentencesWithMapResult( + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputTexts(TestCommonConstants.SENTENCES_LIST).build(), + resultListener + ); Mockito.verify(client, times(1)) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -396,7 +419,10 @@ public void testInferenceMultimodal_whenValidInput_thenSuccess() { setupMocksForTextEmbeddingModelAsymmetryCheck(false); - accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); + accessor.inferenceSentencesMap( + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputObjects(TestCommonConstants.SENTENCES_MAP).build(), + singleSentenceResultListener + ); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -414,7 +440,10 @@ public void testInferenceMultimodal_whenExceptionFromMLClient_thenFailure() { setupMocksForTextEmbeddingModelAsymmetryCheck(false); - accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); + accessor.inferenceSentencesMap( + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputObjects(TestCommonConstants.SENTENCES_MAP).build(), + singleSentenceResultListener + ); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -422,7 +451,7 @@ public void testInferenceMultimodal_whenExceptionFromMLClient_thenFailure() { Mockito.verifyNoMoreInteractions(singleSentenceResultListener); } - public void testInferenceSentencesMultimodal_whenNodeNotConnectedException_thenRetryThreeTimes() { + public void testInferenceSentencesMapMultimodal_whenNodeNotConnectedException_thenRetryThreeTimes() { final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException( mock(DiscoveryNode.class), "Node not connected" @@ -435,7 +464,10 @@ public void testInferenceSentencesMultimodal_whenNodeNotConnectedException_thenR setupMocksForTextEmbeddingModelAsymmetryCheck(false); - accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); + accessor.inferenceSentencesMap( + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).inputObjects(TestCommonConstants.SENTENCES_MAP).build(), + singleSentenceResultListener + ); Mockito.verify(client, times(4)) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -453,9 +485,9 @@ public void testInferenceSimilarity_whenValidInput_thenSuccess() { setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSimilarity( - TestCommonConstants.MODEL_ID, - "is it sunny", - List.of("it is sunny today", "roses are red"), + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).queryText("is it sunny") + .inputTexts(List.of("it is sunny today", "roses are red")) + .build(), singleSentenceResultListener ); @@ -476,9 +508,9 @@ public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() { setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSimilarity( - TestCommonConstants.MODEL_ID, - "is it sunny", - List.of("it is sunny today", "roses are red"), + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).queryText("is it sunny") + .inputTexts(List.of("it is sunny today", "roses are red")) + .build(), singleSentenceResultListener ); @@ -502,9 +534,9 @@ public void testInferenceSimilarity_whenNodeNotConnectedException_ThenTryThreeTi setupMocksForTextEmbeddingModelAsymmetryCheck(false); accessor.inferenceSimilarity( - TestCommonConstants.MODEL_ID, - "is it sunny", - List.of("it is sunny today", "roses are red"), + new InferenceRequest.Builder(TestCommonConstants.MODEL_ID).queryText("is it sunny") + .inputTexts(List.of("it is sunny today", "roses are red")) + .build(), singleSentenceResultListener ); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java index dc86975bd..52ee1009e 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java @@ -23,10 +23,10 @@ import java.util.Map; import java.util.function.BiConsumer; import java.util.function.Consumer; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyList; -import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.RETURNS_DEEP_STUBS; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -63,7 +63,7 @@ public void test_batchExecute_emptyInput() { ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); verify(resultHandler).accept(captor.capture()); assertTrue(captor.getValue().isEmpty()); - verify(clientAccessor, never()).inferenceSentences(anyString(), anyList(), any()); + verify(clientAccessor, never()).inferenceSentencesMap(argThat(request -> request.getInputTexts() != null), any()); } public void test_batchExecuteWithEmpty_allFailedValidation() { @@ -85,7 +85,7 @@ public void test_batchExecuteWithEmpty_allFailedValidation() { ); assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument()); } - verify(clientAccessor, never()).inferenceSentences(anyString(), anyList(), any()); + verify(clientAccessor, never()).inferenceSentencesMap(argThat(request -> request.getInputTexts() != null), any()); } public void test_batchExecuteWithNull_allFailedValidation() { @@ -104,7 +104,7 @@ public void test_batchExecuteWithNull_allFailedValidation() { assertEquals("list type field [key1] has null, cannot process it", captor.getValue().get(i).getException().getMessage()); assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument()); } - verify(clientAccessor, never()).inferenceSentences(anyString(), anyList(), any()); + verify(clientAccessor, never()).inferenceSentencesMap(argThat(request -> request.getInputTexts() != null), any()); } public void test_batchExecute_partialFailedValidation() { @@ -123,9 +123,9 @@ public void test_batchExecute_partialFailedValidation() { for (int i = 0; i < docCount; ++i) { assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument()); } - ArgumentCaptor> inferenceTextCaptor = ArgumentCaptor.forClass(List.class); - verify(clientAccessor).inferenceSentences(anyString(), inferenceTextCaptor.capture(), any()); - assertEquals(2, inferenceTextCaptor.getValue().size()); + ArgumentCaptor inferenceRequestArgumentCaptor = ArgumentCaptor.forClass(InferenceRequest.class); + verify(clientAccessor).inferenceSentences(inferenceRequestArgumentCaptor.capture(), any()); + assertEquals(2, inferenceRequestArgumentCaptor.getValue().getInputTexts().size()); } public void test_batchExecute_happyCase() { @@ -144,9 +144,9 @@ public void test_batchExecute_happyCase() { assertNull(captor.getValue().get(i).getException()); assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument()); } - ArgumentCaptor> inferenceTextCaptor = ArgumentCaptor.forClass(List.class); - verify(clientAccessor).inferenceSentences(anyString(), inferenceTextCaptor.capture(), any()); - assertEquals(4, inferenceTextCaptor.getValue().size()); + ArgumentCaptor inferenceRequestArgumentCaptor = ArgumentCaptor.forClass(InferenceRequest.class); + verify(clientAccessor).inferenceSentences(inferenceRequestArgumentCaptor.capture(), any()); + assertEquals(4, inferenceRequestArgumentCaptor.getValue().getInputTexts().size()); } public void test_batchExecute_sort() { @@ -165,10 +165,10 @@ public void test_batchExecute_sort() { assertNull(captor.getValue().get(i).getException()); assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument()); } - ArgumentCaptor> inferenceTextCaptor = ArgumentCaptor.forClass(List.class); - verify(clientAccessor).inferenceSentences(anyString(), inferenceTextCaptor.capture(), any()); - assertEquals(4, inferenceTextCaptor.getValue().size()); - assertEquals(Arrays.asList("cc", "bbb", "ddd", "aaaaa"), inferenceTextCaptor.getValue()); + ArgumentCaptor inferenceRequestArgumentCaptor = ArgumentCaptor.forClass(InferenceRequest.class); + verify(clientAccessor).inferenceSentences(inferenceRequestArgumentCaptor.capture(), any()); + assertEquals(4, inferenceRequestArgumentCaptor.getValue().getInputTexts().size()); + assertEquals(Arrays.asList("cc", "bbb", "ddd", "aaaaa"), inferenceRequestArgumentCaptor.getValue().getInputTexts()); List doc1Embeddings = (List) (captor.getValue().get(0).getIngestDocument().getFieldValue("embedding_key1", List.class)); List doc2Embeddings = (List) (captor.getValue().get(1).getIngestDocument().getFieldValue("embedding_key1", List.class)); @@ -197,7 +197,7 @@ public void test_doBatchExecute_exception() { assertNotNull(captor.getValue().get(i).getException()); assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument()); } - verify(clientAccessor).inferenceSentences(anyString(), anyList(), any()); + verify(clientAccessor).inferenceSentences(argThat(request -> request.getInputTexts() != null), any()); } public void test_batchExecute_subBatches() { @@ -245,7 +245,10 @@ public void doExecute( @Override void doBatchExecute(List inferenceList, Consumer> handler, Consumer onException) { // use to verify if doBatchExecute is called from InferenceProcessor - clientAccessor.inferenceSentences(MODEL_ID, inferenceList, ActionListener.wrap(results -> {}, ex -> {})); + clientAccessor.inferenceSentences( + new InferenceRequest.Builder(MODEL_ID).inputTexts(inferenceList).build(), + ActionListener.wrap(results -> {}, ex -> {}) + ); allInferenceInputs.add(inferenceList); if (this.exception != null) { onException.accept(this.exception); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java index 9486ee2ca..ac2a1f0d7 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java @@ -4,10 +4,10 @@ */ package org.opensearch.neuralsearch.processor; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.RETURNS_DEEP_STUBS; @@ -100,10 +100,11 @@ public void testExecute_successful() { List> dataAsMapList = createMockMapResult(2); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(dataAsMapList); return null; - }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesWithMapResult(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -132,7 +133,8 @@ public void testExecute_whenInferenceTextListEmpty_SuccessWithoutAnyMap() { DESCRIPTION, config ); - doThrow(new RuntimeException()).when(accessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + doThrow(new RuntimeException()).when(accessor) + .inferenceSentencesMap(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(any(IngestDocument.class), isNull()); @@ -150,10 +152,11 @@ public void testExecute_withListTypeInput_successful() { List> dataAsMapList = createMockMapResult(6); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(dataAsMapList); return null; - }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesWithMapResult(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -169,10 +172,11 @@ public void testExecute_MLClientAccessorThrowFail_handlerFailure() { SparseEncodingProcessor processor = createInstance(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onFailure(new IllegalArgumentException("illegal argument")); return null; - }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesWithMapResult(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -206,10 +210,11 @@ public void testExecute_withMapTypeInput_successful() { List> dataAsMapList = createMockMapResult(2); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(dataAsMapList); return null; - }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesWithMapResult(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -223,10 +228,11 @@ public void test_batchExecute_successful() { SparseEncodingProcessor processor = createInstance(docCount); List> dataAsMapList = createMockMapResult(10); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(dataAsMapList); return null; - }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesWithMapResult(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); Consumer resultHandler = mock(Consumer.class); processor.batchExecute(ingestDocumentWrappers, resultHandler); @@ -244,10 +250,11 @@ public void test_batchExecute_exception() { List ingestDocumentWrappers = createIngestDocumentWrappers(docCount); SparseEncodingProcessor processor = createInstance(docCount); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onFailure(new RuntimeException()); return null; - }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesWithMapResult(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); Consumer resultHandler = mock(Consumer.class); processor.batchExecute(ingestDocumentWrappers, resultHandler); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index 1d83c8c95..ea009e777 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -4,8 +4,8 @@ */ package org.opensearch.neuralsearch.processor; -import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.RETURNS_DEEP_STUBS; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; @@ -44,7 +44,6 @@ import org.opensearch.ingest.IngestDocument; import org.opensearch.ingest.IngestDocumentWrapper; import org.opensearch.ingest.Processor; -import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; @@ -152,10 +151,14 @@ public void testExecute_successful() { List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(3); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentences( + argThat(request -> request.getMlAlgoParams() != null && request.getInputTexts() != null), + isA(ActionListener.class) + ); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -186,7 +189,10 @@ public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeExcep config ); doThrow(new RuntimeException()).when(accessor) - .inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); + .inferenceSentences( + argThat(request -> request.getMlAlgoParams() != null && request.getInputTexts() != null), + isA(ActionListener.class) + ); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(RuntimeException.class)); @@ -214,7 +220,8 @@ public void testExecute_whenInferenceTextListEmpty_SuccessWithoutEmbedding() { DESCRIPTION, config ); - doThrow(new RuntimeException()).when(accessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + doThrow(new RuntimeException()).when(accessor) + .inferenceSentences(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(any(IngestDocument.class), isNull()); @@ -232,10 +239,14 @@ public void testExecute_withListTypeInput_successful() { List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(3); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentences( + argThat(request -> request.getMlAlgoParams() != null && request.getInputTexts() != null), + isA(ActionListener.class) + ); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -308,10 +319,14 @@ public void testExecute_withMapTypeInput_successful() { List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(3); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentences( + argThat(request -> request.getMlAlgoParams() != null && request.getInputTexts() != null), + isA(ActionListener.class) + ); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -349,10 +364,14 @@ public void testNestedFieldInMapping_withMapTypeInput_successful() { List> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(3); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentences( + argThat(request -> request.getMlAlgoParams() != null && request.getInputTexts() != null), + isA(ActionListener.class) + ); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); assertNotNull(ingestDocument); @@ -409,10 +428,14 @@ public void testNestedFieldInMappingForSourceAndDestination_withIngestDocumentHa List> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(3); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentences( + argThat(request -> request.getMlAlgoParams() != null && request.getInputTexts() != null), + isA(ActionListener.class) + ); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); assertNotNull(ingestDocument); @@ -467,10 +490,14 @@ public void testNestedFieldInMappingForSourceAndDestination_withIngestDocumentWi List> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(3); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentences( + argThat(request -> request.getMlAlgoParams() != null && request.getInputTexts() != null), + isA(ActionListener.class) + ); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); assertNotNull(ingestDocument); @@ -518,10 +545,14 @@ public void testNestedFieldInMappingMixedSyntax_withMapTypeInput_successful() { List> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(3); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentences( + argThat(request -> request.getMlAlgoParams() != null && request.getInputTexts() != null), + isA(ActionListener.class) + ); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); assertNotNull(ingestDocument); @@ -587,10 +618,14 @@ public void testExecute_MLClientAccessorThrowFail_handlerFailure() { TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(3); + ActionListener>> listener = invocation.getArgument(1); listener.onFailure(new IllegalArgumentException("illegal argument")); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentences( + argThat(request -> request.getMlAlgoParams() != null && request.getInputTexts() != null), + isA(ActionListener.class) + ); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -630,7 +665,7 @@ public void testExecute_simpleTypeInputWithNonStringValue_handleIllegalArgumentE ActionListener>> listener = invocation.getArgument(2); listener.onFailure(new IllegalArgumentException("illegal argument")); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor).inferenceSentences(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); @@ -830,10 +865,10 @@ public void test_batchExecute_successful() { List> modelTensorList = createMockVectorWithLength(10); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor).inferenceSentences(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); Consumer resultHandler = mock(Consumer.class); processor.batchExecute(ingestDocumentWrappers, resultHandler); @@ -851,10 +886,10 @@ public void test_batchExecute_exception() { List ingestDocumentWrappers = createIngestDocumentWrappers(docCount); TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(docCount); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onFailure(new RuntimeException()); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor).inferenceSentences(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); Consumer resultHandler = mock(Consumer.class); processor.batchExecute(ingestDocumentWrappers, resultHandler); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java index 8f0018f52..a22a7ef76 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java @@ -4,8 +4,8 @@ */ package org.opensearch.neuralsearch.processor; -import static org.mockito.ArgumentMatchers.anyMap; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; @@ -192,10 +192,11 @@ public void testExecute_successful() { List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesMap(argThat(request -> request.getInputObjects() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -229,7 +230,8 @@ public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeExcep DESCRIPTION, config ); - doThrow(new RuntimeException()).when(accessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + doThrow(new RuntimeException()).when(accessor) + .inferenceSentencesMap(argThat(request -> request.getInputObjects() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(RuntimeException.class)); @@ -245,10 +247,11 @@ public void testExecute_withListTypeInput_successful() { List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesMap(argThat(request -> request.getInputObjects() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -277,10 +280,11 @@ public void testExecute_MLClientAccessorThrowFail_handlerFailure() { TextImageEmbeddingProcessor processor = createInstance(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onFailure(new IllegalArgumentException("illegal argument")); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesMap(argThat(request -> request.getInputObjects() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -336,10 +340,11 @@ public void testExecute_whenInferencesAreEmpty_thenSuccessful() { List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor) + .inferenceSentencesMap(argThat(request -> request.getInputObjects() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java index dbd1c2bd6..3f745b9c3 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java @@ -5,8 +5,7 @@ package org.opensearch.neuralsearch.processor.rerank; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyList; -import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; @@ -114,11 +113,12 @@ private void setupParams(Map params) { private void setupSimilarityRescoring() { doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(3); + ActionListener> listener = invocation.getArgument(1); List scores = List.of(1f, 2f, 3f); listener.onResponse(scores); return null; - }).when(mlCommonsClientAccessor).inferenceSimilarity(anyString(), anyString(), anyList(), any()); + }).when(mlCommonsClientAccessor) + .inferenceSimilarity(argThat(request -> request.getQueryText() != null && request.getInputTexts() != null), any()); } private void setupSearchResults() throws IOException { @@ -345,11 +345,12 @@ public void testRerank_HappyPath() throws IOException { public void testRerank_whenScoresAndHitsHaveDiffLengths_thenFail() throws IOException { doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(3); + ActionListener> listener = invocation.getArgument(1); List scores = List.of(1f, 2f); listener.onResponse(scores); return null; - }).when(mlCommonsClientAccessor).inferenceSimilarity(anyString(), anyString(), anyList(), any()); + }).when(mlCommonsClientAccessor) + .inferenceSimilarity(argThat(request -> request.getQueryText() != null && request.getInputTexts() != null), any()); setupSearchResults(); @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java index 5efbf3869..cc3dc5fbb 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java @@ -5,7 +5,7 @@ package org.opensearch.neuralsearch.query; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; @@ -649,10 +649,10 @@ public void testRewrite_whenVectorSupplierNull_thenSetVectorSupplier() { List expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f); MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class); doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(3); + ActionListener> listener = invocation.getArgument(1); listener.onResponse(expectedVector); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(any(), anyMap(), any(), any()); + }).when(mlCommonsClientAccessor).inferenceSentencesMap(argThat(request -> request.getInputObjects() != null), any()); NeuralQueryBuilder.initialize(mlCommonsClientAccessor); final CountDownLatch inProgressLatch = new CountDownLatch(1); @@ -685,10 +685,10 @@ public void testRewrite_whenVectorSupplierNullAndQueryTextAndImageTextSet_thenSe List expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f); MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class); doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(3); + ActionListener> listener = invocation.getArgument(1); listener.onResponse(expectedVector); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(any(), anyMap(), any(), any()); + }).when(mlCommonsClientAccessor).inferenceSentencesMap(argThat(request -> request.getInputObjects() != null), any()); NeuralQueryBuilder.initialize(mlCommonsClientAccessor); final CountDownLatch inProgressLatch = new CountDownLatch(1); diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java index 7509efd42..eb350a9ea 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java @@ -623,10 +623,10 @@ public void testRewrite_whenQueryTokensSupplierNull_thenSetQueryTokensSupplier() Map expectedMap = Map.of("1", 1f, "2", 2f); MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(1); listener.onResponse(List.of(Map.of("response", List.of(expectedMap)))); return null; - }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(any(), any(), any()); + }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(any(), any()); NeuralSparseQueryBuilder.initialize(mlCommonsClientAccessor); final CountDownLatch inProgressLatch = new CountDownLatch(1);