Skip to content

Commit

Permalink
Uses mocking in the SQS Source test to simplify the unit tests and re…
Browse files Browse the repository at this point in the history
…duce build times. This knocks off close to a minute from the build. (#3303)

Signed-off-by: David Venable <[email protected]>
  • Loading branch information
dlvenable authored Sep 11, 2023
1 parent 39b67fe commit 00719cb
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,20 @@
import io.micrometer.core.instrument.Counter;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mock;
import org.opensearch.dataprepper.buffer.common.BufferAccumulator;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager;
import org.opensearch.dataprepper.model.buffer.Buffer;
import org.opensearch.dataprepper.model.configuration.PluginSetting;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.record.Record;
import org.opensearch.dataprepper.plugins.aws.sqs.common.SqsService;
import org.opensearch.dataprepper.plugins.aws.sqs.common.handler.SqsMessageHandler;
import org.opensearch.dataprepper.plugins.aws.sqs.common.metrics.SqsMetrics;
import org.opensearch.dataprepper.plugins.aws.sqs.common.model.SqsOptions;
import org.opensearch.dataprepper.plugins.buffer.blockingbuffer.BlockingBuffer;
import org.opensearch.dataprepper.plugins.source.sqssource.handler.RawSqsMessageHandler;
import software.amazon.awssdk.services.sqs.SqsClient;
import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequest;
import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchResponse;
Expand All @@ -30,31 +28,24 @@
import software.amazon.awssdk.services.sts.model.StsException;

import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.UUID;
import java.util.function.Consumer;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;


@ExtendWith(MockitoExtension.class)
class SqsSourceTaskTest {

private static final String TEST_PIPELINE_NAME = "pipeline";

private static final String MESSAGE = "message";

@Mock
private SqsService sqsService;

@Mock
private SqsOptions sqsOptions;

@Mock
Expand All @@ -63,11 +54,7 @@ class SqsSourceTaskTest {
@Mock
private AcknowledgementSetManager acknowledgementSetManager;

private final boolean endToEndAcknowledgementsEnabled = false;

static final Duration BUFFER_TIMEOUT = Duration.ofSeconds(10);

static final int NO_OF_RECORDS_TO_ACCUMULATE = 100;
private boolean endToEndAcknowledgementsEnabled = false;

@Mock
private SqsMessageHandler sqsHandler;
Expand All @@ -91,74 +78,56 @@ class SqsSourceTaskTest {

@BeforeEach
public void setup(){
backoff = mock(Backoff.class);
sqsClient = mock(SqsClient.class);
sqsMetrics = mock(SqsMetrics.class);
messageReceivedCounter = mock(Counter.class);
messageDeletedCounter = mock(Counter.class);
sqsMessagesFailedCounter = mock(Counter.class);
buffer = getBuffer();
acknowledgementSetManager = mock(AcknowledgementSetManager.class);
acknowledgementSet = mock(AcknowledgementSet.class);
when(sqsMetrics.getSqsMessagesReceivedCounter()).thenReturn(messageReceivedCounter);
when(sqsMetrics.getSqsMessagesDeletedCounter()).thenReturn(messageDeletedCounter);
}

private BlockingBuffer<Record<Event>> getBuffer() {
final HashMap<String, Object> integerHashMap = new HashMap<>();
integerHashMap.put("buffer_size", 2);
integerHashMap.put("batch_size", 2);
final PluginSetting pluginSetting = new PluginSetting("blocking_buffer", integerHashMap);
pluginSetting.setPipelineName(TEST_PIPELINE_NAME);
return new BlockingBuffer<>(pluginSetting);
private SqsSourceTask createObjectUnderTest() {
sqsService = new SqsService(sqsMetrics,sqsClient,backoff);
sqsOptions = new SqsOptions.Builder()
.setSqsUrl("https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue")
.setVisibilityTimeout(Duration.ofSeconds(30))
.setWaitTime(Duration.ofSeconds(20)).build();
return new SqsSourceTask(buffer,1,Duration.ofSeconds(10),sqsService,
sqsOptions,
sqsMetrics,
acknowledgementSetManager,
endToEndAcknowledgementsEnabled,
sqsHandler);
}

@ParameterizedTest
@ValueSource(strings = {
"'{\"S.No\":\"1\",\"name\":\"data-prep\",\"country\":\"USA\"}'",
"Test Message",
"'2023-05-30T13:25:11,889 [main] INFO org.opensearch.dataprepper.pipeline.server.DataPrepperServer - Data Prepper server running at :4900'"})
void processSqsMessages_test_with_different_types_of_messages(final String message) {
void processSqsMessages_test_with_different_types_of_messages(final String message) throws Exception {
when(sqsMetrics.getSqsMessagesReceivedCounter()).thenReturn(messageReceivedCounter);
when(sqsMetrics.getSqsMessagesDeletedCounter()).thenReturn(messageDeletedCounter);

List<Message> messageList = List.of(Message.builder().body(message).messageId(UUID.randomUUID().toString()).receiptHandle(UUID.randomUUID().toString()).build());
when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(ReceiveMessageResponse.builder().messages(messageList).build());
when(sqsMetrics.getSqsMessagesReceivedCounter()).thenReturn(messageReceivedCounter);
when(sqsMetrics.getSqsMessagesDeletedCounter()).thenReturn(messageDeletedCounter);

when(sqsClient.deleteMessageBatch(any(DeleteMessageBatchRequest.class))).
thenReturn(DeleteMessageBatchResponse.builder().successful(builder -> builder.id(UUID.randomUUID().toString()).build()).build());
final SqsSourceTask sqsSourceTask = createObjectUnderTest(buffer,endToEndAcknowledgementsEnabled);
final SqsSourceTask sqsSourceTask = createObjectUnderTest();
sqsSourceTask.processSqsMessages();

final List<Record<Event>> bufferEvents = new ArrayList<>(buffer.read((int) Duration.ofSeconds(10).toMillis()).getKey());
final String bufferMessage = bufferEvents.get(0).getData().get(MESSAGE, String.class);
verify(sqsHandler).handleMessages(eq(messageList), any(), isNull());

assertThat(bufferMessage,equalTo(message));
verify(sqsMetrics.getSqsMessagesReceivedCounter()).increment();
verify(sqsMetrics.getSqsMessagesDeletedCounter()).increment(1);
}

private SqsSourceTask createObjectUnderTest(Buffer<Record<Event>> buffer,boolean endToEndAckFlag) {
final BufferAccumulator<Record<Event>> recordBufferAccumulator =
BufferAccumulator.create(buffer, NO_OF_RECORDS_TO_ACCUMULATE, BUFFER_TIMEOUT);
sqsService = new SqsService(sqsMetrics,sqsClient,backoff);
sqsHandler = new RawSqsMessageHandler(sqsService);
sqsOptions = new SqsOptions.Builder()
.setSqsUrl("https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue")
.setVisibilityTimeout(Duration.ofSeconds(30))
.setWaitTime(Duration.ofSeconds(20)).build();
return new SqsSourceTask(buffer,100,Duration.ofSeconds(10),sqsService,
sqsOptions,
sqsMetrics,
acknowledgementSetManager,
endToEndAckFlag,
sqsHandler);
}

@Test
void processSqsMessages_should_return_zero_messages_with_backoff() {
when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenThrow(SqsException.class);
when(sqsMetrics.getSqsReceiveMessagesFailedCounter()).thenReturn(sqsMessagesFailedCounter);
createObjectUnderTest(buffer,endToEndAcknowledgementsEnabled).processSqsMessages();
createObjectUnderTest().processSqsMessages();
verify(backoff).nextDelayMillis(1);
verify(sqsMessagesFailedCounter).increment();
}
Expand All @@ -171,26 +140,27 @@ void processSqsMessages_should_return_one_message_with_buffer_write_fail_with_ba
when(sqsClient.deleteMessageBatch(any(DeleteMessageBatchRequest.class))).thenThrow(mock(StsException.class));
when(sqsMetrics.getSqsMessagesReceivedCounter()).thenReturn(messageReceivedCounter);
when(sqsMetrics.getSqsMessagesDeleteFailedCounter()).thenReturn(messageDeletedCounter);
createObjectUnderTest(mock(Buffer.class),endToEndAcknowledgementsEnabled).processSqsMessages();
createObjectUnderTest().processSqsMessages();
verify(backoff).nextDelayMillis(1);
verify(messageReceivedCounter).increment();
}

@Test
void processSqsMessages_test_with_different_types_of_messages_with_end_to_end_ack() {
void processSqsMessages_test_with_different_types_of_messages_with_end_to_end_ack() throws Exception {
when(sqsMetrics.getSqsMessagesReceivedCounter()).thenReturn(messageReceivedCounter);
when(sqsMetrics.getSqsMessagesDeletedCounter()).thenReturn(messageDeletedCounter);

endToEndAcknowledgementsEnabled = true;

String message = "'{\"S.No\":\"1\",\"name\":\"data-prep\",\"country\":\"USA\"}'";
when(acknowledgementSetManager.create(any( Consumer.class), any(Duration.class))).thenReturn(acknowledgementSet);
List<Message> messageList = List.of(Message.builder().body(message).messageId(UUID.randomUUID().toString()).receiptHandle(UUID.randomUUID().toString()).build());
when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(ReceiveMessageResponse.builder().messages(messageList).build());
when(sqsClient.deleteMessageBatch(any(DeleteMessageBatchRequest.class))).
thenReturn(DeleteMessageBatchResponse.builder().successful(builder -> builder.id(UUID.randomUUID().toString()).build()).build());

createObjectUnderTest(buffer,true).processSqsMessages();
createObjectUnderTest().processSqsMessages();

final List<Record<Event>> bufferEvents = new ArrayList<>(buffer.read((int) Duration.ofSeconds(10).toMillis()).getKey());
final String bufferMessage = bufferEvents.get(0).getData().get(MESSAGE, String.class);
verify(sqsHandler).handleMessages(eq(messageList), any(), eq(acknowledgementSet));

assertThat(bufferMessage,equalTo(message));
verify(sqsMetrics.getSqsMessagesReceivedCounter()).increment();
verify(acknowledgementSetManager).create(any(), any(Duration.class));
verifyNoInteractions(sqsMetrics.getSqsMessagesDeletedCounter());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,63 +1,132 @@
package org.opensearch.dataprepper.plugins.source.sqssource.handler;

import com.linecorp.armeria.client.retry.Backoff;
import org.hamcrest.CoreMatchers;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.InOrder;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.dataprepper.buffer.common.BufferAccumulator;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet;
import org.opensearch.dataprepper.model.configuration.PluginSetting;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.record.Record;
import org.opensearch.dataprepper.plugins.aws.sqs.common.SqsService;
import org.opensearch.dataprepper.plugins.aws.sqs.common.metrics.SqsMetrics;
import org.opensearch.dataprepper.plugins.buffer.blockingbuffer.BlockingBuffer;
import software.amazon.awssdk.services.sqs.SqsClient;
import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequestEntry;
import software.amazon.awssdk.services.sqs.model.Message;

import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
class RawSqsMessageHandlerTest {
@Mock
private SqsService sqsService;

static final Duration BUFFER_TIMEOUT = Duration.ofSeconds(10);
@Mock
private BufferAccumulator bufferAccumulator;

static final int NO_OF_RECORDS_TO_ACCUMULATE = 100;
private AcknowledgementSet acknowledgementSet;
private List<String> messageBodies;
private List<Message> messages;

private BlockingBuffer<Record<Event>> getBuffer() {
final HashMap<String, Object> integerHashMap = new HashMap<>();
integerHashMap.put("buffer_size", 2);
integerHashMap.put("batch_size", 2);
final PluginSetting pluginSetting = new PluginSetting("blocking_buffer", integerHashMap);
pluginSetting.setPipelineName("pipeline");
return new BlockingBuffer<>(pluginSetting);
@BeforeEach
void setUp() {
messageBodies = IntStream.range(0, 3).mapToObj(i -> UUID.randomUUID().toString())
.collect(Collectors.toList());

messages = messageBodies.stream()
.map(body -> {
Message message = mock(Message.class);
when(message.body()).thenReturn(body);
return message;
})
.collect(Collectors.toList());


acknowledgementSet = null;
}

private RawSqsMessageHandler createObjectUnderTest() {
return new RawSqsMessageHandler(sqsService);
}

@Test
void handleMessages_writes_to_buffer_and_flushes() throws Exception {
createObjectUnderTest().handleMessages(messages, bufferAccumulator, acknowledgementSet);

InOrder inOrder = inOrder(bufferAccumulator);

ArgumentCaptor<Record<Event>> recordArgumentCaptor = ArgumentCaptor.forClass(Record.class);

inOrder.verify(bufferAccumulator, times(messages.size())).add(recordArgumentCaptor.capture());
inOrder.verify(bufferAccumulator).flush();

List<Object> actualEventData = recordArgumentCaptor.getAllValues()
.stream()
.map(Record::getData)
.map(e -> e.get("message", Object.class))
.collect(Collectors.toList());

assertThat(actualEventData.size(), equalTo(messages.size()));

for (int i = 0; i < actualEventData.size(); i++){
Object messageData = actualEventData.get(i);
assertThat(messageData, instanceOf(String.class));
assertThat(messageData, equalTo(messageBodies.get(i)));
}
}

@Test
void sqs_messages_handler_will_read_sqs_message_and_push_to_buffer(){
final BlockingBuffer<Record<Event>> buffer = getBuffer();
AcknowledgementSet acknowledgementSet = mock(AcknowledgementSet.class);
String message = UUID.randomUUID().toString();
String messageId = UUID.randomUUID().toString();
String receiptHandle = UUID.randomUUID().toString();
List<Message> messageList = List.of(Message.builder().body(message).messageId(messageId).receiptHandle(receiptHandle).build());
SqsService sqsService = new SqsService(mock(SqsMetrics.class),mock(SqsClient.class),mock(Backoff.class));
final BufferAccumulator<Record<Event>> recordBufferAccumulator = BufferAccumulator.create(buffer, NO_OF_RECORDS_TO_ACCUMULATE, BUFFER_TIMEOUT);
RawSqsMessageHandler rawSqsMessageHandler = new RawSqsMessageHandler(sqsService);
final List<DeleteMessageBatchRequestEntry> deleteMessageBatchRequestEntries = rawSqsMessageHandler.handleMessages(messageList,recordBufferAccumulator, acknowledgementSet);
final List<Record<Event>> bufferEvents = new ArrayList<>(buffer.read((int) Duration.ofSeconds(10).toMillis()).getKey());
final String bufferMessage = bufferEvents.get(0).getData().get("message", String.class);

assertThat(bufferMessage, CoreMatchers.equalTo(message));

assertThat(deleteMessageBatchRequestEntries.get(0).receiptHandle(),equalTo(receiptHandle));
assertThat(deleteMessageBatchRequestEntries.get(0).id(),equalTo(messageId));
void handleMessages_returns_deleteList() throws Exception {
List<DeleteMessageBatchRequestEntry> stubbedDeleteList = List.of(mock(DeleteMessageBatchRequestEntry.class));
when(sqsService.getDeleteMessageBatchRequestEntryList(messages))
.thenReturn(stubbedDeleteList);

List<DeleteMessageBatchRequestEntry> actualList = createObjectUnderTest().handleMessages(messages, bufferAccumulator, acknowledgementSet);

assertThat(actualList, equalTo(stubbedDeleteList));
}

@Nested
class WithAcknowledgementSet {
@BeforeEach
void setUp() {
acknowledgementSet = mock(AcknowledgementSet.class);
}

@Test
void handleMessages_with_acknowledgementSet_adds_events() throws Exception {
createObjectUnderTest().handleMessages(messages, bufferAccumulator, acknowledgementSet);

ArgumentCaptor<Event> eventArgumentCaptor = ArgumentCaptor.forClass(Event.class);

verify(acknowledgementSet, times(messages.size())).add(eventArgumentCaptor.capture());

List<Object> actualEventData = eventArgumentCaptor.getAllValues()
.stream()
.map(e -> e.get("message", Object.class))
.collect(Collectors.toList());

assertThat(actualEventData.size(), equalTo(messages.size()));

for (int i = 0; i < actualEventData.size(); i++) {
Object messageData = actualEventData.get(i);
assertThat(messageData, instanceOf(String.class));
assertThat(messageData, equalTo(messageBodies.get(i)));
}
}
}
}

0 comments on commit 00719cb

Please sign in to comment.