Skip to content

Commit

Permalink
ut
Browse files Browse the repository at this point in the history
Signed-off-by: zhichao-aws <[email protected]>
  • Loading branch information
zhichao-aws committed Nov 15, 2024
1 parent 30babbb commit b8d8b7f
Showing 1 changed file with 104 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.when;

import static org.mockito.Mockito.verify;

import java.util.Arrays;
import java.util.Map;
import java.util.ArrayList;
import java.util.Collections;
Expand Down Expand Up @@ -49,6 +51,7 @@
import com.google.common.collect.ImmutableMap;

import lombok.SneakyThrows;
import org.opensearch.neuralsearch.util.pruning.PruneType;

public class SparseEncodingProcessorTests extends InferenceProcessorTestCase {
@Mock
Expand Down Expand Up @@ -90,6 +93,17 @@ private SparseEncodingProcessor createInstance(int batchSize) {
return (SparseEncodingProcessor) sparseEncodingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config);
}

@SneakyThrows
private SparseEncodingProcessor createInstance(PruneType pruneType, float pruneRatio) {
Map<String, Processor.Factory> registry = new HashMap<>();
Map<String, Object> config = new HashMap<>();
config.put(SparseEncodingProcessor.MODEL_ID_FIELD, "mockModelId");
config.put(SparseEncodingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1Mapped", "key2", "key2Mapped"));
config.put("prune_type", pruneType.getValue());
config.put("prune_ratio", pruneRatio);
return (SparseEncodingProcessor) sparseEncodingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config);
}

public void testExecute_successful() {
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index");
Expand Down Expand Up @@ -260,9 +274,98 @@ public void test_batchExecute_exception() {
}
}

@SuppressWarnings("unchecked")
public void testExecute_withPruningConfig_successful() {
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index");
sourceAndMetadata.put("key1", "value1");
sourceAndMetadata.put("key2", "value2");
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());

SparseEncodingProcessor processor = createInstance(PruneType.MAX_RATIO, 0.5f);

List<Map<String, ?>> dataAsMapList = Collections.singletonList(
Map.of("response", Arrays.asList(ImmutableMap.of("hello", 1.0f, "world", 0.1f), ImmutableMap.of("test", 0.8f, "low", 0.4f)))
);

doAnswer(invocation -> {
ActionListener<List<Map<String, ?>>> listener = invocation.getArgument(2);
listener.onResponse(dataAsMapList);
return null;
}).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class));

BiConsumer handler = mock(BiConsumer.class);
processor.execute(ingestDocument, handler);

ArgumentCaptor<IngestDocument> docCaptor = ArgumentCaptor.forClass(IngestDocument.class);
verify(handler).accept(docCaptor.capture(), isNull());

IngestDocument processedDoc = docCaptor.getValue();
Map<String, Float> first = (Map<String, Float>) processedDoc.getFieldValue("key1Mapped", Map.class);
Map<String, Float> second = (Map<String, Float>) processedDoc.getFieldValue("key2Mapped", Map.class);

assertNotNull(first);
assertNotNull(second);

assertTrue(first.containsKey("hello"));
assertFalse(first.containsKey("world"));
assertEquals(1.0f, first.get("hello"), 0.001f);

assertTrue(second.containsKey("test"));
assertTrue(second.containsKey("low"));
assertEquals(0.8f, second.get("test"), 0.001f);
assertEquals(0.4f, second.get("low"), 0.001f);
}

public void test_batchExecute_withPruning_successful() {
SparseEncodingProcessor processor = createInstance(PruneType.MAX_RATIO, 0.5f);

List<Map<String, ?>> mockMLResponse = Collections.singletonList(
Map.of(
"response",
Arrays.asList(
ImmutableMap.of("token1", 1.0f, "token2", 0.3f, "token3", 0.8f),
ImmutableMap.of("token4", 0.9f, "token5", 0.2f, "token6", 0.7f)
)
)
);

doAnswer(invocation -> {
ActionListener<List<Map<String, ?>>> listener = invocation.getArgument(2);
listener.onResponse(mockMLResponse);
return null;
}).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class));

Consumer<List<?>> resultHandler = mock(Consumer.class);
Consumer<Exception> exceptionHandler = mock(Consumer.class);

List<String> inferenceList = Arrays.asList("test1", "test2");
processor.doBatchExecute(inferenceList, resultHandler, exceptionHandler);

ArgumentCaptor<List<Map<String, Float>>> resultCaptor = ArgumentCaptor.forClass(List.class);
verify(resultHandler).accept(resultCaptor.capture());
verify(exceptionHandler, never()).accept(any());

List<Map<String, Float>> processedResults = resultCaptor.getValue();

assertEquals(2, processedResults.size());

Map<String, Float> firstResult = processedResults.get(0);
assertEquals(2, firstResult.size());
assertTrue(firstResult.containsKey("token1"));
assertTrue(firstResult.containsKey("token3"));
assertFalse(firstResult.containsKey("token2"));

Map<String, Float> secondResult = processedResults.get(1);
assertEquals(2, secondResult.size());
assertTrue(secondResult.containsKey("token4"));
assertTrue(secondResult.containsKey("token6"));
assertFalse(secondResult.containsKey("token5"));
}

private List<Map<String, ?>> createMockMapResult(int number) {
List<Map<String, Float>> mockSparseEncodingResult = new ArrayList<>();
IntStream.range(0, number).forEachOrdered(x -> mockSparseEncodingResult.add(ImmutableMap.of("hello", 1.0f)));
IntStream.range(0, number).forEachOrdered(x -> mockSparseEncodingResult.add(ImmutableMap.of("hello", 1.0f, "world", 0.1f)));

List<Map<String, ?>> mockMapResult = Collections.singletonList(Map.of("response", mockSparseEncodingResult));
return mockMapResult;
Expand Down

0 comments on commit b8d8b7f

Please sign in to comment.