Skip to content

Commit

Permalink
Support different embedding types of model response (#1007)
Browse files Browse the repository at this point in the history
* Support different embedding types of model response
---------
Signed-off-by: zane-neo <[email protected]>
Signed-off-by: Martin Gaievski <[email protected]>
Co-authored-by: Martin Gaievski <[email protected]>
  • Loading branch information
zane-neo authored Feb 18, 2025
1 parent c36ca15 commit 628cb64
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 44 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
- Set neural-search plugin 3.0.0 baseline JDK version to JDK-21 ([#838](https://github.com/opensearch-project/neural-search/pull/838))
- Support different embedding types in model's response ([#1007](https://github.com/opensearch-project/neural-search/pull/1007))
### Bug Fixes
### Infrastructure
- [3.0] Update neural-search for OpenSearch 3.0 compatibility ([#1141](https://github.com/opensearch-project/neural-search/pull/1141))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,12 @@ public void testBatchIngestion_SparseEncodingProcessor_E2EFlow() throws Exceptio
loadModel(sparseModelId);
MLModelState oldModelState = getModelState(sparseModelId);
logger.info("Model state in OLD phase: {}", oldModelState);
if (oldModelState != MLModelState.LOADED) {
logger.error("Model {} is not in LOADED state in OLD phase. Current state: {}", sparseModelId, oldModelState);
if (oldModelState != MLModelState.LOADED && oldModelState != MLModelState.DEPLOYED) {
logger.error(
"Model {} is not in LOADED or DEPLOYED state in OLD phase. Current state: {}",
sparseModelId,
oldModelState
);
waitForModelToLoad(sparseModelId);
}
createPipelineForSparseEncodingProcessor(sparseModelId, SPARSE_PIPELINE, 2);
Expand All @@ -52,8 +56,12 @@ public void testBatchIngestion_SparseEncodingProcessor_E2EFlow() throws Exceptio
loadModel(sparseModelId);
MLModelState mixedModelState = getModelState(sparseModelId);
logger.info("Model state in MIXED phase: {}", mixedModelState);
if (mixedModelState != MLModelState.LOADED) {
logger.error("Model {} is not in LOADED state in MIXED phase. Current state: {}", sparseModelId, mixedModelState);
if (mixedModelState != MLModelState.LOADED && mixedModelState != MLModelState.DEPLOYED) {
logger.error(
"Model {} is not in LOADED or DEPLOYED state in MIXED phase. Current state: {}",
sparseModelId,
mixedModelState
);
waitForModelToLoad(sparseModelId);
}
logger.info("Pipeline state in MIXED phase: {}", getIngestionPipeline(SPARSE_PIPELINE));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ public class VectorUtil {
* @param vectorAsList {@link List} of {@link Float}'s representing the vector
* @return array of floats produced from input list
*/
public static float[] vectorAsListToArray(List<Float> vectorAsList) {
public static float[] vectorAsListToArray(List<Number> vectorAsList) {
float[] vector = new float[vectorAsList.size()];
for (int i = 0; i < vectorAsList.size(); i++) {
vector[i] = vectorAsList.get(i);
vector[i] = vectorAsList.get(i).floatValue();
}
return vector;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public class MLCommonsClientAccessor {
public void inferenceSentence(
@NonNull final String modelId,
@NonNull final String inputText,
@NonNull final ActionListener<List<Float>> listener
@NonNull final ActionListener<List<Number>> listener
) {

inferenceSentences(
Expand Down Expand Up @@ -87,7 +87,7 @@ public void inferenceSentence(
*/
public void inferenceSentences(
@NonNull final TextInferenceRequest inferenceRequest,
@NonNull final ActionListener<List<List<Float>>> listener
@NonNull final ActionListener<List<List<Number>>> listener
) {
retryableInferenceSentencesWithVectorResult(inferenceRequest, 0, listener);
}
Expand All @@ -107,7 +107,7 @@ public void inferenceSentencesWithMapResult(
* @param inferenceRequest {@link InferenceRequest}
* @param listener {@link ActionListener} which will be called when prediction is completed or errored out.
*/
public void inferenceSentencesMap(@NonNull MapInferenceRequest inferenceRequest, @NonNull final ActionListener<List<Float>> listener) {
public void inferenceSentencesMap(@NonNull MapInferenceRequest inferenceRequest, @NonNull final ActionListener<List<Number>> listener) {
retryableInferenceSentencesWithSingleVectorResult(inferenceRequest, 0, listener);
}

Expand Down Expand Up @@ -148,11 +148,11 @@ private void retryableInferenceSentencesWithMapResult(
private void retryableInferenceSentencesWithVectorResult(
final TextInferenceRequest inferenceRequest,
final int retryTime,
final ActionListener<List<List<Float>>> listener
final ActionListener<List<List<Number>>> listener
) {
MLInput mlInput = createMLTextInput(inferenceRequest.getTargetResponseFilters(), inferenceRequest.getInputTexts());
mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
final List<List<Number>> vector = buildVectorFromResponse(mlOutput);
listener.onResponse(vector);
},
e -> RetryUtil.handleRetryOrFailure(
Expand All @@ -171,7 +171,9 @@ private void retryableInferenceSimilarityWithVectorResult(
) {
MLInput mlInput = createMLTextPairsInput(inferenceRequest.getQueryText(), inferenceRequest.getInputTexts());
mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
final List<Float> scores = buildVectorFromResponse(mlOutput).stream().map(v -> v.get(0)).collect(Collectors.toList());
final List<Float> scores = buildVectorFromResponse(mlOutput).stream()
.map(v -> v.getFirst().floatValue())
.collect(Collectors.toList());
listener.onResponse(scores);
},
e -> RetryUtil.handleRetryOrFailure(
Expand All @@ -194,14 +196,14 @@ private MLInput createMLTextPairsInput(final String query, final List<String> in
return new MLInput(FunctionName.TEXT_SIMILARITY, null, inputDataset);
}

private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
final List<List<Float>> vector = new ArrayList<>();
private <T extends Number> List<List<T>> buildVectorFromResponse(MLOutput mlOutput) {
final List<List<T>> vector = new ArrayList<>();
final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput;
final List<ModelTensors> tensorOutputList = modelTensorOutput.getMlModelOutputs();
for (final ModelTensors tensors : tensorOutputList) {
final List<ModelTensor> tensorsList = tensors.getMlModelTensors();
for (final ModelTensor tensor : tensorsList) {
vector.add(Arrays.stream(tensor.getData()).map(value -> (Float) value).collect(Collectors.toList()));
vector.add(Arrays.stream(tensor.getData()).map(value -> (T) value).collect(Collectors.toList()));
}
}
return vector;
Expand All @@ -225,19 +227,19 @@ private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
return resultMaps;
}

private List<Float> buildSingleVectorFromResponse(final MLOutput mlOutput) {
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
private <T extends Number> List<T> buildSingleVectorFromResponse(final MLOutput mlOutput) {
final List<List<T>> vector = buildVectorFromResponse(mlOutput);
return vector.isEmpty() ? new ArrayList<>() : vector.get(0);
}

private void retryableInferenceSentencesWithSingleVectorResult(
final MapInferenceRequest inferenceRequest,
final int retryTime,
final ActionListener<List<Float>> listener
final ActionListener<List<Number>> listener
) {
MLInput mlInput = createMLMultimodalInput(inferenceRequest.getTargetResponseFilters(), inferenceRequest.getInputObjects());
mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
final List<Float> vector = buildSingleVectorFromResponse(mlOutput);
final List<Number> vector = buildSingleVectorFromResponse(mlOutput);
log.debug("Inference Response for input sentence is : {} ", vector);
listener.onResponse(vector);
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public void execute(final IngestDocument ingestDocument, final BiConsumer<Ingest

}

private void setVectorFieldsToDocument(final IngestDocument ingestDocument, final List<Float> vectors) {
private void setVectorFieldsToDocument(final IngestDocument ingestDocument, final List<Number> vectors) {
Objects.requireNonNull(vectors, "embedding failed, inference returns null result!");
log.debug("Text embedding result fetched, starting build vector output!");
Map<String, Object> textEmbeddingResult = buildTextEmbeddingResult(this.embedding, vectors);
Expand Down Expand Up @@ -167,7 +167,7 @@ Map<String, String> buildMapWithKnnKeyAndOriginalValue(final IngestDocument inge

@SuppressWarnings({ "unchecked" })
@VisibleForTesting
Map<String, Object> buildTextEmbeddingResult(final String knnKey, List<Float> modelTensorList) {
Map<String, Object> buildTextEmbeddingResult(final String knnKey, List<Number> modelTensorList) {
Map<String, Object> result = new LinkedHashMap<>();
result.put(knnKey, modelTensorList);
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
public class VectorUtilTests extends OpenSearchTestCase {

public void testVectorAsListToArray() {
List<Float> vectorAsList_withThreeElements = List.of(1.3f, 2.5f, 3.5f);
List<Number> vectorAsList_withThreeElements = List.of(1.3f, 2.5f, 3.5f);
float[] vectorAsArray_withThreeElements = VectorUtil.vectorAsListToArray(vectorAsList_withThreeElements);

assertEquals(vectorAsList_withThreeElements.size(), vectorAsArray_withThreeElements.length);
for (int i = 0; i < vectorAsList_withThreeElements.size(); i++) {
assertEquals(vectorAsList_withThreeElements.get(i), vectorAsArray_withThreeElements[i], 0.0f);
assertEquals(vectorAsList_withThreeElements.get(i).floatValue(), vectorAsArray_withThreeElements[i], 0.0f);
}

List<Float> vectorAsList_withNoElements = Collections.emptyList();
List<Number> vectorAsList_withNoElements = Collections.emptyList();
float[] vectorAsArray_withNoElements = VectorUtil.vectorAsListToArray(vectorAsList_withNoElements);
assertEquals(0, vectorAsArray_withNoElements.length);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,13 @@
public class MLCommonsClientAccessorTests extends OpenSearchTestCase {

@Mock
private ActionListener<List<List<Float>>> resultListener;
private ActionListener<List<List<Number>>> resultListener;

@Mock
private ActionListener<List<Float>> singleSentenceResultListener;
private ActionListener<List<Number>> singleSentenceResultListener;

@Mock
private ActionListener<List<Float>> similarityResultListener;

@Mock
private MachineLearningNodeClient client;
Expand All @@ -53,7 +56,7 @@ public void setup() {
}

public void testInferenceSentence_whenValidInput_thenSuccess() {
final List<Float> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
final List<Number> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY));
Expand All @@ -69,7 +72,7 @@ public void testInferenceSentence_whenValidInput_thenSuccess() {
}

public void testInferenceSentences_whenValidInputThenSuccess() {
final List<List<Float>> vectorList = new ArrayList<>();
final List<List<Number>> vectorList = new ArrayList<>();
vectorList.add(Arrays.asList(TestCommonConstants.PREDICT_VECTOR_ARRAY));
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
Expand All @@ -85,7 +88,7 @@ public void testInferenceSentences_whenValidInputThenSuccess() {
}

public void testInferenceSentences_whenResultFromClient_thenEmptyVectorList() {
final List<List<Float>> vectorList = new ArrayList<>();
final List<List<Number>> vectorList = new ArrayList<>();
vectorList.add(Collections.emptyList());
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
Expand Down Expand Up @@ -127,17 +130,17 @@ public void testInferenceSimilarity_whenNodeNotConnectedException_ThenRetry() {
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, singleSentenceResultListener);
accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, similarityResultListener);

// Verify client.predict is called 4 times (1 initial + 3 retries)
Mockito.verify(client, times(4))
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

// Verify failure is propagated to the listener after all retries
Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException);
Mockito.verify(similarityResultListener).onFailure(nodeNodeConnectedException);

// Ensure no additional interactions with the listener
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
Mockito.verifyNoMoreInteractions(similarityResultListener);
}

public void testInferenceSentences_whenExceptionFromMLClient_thenRetry_thenFailure() {
Expand Down Expand Up @@ -288,7 +291,7 @@ public void testInferenceSentencesWithMapResult_whenNotRetryableException_thenFa
}

public void testInferenceMultimodal_whenValidInput_thenSuccess() {
final List<Float> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
final List<Number> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY));
Expand Down Expand Up @@ -353,12 +356,12 @@ public void testInferenceSimilarity_whenValidInput_thenSuccess() {
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, singleSentenceResultListener);
accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, similarityResultListener);

Mockito.verify(client)
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(singleSentenceResultListener).onResponse(vector);
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
Mockito.verify(similarityResultListener).onResponse(vector);
Mockito.verifyNoMoreInteractions(similarityResultListener);
}

public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() {
Expand All @@ -369,12 +372,12 @@ public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() {
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, singleSentenceResultListener);
accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, similarityResultListener);

Mockito.verify(client)
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(singleSentenceResultListener).onFailure(exception);
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
Mockito.verify(similarityResultListener).onFailure(exception);
Mockito.verifyNoMoreInteractions(similarityResultListener);
}

private ModelTensorOutput createModelTensorOutput(final Float[] output) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -771,10 +771,10 @@ public void testRewrite_whenVectorSupplierNull_thenSetVectorSupplier() {
.modelId(MODEL_ID)
.k(K)
.build();
List<Float> expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f);
List<Number> expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f);
MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class);
doAnswer(invocation -> {
ActionListener<List<Float>> listener = invocation.getArgument(1);
ActionListener<List<Number>> listener = invocation.getArgument(1);
listener.onResponse(expectedVector);
return null;
}).when(mlCommonsClientAccessor)
Expand Down Expand Up @@ -810,10 +810,10 @@ public void testRewrite_whenVectorSupplierNullAndQueryTextAndImageTextSet_thenSe
.modelId(MODEL_ID)
.k(K)
.build();
List<Float> expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f);
List<Number> expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f);
MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class);
doAnswer(invocation -> {
ActionListener<List<Float>> listener = invocation.getArgument(1);
ActionListener<List<Number>> listener = invocation.getArgument(1);
listener.onResponse(expectedVector);
return null;
}).when(mlCommonsClientAccessor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ protected float[] runInference(final String modelId, final String queryText) {
List<Object> output = (List<Object>) result.get("output");
assertEquals(1, output.size());
Map<String, Object> map = (Map<String, Object>) output.get(0);
List<Float> data = ((List<Double>) map.get("data")).stream().map(Double::floatValue).collect(Collectors.toList());
List<Number> data = ((List<Double>) map.get("data")).stream().map(Double::floatValue).collect(Collectors.toList());
return vectorAsListToArray(data);
}

Expand Down

0 comments on commit 628cb64

Please sign in to comment.