Skip to content

Commit

Permalink
Improve the SQS shutdown process such that it does not prevent the pi…
Browse files Browse the repository at this point in the history
…peline from shutting down and no longer results in failures. Resolves opensearch-project#4575 (opensearch-project#4748)

The previous approach to shutting down the SQS thread closed the SqsClient. However, with acknowledgments enabled, asynchronous callbacks would result in further attempts to either ChangeVisibilityTimeout or DeleteMessages. These were failing because the client was closed. Also, the threads would remain and prevent Data Prepper from correctly shutting down. With this change, we correctly stop each processing thread. Then we close the client. Additionally, the SqsWorker now checks that it is not stopped before attempting to change the message visibility or delete messages.

Additionally, I found some missing test cases. Also, modifying this code and especially unit testing it is becoming more difficult, so I performed some refactoring to move message parsing out of the SqsWorker.

Signed-off-by: David Venable <[email protected]>
Signed-off-by: Krishna Kondaka <[email protected]>
  • Loading branch information
dlvenable authored and Krishna Kondaka committed Jul 23, 2024
1 parent 7ced3c5 commit ba49029
Show file tree
Hide file tree
Showing 9 changed files with 559 additions and 236 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
import software.amazon.awssdk.services.sqs.SqsClient;

import java.time.Duration;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.Executors;
import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class SqsService {
private static final Logger LOG = LoggerFactory.getLogger(SqsService.class);
Expand All @@ -34,6 +37,7 @@ public class SqsService {
private final PluginMetrics pluginMetrics;
private final AcknowledgementSetManager acknowledgementSetManager;
private final ExecutorService executorService;
private final List<SqsWorker> sqsWorkers;

public SqsService(final AcknowledgementSetManager acknowledgementSetManager,
final S3SourceConfig s3SourceConfig,
Expand All @@ -46,18 +50,20 @@ public SqsService(final AcknowledgementSetManager acknowledgementSetManager,
this.acknowledgementSetManager = acknowledgementSetManager;
this.sqsClient = createSqsClient(credentialsProvider);
executorService = Executors.newFixedThreadPool(s3SourceConfig.getNumWorkers(), BackgroundThreadFactory.defaultExecutorThreadFactory("s3-source-sqs"));
}

public void start() {
final Backoff backoff = Backoff.exponential(INITIAL_DELAY, MAXIMUM_DELAY).withJitter(JITTER_RATE)
.withMaxAttempts(Integer.MAX_VALUE);
for (int i = 0; i < s3SourceConfig.getNumWorkers(); i++) {
executorService.submit(new SqsWorker(acknowledgementSetManager, sqsClient, s3Accessor, s3SourceConfig, pluginMetrics, backoff));
}
sqsWorkers = IntStream.range(0, s3SourceConfig.getNumWorkers())
.mapToObj(i -> new SqsWorker(acknowledgementSetManager, sqsClient, s3Accessor, s3SourceConfig, pluginMetrics, backoff))
.collect(Collectors.toList());
}

public void start() {
sqsWorkers.forEach(executorService::submit);
}

SqsClient createSqsClient(final AwsCredentialsProvider credentialsProvider) {
LOG.info("Creating SQS client");
LOG.debug("Creating SQS client");
return SqsClient.builder()
.region(s3SourceConfig.getAwsAuthenticationOptions().getAwsRegion())
.credentialsProvider(credentialsProvider)
Expand All @@ -68,8 +74,8 @@ SqsClient createSqsClient(final AwsCredentialsProvider credentialsProvider) {
}

public void stop() {
sqsClient.close();
executorService.shutdown();
sqsWorkers.forEach(SqsWorker::stop);
try {
if (!executorService.awaitTermination(SHUTDOWN_TIMEOUT, TimeUnit.SECONDS)) {
LOG.warn("Failed to terminate SqsWorkers");
Expand All @@ -82,5 +88,7 @@ public void stop() {
Thread.currentThread().interrupt();
}
}

sqsClient.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package org.opensearch.dataprepper.plugins.source.s3;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.linecorp.armeria.client.retry.Backoff;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Timer;
Expand All @@ -20,8 +19,7 @@
import org.opensearch.dataprepper.plugins.source.s3.filter.S3EventFilter;
import org.opensearch.dataprepper.plugins.source.s3.filter.S3ObjectCreatedFilter;
import org.opensearch.dataprepper.plugins.source.s3.parser.ParsedMessage;
import org.opensearch.dataprepper.plugins.source.s3.parser.S3EventBridgeNotificationParser;
import org.opensearch.dataprepper.plugins.source.s3.parser.S3EventNotificationParser;
import org.opensearch.dataprepper.plugins.source.s3.parser.SqsMessageParser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.core.exception.SdkException;
Expand Down Expand Up @@ -75,11 +73,10 @@ public class SqsWorker implements Runnable {
private final Counter sqsVisibilityTimeoutChangeFailedCount;
private final Timer sqsMessageDelayTimer;
private final Backoff standardBackoff;
private final SqsMessageParser sqsMessageParser;
private int failedAttemptCount;
private final boolean endToEndAcknowledgementsEnabled;
private final AcknowledgementSetManager acknowledgementSetManager;

private final ObjectMapper objectMapper = new ObjectMapper();
private volatile boolean isStopped = false;
private Map<ParsedMessage, Integer> parsedMessageVisibilityTimesMap;

Expand All @@ -98,6 +95,7 @@ public SqsWorker(final AcknowledgementSetManager acknowledgementSetManager,
sqsOptions = s3SourceConfig.getSqsOptions();
objectCreatedFilter = new S3ObjectCreatedFilter();
evenBridgeObjectCreatedFilter = new EventBridgeObjectCreatedFilter();
sqsMessageParser = new SqsMessageParser(s3SourceConfig);
failedAttemptCount = 0;
parsedMessageVisibilityTimesMap = new HashMap<>();

Expand Down Expand Up @@ -139,7 +137,7 @@ int processSqsMessages() {
if (!sqsMessages.isEmpty()) {
sqsMessagesReceivedCounter.increment(sqsMessages.size());

final Collection<ParsedMessage> s3MessageEventNotificationRecords = getS3MessageEventNotificationRecords(sqsMessages);
final Collection<ParsedMessage> s3MessageEventNotificationRecords = sqsMessageParser.parseSqsMessages(sqsMessages);

// build s3ObjectReference from S3EventNotificationRecord if event name starts with ObjectCreated
final List<DeleteMessageBatchRequestEntry> deleteMessageBatchRequestEntries = processS3EventNotificationRecords(s3MessageEventNotificationRecords);
Expand Down Expand Up @@ -191,22 +189,6 @@ private ReceiveMessageRequest createReceiveMessageRequest() {
.build();
}

private Collection<ParsedMessage> getS3MessageEventNotificationRecords(final List<Message> sqsMessages) {
return sqsMessages.stream()
.map(this::convertS3EventMessages)
.collect(Collectors.toList());
}

private ParsedMessage convertS3EventMessages(final Message message) {
if (s3SourceConfig.getNotificationSource().equals(NotificationSourceOption.S3)) {
return new S3EventNotificationParser().parseMessage(message, objectMapper);
}
else if (s3SourceConfig.getNotificationSource().equals(NotificationSourceOption.EVENTBRIDGE)) {
return new S3EventBridgeNotificationParser().parseMessage(message, objectMapper);
}
return new ParsedMessage(message, true);
}

private List<DeleteMessageBatchRequestEntry> processS3EventNotificationRecords(final Collection<ParsedMessage> s3EventNotificationRecords) {
final List<DeleteMessageBatchRequestEntry> deleteMessageBatchRequestEntryCollection = new ArrayList<>();
final List<ParsedMessage> parsedMessagesToRead = new ArrayList<>();
Expand Down Expand Up @@ -276,21 +258,7 @@ && isEventBridgeEventTypeCreated(parsedMessage)) {
return;
}
parsedMessageVisibilityTimesMap.put(parsedMessage, newValue);
final ChangeMessageVisibilityRequest changeMessageVisibilityRequest = ChangeMessageVisibilityRequest.builder()
.visibilityTimeout(newVisibilityTimeoutSeconds)
.queueUrl(sqsOptions.getSqsUrl())
.receiptHandle(parsedMessage.getMessage().receiptHandle())
.build();

try {
sqsClient.changeMessageVisibility(changeMessageVisibilityRequest);
sqsVisibilityTimeoutChangedCount.increment();
LOG.debug("Set visibility timeout for message {} to {}", parsedMessage.getMessage().messageId(), newVisibilityTimeoutSeconds);
} catch (Exception e) {
LOG.error("Failed to set visibility timeout for message {} to {}", parsedMessage.getMessage().messageId(), newVisibilityTimeoutSeconds, e);
sqsVisibilityTimeoutChangeFailedCount.increment();
}

increaseVisibilityTimeout(parsedMessage, newVisibilityTimeoutSeconds);
},
Duration.ofSeconds(progressCheckInterval));
}
Expand All @@ -308,6 +276,27 @@ && isEventBridgeEventTypeCreated(parsedMessage)) {
return deleteMessageBatchRequestEntryCollection;
}

private void increaseVisibilityTimeout(final ParsedMessage parsedMessage, final int newVisibilityTimeoutSeconds) {
if(isStopped) {
LOG.info("Some messages are pending completion of acknowledgments. Data Prepper will not increase the visibility timeout because it is shutting down. {}", parsedMessage);
return;
}
final ChangeMessageVisibilityRequest changeMessageVisibilityRequest = ChangeMessageVisibilityRequest.builder()
.visibilityTimeout(newVisibilityTimeoutSeconds)
.queueUrl(sqsOptions.getSqsUrl())
.receiptHandle(parsedMessage.getMessage().receiptHandle())
.build();

try {
sqsClient.changeMessageVisibility(changeMessageVisibilityRequest);
sqsVisibilityTimeoutChangedCount.increment();
LOG.debug("Set visibility timeout for message {} to {}", parsedMessage.getMessage().messageId(), newVisibilityTimeoutSeconds);
} catch (Exception e) {
LOG.error("Failed to set visibility timeout for message {} to {}", parsedMessage.getMessage().messageId(), newVisibilityTimeoutSeconds, e);
sqsVisibilityTimeoutChangeFailedCount.increment();
}
}

private Optional<DeleteMessageBatchRequestEntry> processS3Object(
final ParsedMessage parsedMessage,
final S3ObjectReference s3ObjectReference,
Expand All @@ -328,6 +317,8 @@ private Optional<DeleteMessageBatchRequestEntry> processS3Object(
}

private void deleteSqsMessages(final List<DeleteMessageBatchRequestEntry> deleteMessageBatchRequestEntryCollection) {
if(isStopped)
return;
if (deleteMessageBatchRequestEntryCollection.size() == 0) {
return;
}
Expand Down Expand Up @@ -396,6 +387,5 @@ private S3ObjectReference populateS3Reference(final String bucketName, final Str

void stop() {
isStopped = true;
Thread.currentThread().interrupt();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import software.amazon.awssdk.services.sqs.model.Message;

import java.util.List;
import java.util.Objects;

public class ParsedMessage {
private final Message message;
Expand All @@ -24,14 +25,14 @@ public class ParsedMessage {
private String detailType;

public ParsedMessage(final Message message, final boolean failedParsing) {
this.message = message;
this.message = Objects.requireNonNull(message);
this.failedParsing = failedParsing;
this.emptyNotification = true;
}

// S3EventNotification contains only one S3EventNotificationRecord
ParsedMessage(final Message message, final List<S3EventNotification.S3EventNotificationRecord> notificationRecords) {
this.message = message;
this.message = Objects.requireNonNull(message);
// S3EventNotification contains only one S3EventNotificationRecord
this.bucketName = notificationRecords.get(0).getS3().getBucket().getName();
this.objectKey = notificationRecords.get(0).getS3().getObject().getUrlDecodedKey();
this.objectSize = notificationRecords.get(0).getS3().getObject().getSizeAsLong();
Expand All @@ -42,7 +43,7 @@ public ParsedMessage(final Message message, final boolean failedParsing) {
}

ParsedMessage(final Message message, final S3EventBridgeNotification eventBridgeNotification) {
this.message = message;
this.message = Objects.requireNonNull(message);
this.bucketName = eventBridgeNotification.getDetail().getBucket().getName();
this.objectKey = eventBridgeNotification.getDetail().getObject().getUrlDecodedKey();
this.objectSize = eventBridgeNotification.getDetail().getObject().getSize();
Expand Down Expand Up @@ -85,4 +86,12 @@ public boolean isEmptyNotification() {
public String getDetailType() {
return detailType;
}

@Override
public String toString() {
return "Message{" +
"messageId=" + message.messageId() +
", objectKey=" + objectKey +
'}';
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.source.s3.parser;

import com.fasterxml.jackson.databind.ObjectMapper;
import org.opensearch.dataprepper.plugins.source.s3.S3SourceConfig;
import software.amazon.awssdk.services.sqs.model.Message;

import java.util.Collection;
import java.util.stream.Collectors;

public class SqsMessageParser {
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
private final S3SourceConfig s3SourceConfig;
private final S3NotificationParser s3NotificationParser;

public SqsMessageParser(final S3SourceConfig s3SourceConfig) {
this.s3SourceConfig = s3SourceConfig;
s3NotificationParser = createNotificationParser(s3SourceConfig);
}

public Collection<ParsedMessage> parseSqsMessages(final Collection<Message> sqsMessages) {
return sqsMessages.stream()
.map(this::convertS3EventMessages)
.collect(Collectors.toList());
}

private ParsedMessage convertS3EventMessages(final Message message) {
return s3NotificationParser.parseMessage(message, OBJECT_MAPPER);
}

private static S3NotificationParser createNotificationParser(final S3SourceConfig s3SourceConfig) {
switch (s3SourceConfig.getNotificationSource()) {
case EVENTBRIDGE:
return new S3EventBridgeNotificationParser();
case S3:
default:
return new S3EventNotificationParser();
}
}
}
Loading

0 comments on commit ba49029

Please sign in to comment.