From b8d8b7f3ce3627646ef397f4aed6597151e9bcaf Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Fri, 15 Nov 2024 17:44:35 +0800 Subject: [PATCH] ut Signed-off-by: zhichao-aws --- .../SparseEncodingProcessorTests.java | 105 +++++++++++++++++- 1 file changed, 104 insertions(+), 1 deletion(-) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java index 9486ee2ca..d705616a9 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java @@ -14,10 +14,12 @@ import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.when; import static org.mockito.Mockito.verify; +import java.util.Arrays; import java.util.Map; import java.util.ArrayList; import java.util.Collections; @@ -49,6 +51,7 @@ import com.google.common.collect.ImmutableMap; import lombok.SneakyThrows; +import org.opensearch.neuralsearch.util.pruning.PruneType; public class SparseEncodingProcessorTests extends InferenceProcessorTestCase { @Mock @@ -90,6 +93,17 @@ private SparseEncodingProcessor createInstance(int batchSize) { return (SparseEncodingProcessor) sparseEncodingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); } + @SneakyThrows + private SparseEncodingProcessor createInstance(PruneType pruneType, float pruneRatio) { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(SparseEncodingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put(SparseEncodingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1Mapped", "key2", "key2Mapped")); + config.put("prune_type", pruneType.getValue()); + config.put("prune_ratio", pruneRatio); + return (SparseEncodingProcessor) sparseEncodingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } + public void testExecute_successful() { Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); @@ -260,9 +274,98 @@ public void test_batchExecute_exception() { } } + @SuppressWarnings("unchecked") + public void testExecute_withPruningConfig_successful() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + sourceAndMetadata.put("key1", "value1"); + sourceAndMetadata.put("key2", "value2"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + + SparseEncodingProcessor processor = createInstance(PruneType.MAX_RATIO, 0.5f); + + List> dataAsMapList = Collections.singletonList( + Map.of("response", Arrays.asList(ImmutableMap.of("hello", 1.0f, "world", 0.1f), ImmutableMap.of("test", 0.8f, "low", 0.4f))) + ); + + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(dataAsMapList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + + ArgumentCaptor docCaptor = ArgumentCaptor.forClass(IngestDocument.class); + verify(handler).accept(docCaptor.capture(), isNull()); + + IngestDocument processedDoc = docCaptor.getValue(); + Map first = (Map) processedDoc.getFieldValue("key1Mapped", Map.class); + Map second = (Map) processedDoc.getFieldValue("key2Mapped", Map.class); + + assertNotNull(first); + assertNotNull(second); + + assertTrue(first.containsKey("hello")); + assertFalse(first.containsKey("world")); + assertEquals(1.0f, first.get("hello"), 0.001f); + + assertTrue(second.containsKey("test")); + assertTrue(second.containsKey("low")); + assertEquals(0.8f, second.get("test"), 0.001f); + assertEquals(0.4f, second.get("low"), 0.001f); + } + + public void test_batchExecute_withPruning_successful() { + SparseEncodingProcessor processor = createInstance(PruneType.MAX_RATIO, 0.5f); + + List> mockMLResponse = Collections.singletonList( + Map.of( + "response", + Arrays.asList( + ImmutableMap.of("token1", 1.0f, "token2", 0.3f, "token3", 0.8f), + ImmutableMap.of("token4", 0.9f, "token5", 0.2f, "token6", 0.7f) + ) + ) + ); + + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(mockMLResponse); + return null; + }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + + Consumer> resultHandler = mock(Consumer.class); + Consumer exceptionHandler = mock(Consumer.class); + + List inferenceList = Arrays.asList("test1", "test2"); + processor.doBatchExecute(inferenceList, resultHandler, exceptionHandler); + + ArgumentCaptor>> resultCaptor = ArgumentCaptor.forClass(List.class); + verify(resultHandler).accept(resultCaptor.capture()); + verify(exceptionHandler, never()).accept(any()); + + List> processedResults = resultCaptor.getValue(); + + assertEquals(2, processedResults.size()); + + Map firstResult = processedResults.get(0); + assertEquals(2, firstResult.size()); + assertTrue(firstResult.containsKey("token1")); + assertTrue(firstResult.containsKey("token3")); + assertFalse(firstResult.containsKey("token2")); + + Map secondResult = processedResults.get(1); + assertEquals(2, secondResult.size()); + assertTrue(secondResult.containsKey("token4")); + assertTrue(secondResult.containsKey("token6")); + assertFalse(secondResult.containsKey("token5")); + } + private List> createMockMapResult(int number) { List> mockSparseEncodingResult = new ArrayList<>(); - IntStream.range(0, number).forEachOrdered(x -> mockSparseEncodingResult.add(ImmutableMap.of("hello", 1.0f))); + IntStream.range(0, number).forEachOrdered(x -> mockSparseEncodingResult.add(ImmutableMap.of("hello", 1.0f, "world", 0.1f))); List> mockMapResult = Collections.singletonList(Map.of("response", mockSparseEncodingResult)); return mockMapResult;