diff --git a/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqssource/SqsSourceTaskTest.java b/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqssource/SqsSourceTaskTest.java index 1e57e2aba8..c4dac097a1 100644 --- a/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqssource/SqsSourceTaskTest.java +++ b/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqssource/SqsSourceTaskTest.java @@ -4,22 +4,20 @@ import io.micrometer.core.instrument.Counter; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.Mock; -import org.opensearch.dataprepper.buffer.common.BufferAccumulator; +import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; import org.opensearch.dataprepper.model.buffer.Buffer; -import org.opensearch.dataprepper.model.configuration.PluginSetting; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.plugins.aws.sqs.common.SqsService; import org.opensearch.dataprepper.plugins.aws.sqs.common.handler.SqsMessageHandler; import org.opensearch.dataprepper.plugins.aws.sqs.common.metrics.SqsMetrics; import org.opensearch.dataprepper.plugins.aws.sqs.common.model.SqsOptions; -import org.opensearch.dataprepper.plugins.buffer.blockingbuffer.BlockingBuffer; -import org.opensearch.dataprepper.plugins.source.sqssource.handler.RawSqsMessageHandler; import software.amazon.awssdk.services.sqs.SqsClient; import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequest; import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchResponse; @@ -30,31 +28,24 @@ import software.amazon.awssdk.services.sts.model.StsException; import java.time.Duration; -import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.UUID; import java.util.function.Consumer; -import static org.hamcrest.CoreMatchers.equalTo; -import static org.hamcrest.MatcherAssert.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; +@ExtendWith(MockitoExtension.class) class SqsSourceTaskTest { - private static final String TEST_PIPELINE_NAME = "pipeline"; - - private static final String MESSAGE = "message"; - - @Mock private SqsService sqsService; - @Mock private SqsOptions sqsOptions; @Mock @@ -63,11 +54,7 @@ class SqsSourceTaskTest { @Mock private AcknowledgementSetManager acknowledgementSetManager; - private final boolean endToEndAcknowledgementsEnabled = false; - - static final Duration BUFFER_TIMEOUT = Duration.ofSeconds(10); - - static final int NO_OF_RECORDS_TO_ACCUMULATE = 100; + private boolean endToEndAcknowledgementsEnabled = false; @Mock private SqsMessageHandler sqsHandler; @@ -91,26 +78,24 @@ class SqsSourceTaskTest { @BeforeEach public void setup(){ - backoff = mock(Backoff.class); - sqsClient = mock(SqsClient.class); - sqsMetrics = mock(SqsMetrics.class); messageReceivedCounter = mock(Counter.class); messageDeletedCounter = mock(Counter.class); sqsMessagesFailedCounter = mock(Counter.class); - buffer = getBuffer(); - acknowledgementSetManager = mock(AcknowledgementSetManager.class); acknowledgementSet = mock(AcknowledgementSet.class); - when(sqsMetrics.getSqsMessagesReceivedCounter()).thenReturn(messageReceivedCounter); - when(sqsMetrics.getSqsMessagesDeletedCounter()).thenReturn(messageDeletedCounter); } - private BlockingBuffer> getBuffer() { - final HashMap integerHashMap = new HashMap<>(); - integerHashMap.put("buffer_size", 2); - integerHashMap.put("batch_size", 2); - final PluginSetting pluginSetting = new PluginSetting("blocking_buffer", integerHashMap); - pluginSetting.setPipelineName(TEST_PIPELINE_NAME); - return new BlockingBuffer<>(pluginSetting); + private SqsSourceTask createObjectUnderTest() { + sqsService = new SqsService(sqsMetrics,sqsClient,backoff); + sqsOptions = new SqsOptions.Builder() + .setSqsUrl("https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue") + .setVisibilityTimeout(Duration.ofSeconds(30)) + .setWaitTime(Duration.ofSeconds(20)).build(); + return new SqsSourceTask(buffer,1,Duration.ofSeconds(10),sqsService, + sqsOptions, + sqsMetrics, + acknowledgementSetManager, + endToEndAcknowledgementsEnabled, + sqsHandler); } @ParameterizedTest @@ -118,7 +103,10 @@ private BlockingBuffer> getBuffer() { "'{\"S.No\":\"1\",\"name\":\"data-prep\",\"country\":\"USA\"}'", "Test Message", "'2023-05-30T13:25:11,889 [main] INFO org.opensearch.dataprepper.pipeline.server.DataPrepperServer - Data Prepper server running at :4900'"}) - void processSqsMessages_test_with_different_types_of_messages(final String message) { + void processSqsMessages_test_with_different_types_of_messages(final String message) throws Exception { + when(sqsMetrics.getSqsMessagesReceivedCounter()).thenReturn(messageReceivedCounter); + when(sqsMetrics.getSqsMessagesDeletedCounter()).thenReturn(messageDeletedCounter); + List messageList = List.of(Message.builder().body(message).messageId(UUID.randomUUID().toString()).receiptHandle(UUID.randomUUID().toString()).build()); when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(ReceiveMessageResponse.builder().messages(messageList).build()); when(sqsMetrics.getSqsMessagesReceivedCounter()).thenReturn(messageReceivedCounter); @@ -126,39 +114,20 @@ void processSqsMessages_test_with_different_types_of_messages(final String messa when(sqsClient.deleteMessageBatch(any(DeleteMessageBatchRequest.class))). thenReturn(DeleteMessageBatchResponse.builder().successful(builder -> builder.id(UUID.randomUUID().toString()).build()).build()); - final SqsSourceTask sqsSourceTask = createObjectUnderTest(buffer,endToEndAcknowledgementsEnabled); + final SqsSourceTask sqsSourceTask = createObjectUnderTest(); sqsSourceTask.processSqsMessages(); - final List> bufferEvents = new ArrayList<>(buffer.read((int) Duration.ofSeconds(10).toMillis()).getKey()); - final String bufferMessage = bufferEvents.get(0).getData().get(MESSAGE, String.class); + verify(sqsHandler).handleMessages(eq(messageList), any(), isNull()); - assertThat(bufferMessage,equalTo(message)); verify(sqsMetrics.getSqsMessagesReceivedCounter()).increment(); verify(sqsMetrics.getSqsMessagesDeletedCounter()).increment(1); } - private SqsSourceTask createObjectUnderTest(Buffer> buffer,boolean endToEndAckFlag) { - final BufferAccumulator> recordBufferAccumulator = - BufferAccumulator.create(buffer, NO_OF_RECORDS_TO_ACCUMULATE, BUFFER_TIMEOUT); - sqsService = new SqsService(sqsMetrics,sqsClient,backoff); - sqsHandler = new RawSqsMessageHandler(sqsService); - sqsOptions = new SqsOptions.Builder() - .setSqsUrl("https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue") - .setVisibilityTimeout(Duration.ofSeconds(30)) - .setWaitTime(Duration.ofSeconds(20)).build(); - return new SqsSourceTask(buffer,100,Duration.ofSeconds(10),sqsService, - sqsOptions, - sqsMetrics, - acknowledgementSetManager, - endToEndAckFlag, - sqsHandler); - } - @Test void processSqsMessages_should_return_zero_messages_with_backoff() { when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenThrow(SqsException.class); when(sqsMetrics.getSqsReceiveMessagesFailedCounter()).thenReturn(sqsMessagesFailedCounter); - createObjectUnderTest(buffer,endToEndAcknowledgementsEnabled).processSqsMessages(); + createObjectUnderTest().processSqsMessages(); verify(backoff).nextDelayMillis(1); verify(sqsMessagesFailedCounter).increment(); } @@ -171,26 +140,27 @@ void processSqsMessages_should_return_one_message_with_buffer_write_fail_with_ba when(sqsClient.deleteMessageBatch(any(DeleteMessageBatchRequest.class))).thenThrow(mock(StsException.class)); when(sqsMetrics.getSqsMessagesReceivedCounter()).thenReturn(messageReceivedCounter); when(sqsMetrics.getSqsMessagesDeleteFailedCounter()).thenReturn(messageDeletedCounter); - createObjectUnderTest(mock(Buffer.class),endToEndAcknowledgementsEnabled).processSqsMessages(); + createObjectUnderTest().processSqsMessages(); verify(backoff).nextDelayMillis(1); verify(messageReceivedCounter).increment(); } @Test - void processSqsMessages_test_with_different_types_of_messages_with_end_to_end_ack() { + void processSqsMessages_test_with_different_types_of_messages_with_end_to_end_ack() throws Exception { + when(sqsMetrics.getSqsMessagesReceivedCounter()).thenReturn(messageReceivedCounter); + when(sqsMetrics.getSqsMessagesDeletedCounter()).thenReturn(messageDeletedCounter); + + endToEndAcknowledgementsEnabled = true; + String message = "'{\"S.No\":\"1\",\"name\":\"data-prep\",\"country\":\"USA\"}'"; when(acknowledgementSetManager.create(any( Consumer.class), any(Duration.class))).thenReturn(acknowledgementSet); List messageList = List.of(Message.builder().body(message).messageId(UUID.randomUUID().toString()).receiptHandle(UUID.randomUUID().toString()).build()); when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(ReceiveMessageResponse.builder().messages(messageList).build()); - when(sqsClient.deleteMessageBatch(any(DeleteMessageBatchRequest.class))). - thenReturn(DeleteMessageBatchResponse.builder().successful(builder -> builder.id(UUID.randomUUID().toString()).build()).build()); - createObjectUnderTest(buffer,true).processSqsMessages(); + createObjectUnderTest().processSqsMessages(); - final List> bufferEvents = new ArrayList<>(buffer.read((int) Duration.ofSeconds(10).toMillis()).getKey()); - final String bufferMessage = bufferEvents.get(0).getData().get(MESSAGE, String.class); + verify(sqsHandler).handleMessages(eq(messageList), any(), eq(acknowledgementSet)); - assertThat(bufferMessage,equalTo(message)); verify(sqsMetrics.getSqsMessagesReceivedCounter()).increment(); verify(acknowledgementSetManager).create(any(), any(Duration.class)); verifyNoInteractions(sqsMetrics.getSqsMessagesDeletedCounter()); diff --git a/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqssource/handler/RawSqsMessageHandlerTest.java b/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqssource/handler/RawSqsMessageHandlerTest.java index 3773de6ae5..c71ae04204 100644 --- a/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqssource/handler/RawSqsMessageHandlerTest.java +++ b/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqssource/handler/RawSqsMessageHandlerTest.java @@ -1,63 +1,132 @@ package org.opensearch.dataprepper.plugins.source.sqssource.handler; -import com.linecorp.armeria.client.retry.Backoff; -import org.hamcrest.CoreMatchers; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.InOrder; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.dataprepper.buffer.common.BufferAccumulator; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; -import org.opensearch.dataprepper.model.configuration.PluginSetting; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.plugins.aws.sqs.common.SqsService; -import org.opensearch.dataprepper.plugins.aws.sqs.common.metrics.SqsMetrics; -import org.opensearch.dataprepper.plugins.buffer.blockingbuffer.BlockingBuffer; -import software.amazon.awssdk.services.sqs.SqsClient; import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequestEntry; import software.amazon.awssdk.services.sqs.model.Message; -import java.time.Duration; -import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +@ExtendWith(MockitoExtension.class) class RawSqsMessageHandlerTest { + @Mock + private SqsService sqsService; - static final Duration BUFFER_TIMEOUT = Duration.ofSeconds(10); + @Mock + private BufferAccumulator bufferAccumulator; - static final int NO_OF_RECORDS_TO_ACCUMULATE = 100; + private AcknowledgementSet acknowledgementSet; + private List messageBodies; + private List messages; - private BlockingBuffer> getBuffer() { - final HashMap integerHashMap = new HashMap<>(); - integerHashMap.put("buffer_size", 2); - integerHashMap.put("batch_size", 2); - final PluginSetting pluginSetting = new PluginSetting("blocking_buffer", integerHashMap); - pluginSetting.setPipelineName("pipeline"); - return new BlockingBuffer<>(pluginSetting); + @BeforeEach + void setUp() { + messageBodies = IntStream.range(0, 3).mapToObj(i -> UUID.randomUUID().toString()) + .collect(Collectors.toList()); + + messages = messageBodies.stream() + .map(body -> { + Message message = mock(Message.class); + when(message.body()).thenReturn(body); + return message; + }) + .collect(Collectors.toList()); + + + acknowledgementSet = null; + } + + private RawSqsMessageHandler createObjectUnderTest() { + return new RawSqsMessageHandler(sqsService); + } + + @Test + void handleMessages_writes_to_buffer_and_flushes() throws Exception { + createObjectUnderTest().handleMessages(messages, bufferAccumulator, acknowledgementSet); + + InOrder inOrder = inOrder(bufferAccumulator); + + ArgumentCaptor> recordArgumentCaptor = ArgumentCaptor.forClass(Record.class); + + inOrder.verify(bufferAccumulator, times(messages.size())).add(recordArgumentCaptor.capture()); + inOrder.verify(bufferAccumulator).flush(); + + List actualEventData = recordArgumentCaptor.getAllValues() + .stream() + .map(Record::getData) + .map(e -> e.get("message", Object.class)) + .collect(Collectors.toList()); + + assertThat(actualEventData.size(), equalTo(messages.size())); + + for (int i = 0; i < actualEventData.size(); i++){ + Object messageData = actualEventData.get(i); + assertThat(messageData, instanceOf(String.class)); + assertThat(messageData, equalTo(messageBodies.get(i))); + } } @Test - void sqs_messages_handler_will_read_sqs_message_and_push_to_buffer(){ - final BlockingBuffer> buffer = getBuffer(); - AcknowledgementSet acknowledgementSet = mock(AcknowledgementSet.class); - String message = UUID.randomUUID().toString(); - String messageId = UUID.randomUUID().toString(); - String receiptHandle = UUID.randomUUID().toString(); - List messageList = List.of(Message.builder().body(message).messageId(messageId).receiptHandle(receiptHandle).build()); - SqsService sqsService = new SqsService(mock(SqsMetrics.class),mock(SqsClient.class),mock(Backoff.class)); - final BufferAccumulator> recordBufferAccumulator = BufferAccumulator.create(buffer, NO_OF_RECORDS_TO_ACCUMULATE, BUFFER_TIMEOUT); - RawSqsMessageHandler rawSqsMessageHandler = new RawSqsMessageHandler(sqsService); - final List deleteMessageBatchRequestEntries = rawSqsMessageHandler.handleMessages(messageList,recordBufferAccumulator, acknowledgementSet); - final List> bufferEvents = new ArrayList<>(buffer.read((int) Duration.ofSeconds(10).toMillis()).getKey()); - final String bufferMessage = bufferEvents.get(0).getData().get("message", String.class); - - assertThat(bufferMessage, CoreMatchers.equalTo(message)); - - assertThat(deleteMessageBatchRequestEntries.get(0).receiptHandle(),equalTo(receiptHandle)); - assertThat(deleteMessageBatchRequestEntries.get(0).id(),equalTo(messageId)); + void handleMessages_returns_deleteList() throws Exception { + List stubbedDeleteList = List.of(mock(DeleteMessageBatchRequestEntry.class)); + when(sqsService.getDeleteMessageBatchRequestEntryList(messages)) + .thenReturn(stubbedDeleteList); + + List actualList = createObjectUnderTest().handleMessages(messages, bufferAccumulator, acknowledgementSet); + + assertThat(actualList, equalTo(stubbedDeleteList)); + } + + @Nested + class WithAcknowledgementSet { + @BeforeEach + void setUp() { + acknowledgementSet = mock(AcknowledgementSet.class); + } + + @Test + void handleMessages_with_acknowledgementSet_adds_events() throws Exception { + createObjectUnderTest().handleMessages(messages, bufferAccumulator, acknowledgementSet); + + ArgumentCaptor eventArgumentCaptor = ArgumentCaptor.forClass(Event.class); + + verify(acknowledgementSet, times(messages.size())).add(eventArgumentCaptor.capture()); + + List actualEventData = eventArgumentCaptor.getAllValues() + .stream() + .map(e -> e.get("message", Object.class)) + .collect(Collectors.toList()); + + assertThat(actualEventData.size(), equalTo(messages.size())); + + for (int i = 0; i < actualEventData.size(); i++) { + Object messageData = actualEventData.get(i); + assertThat(messageData, instanceOf(String.class)); + assertThat(messageData, equalTo(messageBodies.get(i))); + } + } } }