From 5d0c2919d4a0c6308c19b638183a9df044fc1910 Mon Sep 17 00:00:00 2001 From: Dinu John <86094133+dinujoh@users.noreply.github.com> Date: Tue, 2 Apr 2024 07:24:23 -0500 Subject: [PATCH] Add stop method to Stream worker to stop processing stream Signed-off-by: Dinu John <86094133+dinujoh@users.noreply.github.com> --- data-prepper-plugins/mongodb/build.gradle | 1 + .../stream/StreamAcknowledgementManager.java | 22 +-- .../plugins/mongo/stream/StreamScheduler.java | 7 +- .../plugins/mongo/stream/StreamWorker.java | 28 ++-- .../export/ExportPartitionWorkerTest.java | 1 - .../StreamAcknowledgementManagerTest.java | 23 ++- .../mongo/stream/StreamSchedulerTest.java | 6 +- .../mongo/stream/StreamWorkerTest.java | 157 +++++++++++------- 8 files changed, 149 insertions(+), 96 deletions(-) diff --git a/data-prepper-plugins/mongodb/build.gradle b/data-prepper-plugins/mongodb/build.gradle index 2c05cecc55..d32efec8fa 100644 --- a/data-prepper-plugins/mongodb/build.gradle +++ b/data-prepper-plugins/mongodb/build.gradle @@ -16,6 +16,7 @@ dependencies { testImplementation testLibs.mockito.inline testImplementation testLibs.bundles.junit + testImplementation testLibs.slf4j.simple testImplementation project(path: ':data-prepper-test-common') } diff --git a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamAcknowledgementManager.java b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamAcknowledgementManager.java index f0dbe5403f..f25c2f6188 100644 --- a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamAcknowledgementManager.java +++ b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamAcknowledgementManager.java @@ -14,11 +14,12 @@ import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.function.Consumer; public class StreamAcknowledgementManager { private static final Logger LOG = LoggerFactory.getLogger(StreamAcknowledgementManager.class); - private final ConcurrentLinkedQueue checkpoints = new ConcurrentLinkedQueue<>(); - private final ConcurrentHashMap ackStatus = new ConcurrentHashMap<>(); + private ConcurrentLinkedQueue checkpoints; + private ConcurrentHashMap ackStatus; private final AcknowledgementSetManager acknowledgementSetManager; private final DataStreamPartitionCheckpoint partitionCheckpoint; @@ -43,13 +44,14 @@ public StreamAcknowledgementManager(final AcknowledgementSetManager acknowledgem executorService = Executors.newSingleThreadExecutor(); } - void init() { + void init(final Consumer stopWorkerConsumer) { enableAcknowledgement = true; - final Thread currentThread = Thread.currentThread(); - executorService.submit(() -> monitorCheckpoints(executorService, currentThread)); + executorService.submit(() -> monitorCheckpoints(executorService, stopWorkerConsumer)); } - private void monitorCheckpoints(final ExecutorService executorService, final Thread parentThread) { + private void monitorCheckpoints(final ExecutorService executorService, final Consumer stopWorkerConsumer) { + checkpoints = new ConcurrentLinkedQueue<>(); + ackStatus = new ConcurrentHashMap<>(); long lastCheckpointTime = System.currentTimeMillis(); CheckpointStatus lastCheckpointStatus = null; while (!Thread.currentThread().isInterrupted()) { @@ -67,14 +69,14 @@ private void monitorCheckpoints(final ExecutorService executorService, final Thr LOG.debug("Checkpoint not complete for resume token {}", checkpointStatus.getResumeToken()); final Duration ackWaitDuration = Duration.between(Instant.ofEpochMilli(checkpointStatus.getCreateTimestamp()), Instant.now()); // Acknowledgement not received for the checkpoint after twice ack wait time - if (ackWaitDuration.getSeconds() > partitionAcknowledgmentTimeout.getSeconds() * 2) { + if (ackWaitDuration.getSeconds() >= partitionAcknowledgmentTimeout.getSeconds() * 2) { // Give up partition and should interrupt parent thread to stop processing stream if (lastCheckpointStatus != null && lastCheckpointStatus.isAcknowledged()) { partitionCheckpoint.checkpoint(lastCheckpointStatus.getResumeToken(), lastCheckpointStatus.getRecordCount()); } LOG.warn("Acknowledgement not received for the checkpoint {} past wait time. Giving up partition.", checkpointStatus.getResumeToken()); partitionCheckpoint.giveUpPartition(); - Thread.currentThread().interrupt(); + break; } } } @@ -82,10 +84,10 @@ private void monitorCheckpoints(final ExecutorService executorService, final Thr try { Thread.sleep(acknowledgementMonitorWaitTimeInMs); } catch (InterruptedException ex) { - Thread.currentThread().interrupt(); + break; } } - parentThread.interrupt(); + stopWorkerConsumer.accept(null); executorService.shutdown(); } diff --git a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamScheduler.java b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamScheduler.java index 676b04ee95..4fd8ec0e0a 100644 --- a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamScheduler.java +++ b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamScheduler.java @@ -24,6 +24,7 @@ public class StreamScheduler implements Runnable { private static final Logger LOG = LoggerFactory.getLogger(StreamScheduler.class); private static final int DEFAULT_TAKE_LEASE_INTERVAL_MILLIS = 60_000; static final int DEFAULT_CHECKPOINT_INTERVAL_MILLS = 120_000; + private static final int DEFAULT_MONITOR_WAIT_TIME_MS = 15_000; /** * Number of records to accumulate before flushing to buffer */ @@ -62,8 +63,10 @@ public void run() { if (sourcePartition.isPresent()) { streamPartition = (StreamPartition) sourcePartition.get(); final DataStreamPartitionCheckpoint partitionCheckpoint = new DataStreamPartitionCheckpoint(sourceCoordinator, streamPartition); - final StreamWorker streamWorker = StreamWorker.create(recordBufferWriter, acknowledgementSetManager, - sourceConfig, partitionCheckpoint, pluginMetrics, DEFAULT_STREAM_BATCH_SIZE, DEFAULT_CHECKPOINT_INTERVAL_MILLS); + final StreamAcknowledgementManager streamAcknowledgementManager = new StreamAcknowledgementManager(acknowledgementSetManager, partitionCheckpoint, + sourceConfig.getPartitionAcknowledgmentTimeout(), DEFAULT_MONITOR_WAIT_TIME_MS, DEFAULT_CHECKPOINT_INTERVAL_MILLS); + final StreamWorker streamWorker = StreamWorker.create(recordBufferWriter, sourceConfig, + streamAcknowledgementManager, partitionCheckpoint, pluginMetrics, DEFAULT_STREAM_BATCH_SIZE, DEFAULT_CHECKPOINT_INTERVAL_MILLS); streamWorker.processStream(streamPartition); } try { diff --git a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamWorker.java b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamWorker.java index b18e530063..71bc26c474 100644 --- a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamWorker.java +++ b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamWorker.java @@ -13,7 +13,6 @@ import org.bson.json.JsonWriterSettings; import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; -import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; import org.opensearch.dataprepper.plugins.mongo.buffer.RecordBufferWriter; import org.opensearch.dataprepper.plugins.mongo.client.MongoDBConnection; import org.opensearch.dataprepper.plugins.mongo.configuration.MongoDBSourceConfig; @@ -32,8 +31,6 @@ public class StreamWorker { public static final String STREAM_PREFIX = "STREAM-"; private static final Logger LOG = LoggerFactory.getLogger(StreamWorker.class); private static final int DEFAULT_EXPORT_COMPLETE_WAIT_INTERVAL_MILLIS = 90_000; - private static final int DEFAULT_MONITOR_WAIT_TIME_MS = 15_000; - private static final String COLLECTION_SPLITTER = "\\."; static final String SUCCESS_ITEM_COUNTER_NAME = "streamRecordsSuccessTotal"; static final String FAILURE_ITEM_COUNTER_NAME = "streamRecordsFailedTotal"; @@ -43,11 +40,12 @@ public class StreamWorker { private final MongoDBSourceConfig sourceConfig; private final Counter successItemsCounter; private final Counter failureItemsCounter; - private final AcknowledgementSetManager acknowledgementSetManager; + private final StreamAcknowledgementManager streamAcknowledgementManager; private final PluginMetrics pluginMetrics; private final int recordFlushBatchSize; final int checkPointIntervalInMs; - private final StreamAcknowledgementManager streamAcknowledgementManager; + private boolean stopWorker = false; + private final JsonWriterSettings writerSettings = JsonWriterSettings.builder() .outputMode(JsonMode.RELAXED) @@ -55,19 +53,19 @@ public class StreamWorker { .build(); public static StreamWorker create(final RecordBufferWriter recordBufferWriter, - final AcknowledgementSetManager acknowledgementSetManager, final MongoDBSourceConfig sourceConfig, + final StreamAcknowledgementManager streamAcknowledgementManager, final DataStreamPartitionCheckpoint partitionCheckpoint, final PluginMetrics pluginMetrics, final int recordFlushBatchSize, final int checkPointIntervalInMs ) { - return new StreamWorker(recordBufferWriter, acknowledgementSetManager, - sourceConfig, partitionCheckpoint, pluginMetrics, recordFlushBatchSize, checkPointIntervalInMs); + return new StreamWorker(recordBufferWriter, sourceConfig, streamAcknowledgementManager, partitionCheckpoint, + pluginMetrics, recordFlushBatchSize, checkPointIntervalInMs); } public StreamWorker(final RecordBufferWriter recordBufferWriter, - final AcknowledgementSetManager acknowledgementSetManager, final MongoDBSourceConfig sourceConfig, + final StreamAcknowledgementManager streamAcknowledgementManager, final DataStreamPartitionCheckpoint partitionCheckpoint, final PluginMetrics pluginMetrics, final int recordFlushBatchSize, @@ -75,18 +73,16 @@ public StreamWorker(final RecordBufferWriter recordBufferWriter, ) { this.recordBufferWriter = recordBufferWriter; this.sourceConfig = sourceConfig; + this.streamAcknowledgementManager = streamAcknowledgementManager; this.partitionCheckpoint = partitionCheckpoint; - this.acknowledgementSetManager = acknowledgementSetManager; this.pluginMetrics = pluginMetrics; this.recordFlushBatchSize = recordFlushBatchSize; this.checkPointIntervalInMs = checkPointIntervalInMs; this.successItemsCounter = pluginMetrics.counter(SUCCESS_ITEM_COUNTER_NAME); this.failureItemsCounter = pluginMetrics.counter(FAILURE_ITEM_COUNTER_NAME); - streamAcknowledgementManager = new StreamAcknowledgementManager(acknowledgementSetManager, partitionCheckpoint, - sourceConfig.getPartitionAcknowledgmentTimeout(), DEFAULT_MONITOR_WAIT_TIME_MS, checkPointIntervalInMs); if (sourceConfig.isAcknowledgmentsEnabled()) { // starts acknowledgement monitoring thread - streamAcknowledgementManager.init(); + streamAcknowledgementManager.init((Void) -> stop()); } } @@ -138,7 +134,7 @@ public void processStream(final StreamPartition streamPartition) { } } long lastCheckpointTime = System.currentTimeMillis(); - while (cursor.hasNext() && !Thread.currentThread().isInterrupted()) { + while (cursor.hasNext() && !Thread.currentThread().isInterrupted() && !stopWorker) { try { final ChangeStreamDocument document = cursor.next(); final String record = document.getFullDocument().toJson(writerSettings); @@ -190,4 +186,8 @@ public void processStream(final StreamPartition streamPartition) { } } } + + void stop() { + stopWorker = true; + } } diff --git a/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/export/ExportPartitionWorkerTest.java b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/export/ExportPartitionWorkerTest.java index 2b5c6b885b..fe5447404c 100644 --- a/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/export/ExportPartitionWorkerTest.java +++ b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/export/ExportPartitionWorkerTest.java @@ -132,6 +132,5 @@ public void testProcessPartitionSuccess(final String partitionKey) { verify(mockRecordBufferWriter).writeToBuffer(eq(mockAcknowledgementSet), any()); verify(successItemsCounter, times(2)).increment(); verify(failureItemsCounter, never()).increment(); - executorService.shutdownNow(); } } \ No newline at end of file diff --git a/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamAcknowledgementManagerTest.java b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamAcknowledgementManagerTest.java index 86dc447fcc..fdf404f5a2 100644 --- a/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamAcknowledgementManagerTest.java +++ b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamAcknowledgementManagerTest.java @@ -26,7 +26,6 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) @@ -40,6 +39,8 @@ public class StreamAcknowledgementManagerTest { private Duration timeout; @Mock private AcknowledgementSet acknowledgementSet; + @Mock + private Consumer stopWorkerConsumer; private StreamAcknowledgementManager streamAckManager; @BeforeEach @@ -55,7 +56,9 @@ public void createAcknowledgementSet_disabled_emptyAckSet() { @Test public void createAcknowledgementSet_enabled_ackSetWithAck() { - streamAckManager.init(); + when(timeout.getSeconds()).thenReturn(10_000L); + streamAckManager = new StreamAcknowledgementManager(acknowledgementSetManager, partitionCheckpoint, timeout, 0, 0); + streamAckManager.init(stopWorkerConsumer); final String resumeToken = UUID.randomUUID().toString(); final long recordCount = new Random().nextLong(); when(acknowledgementSetManager.create(any(Consumer.class), eq(timeout))).thenReturn(acknowledgementSet); @@ -79,7 +82,9 @@ public void createAcknowledgementSet_enabled_ackSetWithAck() { @Test public void createAcknowledgementSet_enabled_multipleAckSetWithAck() { - streamAckManager.init(); + when(timeout.getSeconds()).thenReturn(10_000L); + streamAckManager = new StreamAcknowledgementManager(acknowledgementSetManager, partitionCheckpoint, timeout, 0, 0); + streamAckManager.init(stopWorkerConsumer); final String resumeToken1 = UUID.randomUUID().toString(); final long recordCount1 = new Random().nextLong(); when(acknowledgementSetManager.create(any(Consumer.class), eq(timeout))).thenReturn(acknowledgementSet); @@ -113,7 +118,7 @@ public void createAcknowledgementSet_enabled_multipleAckSetWithAck() { @Test public void createAcknowledgementSet_enabled_multipleAckSetWithAckFailure() { - streamAckManager.init(); + streamAckManager.init(stopWorkerConsumer); final String resumeToken1 = UUID.randomUUID().toString(); final long recordCount1 = new Random().nextLong(); when(acknowledgementSetManager.create(any(Consumer.class), eq(timeout))).thenReturn(acknowledgementSet); @@ -141,14 +146,15 @@ public void createAcknowledgementSet_enabled_multipleAckSetWithAckFailure() { assertThat(ackCheckpointStatus.isAcknowledged(), is(true)); await() .atMost(Duration.ofSeconds(10)).untilAsserted(() -> - verifyNoInteractions(partitionCheckpoint)); + verify(partitionCheckpoint).giveUpPartition()); assertThat(streamAckManager.getCheckpoints().peek().getResumeToken(), is(resumeToken1)); assertThat(streamAckManager.getCheckpoints().peek().getRecordCount(), is(recordCount1)); + verify(stopWorkerConsumer).accept(null); } @Test public void createAcknowledgementSet_enabled_ackSetWithNoAck() { - streamAckManager.init(); + streamAckManager.init(stopWorkerConsumer); final String resumeToken = UUID.randomUUID().toString(); final long recordCount = new Random().nextLong(); when(acknowledgementSetManager.create(any(Consumer.class), eq(timeout))).thenReturn(acknowledgementSet); @@ -164,5 +170,8 @@ public void createAcknowledgementSet_enabled_ackSetWithNoAck() { final ConcurrentHashMap ackStatus = streamAckManager.getAcknowledgementStatus(); final CheckpointStatus ackCheckpointStatus = ackStatus.get(resumeToken); assertThat(ackCheckpointStatus.isAcknowledged(), is(false)); - } + await() + .atMost(Duration.ofSeconds(10)).untilAsserted(() -> + verify(stopWorkerConsumer).accept(null)); +} } diff --git a/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamSchedulerTest.java b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamSchedulerTest.java index 57a645de97..b02e4ee20f 100644 --- a/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamSchedulerTest.java +++ b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamSchedulerTest.java @@ -86,9 +86,9 @@ void test_stream_run() { final ExecutorService executorService = Executors.newSingleThreadExecutor(); final Future future = executorService.submit(() -> { try (MockedStatic streamWorkerMockedStatic = mockStatic(StreamWorker.class)) { - streamWorkerMockedStatic.when(() -> StreamWorker.create(any(RecordBufferWriter.class), eq(acknowledgementSetManager), - eq(sourceConfig), any(DataStreamPartitionCheckpoint.class), eq(pluginMetrics), eq(100), eq(DEFAULT_CHECKPOINT_INTERVAL_MILLS))) - .thenReturn(streamWorker); + streamWorkerMockedStatic.when(() -> StreamWorker.create(any(RecordBufferWriter.class), eq(sourceConfig), + any(StreamAcknowledgementManager.class), any(DataStreamPartitionCheckpoint.class), eq(pluginMetrics), eq(100), eq(DEFAULT_CHECKPOINT_INTERVAL_MILLS))) + .thenReturn(streamWorker); streamScheduler.run(); } }); diff --git a/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamWorkerTest.java b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamWorkerTest.java index 2e2e6a4706..ca69d03b67 100644 --- a/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamWorkerTest.java +++ b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamWorkerTest.java @@ -19,16 +19,20 @@ import org.mockito.MockedStatic; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.dataprepper.metrics.PluginMetrics; -import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; import org.opensearch.dataprepper.plugins.mongo.buffer.RecordBufferWriter; import org.opensearch.dataprepper.plugins.mongo.client.MongoDBConnection; import org.opensearch.dataprepper.plugins.mongo.configuration.MongoDBSourceConfig; import org.opensearch.dataprepper.plugins.mongo.coordination.partition.StreamPartition; import org.opensearch.dataprepper.plugins.mongo.coordination.state.StreamProgressState; +import java.time.Duration; import java.util.Optional; import java.util.UUID; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import static org.awaitility.Awaitility.await; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; @@ -48,7 +52,7 @@ public class StreamWorkerTest { @Mock private RecordBufferWriter mockRecordBufferWriter; @Mock - private AcknowledgementSetManager mockAcknowledgementSetManager; + private StreamAcknowledgementManager mockStreamAcknowledgementManager; @Mock private MongoDBSourceConfig mockSourceConfig; @Mock @@ -71,8 +75,10 @@ public class StreamWorkerTest { public void setup() { when(mockPluginMetrics.counter(SUCCESS_ITEM_COUNTER_NAME)).thenReturn(successItemsCounter); when(mockPluginMetrics.counter(FAILURE_ITEM_COUNTER_NAME)).thenReturn(failureItemsCounter); - streamWorker = new StreamWorker(mockRecordBufferWriter, mockAcknowledgementSetManager, - mockSourceConfig, mockPartitionCheckpoint, mockPluginMetrics, 2, 0); + when(mockSourceConfig.isAcknowledgmentsEnabled()).thenReturn(false); + Thread.interrupted(); + streamWorker = new StreamWorker(mockRecordBufferWriter, mockSourceConfig, mockStreamAcknowledgementManager, + mockPartitionCheckpoint, mockPluginMetrics, 2, 0); } @Test @@ -118,10 +124,10 @@ void test_processStream_success() { .thenReturn(mongoClient); streamWorker.processStream(streamPartition); } - verify(mongoClient, times(1)).close(); + verify(mongoClient).close(); verify(mongoDatabase).getCollection(eq("collection")); verify(mockRecordBufferWriter).writeToBuffer(eq(null), any()); - verify(successItemsCounter, times(1)).increment(2); + verify(successItemsCounter).increment(2); verify(failureItemsCounter, never()).increment(); verify(mockPartitionCheckpoint, times(2)).checkpoint("{\"resumeToken2\": 234}", 2); } @@ -143,67 +149,100 @@ void test_processStream_mongoClientFailure() { verifyNoInteractions(failureItemsCounter); } - //@Test + @Test void test_processStream_checkPointIntervalSuccess() { + when(streamProgressState.shouldWaitForExport()).thenReturn(false); + when(streamPartition.getProgressState()).thenReturn(Optional.of(streamProgressState)); + when(streamPartition.getCollection()).thenReturn("database.collection"); + MongoClient mongoClient = mock(MongoClient.class); + MongoDatabase mongoDatabase = mock(MongoDatabase.class); + MongoCollection col = mock(MongoCollection.class); + ChangeStreamIterable changeStreamIterable = mock(ChangeStreamIterable.class); + MongoCursor cursor = mock(MongoCursor.class); + when(mongoClient.getDatabase(anyString())).thenReturn(mongoDatabase); + when(mongoDatabase.getCollection(anyString())).thenReturn(col); + when(col.watch()).thenReturn(changeStreamIterable); + when(changeStreamIterable.fullDocument(FullDocument.UPDATE_LOOKUP)).thenReturn(changeStreamIterable); + when(changeStreamIterable.iterator()).thenReturn(cursor); + when(cursor.hasNext()).thenReturn(true) + .thenReturn(true) + .thenReturn(true) + .thenReturn(false); + ChangeStreamDocument streamDoc1 = mock(ChangeStreamDocument.class); + ChangeStreamDocument streamDoc2 = mock(ChangeStreamDocument.class); + ChangeStreamDocument streamDoc3 = mock(ChangeStreamDocument.class); + Document doc1 = mock(Document.class); + Document doc2 = mock(Document.class); + Document doc3 = mock(Document.class); + BsonDocument bsonDoc1 = mock(BsonDocument.class); + BsonDocument bsonDoc2 = mock(BsonDocument.class); + BsonDocument bsonDoc3 = mock(BsonDocument.class); + when(streamDoc1.getResumeToken()).thenReturn(bsonDoc1); + when(streamDoc2.getResumeToken()).thenReturn(bsonDoc2); + when(streamDoc3.getResumeToken()).thenReturn(bsonDoc3); + when(cursor.next()) + .thenReturn(streamDoc1, streamDoc2, streamDoc3); + when(doc1.toJson(any(JsonWriterSettings.class))).thenReturn(UUID.randomUUID().toString()); + when(doc2.toJson(any(JsonWriterSettings.class))).thenReturn(UUID.randomUUID().toString()); + when(doc3.toJson(any(JsonWriterSettings.class))).thenReturn(UUID.randomUUID().toString()); + when(streamDoc1.getFullDocument()).thenReturn(doc1); + when(streamDoc2.getFullDocument()).thenReturn(doc2); + when(streamDoc3.getFullDocument()).thenReturn(doc3); + final String resumeToken1 = UUID.randomUUID().toString(); + final String resumeToken2 = UUID.randomUUID().toString(); + final String resumeToken3 = UUID.randomUUID().toString(); + when(bsonDoc1.toJson(any(JsonWriterSettings.class))).thenReturn(resumeToken1); + when(bsonDoc2.toJson(any(JsonWriterSettings.class))).thenReturn(resumeToken2); + when(bsonDoc3.toJson(any(JsonWriterSettings.class))).thenReturn(resumeToken3); + try (MockedStatic mongoDBConnectionMockedStatic = mockStatic(MongoDBConnection.class)) { - when(mockPartitionCheckpoint.getGlobalStreamLoadStatus()).thenReturn(Optional.empty()); - when(mockSourceConfig.isAcknowledgmentsEnabled()).thenReturn(false); - when(streamProgressState.shouldWaitForExport()).thenReturn(false); - when(streamPartition.getProgressState()).thenReturn(Optional.of(streamProgressState)); - when(streamPartition.getCollection()).thenReturn("database.collection"); - MongoClient mongoClient = mock(MongoClient.class); - MongoDatabase mongoDatabase = mock(MongoDatabase.class); - MongoCollection col = mock(MongoCollection.class); - ChangeStreamIterable changeStreamIterable = mock(ChangeStreamIterable.class); - MongoCursor cursor = mock(MongoCursor.class); - when(mongoClient.getDatabase(anyString())).thenReturn(mongoDatabase); - when(mongoDatabase.getCollection(anyString())).thenReturn(col); - when(col.watch()).thenReturn(changeStreamIterable); - when(changeStreamIterable.fullDocument(FullDocument.UPDATE_LOOKUP)).thenReturn(changeStreamIterable); - when(changeStreamIterable.iterator()).thenReturn(cursor); - when(cursor.hasNext()).thenReturn(true, true, true, false); - ChangeStreamDocument streamDoc1 = mock(ChangeStreamDocument.class); - ChangeStreamDocument streamDoc2 = mock(ChangeStreamDocument.class); - ChangeStreamDocument streamDoc3 = mock(ChangeStreamDocument.class); - Document doc1 = mock(Document.class); - Document doc2 = mock(Document.class); - Document doc3 = mock(Document.class); - BsonDocument bsonDoc1 = mock(BsonDocument.class); //new BsonDocument("resumeToken1", new BsonInt32(123)); - BsonDocument bsonDoc2 = mock(BsonDocument.class); //new BsonDocument("resumeToken2", new BsonInt32(234)); - BsonDocument bsonDoc3 = mock(BsonDocument.class); //new BsonDocument("resumeToken3", new BsonInt32(456)); - when(streamDoc1.getResumeToken()).thenReturn(bsonDoc1); - when(streamDoc2.getResumeToken()).thenReturn(bsonDoc2); - when(streamDoc3.getResumeToken()).thenReturn(bsonDoc3); - when(cursor.next()) - .thenReturn(streamDoc1, streamDoc2, streamDoc3); - when(doc1.toJson(any(JsonWriterSettings.class))).thenReturn(UUID.randomUUID().toString()); - when(doc2.toJson(any(JsonWriterSettings.class))).thenReturn(UUID.randomUUID().toString()); - when(doc3.toJson(any(JsonWriterSettings.class))).thenReturn(UUID.randomUUID().toString()); - when(streamDoc1.getFullDocument()).thenReturn(doc1); - when(streamDoc2.getFullDocument()).thenReturn(doc2); - when(streamDoc3.getFullDocument()).thenReturn(doc3); - final String resumeToken1 = UUID.randomUUID().toString(); - final String resumeToken2 = UUID.randomUUID().toString(); - final String resumeToken3 = UUID.randomUUID().toString(); - when(bsonDoc1.toJson(any(JsonWriterSettings.class))).thenReturn(resumeToken1); - when(bsonDoc2.toJson(any(JsonWriterSettings.class))).thenReturn(resumeToken2); - when(bsonDoc3.toJson(any(JsonWriterSettings.class))).thenReturn(resumeToken3); mongoDBConnectionMockedStatic.when(() -> MongoDBConnection.getMongoClient(any(MongoDBSourceConfig.class))) .thenReturn(mongoClient); streamWorker.processStream(streamPartition); - verify(mongoClient, times(1)).close(); - verify(mongoDatabase).getCollection(eq("collection")); - verify(cursor).close(); - // TODO fix - // verify(cursor, times(4)).hasNext(); - // verify(mockPartitionCheckpoint).checkpoint(resumeToken3, 3); - // verify(successItemsCounter).increment(1); - // verify(mockPartitionCheckpoint).checkpoint(resumeToken2, 2); + } - // verify(mockRecordBufferWriter, times(2)).writeToBuffer(eq(null), any()); - // verify(successItemsCounter).increment(2); + verify(mongoClient, times(1)).close(); + verify(mongoDatabase).getCollection(eq("collection")); + verify(cursor).close(); + verify(cursor, times(4)).hasNext(); + verify(mockPartitionCheckpoint).checkpoint(resumeToken3, 3); + verify(successItemsCounter).increment(1); + verify(mockPartitionCheckpoint).checkpoint(resumeToken2, 2); + verify(mockRecordBufferWriter, times(2)).writeToBuffer(eq(null), any()); + verify(successItemsCounter).increment(2); verify(failureItemsCounter, never()).increment(); + } + @Test + void test_processStream_stopWorker() { + when(streamProgressState.shouldWaitForExport()).thenReturn(false); + when(streamPartition.getProgressState()).thenReturn(Optional.of(streamProgressState)); + when(streamPartition.getCollection()).thenReturn("database.collection"); + MongoClient mongoClient = mock(MongoClient.class); + MongoDatabase mongoDatabase = mock(MongoDatabase.class); + MongoCollection col = mock(MongoCollection.class); + ChangeStreamIterable changeStreamIterable = mock(ChangeStreamIterable.class); + MongoCursor cursor = mock(MongoCursor.class); + when(mongoClient.getDatabase(anyString())).thenReturn(mongoDatabase); + when(mongoDatabase.getCollection(anyString())).thenReturn(col); + when(col.watch()).thenReturn(changeStreamIterable); + when(changeStreamIterable.fullDocument(FullDocument.UPDATE_LOOKUP)).thenReturn(changeStreamIterable); + when(changeStreamIterable.iterator()).thenReturn(cursor); + final ExecutorService executorService = Executors.newSingleThreadExecutor(); + final Future future = executorService.submit(() -> { + try (MockedStatic mongoDBConnectionMockedStatic = mockStatic(MongoDBConnection.class)) { + mongoDBConnectionMockedStatic.when(() -> MongoDBConnection.getMongoClient(any(MongoDBSourceConfig.class))) + .thenReturn(mongoClient); + streamWorker.processStream(streamPartition); + } + }); + streamWorker.stop(); + await() + .atMost(Duration.ofSeconds(4)) + .untilAsserted(() -> verify(mongoClient).close()); + future.cancel(true); + executorService.shutdownNow(); + verify(mongoDatabase).getCollection(eq("collection")); } }