From 0a7f3f411ad79a5ac622c353c83179caff0be378 Mon Sep 17 00:00:00 2001 From: David Venable Date: Fri, 19 Jul 2024 13:40:57 -0500 Subject: [PATCH] Improve the SQS shutdown process such that it does not prevent the pipeline from shutting down and no longer results in failures. Resolves #4575 (#4748) The previous approach to shutting down the SQS thread closed the SqsClient. However, with acknowledgments enabled, asynchronous callbacks would result in further attempts to either ChangeVisibilityTimeout or DeleteMessages. These were failing because the client was closed. Also, the threads would remain and prevent Data Prepper from correctly shutting down. With this change, we correctly stop each processing thread. Then we close the client. Additionally, the SqsWorker now checks that it is not stopped before attempting to change the message visibility or delete messages. Additionally, I found some missing test cases. Also, modifying this code and especially unit testing it is becoming more difficult, so I performed some refactoring to move message parsing out of the SqsWorker. Signed-off-by: David Venable Signed-off-by: Krishna Kondaka --- .../plugins/source/s3/SqsService.java | 22 +- .../plugins/source/s3/SqsWorker.java | 66 ++-- .../source/s3/parser/ParsedMessage.java | 17 +- .../source/s3/parser/SqsMessageParser.java | 44 +++ .../plugins/source/s3/SqsWorkerTest.java | 320 +++++++++++------- .../source/s3/parser/ParsedMessageTest.java | 222 ++++++++---- .../S3EventBridgeNotificationParserTest.java | 2 +- .../parser/S3EventNotificationParserTest.java | 6 +- .../s3/parser/SqsMessageParserTest.java | 96 ++++++ 9 files changed, 559 insertions(+), 236 deletions(-) create mode 100644 data-prepper-plugins/s3-source/src/main/java/org/opensearch/dataprepper/plugins/source/s3/parser/SqsMessageParser.java create mode 100644 data-prepper-plugins/s3-source/src/test/java/org/opensearch/dataprepper/plugins/source/s3/parser/SqsMessageParserTest.java diff --git a/data-prepper-plugins/s3-source/src/main/java/org/opensearch/dataprepper/plugins/source/s3/SqsService.java b/data-prepper-plugins/s3-source/src/main/java/org/opensearch/dataprepper/plugins/source/s3/SqsService.java index b05d2806d4..c674be5f68 100644 --- a/data-prepper-plugins/s3-source/src/main/java/org/opensearch/dataprepper/plugins/source/s3/SqsService.java +++ b/data-prepper-plugins/s3-source/src/main/java/org/opensearch/dataprepper/plugins/source/s3/SqsService.java @@ -17,9 +17,12 @@ import software.amazon.awssdk.services.sqs.SqsClient; import java.time.Duration; +import java.util.List; import java.util.concurrent.TimeUnit; import java.util.concurrent.Executors; import java.util.concurrent.ExecutorService; +import java.util.stream.Collectors; +import java.util.stream.IntStream; public class SqsService { private static final Logger LOG = LoggerFactory.getLogger(SqsService.class); @@ -34,6 +37,7 @@ public class SqsService { private final PluginMetrics pluginMetrics; private final AcknowledgementSetManager acknowledgementSetManager; private final ExecutorService executorService; + private final List sqsWorkers; public SqsService(final AcknowledgementSetManager acknowledgementSetManager, final S3SourceConfig s3SourceConfig, @@ -46,18 +50,20 @@ public SqsService(final AcknowledgementSetManager acknowledgementSetManager, this.acknowledgementSetManager = acknowledgementSetManager; this.sqsClient = createSqsClient(credentialsProvider); executorService = Executors.newFixedThreadPool(s3SourceConfig.getNumWorkers(), BackgroundThreadFactory.defaultExecutorThreadFactory("s3-source-sqs")); - } - public void start() { final Backoff backoff = Backoff.exponential(INITIAL_DELAY, MAXIMUM_DELAY).withJitter(JITTER_RATE) .withMaxAttempts(Integer.MAX_VALUE); - for (int i = 0; i < s3SourceConfig.getNumWorkers(); i++) { - executorService.submit(new SqsWorker(acknowledgementSetManager, sqsClient, s3Accessor, s3SourceConfig, pluginMetrics, backoff)); - } + sqsWorkers = IntStream.range(0, s3SourceConfig.getNumWorkers()) + .mapToObj(i -> new SqsWorker(acknowledgementSetManager, sqsClient, s3Accessor, s3SourceConfig, pluginMetrics, backoff)) + .collect(Collectors.toList()); + } + + public void start() { + sqsWorkers.forEach(executorService::submit); } SqsClient createSqsClient(final AwsCredentialsProvider credentialsProvider) { - LOG.info("Creating SQS client"); + LOG.debug("Creating SQS client"); return SqsClient.builder() .region(s3SourceConfig.getAwsAuthenticationOptions().getAwsRegion()) .credentialsProvider(credentialsProvider) @@ -68,8 +74,8 @@ SqsClient createSqsClient(final AwsCredentialsProvider credentialsProvider) { } public void stop() { - sqsClient.close(); executorService.shutdown(); + sqsWorkers.forEach(SqsWorker::stop); try { if (!executorService.awaitTermination(SHUTDOWN_TIMEOUT, TimeUnit.SECONDS)) { LOG.warn("Failed to terminate SqsWorkers"); @@ -82,5 +88,7 @@ public void stop() { Thread.currentThread().interrupt(); } } + + sqsClient.close(); } } diff --git a/data-prepper-plugins/s3-source/src/main/java/org/opensearch/dataprepper/plugins/source/s3/SqsWorker.java b/data-prepper-plugins/s3-source/src/main/java/org/opensearch/dataprepper/plugins/source/s3/SqsWorker.java index b3404cebf6..3c5fba0701 100644 --- a/data-prepper-plugins/s3-source/src/main/java/org/opensearch/dataprepper/plugins/source/s3/SqsWorker.java +++ b/data-prepper-plugins/s3-source/src/main/java/org/opensearch/dataprepper/plugins/source/s3/SqsWorker.java @@ -5,7 +5,6 @@ package org.opensearch.dataprepper.plugins.source.s3; -import com.fasterxml.jackson.databind.ObjectMapper; import com.linecorp.armeria.client.retry.Backoff; import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Timer; @@ -20,8 +19,7 @@ import org.opensearch.dataprepper.plugins.source.s3.filter.S3EventFilter; import org.opensearch.dataprepper.plugins.source.s3.filter.S3ObjectCreatedFilter; import org.opensearch.dataprepper.plugins.source.s3.parser.ParsedMessage; -import org.opensearch.dataprepper.plugins.source.s3.parser.S3EventBridgeNotificationParser; -import org.opensearch.dataprepper.plugins.source.s3.parser.S3EventNotificationParser; +import org.opensearch.dataprepper.plugins.source.s3.parser.SqsMessageParser; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.core.exception.SdkException; @@ -75,11 +73,10 @@ public class SqsWorker implements Runnable { private final Counter sqsVisibilityTimeoutChangeFailedCount; private final Timer sqsMessageDelayTimer; private final Backoff standardBackoff; + private final SqsMessageParser sqsMessageParser; private int failedAttemptCount; private final boolean endToEndAcknowledgementsEnabled; private final AcknowledgementSetManager acknowledgementSetManager; - - private final ObjectMapper objectMapper = new ObjectMapper(); private volatile boolean isStopped = false; private Map parsedMessageVisibilityTimesMap; @@ -98,6 +95,7 @@ public SqsWorker(final AcknowledgementSetManager acknowledgementSetManager, sqsOptions = s3SourceConfig.getSqsOptions(); objectCreatedFilter = new S3ObjectCreatedFilter(); evenBridgeObjectCreatedFilter = new EventBridgeObjectCreatedFilter(); + sqsMessageParser = new SqsMessageParser(s3SourceConfig); failedAttemptCount = 0; parsedMessageVisibilityTimesMap = new HashMap<>(); @@ -139,7 +137,7 @@ int processSqsMessages() { if (!sqsMessages.isEmpty()) { sqsMessagesReceivedCounter.increment(sqsMessages.size()); - final Collection s3MessageEventNotificationRecords = getS3MessageEventNotificationRecords(sqsMessages); + final Collection s3MessageEventNotificationRecords = sqsMessageParser.parseSqsMessages(sqsMessages); // build s3ObjectReference from S3EventNotificationRecord if event name starts with ObjectCreated final List deleteMessageBatchRequestEntries = processS3EventNotificationRecords(s3MessageEventNotificationRecords); @@ -191,22 +189,6 @@ private ReceiveMessageRequest createReceiveMessageRequest() { .build(); } - private Collection getS3MessageEventNotificationRecords(final List sqsMessages) { - return sqsMessages.stream() - .map(this::convertS3EventMessages) - .collect(Collectors.toList()); - } - - private ParsedMessage convertS3EventMessages(final Message message) { - if (s3SourceConfig.getNotificationSource().equals(NotificationSourceOption.S3)) { - return new S3EventNotificationParser().parseMessage(message, objectMapper); - } - else if (s3SourceConfig.getNotificationSource().equals(NotificationSourceOption.EVENTBRIDGE)) { - return new S3EventBridgeNotificationParser().parseMessage(message, objectMapper); - } - return new ParsedMessage(message, true); - } - private List processS3EventNotificationRecords(final Collection s3EventNotificationRecords) { final List deleteMessageBatchRequestEntryCollection = new ArrayList<>(); final List parsedMessagesToRead = new ArrayList<>(); @@ -276,21 +258,7 @@ && isEventBridgeEventTypeCreated(parsedMessage)) { return; } parsedMessageVisibilityTimesMap.put(parsedMessage, newValue); - final ChangeMessageVisibilityRequest changeMessageVisibilityRequest = ChangeMessageVisibilityRequest.builder() - .visibilityTimeout(newVisibilityTimeoutSeconds) - .queueUrl(sqsOptions.getSqsUrl()) - .receiptHandle(parsedMessage.getMessage().receiptHandle()) - .build(); - - try { - sqsClient.changeMessageVisibility(changeMessageVisibilityRequest); - sqsVisibilityTimeoutChangedCount.increment(); - LOG.debug("Set visibility timeout for message {} to {}", parsedMessage.getMessage().messageId(), newVisibilityTimeoutSeconds); - } catch (Exception e) { - LOG.error("Failed to set visibility timeout for message {} to {}", parsedMessage.getMessage().messageId(), newVisibilityTimeoutSeconds, e); - sqsVisibilityTimeoutChangeFailedCount.increment(); - } - + increaseVisibilityTimeout(parsedMessage, newVisibilityTimeoutSeconds); }, Duration.ofSeconds(progressCheckInterval)); } @@ -308,6 +276,27 @@ && isEventBridgeEventTypeCreated(parsedMessage)) { return deleteMessageBatchRequestEntryCollection; } + private void increaseVisibilityTimeout(final ParsedMessage parsedMessage, final int newVisibilityTimeoutSeconds) { + if(isStopped) { + LOG.info("Some messages are pending completion of acknowledgments. Data Prepper will not increase the visibility timeout because it is shutting down. {}", parsedMessage); + return; + } + final ChangeMessageVisibilityRequest changeMessageVisibilityRequest = ChangeMessageVisibilityRequest.builder() + .visibilityTimeout(newVisibilityTimeoutSeconds) + .queueUrl(sqsOptions.getSqsUrl()) + .receiptHandle(parsedMessage.getMessage().receiptHandle()) + .build(); + + try { + sqsClient.changeMessageVisibility(changeMessageVisibilityRequest); + sqsVisibilityTimeoutChangedCount.increment(); + LOG.debug("Set visibility timeout for message {} to {}", parsedMessage.getMessage().messageId(), newVisibilityTimeoutSeconds); + } catch (Exception e) { + LOG.error("Failed to set visibility timeout for message {} to {}", parsedMessage.getMessage().messageId(), newVisibilityTimeoutSeconds, e); + sqsVisibilityTimeoutChangeFailedCount.increment(); + } + } + private Optional processS3Object( final ParsedMessage parsedMessage, final S3ObjectReference s3ObjectReference, @@ -328,6 +317,8 @@ private Optional processS3Object( } private void deleteSqsMessages(final List deleteMessageBatchRequestEntryCollection) { + if(isStopped) + return; if (deleteMessageBatchRequestEntryCollection.size() == 0) { return; } @@ -396,6 +387,5 @@ private S3ObjectReference populateS3Reference(final String bucketName, final Str void stop() { isStopped = true; - Thread.currentThread().interrupt(); } } diff --git a/data-prepper-plugins/s3-source/src/main/java/org/opensearch/dataprepper/plugins/source/s3/parser/ParsedMessage.java b/data-prepper-plugins/s3-source/src/main/java/org/opensearch/dataprepper/plugins/source/s3/parser/ParsedMessage.java index 18bbc58499..ed68dff063 100644 --- a/data-prepper-plugins/s3-source/src/main/java/org/opensearch/dataprepper/plugins/source/s3/parser/ParsedMessage.java +++ b/data-prepper-plugins/s3-source/src/main/java/org/opensearch/dataprepper/plugins/source/s3/parser/ParsedMessage.java @@ -11,6 +11,7 @@ import software.amazon.awssdk.services.sqs.model.Message; import java.util.List; +import java.util.Objects; public class ParsedMessage { private final Message message; @@ -24,14 +25,14 @@ public class ParsedMessage { private String detailType; public ParsedMessage(final Message message, final boolean failedParsing) { - this.message = message; + this.message = Objects.requireNonNull(message); this.failedParsing = failedParsing; this.emptyNotification = true; } - // S3EventNotification contains only one S3EventNotificationRecord ParsedMessage(final Message message, final List notificationRecords) { - this.message = message; + this.message = Objects.requireNonNull(message); + // S3EventNotification contains only one S3EventNotificationRecord this.bucketName = notificationRecords.get(0).getS3().getBucket().getName(); this.objectKey = notificationRecords.get(0).getS3().getObject().getUrlDecodedKey(); this.objectSize = notificationRecords.get(0).getS3().getObject().getSizeAsLong(); @@ -42,7 +43,7 @@ public ParsedMessage(final Message message, final boolean failedParsing) { } ParsedMessage(final Message message, final S3EventBridgeNotification eventBridgeNotification) { - this.message = message; + this.message = Objects.requireNonNull(message); this.bucketName = eventBridgeNotification.getDetail().getBucket().getName(); this.objectKey = eventBridgeNotification.getDetail().getObject().getUrlDecodedKey(); this.objectSize = eventBridgeNotification.getDetail().getObject().getSize(); @@ -85,4 +86,12 @@ public boolean isEmptyNotification() { public String getDetailType() { return detailType; } + + @Override + public String toString() { + return "Message{" + + "messageId=" + message.messageId() + + ", objectKey=" + objectKey + + '}'; + } } diff --git a/data-prepper-plugins/s3-source/src/main/java/org/opensearch/dataprepper/plugins/source/s3/parser/SqsMessageParser.java b/data-prepper-plugins/s3-source/src/main/java/org/opensearch/dataprepper/plugins/source/s3/parser/SqsMessageParser.java new file mode 100644 index 0000000000..ea40e3f041 --- /dev/null +++ b/data-prepper-plugins/s3-source/src/main/java/org/opensearch/dataprepper/plugins/source/s3/parser/SqsMessageParser.java @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.source.s3.parser; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.opensearch.dataprepper.plugins.source.s3.S3SourceConfig; +import software.amazon.awssdk.services.sqs.model.Message; + +import java.util.Collection; +import java.util.stream.Collectors; + +public class SqsMessageParser { + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private final S3SourceConfig s3SourceConfig; + private final S3NotificationParser s3NotificationParser; + + public SqsMessageParser(final S3SourceConfig s3SourceConfig) { + this.s3SourceConfig = s3SourceConfig; + s3NotificationParser = createNotificationParser(s3SourceConfig); + } + + public Collection parseSqsMessages(final Collection sqsMessages) { + return sqsMessages.stream() + .map(this::convertS3EventMessages) + .collect(Collectors.toList()); + } + + private ParsedMessage convertS3EventMessages(final Message message) { + return s3NotificationParser.parseMessage(message, OBJECT_MAPPER); + } + + private static S3NotificationParser createNotificationParser(final S3SourceConfig s3SourceConfig) { + switch (s3SourceConfig.getNotificationSource()) { + case EVENTBRIDGE: + return new S3EventBridgeNotificationParser(); + case S3: + default: + return new S3EventNotificationParser(); + } + } +} diff --git a/data-prepper-plugins/s3-source/src/test/java/org/opensearch/dataprepper/plugins/source/s3/SqsWorkerTest.java b/data-prepper-plugins/s3-source/src/test/java/org/opensearch/dataprepper/plugins/source/s3/SqsWorkerTest.java index 50ed879f4a..ada789cea6 100644 --- a/data-prepper-plugins/s3-source/src/test/java/org/opensearch/dataprepper/plugins/source/s3/SqsWorkerTest.java +++ b/data-prepper-plugins/s3-source/src/test/java/org/opensearch/dataprepper/plugins/source/s3/SqsWorkerTest.java @@ -12,6 +12,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtensionContext; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; @@ -19,19 +20,21 @@ import org.junit.jupiter.params.provider.ArgumentsSource; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; -import org.opensearch.dataprepper.plugins.source.s3.configuration.AwsAuthenticationOptions; +import org.opensearch.dataprepper.model.acknowledgements.ProgressCheck; import org.opensearch.dataprepper.plugins.source.s3.configuration.NotificationSourceOption; import org.opensearch.dataprepper.plugins.source.s3.configuration.OnErrorOption; import org.opensearch.dataprepper.plugins.source.s3.configuration.SqsOptions; import org.opensearch.dataprepper.plugins.source.s3.exception.SqsRetriesExhaustedException; import org.opensearch.dataprepper.plugins.source.s3.filter.S3EventFilter; import org.opensearch.dataprepper.plugins.source.s3.filter.S3ObjectCreatedFilter; -import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.sqs.SqsClient; import software.amazon.awssdk.services.sqs.model.BatchResultErrorEntry; +import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityRequest; import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequest; import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchResponse; import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchResultEntry; @@ -50,6 +53,7 @@ import java.util.Collections; import java.util.List; import java.util.UUID; +import java.util.function.Consumer; import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; @@ -65,20 +69,23 @@ import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; +import static org.opensearch.dataprepper.plugins.source.s3.SqsWorker.ACKNOWLEDGEMENT_SET_CALLACK_METRIC_NAME; +import static org.opensearch.dataprepper.plugins.source.s3.SqsWorker.S3_OBJECTS_EMPTY_METRIC_NAME; import static org.opensearch.dataprepper.plugins.source.s3.SqsWorker.SQS_MESSAGES_DELETED_METRIC_NAME; import static org.opensearch.dataprepper.plugins.source.s3.SqsWorker.SQS_MESSAGES_DELETE_FAILED_METRIC_NAME; import static org.opensearch.dataprepper.plugins.source.s3.SqsWorker.SQS_MESSAGES_FAILED_METRIC_NAME; import static org.opensearch.dataprepper.plugins.source.s3.SqsWorker.SQS_MESSAGES_RECEIVED_METRIC_NAME; import static org.opensearch.dataprepper.plugins.source.s3.SqsWorker.SQS_MESSAGE_DELAY_METRIC_NAME; -import static org.opensearch.dataprepper.plugins.source.s3.SqsWorker.S3_OBJECTS_EMPTY_METRIC_NAME; +import static org.opensearch.dataprepper.plugins.source.s3.SqsWorker.SQS_VISIBILITY_TIMEOUT_CHANGED_COUNT_METRIC_NAME; +@ExtendWith(MockitoExtension.class) class SqsWorkerTest { - private SqsWorker sqsWorker; private SqsClient sqsClient; private S3Service s3Service; private S3SourceConfig s3SourceConfig; @@ -90,10 +97,13 @@ class SqsWorkerTest { private Counter sqsMessagesFailedCounter; private Counter sqsMessagesDeleteFailedCounter; private Counter s3ObjectsEmptyCounter; + @Mock + private Counter sqsVisibilityTimeoutChangedCount; private Timer sqsMessageDelayTimer; private AcknowledgementSetManager acknowledgementSetManager; private AcknowledgementSet acknowledgementSet; private SqsOptions sqsOptions; + private String queueUrl; @BeforeEach void setUp() { @@ -105,15 +115,11 @@ void setUp() { objectCreatedFilter = new S3ObjectCreatedFilter(); backoff = mock(Backoff.class); - AwsAuthenticationOptions awsAuthenticationOptions = mock(AwsAuthenticationOptions.class); - when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.US_EAST_1); - sqsOptions = mock(SqsOptions.class); - when(sqsOptions.getSqsUrl()).thenReturn("https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue"); + queueUrl = "https://sqs.us-east-2.amazonaws.com/123456789012/" + UUID.randomUUID(); + when(sqsOptions.getSqsUrl()).thenReturn(queueUrl); - when(s3SourceConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions); when(s3SourceConfig.getSqsOptions()).thenReturn(sqsOptions); - when(s3SourceConfig.getOnErrorOption()).thenReturn(OnErrorOption.RETAIN_MESSAGES); when(s3SourceConfig.getAcknowledgements()).thenReturn(false); when(s3SourceConfig.getNotificationSource()).thenReturn(NotificationSourceOption.S3); @@ -130,8 +136,12 @@ void setUp() { when(pluginMetrics.counter(SQS_MESSAGES_DELETE_FAILED_METRIC_NAME)).thenReturn(sqsMessagesDeleteFailedCounter); when(pluginMetrics.counter(S3_OBJECTS_EMPTY_METRIC_NAME)).thenReturn(s3ObjectsEmptyCounter); when(pluginMetrics.timer(SQS_MESSAGE_DELAY_METRIC_NAME)).thenReturn(sqsMessageDelayTimer); + when(pluginMetrics.counter(ACKNOWLEDGEMENT_SET_CALLACK_METRIC_NAME)).thenReturn(mock(Counter.class)); + when(pluginMetrics.counter(SQS_VISIBILITY_TIMEOUT_CHANGED_COUNT_METRIC_NAME)).thenReturn(sqsVisibilityTimeoutChangedCount); + } - sqsWorker = new SqsWorker(acknowledgementSetManager, sqsClient, s3Service, s3SourceConfig, pluginMetrics, backoff); + private SqsWorker createObjectUnderTest() { + return new SqsWorker(acknowledgementSetManager, sqsClient, s3Service, s3SourceConfig, pluginMetrics, backoff); } @AfterEach @@ -167,7 +177,7 @@ void processSqsMessages_should_return_number_of_messages_processed(final String when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(receiveMessageResponse); when(receiveMessageResponse.messages()).thenReturn(Collections.singletonList(message)); - final int messagesProcessed = sqsWorker.processSqsMessages(); + final int messagesProcessed = createObjectUnderTest().processSqsMessages(); final ArgumentCaptor deleteMessageBatchRequestArgumentCaptor = ArgumentCaptor.forClass(DeleteMessageBatchRequest.class); verify(sqsClient).deleteMessageBatch(deleteMessageBatchRequestArgumentCaptor.capture()); final DeleteMessageBatchRequest actualDeleteMessageBatchRequest = deleteMessageBatchRequestArgumentCaptor.getValue(); @@ -190,93 +200,6 @@ void processSqsMessages_should_return_number_of_messages_processed(final String assertThat(actualDelay, greaterThanOrEqualTo(Duration.ofHours(1).minus(Duration.ofSeconds(5)))); } - @ParameterizedTest - @ValueSource(strings = {"ObjectCreated:Put", "ObjectCreated:Post", "ObjectCreated:Copy", "ObjectCreated:CompleteMultipartUpload"}) - void processSqsMessages_should_return_number_of_messages_processed_with_acknowledgements(final String eventName) throws IOException { - when(acknowledgementSetManager.create(any(), any(Duration.class))).thenReturn(acknowledgementSet); - when(s3SourceConfig.getAcknowledgements()).thenReturn(true); - sqsWorker = new SqsWorker(acknowledgementSetManager, sqsClient, s3Service, s3SourceConfig, pluginMetrics, backoff); - Instant startTime = Instant.now().minus(1, ChronoUnit.HOURS); - final Message message = mock(Message.class); - when(message.body()).thenReturn(createEventNotification(eventName, startTime)); - final String testReceiptHandle = UUID.randomUUID().toString(); - when(message.messageId()).thenReturn(testReceiptHandle); - when(message.receiptHandle()).thenReturn(testReceiptHandle); - - final ReceiveMessageResponse receiveMessageResponse = mock(ReceiveMessageResponse.class); - when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(receiveMessageResponse); - when(receiveMessageResponse.messages()).thenReturn(Collections.singletonList(message)); - - final int messagesProcessed = sqsWorker.processSqsMessages(); - final ArgumentCaptor deleteMessageBatchRequestArgumentCaptor = ArgumentCaptor.forClass(DeleteMessageBatchRequest.class); - - final ArgumentCaptor durationArgumentCaptor = ArgumentCaptor.forClass(Duration.class); - verify(sqsMessageDelayTimer).record(durationArgumentCaptor.capture()); - Duration actualDelay = durationArgumentCaptor.getValue(); - - assertThat(messagesProcessed, equalTo(1)); - verify(s3Service).addS3Object(any(S3ObjectReference.class), any()); - verify(acknowledgementSetManager).create(any(), any(Duration.class)); - verify(sqsMessagesReceivedCounter).increment(1); - verifyNoInteractions(sqsMessagesDeletedCounter); - assertThat(actualDelay, lessThanOrEqualTo(Duration.ofHours(1).plus(Duration.ofSeconds(5)))); - assertThat(actualDelay, greaterThanOrEqualTo(Duration.ofHours(1).minus(Duration.ofSeconds(5)))); - } - - @ParameterizedTest - @ValueSource(strings = {"ObjectCreated:Put", "ObjectCreated:Post", "ObjectCreated:Copy", "ObjectCreated:CompleteMultipartUpload"}) - void processSqsMessages_should_return_number_of_messages_processed_with_acknowledgements_and_progress_check(final String eventName) throws IOException { - when(sqsOptions.getVisibilityDuplicateProtection()).thenReturn(true); - when(sqsOptions.getVisibilityTimeout()).thenReturn(Duration.ofSeconds(6)); - when(acknowledgementSetManager.create(any(), any(Duration.class))).thenReturn(acknowledgementSet); - when(s3SourceConfig.getAcknowledgements()).thenReturn(true); - sqsWorker = new SqsWorker(acknowledgementSetManager, sqsClient, s3Service, s3SourceConfig, pluginMetrics, backoff); - Instant startTime = Instant.now().minus(1, ChronoUnit.HOURS); - final Message message = mock(Message.class); - when(message.body()).thenReturn(createEventNotification(eventName, startTime)); - final String testReceiptHandle = UUID.randomUUID().toString(); - when(message.messageId()).thenReturn(testReceiptHandle); - when(message.receiptHandle()).thenReturn(testReceiptHandle); - - final ReceiveMessageResponse receiveMessageResponse = mock(ReceiveMessageResponse.class); - when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(receiveMessageResponse); - when(receiveMessageResponse.messages()).thenReturn(Collections.singletonList(message)); - - final int messagesProcessed = sqsWorker.processSqsMessages(); - final ArgumentCaptor deleteMessageBatchRequestArgumentCaptor = ArgumentCaptor.forClass(DeleteMessageBatchRequest.class); - - final ArgumentCaptor durationArgumentCaptor = ArgumentCaptor.forClass(Duration.class); - verify(sqsMessageDelayTimer).record(durationArgumentCaptor.capture()); - Duration actualDelay = durationArgumentCaptor.getValue(); - - assertThat(messagesProcessed, equalTo(1)); - verify(s3Service).addS3Object(any(S3ObjectReference.class), any()); - verify(acknowledgementSetManager).create(any(), any(Duration.class)); - verify(acknowledgementSet).addProgressCheck(any(), any(Duration.class)); - verify(sqsMessagesReceivedCounter).increment(1); - verifyNoInteractions(sqsMessagesDeletedCounter); - assertThat(actualDelay, lessThanOrEqualTo(Duration.ofHours(1).plus(Duration.ofSeconds(5)))); - assertThat(actualDelay, greaterThanOrEqualTo(Duration.ofHours(1).minus(Duration.ofSeconds(5)))); - } - - @ParameterizedTest - @ValueSource(strings = {"", "{\"foo\": \"bar\""}) - void processSqsMessages_should_not_interact_with_S3Service_if_input_is_not_valid_JSON(String inputString) { - final Message message = mock(Message.class); - when(message.body()).thenReturn(inputString); - - final ReceiveMessageResponse receiveMessageResponse = mock(ReceiveMessageResponse.class); - when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(receiveMessageResponse); - when(receiveMessageResponse.messages()).thenReturn(Collections.singletonList(message)); - - final int messagesProcessed = sqsWorker.processSqsMessages(); - assertThat(messagesProcessed, equalTo(1)); - verifyNoInteractions(s3Service); - verify(sqsClient, never()).deleteMessageBatch(any(DeleteMessageBatchRequest.class)); - verify(sqsMessagesReceivedCounter).increment(1); - verify(sqsMessagesFailedCounter).increment(); - } - @Test void processSqsMessages_should_not_interact_with_S3Service_and_delete_message_if_TestEvent() { final String messageId = UUID.randomUUID().toString(); @@ -291,7 +214,7 @@ void processSqsMessages_should_not_interact_with_S3Service_and_delete_message_if when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(receiveMessageResponse); when(receiveMessageResponse.messages()).thenReturn(Collections.singletonList(message)); - final int messagesProcessed = sqsWorker.processSqsMessages(); + final int messagesProcessed = createObjectUnderTest().processSqsMessages(); assertThat(messagesProcessed, equalTo(1)); verifyNoInteractions(s3Service); @@ -324,7 +247,7 @@ void processSqsMessages_should_not_interact_with_S3Service_and_delete_message_if when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(receiveMessageResponse); when(receiveMessageResponse.messages()).thenReturn(Collections.singletonList(message)); - final int messagesProcessed = sqsWorker.processSqsMessages(); + final int messagesProcessed = createObjectUnderTest().processSqsMessages(); assertThat(messagesProcessed, equalTo(1)); verifyNoInteractions(s3Service); @@ -354,7 +277,7 @@ void processSqsMessages_with_irrelevant_eventName_should_return_number_of_messag when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(receiveMessageResponse); when(receiveMessageResponse.messages()).thenReturn(Collections.singletonList(message)); - final int messagesProcessed = sqsWorker.processSqsMessages(); + final int messagesProcessed = createObjectUnderTest().processSqsMessages(); assertThat(messagesProcessed, equalTo(1)); verifyNoInteractions(s3Service); @@ -378,7 +301,7 @@ void processSqsMessages_should_invoke_delete_if_input_is_not_valid_JSON_and_dele when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(receiveMessageResponse); when(receiveMessageResponse.messages()).thenReturn(Collections.singletonList(message)); - final int messagesProcessed = sqsWorker.processSqsMessages(); + final int messagesProcessed = createObjectUnderTest().processSqsMessages(); final ArgumentCaptor deleteMessageBatchRequestArgumentCaptor = ArgumentCaptor.forClass(DeleteMessageBatchRequest.class); verify(sqsClient).deleteMessageBatch(deleteMessageBatchRequestArgumentCaptor.capture()); final DeleteMessageBatchRequest actualDeleteMessageBatchRequest = deleteMessageBatchRequestArgumentCaptor.getValue(); @@ -410,7 +333,7 @@ void processSqsMessages_should_return_number_of_messages_processed_when_using_Ev when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(receiveMessageResponse); when(receiveMessageResponse.messages()).thenReturn(Collections.singletonList(message)); - final int messagesProcessed = sqsWorker.processSqsMessages(); + final int messagesProcessed = createObjectUnderTest().processSqsMessages(); final ArgumentCaptor deleteMessageBatchRequestArgumentCaptor = ArgumentCaptor.forClass(DeleteMessageBatchRequest.class); verify(sqsClient).deleteMessageBatch(deleteMessageBatchRequestArgumentCaptor.capture()); final DeleteMessageBatchRequest actualDeleteMessageBatchRequest = deleteMessageBatchRequestArgumentCaptor.getValue(); @@ -447,7 +370,7 @@ void processSqsMessages_should_return_number_of_messages_processed_when_using_Se when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(receiveMessageResponse); when(receiveMessageResponse.messages()).thenReturn(Collections.singletonList(message)); - final int messagesProcessed = sqsWorker.processSqsMessages(); + final int messagesProcessed = createObjectUnderTest().processSqsMessages(); final ArgumentCaptor deleteMessageBatchRequestArgumentCaptor = ArgumentCaptor.forClass(DeleteMessageBatchRequest.class); verify(sqsClient).deleteMessageBatch(deleteMessageBatchRequestArgumentCaptor.capture()); final DeleteMessageBatchRequest actualDeleteMessageBatchRequest = deleteMessageBatchRequestArgumentCaptor.getValue(); @@ -502,7 +425,7 @@ void processSqsMessages_should_report_correct_metrics_for_DeleteMessages_when_so when(deleteMessageBatchResponse.failed()).thenReturn(failedDeletes); when(sqsClient.deleteMessageBatch(any(DeleteMessageBatchRequest.class))).thenReturn(deleteMessageBatchResponse); - final int messagesProcessed = sqsWorker.processSqsMessages(); + final int messagesProcessed = createObjectUnderTest().processSqsMessages(); final ArgumentCaptor deleteMessageBatchRequestArgumentCaptor = ArgumentCaptor.forClass(DeleteMessageBatchRequest.class); verify(sqsClient).deleteMessageBatch(deleteMessageBatchRequestArgumentCaptor.capture()); @@ -542,7 +465,7 @@ void processSqsMessages_should_report_correct_metrics_for_DeleteMessages_when_re when(sqsClient.deleteMessageBatch(any(DeleteMessageBatchRequest.class))).thenThrow(exClass); - final int messagesProcessed = sqsWorker.processSqsMessages(); + final int messagesProcessed = createObjectUnderTest().processSqsMessages(); final ArgumentCaptor deleteMessageBatchRequestArgumentCaptor = ArgumentCaptor.forClass(DeleteMessageBatchRequest.class); verify(sqsClient).deleteMessageBatch(deleteMessageBatchRequestArgumentCaptor.capture()); @@ -565,7 +488,7 @@ void processSqsMessages_should_report_correct_metrics_for_DeleteMessages_when_re @Test void processSqsMessages_should_return_zero_messages_when_a_SqsException_is_thrown() { when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenThrow(SqsException.class); - final int messagesProcessed = sqsWorker.processSqsMessages(); + final int messagesProcessed = createObjectUnderTest().processSqsMessages(); assertThat(messagesProcessed, equalTo(0)); verify(sqsClient, never()).deleteMessageBatch(any(DeleteMessageBatchRequest.class)); } @@ -573,7 +496,7 @@ void processSqsMessages_should_return_zero_messages_when_a_SqsException_is_throw @Test void processSqsMessages_should_return_zero_messages_with_backoff_when_a_SqsException_is_thrown() { when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenThrow(SqsException.class); - final int messagesProcessed = sqsWorker.processSqsMessages(); + final int messagesProcessed = createObjectUnderTest().processSqsMessages(); verify(backoff).nextDelayMillis(1); assertThat(messagesProcessed, equalTo(0)); } @@ -582,7 +505,8 @@ void processSqsMessages_should_return_zero_messages_with_backoff_when_a_SqsExcep void processSqsMessages_should_throw_when_a_SqsException_is_thrown_with_max_retries() { when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenThrow(SqsException.class); when(backoff.nextDelayMillis(anyInt())).thenReturn((long) -1); - assertThrows(SqsRetriesExhaustedException.class, () -> sqsWorker.processSqsMessages()); + SqsWorker objectUnderTest = createObjectUnderTest(); + assertThrows(SqsRetriesExhaustedException.class, () -> objectUnderTest.processSqsMessages()); } @ParameterizedTest @@ -591,11 +515,13 @@ void processSqsMessages_should_return_zero_messages_when_messages_are_not_S3Even final Message message = mock(Message.class); when(message.body()).thenReturn(inputString); + when(s3SourceConfig.getOnErrorOption()).thenReturn(OnErrorOption.RETAIN_MESSAGES); + final ReceiveMessageResponse receiveMessageResponse = mock(ReceiveMessageResponse.class); when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(receiveMessageResponse); when(receiveMessageResponse.messages()).thenReturn(Collections.singletonList(message)); - final int messagesProcessed = sqsWorker.processSqsMessages(); + final int messagesProcessed = createObjectUnderTest().processSqsMessages(); assertThat(messagesProcessed, equalTo(1)); verifyNoInteractions(s3Service); verify(sqsClient, never()).deleteMessageBatch(any(DeleteMessageBatchRequest.class)); @@ -605,6 +531,7 @@ void processSqsMessages_should_return_zero_messages_when_messages_are_not_S3Even @Test void populateS3Reference_should_interact_with_getUrlDecodedKey() throws NoSuchMethodException, InvocationTargetException, IllegalAccessException { + reset(sqsOptions); // Using reflection to unit test a private method as part of bug fix. Class params[] = new Class[2]; params[0] = String.class; @@ -617,21 +544,176 @@ void populateS3Reference_should_interact_with_getUrlDecodedKey() throws NoSuchMe final S3EventNotification.S3ObjectEntity s3ObjectEntity = mock(S3EventNotification.S3ObjectEntity.class); final S3EventNotification.S3BucketEntity s3BucketEntity = mock(S3EventNotification.S3BucketEntity.class); - when(s3EventNotificationRecord.getS3()).thenReturn(s3Entity); - when(s3Entity.getBucket()).thenReturn(s3BucketEntity); - when(s3Entity.getObject()).thenReturn(s3ObjectEntity); - when(s3BucketEntity.getName()).thenReturn("test-bucket-name"); - when(s3ObjectEntity.getUrlDecodedKey()).thenReturn("test-key"); - - final S3ObjectReference s3ObjectReference = (S3ObjectReference) method.invoke(sqsWorker, "test-bucket-name", "test-key"); + final S3ObjectReference s3ObjectReference = (S3ObjectReference) method.invoke(createObjectUnderTest(), "test-bucket-name", "test-key"); assertThat(s3ObjectReference, notNullValue()); assertThat(s3ObjectReference.getBucketName(), equalTo("test-bucket-name")); assertThat(s3ObjectReference.getKey(), equalTo("test-key")); -// verify(s3ObjectEntity).getUrlDecodedKey(); verifyNoMoreInteractions(s3ObjectEntity); } + + @ParameterizedTest + @ValueSource(strings = {"ObjectCreated:Put", "ObjectCreated:Post", "ObjectCreated:Copy", "ObjectCreated:CompleteMultipartUpload"}) + void processSqsMessages_should_return_number_of_messages_processed_with_acknowledgements(final String eventName) throws IOException { + when(acknowledgementSetManager.create(any(), any(Duration.class))).thenReturn(acknowledgementSet); + when(s3SourceConfig.getAcknowledgements()).thenReturn(true); + Instant startTime = Instant.now().minus(1, ChronoUnit.HOURS); + final Message message = mock(Message.class); + when(message.body()).thenReturn(createEventNotification(eventName, startTime)); + final String testReceiptHandle = UUID.randomUUID().toString(); + when(message.messageId()).thenReturn(testReceiptHandle); + when(message.receiptHandle()).thenReturn(testReceiptHandle); + + final ReceiveMessageResponse receiveMessageResponse = mock(ReceiveMessageResponse.class); + when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(receiveMessageResponse); + when(receiveMessageResponse.messages()).thenReturn(Collections.singletonList(message)); + + final int messagesProcessed = createObjectUnderTest().processSqsMessages(); + final ArgumentCaptor deleteMessageBatchRequestArgumentCaptor = ArgumentCaptor.forClass(DeleteMessageBatchRequest.class); + + final ArgumentCaptor durationArgumentCaptor = ArgumentCaptor.forClass(Duration.class); + verify(sqsMessageDelayTimer).record(durationArgumentCaptor.capture()); + Duration actualDelay = durationArgumentCaptor.getValue(); + + assertThat(messagesProcessed, equalTo(1)); + verify(s3Service).addS3Object(any(S3ObjectReference.class), any()); + verify(acknowledgementSetManager).create(any(), any(Duration.class)); + verify(sqsMessagesReceivedCounter).increment(1); + verifyNoInteractions(sqsMessagesDeletedCounter); + assertThat(actualDelay, lessThanOrEqualTo(Duration.ofHours(1).plus(Duration.ofSeconds(5)))); + assertThat(actualDelay, greaterThanOrEqualTo(Duration.ofHours(1).minus(Duration.ofSeconds(5)))); + } + + @ParameterizedTest + @ValueSource(strings = {"ObjectCreated:Put", "ObjectCreated:Post", "ObjectCreated:Copy", "ObjectCreated:CompleteMultipartUpload"}) + void processSqsMessages_should_return_number_of_messages_processed_with_acknowledgements_and_progress_check(final String eventName) throws IOException { + when(sqsOptions.getVisibilityDuplicateProtection()).thenReturn(true); + when(sqsOptions.getVisibilityTimeout()).thenReturn(Duration.ofSeconds(6)); + when(acknowledgementSetManager.create(any(), any(Duration.class))).thenReturn(acknowledgementSet); + when(s3SourceConfig.getAcknowledgements()).thenReturn(true); + Instant startTime = Instant.now().minus(1, ChronoUnit.HOURS); + final Message message = mock(Message.class); + when(message.body()).thenReturn(createEventNotification(eventName, startTime)); + final String testReceiptHandle = UUID.randomUUID().toString(); + when(message.messageId()).thenReturn(testReceiptHandle); + when(message.receiptHandle()).thenReturn(testReceiptHandle); + + final ReceiveMessageResponse receiveMessageResponse = mock(ReceiveMessageResponse.class); + when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(receiveMessageResponse); + when(receiveMessageResponse.messages()).thenReturn(Collections.singletonList(message)); + + final int messagesProcessed = createObjectUnderTest().processSqsMessages(); + final ArgumentCaptor deleteMessageBatchRequestArgumentCaptor = ArgumentCaptor.forClass(DeleteMessageBatchRequest.class); + + final ArgumentCaptor durationArgumentCaptor = ArgumentCaptor.forClass(Duration.class); + verify(sqsMessageDelayTimer).record(durationArgumentCaptor.capture()); + Duration actualDelay = durationArgumentCaptor.getValue(); + + assertThat(messagesProcessed, equalTo(1)); + verify(s3Service).addS3Object(any(S3ObjectReference.class), any()); + verify(acknowledgementSetManager).create(any(), any(Duration.class)); + verify(acknowledgementSet).addProgressCheck(any(), any(Duration.class)); + verify(sqsMessagesReceivedCounter).increment(1); + verifyNoInteractions(sqsMessagesDeletedCounter); + assertThat(actualDelay, lessThanOrEqualTo(Duration.ofHours(1).plus(Duration.ofSeconds(5)))); + assertThat(actualDelay, greaterThanOrEqualTo(Duration.ofHours(1).minus(Duration.ofSeconds(5)))); + } + + @ParameterizedTest + @ValueSource(strings = {"", "{\"foo\": \"bar\""}) + void processSqsMessages_should_not_interact_with_S3Service_if_input_is_not_valid_JSON(String inputString) { + final Message message = mock(Message.class); + when(message.body()).thenReturn(inputString); + + when(s3SourceConfig.getOnErrorOption()).thenReturn(OnErrorOption.RETAIN_MESSAGES); + + final ReceiveMessageResponse receiveMessageResponse = mock(ReceiveMessageResponse.class); + when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(receiveMessageResponse); + when(receiveMessageResponse.messages()).thenReturn(Collections.singletonList(message)); + + final int messagesProcessed = createObjectUnderTest().processSqsMessages(); + assertThat(messagesProcessed, equalTo(1)); + verifyNoInteractions(s3Service); + verify(sqsClient, never()).deleteMessageBatch(any(DeleteMessageBatchRequest.class)); + verify(sqsMessagesReceivedCounter).increment(1); + verify(sqsMessagesFailedCounter).increment(); + } + + @Test + void processSqsMessages_should_update_visibility_timeout_when_progress_changes() throws IOException { + when(sqsOptions.getVisibilityDuplicateProtection()).thenReturn(true); + when(sqsOptions.getVisibilityTimeout()).thenReturn(Duration.ofMillis(1)); + when(acknowledgementSetManager.create(any(), any(Duration.class))).thenReturn(acknowledgementSet); + when(s3SourceConfig.getAcknowledgements()).thenReturn(true); + Instant startTime = Instant.now().minus(1, ChronoUnit.HOURS); + final Message message = mock(Message.class); + when(message.body()).thenReturn(createEventNotification("ObjectCreated:Put", startTime)); + final String testReceiptHandle = UUID.randomUUID().toString(); + when(message.messageId()).thenReturn(testReceiptHandle); + when(message.receiptHandle()).thenReturn(testReceiptHandle); + + final ReceiveMessageResponse receiveMessageResponse = mock(ReceiveMessageResponse.class); + when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(receiveMessageResponse); + when(receiveMessageResponse.messages()).thenReturn(Collections.singletonList(message)); + + final int messagesProcessed = createObjectUnderTest().processSqsMessages(); + + assertThat(messagesProcessed, equalTo(1)); + verify(s3Service).addS3Object(any(S3ObjectReference.class), any()); + verify(acknowledgementSetManager).create(any(), any(Duration.class)); + + ArgumentCaptor> progressConsumerArgumentCaptor = ArgumentCaptor.forClass(Consumer.class); + verify(acknowledgementSet).addProgressCheck(progressConsumerArgumentCaptor.capture(), any(Duration.class)); + final Consumer actualConsumer = progressConsumerArgumentCaptor.getValue(); + final ProgressCheck progressCheck = mock(ProgressCheck.class); + actualConsumer.accept(progressCheck); + + ArgumentCaptor changeMessageVisibilityRequestArgumentCaptor = ArgumentCaptor.forClass(ChangeMessageVisibilityRequest.class); + verify(sqsClient).changeMessageVisibility(changeMessageVisibilityRequestArgumentCaptor.capture()); + ChangeMessageVisibilityRequest actualChangeVisibilityRequest = changeMessageVisibilityRequestArgumentCaptor.getValue(); + assertThat(actualChangeVisibilityRequest.queueUrl(), equalTo(queueUrl)); + assertThat(actualChangeVisibilityRequest.receiptHandle(), equalTo(testReceiptHandle)); + verify(sqsMessagesReceivedCounter).increment(1); + verify(sqsMessageDelayTimer).record(any(Duration.class)); + } + + @Test + void processSqsMessages_should_stop_updating_visibility_timeout_after_stop() throws IOException { + when(sqsOptions.getVisibilityDuplicateProtection()).thenReturn(true); + when(sqsOptions.getVisibilityTimeout()).thenReturn(Duration.ofMillis(1)); + when(acknowledgementSetManager.create(any(), any(Duration.class))).thenReturn(acknowledgementSet); + when(s3SourceConfig.getAcknowledgements()).thenReturn(true); + Instant startTime = Instant.now().minus(1, ChronoUnit.HOURS); + final Message message = mock(Message.class); + when(message.body()).thenReturn(createEventNotification("ObjectCreated:Put", startTime)); + final String testReceiptHandle = UUID.randomUUID().toString(); + when(message.messageId()).thenReturn(testReceiptHandle); + when(message.receiptHandle()).thenReturn(testReceiptHandle); + + final ReceiveMessageResponse receiveMessageResponse = mock(ReceiveMessageResponse.class); + when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(receiveMessageResponse); + when(receiveMessageResponse.messages()).thenReturn(Collections.singletonList(message)); + + SqsWorker objectUnderTest = createObjectUnderTest(); + final int messagesProcessed = objectUnderTest.processSqsMessages(); + objectUnderTest.stop(); + + assertThat(messagesProcessed, equalTo(1)); + verify(s3Service).addS3Object(any(S3ObjectReference.class), any()); + verify(acknowledgementSetManager).create(any(), any(Duration.class)); + + ArgumentCaptor> progressConsumerArgumentCaptor = ArgumentCaptor.forClass(Consumer.class); + verify(acknowledgementSet).addProgressCheck(progressConsumerArgumentCaptor.capture(), any(Duration.class)); + final Consumer actualConsumer = progressConsumerArgumentCaptor.getValue(); + final ProgressCheck progressCheck = mock(ProgressCheck.class); + actualConsumer.accept(progressCheck); + + verify(sqsClient, never()).changeMessageVisibility(any(ChangeMessageVisibilityRequest.class)); + verify(sqsMessagesReceivedCounter).increment(1); + verify(sqsMessageDelayTimer).record(any(Duration.class)); + } + private static String createPutNotification(final Instant startTime) { return createEventNotification("ObjectCreated:Put", startTime); } diff --git a/data-prepper-plugins/s3-source/src/test/java/org/opensearch/dataprepper/plugins/source/s3/parser/ParsedMessageTest.java b/data-prepper-plugins/s3-source/src/test/java/org/opensearch/dataprepper/plugins/source/s3/parser/ParsedMessageTest.java index 3acec973e1..51f3abad06 100644 --- a/data-prepper-plugins/s3-source/src/test/java/org/opensearch/dataprepper/plugins/source/s3/parser/ParsedMessageTest.java +++ b/data-prepper-plugins/s3-source/src/test/java/org/opensearch/dataprepper/plugins/source/s3/parser/ParsedMessageTest.java @@ -2,6 +2,7 @@ import org.joda.time.DateTime; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.opensearch.dataprepper.plugins.source.s3.S3EventBridgeNotification; import org.opensearch.dataprepper.plugins.source.s3.S3EventNotification; @@ -12,33 +13,31 @@ import java.util.UUID; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.notNullValue; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; class ParsedMessageTest { private static final Random RANDOM = new Random(); private Message message; - private S3EventNotification.S3Entity s3Entity; - private S3EventNotification.S3BucketEntity s3BucketEntity; - private S3EventNotification.S3ObjectEntity s3ObjectEntity; - private S3EventNotification.S3EventNotificationRecord s3EventNotificationRecord; - private S3EventBridgeNotification s3EventBridgeNotification; - private S3EventBridgeNotification.Detail detail; - private S3EventBridgeNotification.Bucket bucket; - private S3EventBridgeNotification.Object object; + private String testBucketName; + private String testDecodedObjectKey; + private long testSize; @BeforeEach void setUp() { message = mock(Message.class); - s3Entity = mock(S3EventNotification.S3Entity.class); - s3BucketEntity = mock(S3EventNotification.S3BucketEntity.class); - s3ObjectEntity = mock(S3EventNotification.S3ObjectEntity.class); - s3EventNotificationRecord = mock(S3EventNotification.S3EventNotificationRecord.class); - s3EventBridgeNotification = mock(S3EventBridgeNotification.class); - detail = mock(S3EventBridgeNotification.Detail.class); - bucket = mock(S3EventBridgeNotification.Bucket.class); - object = mock(S3EventBridgeNotification.Object.class); + testBucketName = UUID.randomUUID().toString(); + testDecodedObjectKey = UUID.randomUUID().toString(); + testSize = RANDOM.nextInt(1_000_000_000) + 1; + } + + @Test + void constructor_with_failed_parsing_throws_if_Message_is_null() { + assertThrows(NullPointerException.class, () -> new ParsedMessage(null, true)); } @Test @@ -50,61 +49,156 @@ void test_parsed_message_with_failed_parsing() { } @Test - void test_parsed_message_with_S3EventNotificationRecord() { - final String testBucketName = UUID.randomUUID().toString(); - final String testDecodedObjectKey = UUID.randomUUID().toString(); - final String testEventName = UUID.randomUUID().toString(); - final DateTime testEventTime = DateTime.now(); - final long testSize = RANDOM.nextLong(); - - when(s3EventNotificationRecord.getS3()).thenReturn(s3Entity); - when(s3Entity.getBucket()).thenReturn(s3BucketEntity); - when(s3Entity.getObject()).thenReturn(s3ObjectEntity); - when(s3ObjectEntity.getSizeAsLong()).thenReturn(testSize); - when(s3BucketEntity.getName()).thenReturn(testBucketName); - when(s3ObjectEntity.getUrlDecodedKey()).thenReturn(testDecodedObjectKey); - when(s3EventNotificationRecord.getEventName()).thenReturn(testEventName); - when(s3EventNotificationRecord.getEventTime()).thenReturn(testEventTime); - - final ParsedMessage parsedMessage = new ParsedMessage(message, List.of(s3EventNotificationRecord)); + void toString_with_failed_parsing_and_messageId() { + final String messageId = UUID.randomUUID().toString(); + when(message.messageId()).thenReturn(messageId); - assertThat(parsedMessage.getMessage(), equalTo(message)); - assertThat(parsedMessage.getBucketName(), equalTo(testBucketName)); - assertThat(parsedMessage.getObjectKey(), equalTo(testDecodedObjectKey)); - assertThat(parsedMessage.getObjectSize(), equalTo(testSize)); - assertThat(parsedMessage.getEventName(), equalTo(testEventName)); - assertThat(parsedMessage.getEventTime(), equalTo(testEventTime)); - assertThat(parsedMessage.isFailedParsing(), equalTo(false)); - assertThat(parsedMessage.isEmptyNotification(), equalTo(false)); + final ParsedMessage parsedMessage = new ParsedMessage(message, true); + final String actualString = parsedMessage.toString(); + assertThat(actualString, notNullValue()); + assertThat(actualString, containsString(messageId)); } @Test - void test_parsed_message_with_S3EventBridgeNotification() { - final String testBucketName = UUID.randomUUID().toString(); - final String testDecodedObjectKey = UUID.randomUUID().toString(); - final String testDetailType = UUID.randomUUID().toString(); - final DateTime testEventTime = DateTime.now(); - final int testSize = RANDOM.nextInt(); + void toString_with_failed_parsing_and_no_messageId() { + final ParsedMessage parsedMessage = new ParsedMessage(message, true); + final String actualString = parsedMessage.toString(); + assertThat(actualString, notNullValue()); + } - when(s3EventBridgeNotification.getDetail()).thenReturn(detail); - when(s3EventBridgeNotification.getDetail().getBucket()).thenReturn(bucket); - when(s3EventBridgeNotification.getDetail().getObject()).thenReturn(object); + @Nested + class WithS3EventNotificationRecord { + private S3EventNotification.S3Entity s3Entity; + private S3EventNotification.S3BucketEntity s3BucketEntity; + private S3EventNotification.S3ObjectEntity s3ObjectEntity; + private S3EventNotification.S3EventNotificationRecord s3EventNotificationRecord; + private List s3EventNotificationRecords; + private String testEventName; + private DateTime testEventTime; - when(bucket.getName()).thenReturn(testBucketName); - when(object.getUrlDecodedKey()).thenReturn(testDecodedObjectKey); - when(object.getSize()).thenReturn(testSize); - when(s3EventBridgeNotification.getDetailType()).thenReturn(testDetailType); - when(s3EventBridgeNotification.getTime()).thenReturn(testEventTime); + @BeforeEach + void setUp() { + testEventName = UUID.randomUUID().toString(); + testEventTime = DateTime.now(); - final ParsedMessage parsedMessage = new ParsedMessage(message, s3EventBridgeNotification); + s3Entity = mock(S3EventNotification.S3Entity.class); + s3BucketEntity = mock(S3EventNotification.S3BucketEntity.class); + s3ObjectEntity = mock(S3EventNotification.S3ObjectEntity.class); + s3EventNotificationRecord = mock(S3EventNotification.S3EventNotificationRecord.class); - assertThat(parsedMessage.getMessage(), equalTo(message)); - assertThat(parsedMessage.getBucketName(), equalTo(testBucketName)); - assertThat(parsedMessage.getObjectKey(), equalTo(testDecodedObjectKey)); - assertThat(parsedMessage.getObjectSize(), equalTo((long) testSize)); - assertThat(parsedMessage.getDetailType(), equalTo(testDetailType)); - assertThat(parsedMessage.getEventTime(), equalTo(testEventTime)); - assertThat(parsedMessage.isFailedParsing(), equalTo(false)); - assertThat(parsedMessage.isEmptyNotification(), equalTo(false)); + when(s3EventNotificationRecord.getS3()).thenReturn(s3Entity); + when(s3Entity.getBucket()).thenReturn(s3BucketEntity); + when(s3Entity.getObject()).thenReturn(s3ObjectEntity); + when(s3ObjectEntity.getSizeAsLong()).thenReturn(testSize); + when(s3BucketEntity.getName()).thenReturn(testBucketName); + when(s3ObjectEntity.getUrlDecodedKey()).thenReturn(testDecodedObjectKey); + when(s3EventNotificationRecord.getEventName()).thenReturn(testEventName); + when(s3EventNotificationRecord.getEventTime()).thenReturn(testEventTime); + + s3EventNotificationRecords = List.of(s3EventNotificationRecord); + } + + private ParsedMessage createObjectUnderTest() { + return new ParsedMessage(message, s3EventNotificationRecords); + } + + @Test + void constructor_with_S3EventNotificationRecord_throws_if_Message_is_null() { + message = null; + assertThrows(NullPointerException.class, this::createObjectUnderTest); + } + + @Test + void test_parsed_message_with_S3EventNotificationRecord() { + final ParsedMessage parsedMessage = createObjectUnderTest(); + + assertThat(parsedMessage.getMessage(), equalTo(message)); + assertThat(parsedMessage.getBucketName(), equalTo(testBucketName)); + assertThat(parsedMessage.getObjectKey(), equalTo(testDecodedObjectKey)); + assertThat(parsedMessage.getObjectSize(), equalTo(testSize)); + assertThat(parsedMessage.getEventName(), equalTo(testEventName)); + assertThat(parsedMessage.getEventTime(), equalTo(testEventTime)); + assertThat(parsedMessage.isFailedParsing(), equalTo(false)); + assertThat(parsedMessage.isEmptyNotification(), equalTo(false)); + } + + @Test + void toString_with_messageId() { + final String messageId = UUID.randomUUID().toString(); + when(message.messageId()).thenReturn(messageId); + + final ParsedMessage parsedMessage = createObjectUnderTest(); + final String actualString = parsedMessage.toString(); + assertThat(actualString, notNullValue()); + assertThat(actualString, containsString(messageId)); + assertThat(actualString, containsString(testDecodedObjectKey)); + } + } + + @Nested + class WithS3EventBridgeNotification { + private String testDetailType; + private DateTime testEventTime; + private S3EventBridgeNotification s3EventBridgeNotification; + private S3EventBridgeNotification.Detail detail; + private S3EventBridgeNotification.Bucket bucket; + private S3EventBridgeNotification.Object object; + + @BeforeEach + void setUp() { + s3EventBridgeNotification = mock(S3EventBridgeNotification.class); + detail = mock(S3EventBridgeNotification.Detail.class); + bucket = mock(S3EventBridgeNotification.Bucket.class); + object = mock(S3EventBridgeNotification.Object.class); + + testDetailType = UUID.randomUUID().toString(); + testEventTime = DateTime.now(); + + when(s3EventBridgeNotification.getDetail()).thenReturn(detail); + when(s3EventBridgeNotification.getDetail().getBucket()).thenReturn(bucket); + when(s3EventBridgeNotification.getDetail().getObject()).thenReturn(object); + + when(bucket.getName()).thenReturn(testBucketName); + when(object.getUrlDecodedKey()).thenReturn(testDecodedObjectKey); + when(object.getSize()).thenReturn((int) testSize); + when(s3EventBridgeNotification.getDetailType()).thenReturn(testDetailType); + when(s3EventBridgeNotification.getTime()).thenReturn(testEventTime); + } + + private ParsedMessage createObjectUnderTest() { + return new ParsedMessage(message, s3EventBridgeNotification); + } + + @Test + void constructor_with_S3EventBridgeNotification_throws_if_Message_is_null() { + message = null; + assertThrows(NullPointerException.class, () -> createObjectUnderTest()); + } + + @Test + void test_parsed_message_with_S3EventBridgeNotification() { + final ParsedMessage parsedMessage = createObjectUnderTest(); + + assertThat(parsedMessage.getMessage(), equalTo(message)); + assertThat(parsedMessage.getBucketName(), equalTo(testBucketName)); + assertThat(parsedMessage.getObjectKey(), equalTo(testDecodedObjectKey)); + assertThat(parsedMessage.getObjectSize(), equalTo(testSize)); + assertThat(parsedMessage.getDetailType(), equalTo(testDetailType)); + assertThat(parsedMessage.getEventTime(), equalTo(testEventTime)); + assertThat(parsedMessage.isFailedParsing(), equalTo(false)); + assertThat(parsedMessage.isEmptyNotification(), equalTo(false)); + } + + @Test + void toString_with_messageId() { + final String messageId = UUID.randomUUID().toString(); + when(message.messageId()).thenReturn(messageId); + + final ParsedMessage parsedMessage = createObjectUnderTest(); + final String actualString = parsedMessage.toString(); + assertThat(actualString, notNullValue()); + assertThat(actualString, containsString(messageId)); + assertThat(actualString, containsString(testDecodedObjectKey)); + } } } diff --git a/data-prepper-plugins/s3-source/src/test/java/org/opensearch/dataprepper/plugins/source/s3/parser/S3EventBridgeNotificationParserTest.java b/data-prepper-plugins/s3-source/src/test/java/org/opensearch/dataprepper/plugins/source/s3/parser/S3EventBridgeNotificationParserTest.java index c779ec561f..db361d70e1 100644 --- a/data-prepper-plugins/s3-source/src/test/java/org/opensearch/dataprepper/plugins/source/s3/parser/S3EventBridgeNotificationParserTest.java +++ b/data-prepper-plugins/s3-source/src/test/java/org/opensearch/dataprepper/plugins/source/s3/parser/S3EventBridgeNotificationParserTest.java @@ -19,7 +19,7 @@ class S3EventBridgeNotificationParserTest { private final ObjectMapper objectMapper = new ObjectMapper(); - private final String EVENTBRIDGE_MESSAGE = "{\"version\":\"0\",\"id\":\"17793124-05d4-b198-2fde-7ededc63b103\",\"detail-type\":\"Object Created\"," + + static final String EVENTBRIDGE_MESSAGE = "{\"version\":\"0\",\"id\":\"17793124-05d4-b198-2fde-7ededc63b103\",\"detail-type\":\"Object Created\"," + "\"source\":\"aws.s3\",\"account\":\"111122223333\",\"time\":\"2021-11-12T00:00:00Z\"," + "\"region\":\"ca-central-1\",\"resources\":[\"arn:aws:s3:::DOC-EXAMPLE-BUCKET1\"]," + "\"detail\":{\"version\":\"0\",\"bucket\":{\"name\":\"DOC-EXAMPLE-BUCKET1\"}," + diff --git a/data-prepper-plugins/s3-source/src/test/java/org/opensearch/dataprepper/plugins/source/s3/parser/S3EventNotificationParserTest.java b/data-prepper-plugins/s3-source/src/test/java/org/opensearch/dataprepper/plugins/source/s3/parser/S3EventNotificationParserTest.java index a3d2c91679..c9e3a39da8 100644 --- a/data-prepper-plugins/s3-source/src/test/java/org/opensearch/dataprepper/plugins/source/s3/parser/S3EventNotificationParserTest.java +++ b/data-prepper-plugins/s3-source/src/test/java/org/opensearch/dataprepper/plugins/source/s3/parser/S3EventNotificationParserTest.java @@ -16,8 +16,8 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -class S3EventNotificationParserTest { - private static final String DIRECT_SQS_MESSAGE = +public class S3EventNotificationParserTest { + static final String DIRECT_SQS_MESSAGE = "{\"Records\":[{\"eventVersion\":\"2.1\",\"eventSource\":\"aws:s3\",\"awsRegion\":\"us-east-1\",\"eventTime\":\"2023-04-28T16:00:11.324Z\"," + "\"eventName\":\"ObjectCreated:Put\",\"userIdentity\":{\"principalId\":\"AWS:xyz\"},\"requestParameters\":{\"sourceIPAddress\":\"127.0.0.1\"}," + "\"responseElements\":{\"x-amz-request-id\":\"xyz\",\"x-amz-id-2\":\"xyz\"},\"s3\":{\"s3SchemaVersion\":\"1.0\"," + @@ -25,7 +25,7 @@ class S3EventNotificationParserTest { "\"arn\":\"arn:aws:s3:::my-bucket\"},\"object\":{\"key\":\"path/to/myfile.log.gz\",\"size\":3159112,\"eTag\":\"abcd123\"," + "\"sequencer\":\"000\"}}}]}"; - private static final String SNS_BASED_MESSAGE = "{\n" + + public static final String SNS_BASED_MESSAGE = "{\n" + " \"Type\" : \"Notification\",\n" + " \"MessageId\" : \"4e01e115-5b91-5096-8a74-bee95ed1e123\",\n" + " \"TopicArn\" : \"arn:aws:sns:us-east-1:123456789012:notifications\",\n" + diff --git a/data-prepper-plugins/s3-source/src/test/java/org/opensearch/dataprepper/plugins/source/s3/parser/SqsMessageParserTest.java b/data-prepper-plugins/s3-source/src/test/java/org/opensearch/dataprepper/plugins/source/s3/parser/SqsMessageParserTest.java new file mode 100644 index 0000000000..d0dd711f7e --- /dev/null +++ b/data-prepper-plugins/s3-source/src/test/java/org/opensearch/dataprepper/plugins/source/s3/parser/SqsMessageParserTest.java @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.source.s3.parser; + +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.ArgumentsProvider; +import org.junit.jupiter.params.provider.ArgumentsSource; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.plugins.source.s3.S3SourceConfig; +import org.opensearch.dataprepper.plugins.source.s3.configuration.NotificationSourceOption; +import software.amazon.awssdk.services.sqs.model.Message; + +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.empty; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class SqsMessageParserTest { + @Mock + private S3SourceConfig s3SourceConfig; + + private SqsMessageParser createObjectUnderTest() { + return new SqsMessageParser(s3SourceConfig); + } + + @ParameterizedTest + @ArgumentsSource(SourceArgumentsProvider.class) + void parseSqsMessages_returns_empty_for_empty_messages(final NotificationSourceOption sourceOption) { + when(s3SourceConfig.getNotificationSource()).thenReturn(sourceOption); + final Collection parsedMessages = createObjectUnderTest().parseSqsMessages(Collections.emptyList()); + + assertThat(parsedMessages, notNullValue()); + assertThat(parsedMessages, empty()); + } + + @ParameterizedTest + @ArgumentsSource(SourceArgumentsProvider.class) + void parseSqsMessages_parsed_messages(final NotificationSourceOption sourceOption, + final String messageBody, + final String replacementString) { + when(s3SourceConfig.getNotificationSource()).thenReturn(sourceOption); + final int numberOfMessages = 10; + List messages = IntStream.range(0, numberOfMessages) + .mapToObj(i -> messageBody.replaceAll(replacementString, replacementString + i)) + .map(SqsMessageParserTest::createMockMessage) + .collect(Collectors.toList()); + final Collection parsedMessages = createObjectUnderTest().parseSqsMessages(messages); + + assertThat(parsedMessages, notNullValue()); + assertThat(parsedMessages.size(), equalTo(numberOfMessages)); + + final Set bucketNames = parsedMessages.stream().map(ParsedMessage::getBucketName).collect(Collectors.toSet()); + assertThat("The bucket names are unique, so the bucketNames should match the numberOfMessages.", + bucketNames.size(), equalTo(numberOfMessages)); + } + + static class SourceArgumentsProvider implements ArgumentsProvider { + @Override + public Stream provideArguments(final ExtensionContext extensionContext) { + return Stream.of( + Arguments.arguments( + NotificationSourceOption.S3, + S3EventNotificationParserTest.DIRECT_SQS_MESSAGE, + "my-bucket"), + Arguments.arguments( + NotificationSourceOption.EVENTBRIDGE, + S3EventBridgeNotificationParserTest.EVENTBRIDGE_MESSAGE, + "DOC-EXAMPLE-BUCKET1") + ); + } + } + + private static Message createMockMessage(final String body) { + final Message message = mock(Message.class); + when(message.body()).thenReturn(body); + return message; + } +} \ No newline at end of file