Skip to content

Commit

Permalink
Add stop method to Stream worker to stop processing stream
Browse files Browse the repository at this point in the history
Signed-off-by: Dinu John <[email protected]>
  • Loading branch information
dinujoh committed Apr 2, 2024
1 parent 99a73d3 commit 5d0c291
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 96 deletions.
1 change: 1 addition & 0 deletions data-prepper-plugins/mongodb/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies {

testImplementation testLibs.mockito.inline
testImplementation testLibs.bundles.junit
testImplementation testLibs.slf4j.simple
testImplementation project(path: ':data-prepper-test-common')

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.function.Consumer;

public class StreamAcknowledgementManager {
private static final Logger LOG = LoggerFactory.getLogger(StreamAcknowledgementManager.class);
private final ConcurrentLinkedQueue<CheckpointStatus> checkpoints = new ConcurrentLinkedQueue<>();
private final ConcurrentHashMap<String, CheckpointStatus> ackStatus = new ConcurrentHashMap<>();
private ConcurrentLinkedQueue<CheckpointStatus> checkpoints;
private ConcurrentHashMap<String, CheckpointStatus> ackStatus;

private final AcknowledgementSetManager acknowledgementSetManager;
private final DataStreamPartitionCheckpoint partitionCheckpoint;
Expand All @@ -43,13 +44,14 @@ public StreamAcknowledgementManager(final AcknowledgementSetManager acknowledgem
executorService = Executors.newSingleThreadExecutor();
}

void init() {
void init(final Consumer<Void> stopWorkerConsumer) {
enableAcknowledgement = true;
final Thread currentThread = Thread.currentThread();
executorService.submit(() -> monitorCheckpoints(executorService, currentThread));
executorService.submit(() -> monitorCheckpoints(executorService, stopWorkerConsumer));
}

private void monitorCheckpoints(final ExecutorService executorService, final Thread parentThread) {
private void monitorCheckpoints(final ExecutorService executorService, final Consumer<Void> stopWorkerConsumer) {
checkpoints = new ConcurrentLinkedQueue<>();
ackStatus = new ConcurrentHashMap<>();
long lastCheckpointTime = System.currentTimeMillis();
CheckpointStatus lastCheckpointStatus = null;
while (!Thread.currentThread().isInterrupted()) {
Expand All @@ -67,25 +69,25 @@ private void monitorCheckpoints(final ExecutorService executorService, final Thr
LOG.debug("Checkpoint not complete for resume token {}", checkpointStatus.getResumeToken());
final Duration ackWaitDuration = Duration.between(Instant.ofEpochMilli(checkpointStatus.getCreateTimestamp()), Instant.now());
// Acknowledgement not received for the checkpoint after twice ack wait time
if (ackWaitDuration.getSeconds() > partitionAcknowledgmentTimeout.getSeconds() * 2) {
if (ackWaitDuration.getSeconds() >= partitionAcknowledgmentTimeout.getSeconds() * 2) {
// Give up partition and should interrupt parent thread to stop processing stream
if (lastCheckpointStatus != null && lastCheckpointStatus.isAcknowledged()) {
partitionCheckpoint.checkpoint(lastCheckpointStatus.getResumeToken(), lastCheckpointStatus.getRecordCount());
}
LOG.warn("Acknowledgement not received for the checkpoint {} past wait time. Giving up partition.", checkpointStatus.getResumeToken());
partitionCheckpoint.giveUpPartition();
Thread.currentThread().interrupt();
break;
}
}
}

try {
Thread.sleep(acknowledgementMonitorWaitTimeInMs);
} catch (InterruptedException ex) {
Thread.currentThread().interrupt();
break;
}
}
parentThread.interrupt();
stopWorkerConsumer.accept(null);
executorService.shutdown();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public class StreamScheduler implements Runnable {
private static final Logger LOG = LoggerFactory.getLogger(StreamScheduler.class);
private static final int DEFAULT_TAKE_LEASE_INTERVAL_MILLIS = 60_000;
static final int DEFAULT_CHECKPOINT_INTERVAL_MILLS = 120_000;
private static final int DEFAULT_MONITOR_WAIT_TIME_MS = 15_000;
/**
* Number of records to accumulate before flushing to buffer
*/
Expand Down Expand Up @@ -62,8 +63,10 @@ public void run() {
if (sourcePartition.isPresent()) {
streamPartition = (StreamPartition) sourcePartition.get();
final DataStreamPartitionCheckpoint partitionCheckpoint = new DataStreamPartitionCheckpoint(sourceCoordinator, streamPartition);
final StreamWorker streamWorker = StreamWorker.create(recordBufferWriter, acknowledgementSetManager,
sourceConfig, partitionCheckpoint, pluginMetrics, DEFAULT_STREAM_BATCH_SIZE, DEFAULT_CHECKPOINT_INTERVAL_MILLS);
final StreamAcknowledgementManager streamAcknowledgementManager = new StreamAcknowledgementManager(acknowledgementSetManager, partitionCheckpoint,
sourceConfig.getPartitionAcknowledgmentTimeout(), DEFAULT_MONITOR_WAIT_TIME_MS, DEFAULT_CHECKPOINT_INTERVAL_MILLS);
final StreamWorker streamWorker = StreamWorker.create(recordBufferWriter, sourceConfig,
streamAcknowledgementManager, partitionCheckpoint, pluginMetrics, DEFAULT_STREAM_BATCH_SIZE, DEFAULT_CHECKPOINT_INTERVAL_MILLS);
streamWorker.processStream(streamPartition);
}
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.bson.json.JsonWriterSettings;
import org.opensearch.dataprepper.metrics.PluginMetrics;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager;
import org.opensearch.dataprepper.plugins.mongo.buffer.RecordBufferWriter;
import org.opensearch.dataprepper.plugins.mongo.client.MongoDBConnection;
import org.opensearch.dataprepper.plugins.mongo.configuration.MongoDBSourceConfig;
Expand All @@ -32,8 +31,6 @@ public class StreamWorker {
public static final String STREAM_PREFIX = "STREAM-";
private static final Logger LOG = LoggerFactory.getLogger(StreamWorker.class);
private static final int DEFAULT_EXPORT_COMPLETE_WAIT_INTERVAL_MILLIS = 90_000;
private static final int DEFAULT_MONITOR_WAIT_TIME_MS = 15_000;

private static final String COLLECTION_SPLITTER = "\\.";
static final String SUCCESS_ITEM_COUNTER_NAME = "streamRecordsSuccessTotal";
static final String FAILURE_ITEM_COUNTER_NAME = "streamRecordsFailedTotal";
Expand All @@ -43,50 +40,49 @@ public class StreamWorker {
private final MongoDBSourceConfig sourceConfig;
private final Counter successItemsCounter;
private final Counter failureItemsCounter;
private final AcknowledgementSetManager acknowledgementSetManager;
private final StreamAcknowledgementManager streamAcknowledgementManager;
private final PluginMetrics pluginMetrics;
private final int recordFlushBatchSize;
final int checkPointIntervalInMs;
private final StreamAcknowledgementManager streamAcknowledgementManager;
private boolean stopWorker = false;


private final JsonWriterSettings writerSettings = JsonWriterSettings.builder()
.outputMode(JsonMode.RELAXED)
.objectIdConverter((value, writer) -> writer.writeString(value.toHexString()))
.build();

public static StreamWorker create(final RecordBufferWriter recordBufferWriter,
final AcknowledgementSetManager acknowledgementSetManager,
final MongoDBSourceConfig sourceConfig,
final StreamAcknowledgementManager streamAcknowledgementManager,
final DataStreamPartitionCheckpoint partitionCheckpoint,
final PluginMetrics pluginMetrics,
final int recordFlushBatchSize,
final int checkPointIntervalInMs
) {
return new StreamWorker(recordBufferWriter, acknowledgementSetManager,
sourceConfig, partitionCheckpoint, pluginMetrics, recordFlushBatchSize, checkPointIntervalInMs);
return new StreamWorker(recordBufferWriter, sourceConfig, streamAcknowledgementManager, partitionCheckpoint,
pluginMetrics, recordFlushBatchSize, checkPointIntervalInMs);
}
public StreamWorker(final RecordBufferWriter recordBufferWriter,
final AcknowledgementSetManager acknowledgementSetManager,
final MongoDBSourceConfig sourceConfig,
final StreamAcknowledgementManager streamAcknowledgementManager,
final DataStreamPartitionCheckpoint partitionCheckpoint,
final PluginMetrics pluginMetrics,
final int recordFlushBatchSize,
final int checkPointIntervalInMs
) {
this.recordBufferWriter = recordBufferWriter;
this.sourceConfig = sourceConfig;
this.streamAcknowledgementManager = streamAcknowledgementManager;
this.partitionCheckpoint = partitionCheckpoint;
this.acknowledgementSetManager = acknowledgementSetManager;
this.pluginMetrics = pluginMetrics;
this.recordFlushBatchSize = recordFlushBatchSize;
this.checkPointIntervalInMs = checkPointIntervalInMs;
this.successItemsCounter = pluginMetrics.counter(SUCCESS_ITEM_COUNTER_NAME);
this.failureItemsCounter = pluginMetrics.counter(FAILURE_ITEM_COUNTER_NAME);
streamAcknowledgementManager = new StreamAcknowledgementManager(acknowledgementSetManager, partitionCheckpoint,
sourceConfig.getPartitionAcknowledgmentTimeout(), DEFAULT_MONITOR_WAIT_TIME_MS, checkPointIntervalInMs);
if (sourceConfig.isAcknowledgmentsEnabled()) {
// starts acknowledgement monitoring thread
streamAcknowledgementManager.init();
streamAcknowledgementManager.init((Void) -> stop());
}
}

Expand Down Expand Up @@ -138,7 +134,7 @@ public void processStream(final StreamPartition streamPartition) {
}
}
long lastCheckpointTime = System.currentTimeMillis();
while (cursor.hasNext() && !Thread.currentThread().isInterrupted()) {
while (cursor.hasNext() && !Thread.currentThread().isInterrupted() && !stopWorker) {
try {
final ChangeStreamDocument<Document> document = cursor.next();
final String record = document.getFullDocument().toJson(writerSettings);
Expand Down Expand Up @@ -190,4 +186,8 @@ public void processStream(final StreamPartition streamPartition) {
}
}
}

void stop() {
stopWorker = true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,5 @@ public void testProcessPartitionSuccess(final String partitionKey) {
verify(mockRecordBufferWriter).writeToBuffer(eq(mockAcknowledgementSet), any());
verify(successItemsCounter, times(2)).increment();
verify(failureItemsCounter, never()).increment();
executorService.shutdownNow();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
Expand All @@ -40,6 +39,8 @@ public class StreamAcknowledgementManagerTest {
private Duration timeout;
@Mock
private AcknowledgementSet acknowledgementSet;
@Mock
private Consumer<Void> stopWorkerConsumer;
private StreamAcknowledgementManager streamAckManager;

@BeforeEach
Expand All @@ -55,7 +56,9 @@ public void createAcknowledgementSet_disabled_emptyAckSet() {

@Test
public void createAcknowledgementSet_enabled_ackSetWithAck() {
streamAckManager.init();
when(timeout.getSeconds()).thenReturn(10_000L);
streamAckManager = new StreamAcknowledgementManager(acknowledgementSetManager, partitionCheckpoint, timeout, 0, 0);
streamAckManager.init(stopWorkerConsumer);
final String resumeToken = UUID.randomUUID().toString();
final long recordCount = new Random().nextLong();
when(acknowledgementSetManager.create(any(Consumer.class), eq(timeout))).thenReturn(acknowledgementSet);
Expand All @@ -79,7 +82,9 @@ public void createAcknowledgementSet_enabled_ackSetWithAck() {

@Test
public void createAcknowledgementSet_enabled_multipleAckSetWithAck() {
streamAckManager.init();
when(timeout.getSeconds()).thenReturn(10_000L);
streamAckManager = new StreamAcknowledgementManager(acknowledgementSetManager, partitionCheckpoint, timeout, 0, 0);
streamAckManager.init(stopWorkerConsumer);
final String resumeToken1 = UUID.randomUUID().toString();
final long recordCount1 = new Random().nextLong();
when(acknowledgementSetManager.create(any(Consumer.class), eq(timeout))).thenReturn(acknowledgementSet);
Expand Down Expand Up @@ -113,7 +118,7 @@ public void createAcknowledgementSet_enabled_multipleAckSetWithAck() {

@Test
public void createAcknowledgementSet_enabled_multipleAckSetWithAckFailure() {
streamAckManager.init();
streamAckManager.init(stopWorkerConsumer);
final String resumeToken1 = UUID.randomUUID().toString();
final long recordCount1 = new Random().nextLong();
when(acknowledgementSetManager.create(any(Consumer.class), eq(timeout))).thenReturn(acknowledgementSet);
Expand Down Expand Up @@ -141,14 +146,15 @@ public void createAcknowledgementSet_enabled_multipleAckSetWithAckFailure() {
assertThat(ackCheckpointStatus.isAcknowledged(), is(true));
await()
.atMost(Duration.ofSeconds(10)).untilAsserted(() ->
verifyNoInteractions(partitionCheckpoint));
verify(partitionCheckpoint).giveUpPartition());
assertThat(streamAckManager.getCheckpoints().peek().getResumeToken(), is(resumeToken1));
assertThat(streamAckManager.getCheckpoints().peek().getRecordCount(), is(recordCount1));
verify(stopWorkerConsumer).accept(null);
}

@Test
public void createAcknowledgementSet_enabled_ackSetWithNoAck() {
streamAckManager.init();
streamAckManager.init(stopWorkerConsumer);
final String resumeToken = UUID.randomUUID().toString();
final long recordCount = new Random().nextLong();
when(acknowledgementSetManager.create(any(Consumer.class), eq(timeout))).thenReturn(acknowledgementSet);
Expand All @@ -164,5 +170,8 @@ public void createAcknowledgementSet_enabled_ackSetWithNoAck() {
final ConcurrentHashMap<String, CheckpointStatus> ackStatus = streamAckManager.getAcknowledgementStatus();
final CheckpointStatus ackCheckpointStatus = ackStatus.get(resumeToken);
assertThat(ackCheckpointStatus.isAcknowledged(), is(false));
}
await()
.atMost(Duration.ofSeconds(10)).untilAsserted(() ->
verify(stopWorkerConsumer).accept(null));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ void test_stream_run() {
final ExecutorService executorService = Executors.newSingleThreadExecutor();
final Future<?> future = executorService.submit(() -> {
try (MockedStatic<StreamWorker> streamWorkerMockedStatic = mockStatic(StreamWorker.class)) {
streamWorkerMockedStatic.when(() -> StreamWorker.create(any(RecordBufferWriter.class), eq(acknowledgementSetManager),
eq(sourceConfig), any(DataStreamPartitionCheckpoint.class), eq(pluginMetrics), eq(100), eq(DEFAULT_CHECKPOINT_INTERVAL_MILLS)))
.thenReturn(streamWorker);
streamWorkerMockedStatic.when(() -> StreamWorker.create(any(RecordBufferWriter.class), eq(sourceConfig),
any(StreamAcknowledgementManager.class), any(DataStreamPartitionCheckpoint.class), eq(pluginMetrics), eq(100), eq(DEFAULT_CHECKPOINT_INTERVAL_MILLS)))
.thenReturn(streamWorker);
streamScheduler.run();
}
});
Expand Down
Loading

0 comments on commit 5d0c291

Please sign in to comment.