From b15052c5bd3b8398b2be95ba303d775edb50ddd5 Mon Sep 17 00:00:00 2001 From: wdongyu <23725216+wdongyu@users.noreply.github.com> Date: Fri, 11 Oct 2024 23:25:09 +0800 Subject: [PATCH] Fix nested field missing sub embedding field (#913) * Adding non empty check before filling in result Signed-off-by: wangdongyu.danny --- CHANGELOG.md | 1 + .../processor/InferenceProcessor.java | 9 ++- .../processor/InferenceProcessorTests.java | 25 +++++++- .../processor/TextEmbeddingProcessorIT.java | 9 +++ .../TextEmbeddingProcessorTests.java | 57 ++++++++++++++++++- .../resources/processor/IndexMappings.json | 3 + src/test/resources/processor/ingest_doc1.json | 3 + src/test/resources/processor/ingest_doc2.json | 3 + .../neuralsearch/BaseNeuralSearchIT.java | 1 + 9 files changed, 106 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 675ea5983..329cf0d1c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Enhancements - Set neural-search plugin 3.0.0 baseline JDK version to JDK-2 ([#838](https://github.com/opensearch-project/neural-search/pull/838)) ### Bug Fixes +- Fix for nested field missing sub embedding field in text embedding processor ([#913](https://github.com/opensearch-project/neural-search/pull/913)) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java index 30780a3f5..ae996251d 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java @@ -10,6 +10,7 @@ import java.util.Collections; import java.util.Comparator; import java.util.HashMap; +import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -285,7 +286,7 @@ private void createInferenceListForMapTypeInput(Object sourceValue, List if (sourceValue instanceof Map) { ((Map) sourceValue).forEach((k, v) -> createInferenceListForMapTypeInput(v, texts)); } else if (sourceValue instanceof List) { - texts.addAll(((List) sourceValue)); + ((List) sourceValue).stream().filter(Objects::nonNull).forEach(texts::add); } else { if (sourceValue == null) return; texts.add(sourceValue.toString()); @@ -419,8 +420,12 @@ private void putNLPResultToSourceMapForMapType( for (Map.Entry inputNestedMapEntry : ((Map) sourceValue).entrySet()) { if (sourceAndMetadataMap.get(processorKey) instanceof List) { // build nlp output for list of nested objects + Iterator inputNestedMapValueIt = ((List) inputNestedMapEntry.getValue()).iterator(); for (Map nestedElement : (List>) sourceAndMetadataMap.get(processorKey)) { - nestedElement.put(inputNestedMapEntry.getKey(), results.get(indexWrapper.index++)); + // Only fill in when value is not null + if (inputNestedMapValueIt.hasNext() && inputNestedMapValueIt.next() != null) { + nestedElement.put(inputNestedMapEntry.getKey(), results.get(indexWrapper.index++)); + } } } else { Pair processedNestedKey = processNestedKey(inputNestedMapEntry); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java index cd2d0816a..dc86975bd 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java @@ -66,7 +66,7 @@ public void test_batchExecute_emptyInput() { verify(clientAccessor, never()).inferenceSentences(anyString(), anyList(), any()); } - public void test_batchExecute_allFailedValidation() { + public void test_batchExecuteWithEmpty_allFailedValidation() { final int docCount = 2; TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), BATCH_SIZE, null); List wrapperList = createIngestDocumentWrappers(docCount); @@ -79,6 +79,29 @@ public void test_batchExecute_allFailedValidation() { assertEquals(docCount, captor.getValue().size()); for (int i = 0; i < docCount; ++i) { assertNotNull(captor.getValue().get(i).getException()); + assertEquals( + "list type field [key1] has empty string, cannot process it", + captor.getValue().get(i).getException().getMessage() + ); + assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument()); + } + verify(clientAccessor, never()).inferenceSentences(anyString(), anyList(), any()); + } + + public void test_batchExecuteWithNull_allFailedValidation() { + final int docCount = 2; + TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), BATCH_SIZE, null); + List wrapperList = createIngestDocumentWrappers(docCount); + wrapperList.get(0).getIngestDocument().setFieldValue("key1", Arrays.asList(null, "value1")); + wrapperList.get(1).getIngestDocument().setFieldValue("key1", Arrays.asList(null, "value1")); + Consumer resultHandler = mock(Consumer.class); + processor.batchExecute(wrapperList, resultHandler); + ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); + verify(resultHandler).accept(captor.capture()); + assertEquals(docCount, captor.getValue().size()); + for (int i = 0; i < docCount; ++i) { + assertNotNull(captor.getValue().get(i).getException()); + assertEquals("list type field [key1] has null, cannot process it", captor.getValue().get(i).getException().getMessage()); assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument()); } verify(clientAccessor, never()).inferenceSentences(anyString(), anyList(), any()); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java index 8fd87f091..4afa4031d 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java @@ -310,5 +310,14 @@ private void ingestBatchDocumentWithBulk(String idPrefix, int docCount, Set> itemMap = (Map>) item; + if (itemMap.get("index").get("error") != null) { + failedDocCount++; + } + } + assertEquals(failedIds.size(), failedDocCount); } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index 82b24324c..97e85e46e 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -730,7 +730,7 @@ public void testBuildVectorOutput_withNestedList_successful() { IngestDocument ingestDocument = createNestedListIngestDocument(); TextEmbeddingProcessor textEmbeddingProcessor = createInstanceWithNestedMapConfiguration(config); Map knnMap = textEmbeddingProcessor.buildMapWithTargetKeys(ingestDocument); - List> modelTensorList = createMockVectorResult(); + List> modelTensorList = createRandomOneDimensionalMockVector(2, 2, 0.0f, 1.0f); textEmbeddingProcessor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); List> nestedObj = (List>) ingestDocument.getSourceAndMetadata().get("nestedField"); assertTrue(nestedObj.get(0).containsKey("vectorField")); @@ -739,12 +739,27 @@ public void testBuildVectorOutput_withNestedList_successful() { assertNotNull(nestedObj.get(1).get("vectorField")); } + @SuppressWarnings("unchecked") + public void testBuildVectorOutput_withNestedListHasNotForEmbeddingField_successful() { + Map config = createNestedListConfiguration(); + IngestDocument ingestDocument = createNestedListWithNotEmbeddingFieldIngestDocument(); + TextEmbeddingProcessor textEmbeddingProcessor = createInstanceWithNestedMapConfiguration(config); + Map knnMap = textEmbeddingProcessor.buildMapWithTargetKeys(ingestDocument); + List> modelTensorList = createRandomOneDimensionalMockVector(1, 2, 0.0f, 1.0f); + textEmbeddingProcessor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); + List> nestedObj = (List>) ingestDocument.getSourceAndMetadata().get("nestedField"); + assertFalse(nestedObj.get(0).containsKey("vectorField")); + assertTrue(nestedObj.get(0).containsKey("textFieldNotForEmbedding")); + assertTrue(nestedObj.get(1).containsKey("vectorField")); + assertNotNull(nestedObj.get(1).get("vectorField")); + } + public void testBuildVectorOutput_withNestedList_Level2_successful() { Map config = createNestedList2LevelConfiguration(); IngestDocument ingestDocument = create2LevelNestedListIngestDocument(); TextEmbeddingProcessor textEmbeddingProcessor = createInstanceWithNestedMapConfiguration(config); Map knnMap = textEmbeddingProcessor.buildMapWithTargetKeys(ingestDocument); - List> modelTensorList = createMockVectorResult(); + List> modelTensorList = createRandomOneDimensionalMockVector(2, 2, 0.0f, 1.0f); textEmbeddingProcessor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); Map nestedLevel1 = (Map) ingestDocument.getSourceAndMetadata().get("nestedField"); List> nestedObj = (List>) nestedLevel1.get("nestedField"); @@ -754,6 +769,22 @@ public void testBuildVectorOutput_withNestedList_Level2_successful() { assertNotNull(nestedObj.get(1).get("vectorField")); } + @SuppressWarnings("unchecked") + public void testBuildVectorOutput_withNestedListHasNotForEmbeddingField_Level2_successful() { + Map config = createNestedList2LevelConfiguration(); + IngestDocument ingestDocument = create2LevelNestedListWithNotEmbeddingFieldIngestDocument(); + TextEmbeddingProcessor textEmbeddingProcessor = createInstanceWithNestedMapConfiguration(config); + Map knnMap = textEmbeddingProcessor.buildMapWithTargetKeys(ingestDocument); + List> modelTensorList = createRandomOneDimensionalMockVector(1, 2, 0.0f, 1.0f); + textEmbeddingProcessor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); + Map nestedLevel1 = (Map) ingestDocument.getSourceAndMetadata().get("nestedField"); + List> nestedObj = (List>) nestedLevel1.get("nestedField"); + assertFalse(nestedObj.get(0).containsKey("vectorField")); + assertTrue(nestedObj.get(0).containsKey("textFieldNotForEmbedding")); + assertTrue(nestedObj.get(1).containsKey("vectorField")); + assertNotNull(nestedObj.get(1).get("vectorField")); + } + public void test_updateDocument_appendVectorFieldsToDocument_successful() { Map config = createPlainStringConfiguration(); IngestDocument ingestDocument = createPlainIngestDocument(); @@ -1039,6 +1070,16 @@ private IngestDocument createNestedListIngestDocument() { return new IngestDocument(nestedList, new HashMap<>()); } + private IngestDocument createNestedListWithNotEmbeddingFieldIngestDocument() { + HashMap nestedObj1 = new HashMap<>(); + nestedObj1.put("textFieldNotForEmbedding", "This is a text field"); + HashMap nestedObj2 = new HashMap<>(); + nestedObj2.put("textField", "This is another text field"); + HashMap nestedList = new HashMap<>(); + nestedList.put("nestedField", Arrays.asList(nestedObj1, nestedObj2)); + return new IngestDocument(nestedList, new HashMap<>()); + } + private IngestDocument create2LevelNestedListIngestDocument() { HashMap nestedObj1 = new HashMap<>(); nestedObj1.put("textField", "This is a text field"); @@ -1050,4 +1091,16 @@ private IngestDocument create2LevelNestedListIngestDocument() { nestedList1.put("nestedField", nestedList); return new IngestDocument(nestedList1, new HashMap<>()); } + + private IngestDocument create2LevelNestedListWithNotEmbeddingFieldIngestDocument() { + HashMap nestedObj1 = new HashMap<>(); + nestedObj1.put("textFieldNotForEmbedding", "This is a text field"); + HashMap nestedObj2 = new HashMap<>(); + nestedObj2.put("textField", "This is another text field"); + HashMap nestedList = new HashMap<>(); + nestedList.put("nestedField", Arrays.asList(nestedObj1, nestedObj2)); + HashMap nestedList1 = new HashMap<>(); + nestedList1.put("nestedField", nestedList); + return new IngestDocument(nestedList1, new HashMap<>()); + } } diff --git a/src/test/resources/processor/IndexMappings.json b/src/test/resources/processor/IndexMappings.json index 7afbaa92e..afe613117 100644 --- a/src/test/resources/processor/IndexMappings.json +++ b/src/test/resources/processor/IndexMappings.json @@ -90,6 +90,9 @@ "text": { "type": "text" }, + "text_not_for_embedding": { + "type": "text" + }, "embedding": { "type": "knn_vector", "dimension": 768, diff --git a/src/test/resources/processor/ingest_doc1.json b/src/test/resources/processor/ingest_doc1.json index b1cc5392b..d952d07d8 100644 --- a/src/test/resources/processor/ingest_doc1.json +++ b/src/test/resources/processor/ingest_doc1.json @@ -12,6 +12,9 @@ "movie": null }, "nested_passages": [ + { + "text_not_for_embedding": "test" + }, { "text": "hello" }, diff --git a/src/test/resources/processor/ingest_doc2.json b/src/test/resources/processor/ingest_doc2.json index cce93d4a1..5ab1f7525 100644 --- a/src/test/resources/processor/ingest_doc2.json +++ b/src/test/resources/processor/ingest_doc2.json @@ -10,6 +10,9 @@ "movie": null }, "nested_passages": [ + { + "text_not_for_embedding": "test" + }, { "text": "apple" }, diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 7ad0e63f8..e6fb45d2a 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -203,6 +203,7 @@ protected void loadModel(final String modelId) throws Exception { isComplete = checkComplete(taskQueryResult); Thread.sleep(DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND); } + assertTrue(isComplete); } /**