diff --git a/data-prepper-plugins/mongodb/build.gradle b/data-prepper-plugins/mongodb/build.gradle index 2c05cecc55..89eaeff185 100644 --- a/data-prepper-plugins/mongodb/build.gradle +++ b/data-prepper-plugins/mongodb/build.gradle @@ -12,10 +12,12 @@ dependencies { implementation project(path: ':data-prepper-plugins:aws-plugin-api') implementation project(path: ':data-prepper-plugins:buffer-common') implementation project(':data-prepper-plugins:http-common') + implementation project(path: ':data-prepper-plugins:common') testImplementation testLibs.mockito.inline testImplementation testLibs.bundles.junit + testImplementation testLibs.slf4j.simple testImplementation project(path: ':data-prepper-test-common') } diff --git a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/configuration/CollectionConfig.java b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/configuration/CollectionConfig.java index 9d9ffa87d0..f710e0f5c6 100644 --- a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/configuration/CollectionConfig.java +++ b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/configuration/CollectionConfig.java @@ -44,6 +44,16 @@ public ExportConfig getExportConfig() { return this.exportConfig; } + public boolean isExportRequired() { + return this.ingestionMode == CollectionConfig.IngestionMode.EXPORT_STREAM || + this.ingestionMode == CollectionConfig.IngestionMode.EXPORT; + } + + public boolean isStreamRequired() { + return this.ingestionMode == CollectionConfig.IngestionMode.EXPORT_STREAM || + this.ingestionMode == CollectionConfig.IngestionMode.STREAM; + } + public static class ExportConfig { private static final int DEFAULT_ITEMS_PER_PARTITION = 4000; @JsonProperty("items_per_partition") diff --git a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/converter/RecordConverter.java b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/converter/RecordConverter.java index 7a5b0636e7..d21258a7ca 100644 --- a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/converter/RecordConverter.java +++ b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/converter/RecordConverter.java @@ -12,8 +12,6 @@ import org.opensearch.dataprepper.model.event.EventMetadata; import org.opensearch.dataprepper.model.opensearch.OpenSearchBulkActions; import org.opensearch.dataprepper.plugins.mongo.configuration.CollectionConfig; -import org.opensearch.dataprepper.plugins.mongo.coordination.partition.ExportPartition; -import org.opensearch.dataprepper.plugins.mongo.coordination.partition.StreamPartition; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -78,12 +76,7 @@ public Event convert(final String record, } final EventMetadata eventMetadata = event.getMetadata(); - if (dataType.equals(ExportPartition.PARTITION_TYPE)) { - eventMetadata.setAttribute(MetadataKeyAttributes.INGESTION_EVENT_TYPE_ATTRIBUTE, ExportPartition.PARTITION_TYPE); - } else if (dataType.equals(StreamPartition.PARTITION_TYPE)) { - eventMetadata.setAttribute(MetadataKeyAttributes.INGESTION_EVENT_TYPE_ATTRIBUTE, StreamPartition.PARTITION_TYPE); - } - + eventMetadata.setAttribute(MetadataKeyAttributes.INGESTION_EVENT_TYPE_ATTRIBUTE, dataType); eventMetadata.setAttribute(MetadataKeyAttributes.MONGODB_EVENT_COLLECTION_METADATA_ATTRIBUTE, collectionConfig.getCollection()); eventMetadata.setAttribute(MetadataKeyAttributes.MONGODB_EVENT_TIMESTAMP_METADATA_ATTRIBUTE, eventCreationTimeMillis); eventMetadata.setAttribute(MetadataKeyAttributes.MONGODB_STREAM_EVENT_NAME_METADATA_ATTRIBUTE, eventName); diff --git a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/documentdb/DocumentDBService.java b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/documentdb/DocumentDBService.java index 7898768d41..5063912fa6 100644 --- a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/documentdb/DocumentDBService.java +++ b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/documentdb/DocumentDBService.java @@ -1,11 +1,13 @@ package org.opensearch.dataprepper.plugins.mongo.documentdb; +import org.opensearch.dataprepper.common.concurrent.BackgroundThreadFactory; import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; import org.opensearch.dataprepper.model.buffer.Buffer; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator; +import org.opensearch.dataprepper.plugins.mongo.configuration.CollectionConfig; import org.opensearch.dataprepper.plugins.mongo.export.MongoDBExportPartitionSupplier; import org.opensearch.dataprepper.plugins.mongo.configuration.MongoDBSourceConfig; import org.opensearch.dataprepper.plugins.mongo.export.ExportScheduler; @@ -15,6 +17,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -24,11 +28,7 @@ public class DocumentDBService { private final PluginMetrics pluginMetrics; private final MongoDBSourceConfig sourceConfig; private final AcknowledgementSetManager acknowledgementSetManager; - private final ExecutorService executor; - private ExportScheduler exportScheduler; - private ExportWorker exportWorker; - private LeaderScheduler leaderScheduler; - private StreamScheduler streamScheduler; + private ExecutorService executor; private final MongoDBExportPartitionSupplier mongoDBExportPartitionSupplier; public DocumentDBService(final EnhancedSourceCoordinator sourceCoordinator, final MongoDBSourceConfig sourceConfig, @@ -38,9 +38,7 @@ public DocumentDBService(final EnhancedSourceCoordinator sourceCoordinator, this.pluginMetrics = pluginMetrics; this.acknowledgementSetManager = acknowledgementSetManager; this.sourceConfig = sourceConfig; - this.mongoDBExportPartitionSupplier = new MongoDBExportPartitionSupplier(sourceConfig); - executor = Executors.newFixedThreadPool(4); } /** @@ -51,15 +49,25 @@ public DocumentDBService(final EnhancedSourceCoordinator sourceCoordinator, * @param buffer Data Prepper Buffer */ public void start(Buffer> buffer) { - this.exportScheduler = new ExportScheduler(sourceCoordinator, mongoDBExportPartitionSupplier, pluginMetrics); - this.exportWorker = new ExportWorker(sourceCoordinator, buffer, pluginMetrics, acknowledgementSetManager, sourceConfig); - this.leaderScheduler = new LeaderScheduler(sourceCoordinator, sourceConfig.getCollections()); - this.streamScheduler = new StreamScheduler(sourceCoordinator, buffer, acknowledgementSetManager, sourceConfig, pluginMetrics); + final List runnableList = new ArrayList<>(); + + final LeaderScheduler leaderScheduler = new LeaderScheduler(sourceCoordinator, sourceConfig.getCollections()); + runnableList.add(leaderScheduler); + + if (sourceConfig.getCollections().stream().anyMatch(CollectionConfig::isExportRequired)) { + final ExportScheduler exportScheduler = new ExportScheduler(sourceCoordinator, mongoDBExportPartitionSupplier, pluginMetrics); + final ExportWorker exportWorker = new ExportWorker(sourceCoordinator, buffer, pluginMetrics, acknowledgementSetManager, sourceConfig); + runnableList.add(exportScheduler); + runnableList.add(exportWorker); + } + + if (sourceConfig.getCollections().stream().anyMatch(CollectionConfig::isStreamRequired)) { + final StreamScheduler streamScheduler = new StreamScheduler(sourceCoordinator, buffer, acknowledgementSetManager, sourceConfig, pluginMetrics); + runnableList.add(streamScheduler); + } - executor.submit(leaderScheduler); - executor.submit(exportScheduler); - executor.submit(exportWorker); - executor.submit(streamScheduler); + executor = Executors.newFixedThreadPool(runnableList.size(), BackgroundThreadFactory.defaultExecutorThreadFactory("documentdb-source")); + runnableList.forEach(executor::submit); } /** @@ -67,7 +75,9 @@ public void start(Buffer> buffer) { * Each scheduler must implement logic for gracefully shutdown. */ public void shutdown() { - LOG.info("shutdown DocumentDB Service scheduler and worker"); - executor.shutdownNow(); + if (executor != null) { + LOG.info("shutdown DocumentDB Service scheduler and worker"); + executor.shutdownNow(); + } } } diff --git a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/leader/LeaderScheduler.java b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/leader/LeaderScheduler.java index f4ab200e27..53d0b8d912 100644 --- a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/leader/LeaderScheduler.java +++ b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/leader/LeaderScheduler.java @@ -101,15 +101,7 @@ public void run() { } } - private boolean isExportRequired(final CollectionConfig.IngestionMode ingestionMode) { - return ingestionMode == CollectionConfig.IngestionMode.EXPORT_STREAM || - ingestionMode == CollectionConfig.IngestionMode.EXPORT; - } - private boolean isStreamRequired(final CollectionConfig.IngestionMode ingestionMode) { - return ingestionMode == CollectionConfig.IngestionMode.EXPORT_STREAM || - ingestionMode == CollectionConfig.IngestionMode.STREAM; - } private void init() { LOG.info("Try to initialize DocumentDB Leader Partition"); @@ -120,13 +112,13 @@ private void init() { coordinator.createPartition(new GlobalState(collectionConfig.getCollection(), null)); final Instant startTime = Instant.now(); - final boolean exportRequired = isExportRequired(collectionConfig.getIngestionMode()); + final boolean exportRequired = collectionConfig.isExportRequired(); LOG.info("Ingestion mode {} for Collection {}", collectionConfig.getIngestionMode(), collectionConfig.getCollection()); if (exportRequired) { createExportPartition(collectionConfig, startTime); } - if (isStreamRequired(collectionConfig.getIngestionMode())) { + if (collectionConfig.isStreamRequired()) { createStreamPartition(collectionConfig, startTime, exportRequired); } diff --git a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/model/CheckpointStatus.java b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/model/CheckpointStatus.java new file mode 100644 index 0000000000..c7aafa556f --- /dev/null +++ b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/model/CheckpointStatus.java @@ -0,0 +1,45 @@ +package org.opensearch.dataprepper.plugins.mongo.model; + +public class CheckpointStatus { + private final String resumeToken; + private final long recordCount; + private boolean acknowledged; + private final long createTimestamp; + private Long acknowledgedTimestamp; + + public CheckpointStatus(final String resumeToken, final long recordCount, final long createTimestamp) { + this.resumeToken = resumeToken; + this.recordCount = recordCount; + this.acknowledged = false; + this.createTimestamp = createTimestamp; + } + + public void setAcknowledgedTimestamp(final Long acknowledgedTimestamp) { + this.acknowledgedTimestamp = acknowledgedTimestamp; + } + + public void setAcknowledged(boolean acknowledged) { + this.acknowledged = acknowledged; + } + + public String getResumeToken() { + return resumeToken; + } + public long getRecordCount() { + return recordCount; + } + + public boolean isAcknowledged() { + return acknowledged; + } + + public long getCreateTimestamp() { + return createTimestamp; + } + + public long getAcknowledgedTimestamp() { + return acknowledgedTimestamp; + } + + +} diff --git a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/DataStreamPartitionCheckpoint.java b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/DataStreamPartitionCheckpoint.java index 642fb659c9..9d6b9a2e67 100644 --- a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/DataStreamPartitionCheckpoint.java +++ b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/DataStreamPartitionCheckpoint.java @@ -50,16 +50,16 @@ private void setProgressState(final String resumeToken, final long recordNumber) } /** - * This method is to do a checkpoint with latest sequence number processed. - * Note that this should be called on a regular basis even there are no changes to sequence number + * This method is to do a checkpoint with latest resume token processed. + * Note that this should be called on a regular basis even there are no changes to resume token * As the checkpoint will also extend the timeout for the lease * - * @param resumeToken - * @param recordNumber The last record number + * @param resumeToken checkpoint token to start resuming the stream when MongoDB/DocumentDB cursor is open + * @param recordCount The last processed record count */ - public void checkpoint(final String resumeToken, final long recordNumber) { - LOG.debug("Checkpoint stream partition for collection " + streamPartition.getCollection() + " with record number " + recordNumber); - setProgressState(resumeToken, recordNumber); + public void checkpoint(final String resumeToken, final long recordCount) { + LOG.debug("Checkpoint stream partition for collection " + streamPartition.getCollection() + " with record number " + recordCount); + setProgressState(resumeToken, recordCount); enhancedSourceCoordinator.saveProgressStateForPartition(streamPartition, CHECKPOINT_OWNERSHIP_TIMEOUT_INCREASE); } @@ -76,4 +76,8 @@ public Optional getGlobalStreamLoadStatus() { public void updateStreamPartitionForAcknowledgmentWait(final Duration acknowledgmentSetTimeout) { enhancedSourceCoordinator.saveProgressStateForPartition(streamPartition, acknowledgmentSetTimeout); } + + public void giveUpPartition() { + enhancedSourceCoordinator.giveUpPartition(streamPartition); + } } diff --git a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamAcknowledgementManager.java b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamAcknowledgementManager.java new file mode 100644 index 0000000000..b1fe8d3529 --- /dev/null +++ b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamAcknowledgementManager.java @@ -0,0 +1,127 @@ +package org.opensearch.dataprepper.plugins.mongo.stream; + +import com.google.common.annotations.VisibleForTesting; +import org.opensearch.dataprepper.common.concurrent.BackgroundThreadFactory; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.opensearch.dataprepper.plugins.mongo.model.CheckpointStatus; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.time.Instant; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.function.Consumer; + +public class StreamAcknowledgementManager { + private static final Logger LOG = LoggerFactory.getLogger(StreamAcknowledgementManager.class); + private final ConcurrentLinkedQueue checkpoints = new ConcurrentLinkedQueue<>(); + private final ConcurrentHashMap ackStatus = new ConcurrentHashMap<>(); + private final AcknowledgementSetManager acknowledgementSetManager; + private final DataStreamPartitionCheckpoint partitionCheckpoint; + + private final Duration partitionAcknowledgmentTimeout; + private final int acknowledgementMonitorWaitTimeInMs; + private final int checkPointIntervalInMs; + private final ExecutorService executorService; + + private boolean enableAcknowledgement = false; + + public StreamAcknowledgementManager(final AcknowledgementSetManager acknowledgementSetManager, + final DataStreamPartitionCheckpoint partitionCheckpoint, + final Duration partitionAcknowledgmentTimeout, + final int acknowledgementMonitorWaitTimeInMs, + final int checkPointIntervalInMs) { + this.acknowledgementSetManager = acknowledgementSetManager; + this.partitionCheckpoint = partitionCheckpoint; + this.partitionAcknowledgmentTimeout = partitionAcknowledgmentTimeout; + this.acknowledgementMonitorWaitTimeInMs = acknowledgementMonitorWaitTimeInMs; + this.checkPointIntervalInMs = checkPointIntervalInMs; + executorService = Executors.newSingleThreadExecutor(BackgroundThreadFactory.defaultExecutorThreadFactory("mongodb-stream-ack-monitor")); + } + + void init(final Consumer stopWorkerConsumer) { + enableAcknowledgement = true; + executorService.submit(() -> monitorAcknowledgment(executorService, stopWorkerConsumer)); + } + + private void monitorAcknowledgment(final ExecutorService executorService, final Consumer stopWorkerConsumer) { + long lastCheckpointTime = System.currentTimeMillis(); + CheckpointStatus lastCheckpointStatus = null; + while (!Thread.currentThread().isInterrupted()) { + final CheckpointStatus checkpointStatus = checkpoints.peek(); + if (checkpointStatus != null) { + if (checkpointStatus.isAcknowledged()) { + lastCheckpointStatus = checkpoints.poll(); + ackStatus.remove(checkpointStatus.getResumeToken()); + if (System.currentTimeMillis() - lastCheckpointTime >= checkPointIntervalInMs) { + LOG.debug("Perform regular checkpointing for resume token {} at record count {}", checkpointStatus.getResumeToken(), checkpointStatus.getRecordCount()); + partitionCheckpoint.checkpoint(checkpointStatus.getResumeToken(), checkpointStatus.getRecordCount()); + lastCheckpointTime = System.currentTimeMillis(); + } + } else { + LOG.debug("Checkpoint not complete for resume token {}", checkpointStatus.getResumeToken()); + final Duration ackWaitDuration = Duration.between(Instant.ofEpochMilli(checkpointStatus.getCreateTimestamp()), Instant.now()); + // Acknowledgement not received for the checkpoint after twice ack wait time + if (ackWaitDuration.getSeconds() >= partitionAcknowledgmentTimeout.getSeconds() * 2) { + // Give up partition and should interrupt parent thread to stop processing stream + if (lastCheckpointStatus != null && lastCheckpointStatus.isAcknowledged()) { + partitionCheckpoint.checkpoint(lastCheckpointStatus.getResumeToken(), lastCheckpointStatus.getRecordCount()); + } + LOG.warn("Acknowledgement not received for the checkpoint {} past wait time. Giving up partition.", checkpointStatus.getResumeToken()); + partitionCheckpoint.giveUpPartition(); + break; + } + } + } + + try { + Thread.sleep(acknowledgementMonitorWaitTimeInMs); + } catch (InterruptedException ex) { + break; + } + } + stopWorkerConsumer.accept(null); + executorService.shutdown(); + } + + Optional createAcknowledgementSet(final String resumeToken, final long recordNumber) { + if (!enableAcknowledgement) { + return Optional.empty(); + } + + final CheckpointStatus checkpointStatus = new CheckpointStatus(resumeToken, recordNumber, Instant.now().toEpochMilli()); + checkpoints.add(checkpointStatus); + ackStatus.put(resumeToken, checkpointStatus); + return Optional.of(acknowledgementSetManager.create((result) -> { + if (result) { + final CheckpointStatus ackCheckpointStatus = ackStatus.get(resumeToken); + ackCheckpointStatus.setAcknowledgedTimestamp(Instant.now().toEpochMilli()); + ackCheckpointStatus.setAcknowledged(true); + LOG.debug("Received acknowledgment of completion from sink for checkpoint {}", resumeToken); + } else { + LOG.warn("Negative acknowledgment received for checkpoint {}, resetting checkpoint", resumeToken); + // default CheckpointStatus acknowledged value is false. The monitorCheckpoints method will time out + // and reprocess stream from last successful checkpoint in the order. + } + }, partitionAcknowledgmentTimeout)); + } + + void shutdown() { + executorService.shutdown(); + } + + @VisibleForTesting + ConcurrentHashMap getAcknowledgementStatus() { + return ackStatus; + } + + @VisibleForTesting + ConcurrentLinkedQueue getCheckpoints() { + return checkpoints; + } +} diff --git a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamScheduler.java b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamScheduler.java index 60752fa0d3..4fd8ec0e0a 100644 --- a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamScheduler.java +++ b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamScheduler.java @@ -23,6 +23,8 @@ public class StreamScheduler implements Runnable { private static final Logger LOG = LoggerFactory.getLogger(StreamScheduler.class); private static final int DEFAULT_TAKE_LEASE_INTERVAL_MILLIS = 60_000; + static final int DEFAULT_CHECKPOINT_INTERVAL_MILLS = 120_000; + private static final int DEFAULT_MONITOR_WAIT_TIME_MS = 15_000; /** * Number of records to accumulate before flushing to buffer */ @@ -30,7 +32,7 @@ public class StreamScheduler implements Runnable { /** * Number of stream records to accumulate to write to buffer and checkpoint */ - private static final int DEFAULT_STREAM_BATCH_SIZE = 100; + static final int DEFAULT_STREAM_BATCH_SIZE = 100; static final Duration BUFFER_TIMEOUT = Duration.ofSeconds(60); private final EnhancedSourceCoordinator sourceCoordinator; private final RecordBufferWriter recordBufferWriter; @@ -61,8 +63,10 @@ public void run() { if (sourcePartition.isPresent()) { streamPartition = (StreamPartition) sourcePartition.get(); final DataStreamPartitionCheckpoint partitionCheckpoint = new DataStreamPartitionCheckpoint(sourceCoordinator, streamPartition); - final StreamWorker streamWorker = StreamWorker.create(recordBufferWriter, acknowledgementSetManager, - sourceConfig, partitionCheckpoint, pluginMetrics, DEFAULT_STREAM_BATCH_SIZE); + final StreamAcknowledgementManager streamAcknowledgementManager = new StreamAcknowledgementManager(acknowledgementSetManager, partitionCheckpoint, + sourceConfig.getPartitionAcknowledgmentTimeout(), DEFAULT_MONITOR_WAIT_TIME_MS, DEFAULT_CHECKPOINT_INTERVAL_MILLS); + final StreamWorker streamWorker = StreamWorker.create(recordBufferWriter, sourceConfig, + streamAcknowledgementManager, partitionCheckpoint, pluginMetrics, DEFAULT_STREAM_BATCH_SIZE, DEFAULT_CHECKPOINT_INTERVAL_MILLS); streamWorker.processStream(streamPartition); } try { diff --git a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamWorker.java b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamWorker.java index 1ff882b8c6..f823bceb98 100644 --- a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamWorker.java +++ b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamWorker.java @@ -13,7 +13,6 @@ import org.bson.json.JsonWriterSettings; import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; -import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; import org.opensearch.dataprepper.plugins.mongo.buffer.RecordBufferWriter; import org.opensearch.dataprepper.plugins.mongo.client.MongoDBConnection; import org.opensearch.dataprepper.plugins.mongo.configuration.MongoDBSourceConfig; @@ -41,9 +40,12 @@ public class StreamWorker { private final MongoDBSourceConfig sourceConfig; private final Counter successItemsCounter; private final Counter failureItemsCounter; - private final AcknowledgementSetManager acknowledgementSetManager; + private final StreamAcknowledgementManager streamAcknowledgementManager; private final PluginMetrics pluginMetrics; - private final int defaultFlushBatchSize; + private final int recordFlushBatchSize; + final int checkPointIntervalInMs; + private boolean stopWorker = false; + private final JsonWriterSettings writerSettings = JsonWriterSettings.builder() .outputMode(JsonMode.RELAXED) @@ -51,30 +53,37 @@ public class StreamWorker { .build(); public static StreamWorker create(final RecordBufferWriter recordBufferWriter, - final AcknowledgementSetManager acknowledgementSetManager, final MongoDBSourceConfig sourceConfig, + final StreamAcknowledgementManager streamAcknowledgementManager, final DataStreamPartitionCheckpoint partitionCheckpoint, final PluginMetrics pluginMetrics, - final int defaultFlushBatchSize + final int recordFlushBatchSize, + final int checkPointIntervalInMs ) { - return new StreamWorker(recordBufferWriter, acknowledgementSetManager, - sourceConfig, partitionCheckpoint, pluginMetrics, defaultFlushBatchSize); + return new StreamWorker(recordBufferWriter, sourceConfig, streamAcknowledgementManager, partitionCheckpoint, + pluginMetrics, recordFlushBatchSize, checkPointIntervalInMs); } public StreamWorker(final RecordBufferWriter recordBufferWriter, - final AcknowledgementSetManager acknowledgementSetManager, final MongoDBSourceConfig sourceConfig, + final StreamAcknowledgementManager streamAcknowledgementManager, final DataStreamPartitionCheckpoint partitionCheckpoint, final PluginMetrics pluginMetrics, - final int defaultFlushBatchSize + final int recordFlushBatchSize, + final int checkPointIntervalInMs ) { this.recordBufferWriter = recordBufferWriter; this.sourceConfig = sourceConfig; + this.streamAcknowledgementManager = streamAcknowledgementManager; this.partitionCheckpoint = partitionCheckpoint; - this.acknowledgementSetManager = acknowledgementSetManager; this.pluginMetrics = pluginMetrics; - this.defaultFlushBatchSize = defaultFlushBatchSize; + this.recordFlushBatchSize = recordFlushBatchSize; + this.checkPointIntervalInMs = checkPointIntervalInMs; this.successItemsCounter = pluginMetrics.counter(SUCCESS_ITEM_COUNTER_NAME); this.failureItemsCounter = pluginMetrics.counter(FAILURE_ITEM_COUNTER_NAME); + if (sourceConfig.isAcknowledgmentsEnabled()) { + // starts acknowledgement monitoring thread + streamAcknowledgementManager.init((Void) -> stop()); + } } private MongoCursor> getChangeStreamCursor(final MongoCollection collection, @@ -101,7 +110,7 @@ public void processStream(final StreamPartition streamPartition) { if (collectionDBNameList.size() < 2) { throw new IllegalArgumentException("Invalid Collection Name. Must be in db.collection format"); } - int recordCount = 0; + long recordCount = 0; final List records = new ArrayList<>(); // TODO: create acknowledgementSet AcknowledgementSet acknowledgementSet = null; @@ -120,11 +129,12 @@ public void processStream(final StreamPartition streamPartition) { Thread.sleep(DEFAULT_EXPORT_COMPLETE_WAIT_INTERVAL_MILLIS); } catch (final InterruptedException ex) { LOG.info("The StreamScheduler was interrupted while waiting to retry, stopping processing"); + Thread.currentThread().interrupt(); break; } } - - while (cursor.hasNext() && !Thread.currentThread().isInterrupted()) { + long lastCheckpointTime = System.currentTimeMillis(); + while (cursor.hasNext() && !Thread.currentThread().isInterrupted() && !stopWorker) { try { final ChangeStreamDocument document = cursor.next(); final String record = document.getFullDocument().toJson(writerSettings); @@ -134,17 +144,24 @@ public void processStream(final StreamPartition streamPartition) { records.add(record); recordCount += 1; - if (recordCount % defaultFlushBatchSize == 0) { - LOG.debug("Write to buffer for line " + (recordCount - defaultFlushBatchSize) + " to " + recordCount); + if (recordCount % recordFlushBatchSize == 0) { + LOG.debug("Write to buffer for line {} to {}", (recordCount - recordFlushBatchSize), recordCount); + acknowledgementSet = streamAcknowledgementManager.createAcknowledgementSet(checkPointToken, recordCount).orElse(null); recordBufferWriter.writeToBuffer(acknowledgementSet, records); + successItemsCounter.increment(records.size()); records.clear(); - LOG.debug("Perform regular checkpointing for stream Loader"); - partitionCheckpoint.checkpoint(checkPointToken, recordCount); - successItemsCounter.increment(); + if (!sourceConfig.isAcknowledgmentsEnabled() && (System.currentTimeMillis() - lastCheckpointTime >= checkPointIntervalInMs)) { + LOG.debug("Perform regular checkpointing for resume token {} at record count {}", checkPointToken, recordCount); + partitionCheckpoint.checkpoint(checkPointToken, recordCount); + lastCheckpointTime = System.currentTimeMillis(); + } } } catch (Exception e) { - LOG.error("Failed to add record to buffer with error {}", e.getMessage()); - failureItemsCounter.increment(); + // TODO handle documents with size > 10 MB. + // this will only happen if writing to buffer gets interrupted from shutdown, + // otherwise it's infinite backoff and retry + LOG.error("Failed to add records to buffer with error {}", e.getMessage()); + failureItemsCounter.increment(records.size()); } } } @@ -152,12 +169,25 @@ public void processStream(final StreamPartition streamPartition) { LOG.error("Exception connecting to cluster and processing stream", e); throw new RuntimeException(e); } finally { - LOG.info("Checkpointing processing stream"); if (!records.isEmpty()) { + LOG.info("Flushing and checkpointing last processed record batch from the stream before terminating"); + acknowledgementSet = streamAcknowledgementManager.createAcknowledgementSet(checkPointToken, recordCount).orElse(null); recordBufferWriter.writeToBuffer(acknowledgementSet, records); + successItemsCounter.increment(records.size()); } // Do final checkpoint. - partitionCheckpoint.checkpoint(checkPointToken, recordCount); + if (!sourceConfig.isAcknowledgmentsEnabled()) { + partitionCheckpoint.checkpoint(checkPointToken, recordCount); + } + + // shutdown acknowledgement monitoring thread + if (streamAcknowledgementManager != null) { + streamAcknowledgementManager.shutdown(); + } } } + + void stop() { + stopWorker = true; + } } diff --git a/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/export/ExportPartitionWorkerTest.java b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/export/ExportPartitionWorkerTest.java index 2b5c6b885b..fe5447404c 100644 --- a/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/export/ExportPartitionWorkerTest.java +++ b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/export/ExportPartitionWorkerTest.java @@ -132,6 +132,5 @@ public void testProcessPartitionSuccess(final String partitionKey) { verify(mockRecordBufferWriter).writeToBuffer(eq(mockAcknowledgementSet), any()); verify(successItemsCounter, times(2)).increment(); verify(failureItemsCounter, never()).increment(); - executorService.shutdownNow(); } } \ No newline at end of file diff --git a/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/leader/LeaderSchedulerTest.java b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/leader/LeaderSchedulerTest.java index afae4cdc8d..1387ed24f8 100644 --- a/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/leader/LeaderSchedulerTest.java +++ b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/leader/LeaderSchedulerTest.java @@ -63,7 +63,8 @@ void test_should_init() { leaderScheduler = new LeaderScheduler(coordinator, List.of(collectionConfig), Duration.ofMillis(100)); leaderPartition = new LeaderPartition(); given(coordinator.acquireAvailablePartition(LeaderPartition.PARTITION_TYPE)).willReturn(Optional.of(leaderPartition)); - given(collectionConfig.getIngestionMode()).willReturn(CollectionConfig.IngestionMode.EXPORT_STREAM); + given(collectionConfig.isExportRequired()).willReturn(true); + given(collectionConfig.isStreamRequired()).willReturn(true); given(collectionConfig.getExportConfig()).willReturn(exportConfig); given(exportConfig.getItemsPerPartition()).willReturn(new Random().nextInt()); given(collectionConfig.getCollection()).willReturn(UUID.randomUUID().toString()); @@ -96,7 +97,7 @@ void test_should_init_export() { leaderScheduler = new LeaderScheduler(coordinator, List.of(collectionConfig), Duration.ofMillis(100)); leaderPartition = new LeaderPartition(); given(coordinator.acquireAvailablePartition(LeaderPartition.PARTITION_TYPE)).willReturn(Optional.of(leaderPartition)); - given(collectionConfig.getIngestionMode()).willReturn(CollectionConfig.IngestionMode.EXPORT); + given(collectionConfig.isExportRequired()).willReturn(true); given(collectionConfig.getExportConfig()).willReturn(exportConfig); given(exportConfig.getItemsPerPartition()).willReturn(new Random().nextInt()); given(collectionConfig.getCollection()).willReturn(UUID.randomUUID().toString()); @@ -129,7 +130,7 @@ void test_should_init_stream() { leaderScheduler = new LeaderScheduler(coordinator, List.of(collectionConfig), Duration.ofMillis(100)); leaderPartition = new LeaderPartition(); given(coordinator.acquireAvailablePartition(LeaderPartition.PARTITION_TYPE)).willReturn(Optional.of(leaderPartition)); - given(collectionConfig.getIngestionMode()).willReturn(CollectionConfig.IngestionMode.STREAM); + given(collectionConfig.isStreamRequired()).willReturn(true); given(collectionConfig.getCollection()).willReturn(UUID.randomUUID().toString()); final ExecutorService executorService = Executors.newSingleThreadExecutor(); diff --git a/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamAcknowledgementManagerTest.java b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamAcknowledgementManagerTest.java new file mode 100644 index 0000000000..40d556ec35 --- /dev/null +++ b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamAcknowledgementManagerTest.java @@ -0,0 +1,178 @@ +package org.opensearch.dataprepper.plugins.mongo.stream; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.opensearch.dataprepper.plugins.mongo.model.CheckpointStatus; + +import java.time.Duration; +import java.util.List; +import java.util.Optional; +import java.util.Random; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; + +import static org.awaitility.Awaitility.await; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.CoreMatchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +public class StreamAcknowledgementManagerTest { + + @Mock + private AcknowledgementSetManager acknowledgementSetManager; + @Mock + private DataStreamPartitionCheckpoint partitionCheckpoint; + @Mock + private Duration timeout; + @Mock + private AcknowledgementSet acknowledgementSet; + @Mock + private Consumer stopWorkerConsumer; + private StreamAcknowledgementManager streamAckManager; + + @BeforeEach + public void setup() { + streamAckManager = new StreamAcknowledgementManager(acknowledgementSetManager, partitionCheckpoint, timeout, 0, 0); + } + + @Test + public void createAcknowledgementSet_disabled_emptyAckSet() { + final Optional ackSet = streamAckManager.createAcknowledgementSet(UUID.randomUUID().toString(), new Random().nextInt()); + assertThat(ackSet.isEmpty(), is(true)); + } + + @Test + public void createAcknowledgementSet_enabled_ackSetWithAck() { + lenient().when(timeout.getSeconds()).thenReturn(10_000L); + streamAckManager = new StreamAcknowledgementManager(acknowledgementSetManager, partitionCheckpoint, timeout, 0, 0); + streamAckManager.init(stopWorkerConsumer); + final String resumeToken = UUID.randomUUID().toString(); + final long recordCount = new Random().nextLong(); + when(acknowledgementSetManager.create(any(Consumer.class), eq(timeout))).thenReturn(acknowledgementSet); + final Optional ackSet = streamAckManager.createAcknowledgementSet(resumeToken, recordCount); + assertThat(ackSet.isEmpty(), is(false)); + assertThat(ackSet.get(), is(acknowledgementSet)); + assertThat(streamAckManager.getCheckpoints().peek().getResumeToken(), is(resumeToken)); + assertThat(streamAckManager.getCheckpoints().peek().getRecordCount(), is(recordCount)); + final ArgumentCaptor> argumentCaptor = ArgumentCaptor.forClass(Consumer.class); + verify(acknowledgementSetManager).create(argumentCaptor.capture(), eq(timeout)); + final Consumer consumer = argumentCaptor.getValue(); + consumer.accept(true); + final ConcurrentHashMap ackStatus = streamAckManager.getAcknowledgementStatus(); + final CheckpointStatus ackCheckpointStatus = ackStatus.get(resumeToken); + assertThat(ackCheckpointStatus.isAcknowledged(), is(true)); + await() + .atMost(Duration.ofSeconds(10)).untilAsserted(() -> + verify(partitionCheckpoint).checkpoint(resumeToken, recordCount)); + assertThat(streamAckManager.getCheckpoints().peek(), is(nullValue())); + } + + @Test + public void createAcknowledgementSet_enabled_multipleAckSetWithAck() { + when(timeout.getSeconds()).thenReturn(10_000L); + streamAckManager = new StreamAcknowledgementManager(acknowledgementSetManager, partitionCheckpoint, timeout, 0, 0); + streamAckManager.init(stopWorkerConsumer); + final String resumeToken1 = UUID.randomUUID().toString(); + final long recordCount1 = new Random().nextLong(); + when(acknowledgementSetManager.create(any(Consumer.class), eq(timeout))).thenReturn(acknowledgementSet); + Optional ackSet = streamAckManager.createAcknowledgementSet(resumeToken1, recordCount1); + assertThat(ackSet.isEmpty(), is(false)); + assertThat(ackSet.get(), is(acknowledgementSet)); + assertThat(streamAckManager.getCheckpoints().peek().getResumeToken(), is(resumeToken1)); + assertThat(streamAckManager.getCheckpoints().peek().getRecordCount(), is(recordCount1)); + + final String resumeToken2 = UUID.randomUUID().toString(); + final long recordCount2 = new Random().nextLong(); + when(acknowledgementSetManager.create(any(Consumer.class), eq(timeout))).thenReturn(acknowledgementSet); + ackSet = streamAckManager.createAcknowledgementSet(resumeToken2, recordCount2); + assertThat(ackSet.isEmpty(), is(false)); + assertThat(ackSet.get(), is(acknowledgementSet)); + assertThat(streamAckManager.getCheckpoints().peek().getResumeToken(), is(resumeToken1)); + assertThat(streamAckManager.getCheckpoints().peek().getRecordCount(), is(recordCount1)); + ArgumentCaptor> argumentCaptor = ArgumentCaptor.forClass(Consumer.class); + verify(acknowledgementSetManager, times(2)).create(argumentCaptor.capture(), eq(timeout)); + List> consumers = argumentCaptor.getAllValues(); + consumers.get(0).accept(true); + consumers.get(1).accept(true); + ConcurrentHashMap ackStatus = streamAckManager.getAcknowledgementStatus(); + CheckpointStatus ackCheckpointStatus = ackStatus.get(resumeToken2); + assertThat(ackCheckpointStatus.isAcknowledged(), is(true)); + await() + .atMost(Duration.ofSeconds(10)).untilAsserted(() -> + verify(partitionCheckpoint).checkpoint(resumeToken2, recordCount2)); + assertThat(streamAckManager.getCheckpoints().peek(), is(nullValue())); + } + + @Test + public void createAcknowledgementSet_enabled_multipleAckSetWithAckFailure() { + streamAckManager.init(stopWorkerConsumer); + final String resumeToken1 = UUID.randomUUID().toString(); + final long recordCount1 = new Random().nextLong(); + when(acknowledgementSetManager.create(any(Consumer.class), eq(timeout))).thenReturn(acknowledgementSet); + Optional ackSet = streamAckManager.createAcknowledgementSet(resumeToken1, recordCount1); + assertThat(ackSet.isEmpty(), is(false)); + assertThat(ackSet.get(), is(acknowledgementSet)); + assertThat(streamAckManager.getCheckpoints().peek().getResumeToken(), is(resumeToken1)); + assertThat(streamAckManager.getCheckpoints().peek().getRecordCount(), is(recordCount1)); + + final String resumeToken2 = UUID.randomUUID().toString(); + final long recordCount2 = new Random().nextLong(); + when(acknowledgementSetManager.create(any(Consumer.class), eq(timeout))).thenReturn(acknowledgementSet); + ackSet = streamAckManager.createAcknowledgementSet(resumeToken2, recordCount2); + assertThat(ackSet.isEmpty(), is(false)); + assertThat(ackSet.get(), is(acknowledgementSet)); + assertThat(streamAckManager.getCheckpoints().peek().getResumeToken(), is(resumeToken1)); + assertThat(streamAckManager.getCheckpoints().peek().getRecordCount(), is(recordCount1)); + ArgumentCaptor> argumentCaptor = ArgumentCaptor.forClass(Consumer.class); + verify(acknowledgementSetManager, times(2)).create(argumentCaptor.capture(), eq(timeout)); + List> consumers = argumentCaptor.getAllValues(); + consumers.get(0).accept(false); + consumers.get(1).accept(true); + ConcurrentHashMap ackStatus = streamAckManager.getAcknowledgementStatus(); + CheckpointStatus ackCheckpointStatus = ackStatus.get(resumeToken2); + assertThat(ackCheckpointStatus.isAcknowledged(), is(true)); + await() + .atMost(Duration.ofSeconds(10)).untilAsserted(() -> + verify(partitionCheckpoint).giveUpPartition()); + assertThat(streamAckManager.getCheckpoints().peek().getResumeToken(), is(resumeToken1)); + assertThat(streamAckManager.getCheckpoints().peek().getRecordCount(), is(recordCount1)); + verify(stopWorkerConsumer).accept(null); + } + + @Test + public void createAcknowledgementSet_enabled_ackSetWithNoAck() { + streamAckManager.init(stopWorkerConsumer); + final String resumeToken = UUID.randomUUID().toString(); + final long recordCount = new Random().nextLong(); + when(acknowledgementSetManager.create(any(Consumer.class), eq(timeout))).thenReturn(acknowledgementSet); + final Optional ackSet = streamAckManager.createAcknowledgementSet(resumeToken, recordCount); + assertThat(ackSet.isEmpty(), is(false)); + assertThat(ackSet.get(), is(acknowledgementSet)); + assertThat(streamAckManager.getCheckpoints().peek().getResumeToken(), is(resumeToken)); + assertThat(streamAckManager.getCheckpoints().peek().getRecordCount(), is(recordCount)); + final ArgumentCaptor> argumentCaptor = ArgumentCaptor.forClass(Consumer.class); + verify(acknowledgementSetManager).create(argumentCaptor.capture(), eq(timeout)); + final Consumer consumer = argumentCaptor.getValue(); + consumer.accept(false); + final ConcurrentHashMap ackStatus = streamAckManager.getAcknowledgementStatus(); + final CheckpointStatus ackCheckpointStatus = ackStatus.get(resumeToken); + assertThat(ackCheckpointStatus.isAcknowledged(), is(false)); + await() + .atMost(Duration.ofSeconds(10)).untilAsserted(() -> + verify(stopWorkerConsumer).accept(null)); +} +} diff --git a/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamSchedulerTest.java b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamSchedulerTest.java index dced676ceb..b02e4ee20f 100644 --- a/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamSchedulerTest.java +++ b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamSchedulerTest.java @@ -32,6 +32,7 @@ import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; +import static org.opensearch.dataprepper.plugins.mongo.stream.StreamScheduler.DEFAULT_CHECKPOINT_INTERVAL_MILLS; @ExtendWith(MockitoExtension.class) @@ -85,9 +86,9 @@ void test_stream_run() { final ExecutorService executorService = Executors.newSingleThreadExecutor(); final Future future = executorService.submit(() -> { try (MockedStatic streamWorkerMockedStatic = mockStatic(StreamWorker.class)) { - streamWorkerMockedStatic.when(() -> StreamWorker.create(any(RecordBufferWriter.class), eq(acknowledgementSetManager), - eq(sourceConfig), any(DataStreamPartitionCheckpoint.class), eq(pluginMetrics), eq(100))) - .thenReturn(streamWorker); + streamWorkerMockedStatic.when(() -> StreamWorker.create(any(RecordBufferWriter.class), eq(sourceConfig), + any(StreamAcknowledgementManager.class), any(DataStreamPartitionCheckpoint.class), eq(pluginMetrics), eq(100), eq(DEFAULT_CHECKPOINT_INTERVAL_MILLS))) + .thenReturn(streamWorker); streamScheduler.run(); } }); diff --git a/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamWorkerTest.java b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamWorkerTest.java index 9f0ba53a63..ca69d03b67 100644 --- a/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamWorkerTest.java +++ b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamWorkerTest.java @@ -11,6 +11,7 @@ import org.bson.BsonDocument; import org.bson.BsonInt32; import org.bson.Document; +import org.bson.json.JsonWriterSettings; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -18,21 +19,24 @@ import org.mockito.MockedStatic; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.dataprepper.metrics.PluginMetrics; -import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; import org.opensearch.dataprepper.plugins.mongo.buffer.RecordBufferWriter; import org.opensearch.dataprepper.plugins.mongo.client.MongoDBConnection; import org.opensearch.dataprepper.plugins.mongo.configuration.MongoDBSourceConfig; import org.opensearch.dataprepper.plugins.mongo.coordination.partition.StreamPartition; import org.opensearch.dataprepper.plugins.mongo.coordination.state.StreamProgressState; +import java.time.Duration; import java.util.Optional; import java.util.UUID; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import static org.awaitility.Awaitility.await; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.never; @@ -48,7 +52,7 @@ public class StreamWorkerTest { @Mock private RecordBufferWriter mockRecordBufferWriter; @Mock - private AcknowledgementSetManager mockAcknowledgementSetManager; + private StreamAcknowledgementManager mockStreamAcknowledgementManager; @Mock private MongoDBSourceConfig mockSourceConfig; @Mock @@ -71,8 +75,10 @@ public class StreamWorkerTest { public void setup() { when(mockPluginMetrics.counter(SUCCESS_ITEM_COUNTER_NAME)).thenReturn(successItemsCounter); when(mockPluginMetrics.counter(FAILURE_ITEM_COUNTER_NAME)).thenReturn(failureItemsCounter); - streamWorker = new StreamWorker(mockRecordBufferWriter, mockAcknowledgementSetManager, - mockSourceConfig, mockPartitionCheckpoint, mockPluginMetrics, 2); + when(mockSourceConfig.isAcknowledgmentsEnabled()).thenReturn(false); + Thread.interrupted(); + streamWorker = new StreamWorker(mockRecordBufferWriter, mockSourceConfig, mockStreamAcknowledgementManager, + mockPartitionCheckpoint, mockPluginMetrics, 2, 0); } @Test @@ -91,23 +97,25 @@ void test_processStream_success() { MongoCollection col = mock(MongoCollection.class); ChangeStreamIterable changeStreamIterable = mock(ChangeStreamIterable.class); MongoCursor cursor = mock(MongoCursor.class); - lenient().when(mongoClient.getDatabase(anyString())).thenReturn(mongoDatabase); - lenient().when(mongoDatabase.getCollection(anyString())).thenReturn(col); - lenient().when(col.watch()).thenReturn(changeStreamIterable); - lenient().when(changeStreamIterable.fullDocument(FullDocument.UPDATE_LOOKUP)).thenReturn(changeStreamIterable); - lenient().when(changeStreamIterable.iterator()).thenReturn(cursor); - lenient().when(cursor.hasNext()).thenReturn(true, true, false); - ChangeStreamDocument streamDoc1 = mock(ChangeStreamDocument.class); - ChangeStreamDocument streamDoc2 = mock(ChangeStreamDocument.class); + when(mongoClient.getDatabase(anyString())).thenReturn(mongoDatabase); + when(mongoDatabase.getCollection(anyString())).thenReturn(col); + when(col.watch()).thenReturn(changeStreamIterable); + when(changeStreamIterable.fullDocument(FullDocument.UPDATE_LOOKUP)).thenReturn(changeStreamIterable); + when(changeStreamIterable.iterator()).thenReturn(cursor); + when(cursor.hasNext()).thenReturn(true, true, false); + ChangeStreamDocument streamDoc1 = mock(ChangeStreamDocument.class); + ChangeStreamDocument streamDoc2 = mock(ChangeStreamDocument.class); Document doc1 = mock(Document.class); - Document doc2 = mock(Document.class); + Document doc2 = mock(Document.class); BsonDocument bsonDoc1 = new BsonDocument("resumeToken1", new BsonInt32(123)); BsonDocument bsonDoc2 = new BsonDocument("resumeToken2", new BsonInt32(234)); when(streamDoc1.getResumeToken()).thenReturn(bsonDoc1); when(streamDoc2.getResumeToken()).thenReturn(bsonDoc2); - lenient().when(cursor.next()) - .thenReturn(streamDoc1) - .thenReturn(streamDoc2); + when(cursor.next()) + .thenReturn(streamDoc1) + .thenReturn(streamDoc2); + when(doc1.toJson(any(JsonWriterSettings.class))).thenReturn(UUID.randomUUID().toString()); + when(doc2.toJson(any(JsonWriterSettings.class))).thenReturn(UUID.randomUUID().toString()); when(streamDoc1.getFullDocument()).thenReturn(doc1); when(streamDoc2.getFullDocument()).thenReturn(doc2); @@ -116,11 +124,12 @@ void test_processStream_success() { .thenReturn(mongoClient); streamWorker.processStream(streamPartition); } - verify(mongoClient, times(1)).close(); + verify(mongoClient).close(); verify(mongoDatabase).getCollection(eq("collection")); verify(mockRecordBufferWriter).writeToBuffer(eq(null), any()); - verify(successItemsCounter, times(1)).increment(); + verify(successItemsCounter).increment(2); verify(failureItemsCounter, never()).increment(); + verify(mockPartitionCheckpoint, times(2)).checkpoint("{\"resumeToken2\": 234}", 2); } @@ -139,4 +148,101 @@ void test_processStream_mongoClientFailure() { verifyNoInteractions(successItemsCounter); verifyNoInteractions(failureItemsCounter); } + + @Test + void test_processStream_checkPointIntervalSuccess() { + when(streamProgressState.shouldWaitForExport()).thenReturn(false); + when(streamPartition.getProgressState()).thenReturn(Optional.of(streamProgressState)); + when(streamPartition.getCollection()).thenReturn("database.collection"); + MongoClient mongoClient = mock(MongoClient.class); + MongoDatabase mongoDatabase = mock(MongoDatabase.class); + MongoCollection col = mock(MongoCollection.class); + ChangeStreamIterable changeStreamIterable = mock(ChangeStreamIterable.class); + MongoCursor cursor = mock(MongoCursor.class); + when(mongoClient.getDatabase(anyString())).thenReturn(mongoDatabase); + when(mongoDatabase.getCollection(anyString())).thenReturn(col); + when(col.watch()).thenReturn(changeStreamIterable); + when(changeStreamIterable.fullDocument(FullDocument.UPDATE_LOOKUP)).thenReturn(changeStreamIterable); + when(changeStreamIterable.iterator()).thenReturn(cursor); + when(cursor.hasNext()).thenReturn(true) + .thenReturn(true) + .thenReturn(true) + .thenReturn(false); + ChangeStreamDocument streamDoc1 = mock(ChangeStreamDocument.class); + ChangeStreamDocument streamDoc2 = mock(ChangeStreamDocument.class); + ChangeStreamDocument streamDoc3 = mock(ChangeStreamDocument.class); + Document doc1 = mock(Document.class); + Document doc2 = mock(Document.class); + Document doc3 = mock(Document.class); + BsonDocument bsonDoc1 = mock(BsonDocument.class); + BsonDocument bsonDoc2 = mock(BsonDocument.class); + BsonDocument bsonDoc3 = mock(BsonDocument.class); + when(streamDoc1.getResumeToken()).thenReturn(bsonDoc1); + when(streamDoc2.getResumeToken()).thenReturn(bsonDoc2); + when(streamDoc3.getResumeToken()).thenReturn(bsonDoc3); + when(cursor.next()) + .thenReturn(streamDoc1, streamDoc2, streamDoc3); + when(doc1.toJson(any(JsonWriterSettings.class))).thenReturn(UUID.randomUUID().toString()); + when(doc2.toJson(any(JsonWriterSettings.class))).thenReturn(UUID.randomUUID().toString()); + when(doc3.toJson(any(JsonWriterSettings.class))).thenReturn(UUID.randomUUID().toString()); + when(streamDoc1.getFullDocument()).thenReturn(doc1); + when(streamDoc2.getFullDocument()).thenReturn(doc2); + when(streamDoc3.getFullDocument()).thenReturn(doc3); + final String resumeToken1 = UUID.randomUUID().toString(); + final String resumeToken2 = UUID.randomUUID().toString(); + final String resumeToken3 = UUID.randomUUID().toString(); + when(bsonDoc1.toJson(any(JsonWriterSettings.class))).thenReturn(resumeToken1); + when(bsonDoc2.toJson(any(JsonWriterSettings.class))).thenReturn(resumeToken2); + when(bsonDoc3.toJson(any(JsonWriterSettings.class))).thenReturn(resumeToken3); + + try (MockedStatic mongoDBConnectionMockedStatic = mockStatic(MongoDBConnection.class)) { + + mongoDBConnectionMockedStatic.when(() -> MongoDBConnection.getMongoClient(any(MongoDBSourceConfig.class))) + .thenReturn(mongoClient); + streamWorker.processStream(streamPartition); + + } + verify(mongoClient, times(1)).close(); + verify(mongoDatabase).getCollection(eq("collection")); + verify(cursor).close(); + verify(cursor, times(4)).hasNext(); + verify(mockPartitionCheckpoint).checkpoint(resumeToken3, 3); + verify(successItemsCounter).increment(1); + verify(mockPartitionCheckpoint).checkpoint(resumeToken2, 2); + verify(mockRecordBufferWriter, times(2)).writeToBuffer(eq(null), any()); + verify(successItemsCounter).increment(2); + verify(failureItemsCounter, never()).increment(); + } + + @Test + void test_processStream_stopWorker() { + when(streamProgressState.shouldWaitForExport()).thenReturn(false); + when(streamPartition.getProgressState()).thenReturn(Optional.of(streamProgressState)); + when(streamPartition.getCollection()).thenReturn("database.collection"); + MongoClient mongoClient = mock(MongoClient.class); + MongoDatabase mongoDatabase = mock(MongoDatabase.class); + MongoCollection col = mock(MongoCollection.class); + ChangeStreamIterable changeStreamIterable = mock(ChangeStreamIterable.class); + MongoCursor cursor = mock(MongoCursor.class); + when(mongoClient.getDatabase(anyString())).thenReturn(mongoDatabase); + when(mongoDatabase.getCollection(anyString())).thenReturn(col); + when(col.watch()).thenReturn(changeStreamIterable); + when(changeStreamIterable.fullDocument(FullDocument.UPDATE_LOOKUP)).thenReturn(changeStreamIterable); + when(changeStreamIterable.iterator()).thenReturn(cursor); + final ExecutorService executorService = Executors.newSingleThreadExecutor(); + final Future future = executorService.submit(() -> { + try (MockedStatic mongoDBConnectionMockedStatic = mockStatic(MongoDBConnection.class)) { + mongoDBConnectionMockedStatic.when(() -> MongoDBConnection.getMongoClient(any(MongoDBSourceConfig.class))) + .thenReturn(mongoClient); + streamWorker.processStream(streamPartition); + } + }); + streamWorker.stop(); + await() + .atMost(Duration.ofSeconds(4)) + .untilAsserted(() -> verify(mongoClient).close()); + future.cancel(true); + executorService.shutdownNow(); + verify(mongoDatabase).getCollection(eq("collection")); + } }