diff --git a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/client/MongoDBConnection.java b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/client/MongoDBConnection.java index fcf9aa83f1..a44c6bfa35 100644 --- a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/client/MongoDBConnection.java +++ b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/client/MongoDBConnection.java @@ -42,17 +42,24 @@ private static String encodeString(final String input) { private static String getConnectionString(final MongoDBSourceConfig sourceConfig) { final String username; try { - username = encodeString(sourceConfig.getCredentialsConfig().getUsername()); + username = encodeString(sourceConfig.getAuthenticationConfig().getUsername()); } catch (final Exception e) { throw new RuntimeException("Unsupported characters in username."); } + final String password; try { - password = encodeString(sourceConfig.getCredentialsConfig().getPassword()); + password = encodeString(sourceConfig.getAuthenticationConfig().getPassword()); } catch (final Exception e) { throw new RuntimeException("Unsupported characters in password."); } - final String hostname = sourceConfig.getHost(); + + if (sourceConfig.getHosts() == null || sourceConfig.getHosts().length == 0) { + throw new RuntimeException("The hosts array should at least have one host."); + } + + // Support for only single host + final String hostname = sourceConfig.getHosts()[0]; final int port = sourceConfig.getPort(); final String tls = sourceConfig.getTls().toString(); final String invalidHostAllowed = sourceConfig.getSslInsecureDisableVerification().toString(); diff --git a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/configuration/AwsConfig.java b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/configuration/AwsConfig.java new file mode 100644 index 0000000000..05e6223865 --- /dev/null +++ b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/configuration/AwsConfig.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.mongo.configuration; + +import com.fasterxml.jackson.annotation.JsonProperty; +import jakarta.validation.constraints.Size; +import software.amazon.awssdk.arns.Arn; +import software.amazon.awssdk.regions.Region; + +import java.util.Map; +import java.util.Optional; + +public class AwsConfig { + private static final String AWS_IAM_ROLE = "role"; + private static final String AWS_IAM = "iam"; + + @JsonProperty("sts_role_arn") + @Size(min = 20, max = 2048, message = "awsStsRoleArn length should be between 1 and 2048 characters") + private String awsStsRoleArn; + + @JsonProperty("sts_external_id") + @Size(min = 2, max = 1224, message = "awsStsExternalId length should be between 2 and 1224 characters") + private String awsStsExternalId; + + @JsonProperty("sts_header_overrides") + @Size(max = 5, message = "sts_header_overrides supports a maximum of 5 headers to override") + private Map awsStsHeaderOverrides; + + private void validateStsRoleArn() { + final Arn arn = getArn(); + if (!AWS_IAM.equals(arn.service())) { + throw new IllegalArgumentException("sts_role_arn must be an IAM Role"); + } + final Optional resourceType = arn.resource().resourceType(); + if (resourceType.isEmpty() || !resourceType.get().equals(AWS_IAM_ROLE)) { + throw new IllegalArgumentException("sts_role_arn must be an IAM Role"); + } + } + + private Arn getArn() { + try { + return Arn.fromString(awsStsRoleArn); + } catch (final Exception e) { + throw new IllegalArgumentException(String.format("Invalid ARN format for awsStsRoleArn. Check the format of %s", awsStsRoleArn)); + } + } + + public String getAwsStsRoleArn() { + return awsStsRoleArn; + } + + public String getAwsStsExternalId() { + return awsStsExternalId; + } + + public Map getAwsStsHeaderOverrides() { + return awsStsHeaderOverrides; + } +} + diff --git a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/configuration/CollectionConfig.java b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/configuration/CollectionConfig.java index 98ac4f524a..7814156b00 100644 --- a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/configuration/CollectionConfig.java +++ b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/configuration/CollectionConfig.java @@ -20,15 +20,6 @@ public class CollectionConfig { @JsonProperty("stream") private boolean stream; - @JsonProperty("s3_bucket") - private String s3Bucket; - - @JsonProperty("s3_path_prefix") - private String s3PathPrefix; - - @JsonProperty("s3_region") - private String s3Region; - @JsonProperty("partition_count") private int partitionCount; @@ -65,14 +56,6 @@ public boolean isStream() { return this.stream; } - public String getS3Bucket() { - return this.s3Bucket; - } - - public String getS3PathPrefix() { - return this.s3PathPrefix; - } - public int getPartitionCount() { return this.partitionCount; } @@ -84,7 +67,4 @@ public int getExportBatchSize() { public int getStreamBatchSize() { return this.streamBatchSize; } - public String getS3Region() { - return this.s3Region; - } } diff --git a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/configuration/MongoDBSourceConfig.java b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/configuration/MongoDBSourceConfig.java index a9ed0b6981..b7b7b94c31 100644 --- a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/configuration/MongoDBSourceConfig.java +++ b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/configuration/MongoDBSourceConfig.java @@ -1,9 +1,11 @@ package org.opensearch.dataprepper.plugins.mongo.configuration; import com.fasterxml.jackson.annotation.JsonProperty; +import jakarta.validation.Valid; import jakarta.validation.constraints.NotNull; import java.time.Duration; +import java.time.Instant; import java.util.ArrayList; import java.util.List; @@ -15,8 +17,11 @@ public class MongoDBSourceConfig { private static final String DEFAULT_READ_PREFERENCE = "primaryPreferred"; private static final Boolean DEFAULT_DIRECT_CONNECT = true; private static final Duration DEFAULT_ACKNOWLEDGEMENT_SET_TIMEOUT = Duration.ofHours(2); - @JsonProperty("host") - private @NotNull String host; + private static final String DATAPREPPER_SERVICE_NAME = "DATAPREPPER_SERVICE_NAME"; + + private static final long currentTimeInEpochMilli = Instant.now().toEpochMilli(); + @JsonProperty("hosts") + private @NotNull String[] hosts; @JsonProperty("port") private int port = DEFAULT_PORT; @JsonProperty("trust_store_file_path") @@ -35,6 +40,15 @@ public class MongoDBSourceConfig { @JsonProperty("acknowledgments") private Boolean acknowledgments = false; + @JsonProperty("s3_bucket") + private String s3Bucket; + + @JsonProperty("s3_path_prefix") + private String s3PathPrefix; + + @JsonProperty("s3_region") + private String s3Region; + @JsonProperty private Duration partitionAcknowledgmentTimeout; @@ -45,6 +59,11 @@ public class MongoDBSourceConfig { @JsonProperty("direct_connection") private Boolean directConnection; + @JsonProperty("aws") + @NotNull + @Valid + private AwsConfig awsConfig; + public MongoDBSourceConfig() { this.snapshotFetchSize = DEFAULT_SNAPSHOT_FETCH_SIZE; this.readPreference = DEFAULT_READ_PREFERENCE; @@ -55,12 +74,12 @@ public MongoDBSourceConfig() { this.partitionAcknowledgmentTimeout = DEFAULT_ACKNOWLEDGEMENT_SET_TIMEOUT; } - public AuthenticationConfig getCredentialsConfig() { + public AuthenticationConfig getAuthenticationConfig() { return this.authenticationConfig; } - public String getHost() { - return this.host; + public String[] getHosts() { + return this.hosts; } public int getPort() { @@ -103,6 +122,32 @@ public Duration getPartitionAcknowledgmentTimeout() { return this.partitionAcknowledgmentTimeout; } + public String getS3Bucket() { + return this.s3Bucket; + } + + public String getS3PathPrefix() { + return this.s3PathPrefix; + } + + public String getTransformedS3PathPrefix(final String collection) { + final String serviceName = System.getenv(DATAPREPPER_SERVICE_NAME); + final String suffixPath = serviceName + "/" + collection + "/" + currentTimeInEpochMilli; + if (this.getS3PathPrefix() == null || this.getS3PathPrefix().trim().isBlank()) { + return this.s3PathPrefix + "/" + suffixPath; + } else { + return suffixPath; + } + } + + public String getS3Region() { + return this.s3Region; + } + + public AwsConfig getAwsConfig() { + return this.awsConfig; + } + public static class AuthenticationConfig { @JsonProperty("username") private String username; diff --git a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/converter/PartitionKeyRecordConverter.java b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/converter/PartitionKeyRecordConverter.java index 2cf4738325..d845ea0094 100644 --- a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/converter/PartitionKeyRecordConverter.java +++ b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/converter/PartitionKeyRecordConverter.java @@ -9,10 +9,13 @@ import java.util.List; public class PartitionKeyRecordConverter extends RecordConverter { + public static final String S3_PATH_DELIMITER = "/"; private List partitionNames = new ArrayList<>(); private int partitionSize = 0; - public PartitionKeyRecordConverter(final String collection, final String partitionType) { + final String s3PathPrefix; + public PartitionKeyRecordConverter(final String collection, final String partitionType, final String s3PathPrefix) { super(collection, partitionType); + this.s3PathPrefix = s3PathPrefix; } public void initializePartitions(final List partitionNames) { @@ -28,7 +31,7 @@ public Event convert(final String record, final Event event = super.convert(record, eventCreationTimeMillis, eventVersionNumber, eventName); final EventMetadata eventMetadata = event.getMetadata(); final String partitionKey = String.valueOf(eventMetadata.getAttribute(MetadataKeyAttributes.PARTITION_KEY_METADATA_ATTRIBUTE)); - eventMetadata.setAttribute(MetadataKeyAttributes.EVENT_S3_PARTITION_KEY, hashKeyToPartition(partitionKey)); + eventMetadata.setAttribute(MetadataKeyAttributes.EVENT_S3_PARTITION_KEY, s3PathPrefix + S3_PATH_DELIMITER + hashKeyToPartition(partitionKey)); return event; } diff --git a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/documentdb/DocumentDBService.java b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/documentdb/DocumentDBService.java index 14c0a30e7e..a1f07e9d7e 100644 --- a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/documentdb/DocumentDBService.java +++ b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/documentdb/DocumentDBService.java @@ -53,7 +53,7 @@ public DocumentDBService(final EnhancedSourceCoordinator sourceCoordinator, public void start(Buffer> buffer) { final List runnableList = new ArrayList<>(); - final LeaderScheduler leaderScheduler = new LeaderScheduler(sourceCoordinator, sourceConfig.getCollections()); + final LeaderScheduler leaderScheduler = new LeaderScheduler(sourceCoordinator, sourceConfig); runnableList.add(leaderScheduler); if (sourceConfig.getCollections().stream().anyMatch(CollectionConfig::isExport)) { diff --git a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/export/ExportWorker.java b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/export/ExportWorker.java index 92984148e7..69d50a699b 100644 --- a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/export/ExportWorker.java +++ b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/export/ExportWorker.java @@ -108,8 +108,10 @@ public void run() { if (sourcePartition.isPresent()) { dataQueryPartition = (DataQueryPartition) sourcePartition.get(); final AcknowledgementSet acknowledgementSet = createAcknowledgementSet(dataQueryPartition).orElse(null); + final String s3PathPrefix = sourceConfig.getTransformedS3PathPrefix(dataQueryPartition.getCollection()); final DataQueryPartitionCheckpoint partitionCheckpoint = new DataQueryPartitionCheckpoint(sourceCoordinator, dataQueryPartition); - final PartitionKeyRecordConverter recordConverter = new PartitionKeyRecordConverter(dataQueryPartition.getCollection(), ExportPartition.PARTITION_TYPE); + final PartitionKeyRecordConverter recordConverter = new PartitionKeyRecordConverter(dataQueryPartition.getCollection(), + ExportPartition.PARTITION_TYPE, s3PathPrefix); final ExportPartitionWorker exportPartitionWorker = new ExportPartitionWorker(recordBufferWriter, recordConverter, dataQueryPartition, acknowledgementSet, sourceConfig, partitionCheckpoint, Instant.now().toEpochMilli(), pluginMetrics); final CompletableFuture runLoader = CompletableFuture.runAsync(exportPartitionWorker, executor); diff --git a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/leader/LeaderScheduler.java b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/leader/LeaderScheduler.java index 58c9746443..6e33910f51 100644 --- a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/leader/LeaderScheduler.java +++ b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/leader/LeaderScheduler.java @@ -2,6 +2,7 @@ import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator; import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourcePartition; +import org.opensearch.dataprepper.plugins.mongo.configuration.MongoDBSourceConfig; import org.opensearch.dataprepper.plugins.mongo.coordination.partition.ExportPartition; import org.opensearch.dataprepper.plugins.mongo.coordination.partition.GlobalState; import org.opensearch.dataprepper.plugins.mongo.coordination.partition.LeaderPartition; @@ -33,7 +34,7 @@ public class LeaderScheduler implements Runnable { */ private static final Duration DEFAULT_LEASE_INTERVAL = Duration.ofMinutes(1); - private final List collectionConfigs; + private final MongoDBSourceConfig sourceConfig; private final EnhancedSourceCoordinator coordinator; @@ -41,14 +42,14 @@ public class LeaderScheduler implements Runnable { private LeaderPartition leaderPartition; - public LeaderScheduler(EnhancedSourceCoordinator coordinator, List collectionConfigs) { - this(coordinator, collectionConfigs, DEFAULT_LEASE_INTERVAL); + public LeaderScheduler(final EnhancedSourceCoordinator coordinator, final MongoDBSourceConfig sourceConfig) { + this(coordinator, sourceConfig, DEFAULT_LEASE_INTERVAL); } LeaderScheduler(EnhancedSourceCoordinator coordinator, - List collectionConfigs, + MongoDBSourceConfig sourceConfig, Duration leaseInterval) { - this.collectionConfigs = collectionConfigs; + this.sourceConfig = sourceConfig; this.coordinator = coordinator; this.leaseInterval = leaseInterval; } @@ -106,7 +107,7 @@ public void run() { private void init() { LOG.info("Try to initialize DocumentDB Leader Partition"); - collectionConfigs.forEach(collectionConfig -> { + sourceConfig.getCollections().forEach(collectionConfig -> { // Create a Global state in the coordination table for the configuration. // Global State here is designed to be able to read whenever needed // So that the jobs can refer to the configuration. @@ -119,7 +120,8 @@ private void init() { createExportPartition(collectionConfig, startTime); } - createS3Partition(collectionConfig); + final String s3PathPrefix = sourceConfig.getTransformedS3PathPrefix(collectionConfig.getCollection() + "-" + Instant.now().toEpochMilli()); + createS3Partition(sourceConfig.getS3Bucket(), sourceConfig.getS3Region(), s3PathPrefix, collectionConfig); if (collectionConfig.isStream()) { createStreamPartition(collectionConfig, startTime, exportRequired); @@ -137,10 +139,10 @@ private void init() { * * @param collectionConfig collection configuration object containing collection details */ - private void createS3Partition(final CollectionConfig collectionConfig) { + private void createS3Partition(final String s3Bucket, final String s3Region, final String s3PathPrefix, final CollectionConfig collectionConfig) { LOG.info("Creating s3 folder global partition: {}", collectionConfig.getCollection()); - coordinator.createPartition(new S3FolderPartition(collectionConfig.getS3Bucket(), collectionConfig.getS3PathPrefix(), - collectionConfig.getS3Region(), collectionConfig.getCollection(), collectionConfig.getPartitionCount())); + coordinator.createPartition(new S3FolderPartition(s3Bucket, s3PathPrefix, + s3Region, collectionConfig.getCollection(), collectionConfig.getPartitionCount())); } /** diff --git a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamScheduler.java b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamScheduler.java index 01fe5f1cb0..25ddb84063 100644 --- a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamScheduler.java +++ b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/stream/StreamScheduler.java @@ -90,7 +90,9 @@ private StreamWorker getStreamWorker (final StreamPartition streamPartition) { final DataStreamPartitionCheckpoint partitionCheckpoint = new DataStreamPartitionCheckpoint(sourceCoordinator, streamPartition); final StreamAcknowledgementManager streamAcknowledgementManager = new StreamAcknowledgementManager(acknowledgementSetManager, partitionCheckpoint, sourceConfig.getPartitionAcknowledgmentTimeout(), DEFAULT_MONITOR_WAIT_TIME_MS, DEFAULT_CHECKPOINT_INTERVAL_MILLS); - final PartitionKeyRecordConverter recordConverter = new PartitionKeyRecordConverter(streamPartition.getCollection(),StreamPartition.PARTITION_TYPE); + final String s3PathPrefix = sourceConfig.getTransformedS3PathPrefix(streamPartition.getCollection()); + final PartitionKeyRecordConverter recordConverter = new PartitionKeyRecordConverter(streamPartition.getCollection(), + StreamPartition.PARTITION_TYPE, s3PathPrefix); final CollectionConfig partitionCollectionConfig = sourceConfig.getCollections().stream() .filter(collectionConfig -> collectionConfig.getCollection().equals(streamPartition.getCollection())) .findFirst() diff --git a/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/client/MongoDBConnectionTest.java b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/client/MongoDBConnectionTest.java index 23ee2bc37d..1a427e0ee5 100644 --- a/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/client/MongoDBConnectionTest.java +++ b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/client/MongoDBConnectionTest.java @@ -1,7 +1,6 @@ package org.opensearch.dataprepper.plugins.mongo.client; import com.mongodb.client.MongoClient; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -18,6 +17,7 @@ import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.notNullValue; import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.when; @@ -32,12 +32,11 @@ public class MongoDBConnectionTest { private final Random random = new Random(); - @BeforeEach void setUp() { - when(mongoDBSourceConfig.getCredentialsConfig()).thenReturn(authenticationConfig); + when(mongoDBSourceConfig.getAuthenticationConfig()).thenReturn(authenticationConfig); when(authenticationConfig.getUsername()).thenReturn("\uD800\uD800" + UUID.randomUUID()); when(authenticationConfig.getPassword()).thenReturn("aЯ ⾀sd?q=%%l€0£.lo" + UUID.randomUUID()); - when(mongoDBSourceConfig.getHost()).thenReturn(UUID.randomUUID().toString()); + when(mongoDBSourceConfig.getHosts()).thenReturn(new String[] { UUID.randomUUID().toString() }); when(mongoDBSourceConfig.getPort()).thenReturn(getRandomInteger()); when(mongoDBSourceConfig.getTls()).thenReturn(getRandomBoolean()); when(mongoDBSourceConfig.getSslInsecureDisableVerification()).thenReturn(getRandomBoolean()); @@ -46,12 +45,14 @@ void setUp() { @Test public void getMongoClient() { + setUp(); final MongoClient mongoClient = MongoDBConnection.getMongoClient(mongoDBSourceConfig); assertThat(mongoClient, is(notNullValue())); } @Test public void getMongoClientWithTLS() { + setUp(); when(mongoDBSourceConfig.getTrustStoreFilePath()).thenReturn(UUID.randomUUID().toString()); when(mongoDBSourceConfig.getTrustStorePassword()).thenReturn(UUID.randomUUID().toString()); final Path path = mock(Path.class); @@ -65,6 +66,24 @@ public void getMongoClientWithTLS() { } } + @Test + public void getMongoClientNullHost() { + when(mongoDBSourceConfig.getAuthenticationConfig()).thenReturn(authenticationConfig); + when(authenticationConfig.getUsername()).thenReturn("\uD800\uD800" + UUID.randomUUID()); + when(authenticationConfig.getPassword()).thenReturn("aЯ ⾀sd?q=%%l€0£.lo" + UUID.randomUUID()); + when(mongoDBSourceConfig.getHosts()).thenReturn(null); + assertThrows(RuntimeException.class, () -> MongoDBConnection.getMongoClient(mongoDBSourceConfig)); + } + + @Test + public void getMongoClientEmptyHost() { + when(mongoDBSourceConfig.getAuthenticationConfig()).thenReturn(authenticationConfig); + when(authenticationConfig.getUsername()).thenReturn("\uD800\uD800" + UUID.randomUUID()); + when(authenticationConfig.getPassword()).thenReturn("aЯ ⾀sd?q=%%l€0£.lo" + UUID.randomUUID()); + when(mongoDBSourceConfig.getHosts()).thenReturn(new String[]{}); + assertThrows(RuntimeException.class, () -> MongoDBConnection.getMongoClient(mongoDBSourceConfig)); + } + private Boolean getRandomBoolean() { return random.nextBoolean(); } diff --git a/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/leader/LeaderSchedulerTest.java b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/leader/LeaderSchedulerTest.java index 39217b750d..c8eb1a0316 100644 --- a/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/leader/LeaderSchedulerTest.java +++ b/data-prepper-plugins/mongodb/src/test/java/org/opensearch/dataprepper/plugins/mongo/leader/LeaderSchedulerTest.java @@ -7,6 +7,7 @@ import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator; import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourcePartition; import org.opensearch.dataprepper.plugins.mongo.configuration.CollectionConfig; +import org.opensearch.dataprepper.plugins.mongo.configuration.MongoDBSourceConfig; import org.opensearch.dataprepper.plugins.mongo.coordination.partition.LeaderPartition; import java.time.Duration; @@ -37,27 +38,28 @@ public class LeaderSchedulerTest { @Mock private CollectionConfig collectionConfig; + @Mock + private MongoDBSourceConfig mongoDBSourceConfig; private LeaderScheduler leaderScheduler; private LeaderPartition leaderPartition; @Test void test_non_leader_run() { - leaderScheduler = new LeaderScheduler(coordinator, List.of(collectionConfig), Duration.ofMillis(100)); + leaderScheduler = new LeaderScheduler(coordinator, mongoDBSourceConfig, Duration.ofMillis(100)); given(coordinator.acquireAvailablePartition(LeaderPartition.PARTITION_TYPE)).willReturn(Optional.empty()); - final ExecutorService executorService = Executors.newSingleThreadExecutor(); executorService.submit(() -> leaderScheduler.run()); await() .atMost(Duration.ofSeconds(2)) - .untilAsserted(() -> verifyNoInteractions(collectionConfig)); + .untilAsserted(() -> verifyNoInteractions(mongoDBSourceConfig)); executorService.shutdownNow(); } @Test void test_should_init() { - - leaderScheduler = new LeaderScheduler(coordinator, List.of(collectionConfig), Duration.ofMillis(100)); + given(mongoDBSourceConfig.getCollections()).willReturn(List.of(collectionConfig)); + leaderScheduler = new LeaderScheduler(coordinator, mongoDBSourceConfig, Duration.ofMillis(100)); leaderPartition = new LeaderPartition(); given(coordinator.acquireAvailablePartition(LeaderPartition.PARTITION_TYPE)).willReturn(Optional.of(leaderPartition)); given(collectionConfig.isExport()).willReturn(true); @@ -89,8 +91,8 @@ void test_should_init() { @Test void test_should_init_export() { - - leaderScheduler = new LeaderScheduler(coordinator, List.of(collectionConfig), Duration.ofMillis(100)); + given(mongoDBSourceConfig.getCollections()).willReturn(List.of(collectionConfig)); + leaderScheduler = new LeaderScheduler(coordinator, mongoDBSourceConfig, Duration.ofMillis(100)); leaderPartition = new LeaderPartition(); given(coordinator.acquireAvailablePartition(LeaderPartition.PARTITION_TYPE)).willReturn(Optional.of(leaderPartition)); given(collectionConfig.isExport()).willReturn(true); @@ -121,8 +123,8 @@ void test_should_init_export() { @Test void test_should_init_stream() { - - leaderScheduler = new LeaderScheduler(coordinator, List.of(collectionConfig), Duration.ofMillis(100)); + given(mongoDBSourceConfig.getCollections()).willReturn(List.of(collectionConfig)); + leaderScheduler = new LeaderScheduler(coordinator, mongoDBSourceConfig, Duration.ofMillis(100)); leaderPartition = new LeaderPartition(); given(coordinator.acquireAvailablePartition(LeaderPartition.PARTITION_TYPE)).willReturn(Optional.of(leaderPartition)); given(collectionConfig.isStream()).willReturn(true);