diff --git a/data-prepper-api/src/test/java/org/opensearch/dataprepper/model/sink/AbstractSinkTest.java b/data-prepper-api/src/test/java/org/opensearch/dataprepper/model/sink/AbstractSinkTest.java index 474d1880c4..e4b19cf7ca 100644 --- a/data-prepper-api/src/test/java/org/opensearch/dataprepper/model/sink/AbstractSinkTest.java +++ b/data-prepper-api/src/test/java/org/opensearch/dataprepper/model/sink/AbstractSinkTest.java @@ -14,6 +14,7 @@ import org.opensearch.dataprepper.model.configuration.PluginSetting; import org.opensearch.dataprepper.model.record.Record; +import java.time.Duration; import java.util.Arrays; import java.util.Collection; import java.util.Collections; @@ -21,6 +22,8 @@ import java.util.StringJoiner; import java.util.UUID; +import static org.awaitility.Awaitility.await; + public class AbstractSinkTest { @Test public void testMetrics() { @@ -51,8 +54,8 @@ public void testMetrics() { Assert.assertEquals(1.0, MetricsTestUtil.getMeasurementFromList(elapsedTimeMeasurements, Statistic.COUNT).getValue(), 0); Assert.assertTrue(MetricsTestUtil.isBetween( MetricsTestUtil.getMeasurementFromList(elapsedTimeMeasurements, Statistic.TOTAL_TIME).getValue(), - 0.5, - 0.6)); + 0.2, + 0.3)); Assert.assertEquals(abstractSink.getRetryThreadState(), null); abstractSink.shutdown(); } @@ -71,14 +74,8 @@ public void testSinkNotReady() { // Do another intialize to make sure the sink is still not ready abstractSink.initialize(); Assert.assertEquals(abstractSink.isReady(), false); - while (!abstractSink.isReady()) { - try { - Thread.sleep(1000); - } catch (Exception e) {} - } - try { - Thread.sleep(2000); - } catch (Exception e) {} + await().atMost(Duration.ofSeconds(5)) + .until(abstractSink::isReady); Assert.assertEquals(abstractSink.getRetryThreadState(), Thread.State.TERMINATED); abstractSink.shutdown(); } @@ -92,7 +89,7 @@ public AbstractSinkImpl(PluginSetting pluginSetting) { @Override public void doOutput(Collection> records) { try { - Thread.sleep(500); + Thread.sleep(200); } catch (InterruptedException e) { } @@ -126,7 +123,7 @@ public AbstractSinkNotReadyImpl(PluginSetting pluginSetting) { @Override public void doOutput(Collection> records) { try { - Thread.sleep(500); + Thread.sleep(100); } catch (InterruptedException e) { } diff --git a/data-prepper-api/src/test/java/org/opensearch/dataprepper/model/sink/SinkThreadTest.java b/data-prepper-api/src/test/java/org/opensearch/dataprepper/model/sink/SinkThreadTest.java index 6f39084a25..1660484a72 100644 --- a/data-prepper-api/src/test/java/org/opensearch/dataprepper/model/sink/SinkThreadTest.java +++ b/data-prepper-api/src/test/java/org/opensearch/dataprepper/model/sink/SinkThreadTest.java @@ -10,7 +10,6 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -19,14 +18,14 @@ @ExtendWith(MockitoExtension.class) public class SinkThreadTest { @Mock - AbstractSink sink; + private AbstractSink sink; - SinkThread sinkThread; + private SinkThread sinkThread; @Test public void testSinkThread() { when(sink.isReady()).thenReturn(true); - sinkThread = new SinkThread(sink, 5, 1000); + sinkThread = new SinkThread(sink, 5, 100); sinkThread.run(); verify(sink, times(1)).isReady(); } @@ -34,25 +33,18 @@ public void testSinkThread() { @Test public void testSinkThread2() { when(sink.isReady()).thenReturn(false); - sinkThread = new SinkThread(sink, 5, 1000); + sinkThread = new SinkThread(sink, 5, 100); sinkThread.run(); verify(sink, times(6)).isReady(); - try { - doAnswer((i) -> { - return null; - }).when(sink).doInitialize(); - verify(sink, times(5)).doInitialize(); - } catch (Exception e){} + verify(sink, times(5)).doInitialize(); when(sink.isReady()).thenReturn(false).thenReturn(true); sinkThread.run(); verify(sink, times(8)).isReady(); when(sink.isReady()).thenReturn(false).thenReturn(true); - try { - lenient().doAnswer((i) -> { - throw new InterruptedException("Fake interrupt"); - }).when(sink).doInitialize(); - sinkThread.run(); - verify(sink, times(7)).doInitialize(); - } catch (Exception e){} + lenient().doAnswer((i) -> { + throw new InterruptedException("Fake interrupt"); + }).when(sink).doInitialize(); + sinkThread.run(); + verify(sink, times(7)).doInitialize(); } } diff --git a/data-prepper-core/src/test/java/org/opensearch/dataprepper/acknowledgements/AcknowledgementSetMonitorTests.java b/data-prepper-core/src/test/java/org/opensearch/dataprepper/acknowledgements/AcknowledgementSetMonitorTests.java index 3022d8687b..8c9d065704 100644 --- a/data-prepper-core/src/test/java/org/opensearch/dataprepper/acknowledgements/AcknowledgementSetMonitorTests.java +++ b/data-prepper-core/src/test/java/org/opensearch/dataprepper/acknowledgements/AcknowledgementSetMonitorTests.java @@ -20,7 +20,7 @@ @ExtendWith(MockitoExtension.class) public class AcknowledgementSetMonitorTests { - private static final int DEFAULT_WAIT_TIME_MS = 2000; + private static final int DEFAULT_WAIT_TIME_MS = 500; @Mock DefaultAcknowledgementSet acknowledgementSet1; @Mock diff --git a/data-prepper-core/src/test/java/org/opensearch/dataprepper/acknowledgements/DefaultAcknowledgementSetManagerTests.java b/data-prepper-core/src/test/java/org/opensearch/dataprepper/acknowledgements/DefaultAcknowledgementSetManagerTests.java index 9b015aea72..486617e9a0 100644 --- a/data-prepper-core/src/test/java/org/opensearch/dataprepper/acknowledgements/DefaultAcknowledgementSetManagerTests.java +++ b/data-prepper-core/src/test/java/org/opensearch/dataprepper/acknowledgements/DefaultAcknowledgementSetManagerTests.java @@ -14,6 +14,8 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; + +import static org.awaitility.Awaitility.await; import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.doAnswer; @@ -27,7 +29,7 @@ @ExtendWith(MockitoExtension.class) class DefaultAcknowledgementSetManagerTests { - private static final Duration TEST_TIMEOUT_MS = Duration.ofMillis(1000); + private static final Duration TEST_TIMEOUT = Duration.ofMillis(400); DefaultAcknowledgementSetManager acknowledgementSetManager; private ExecutorService callbackExecutor; @@ -61,21 +63,25 @@ void setup() { lenient().when(event2.getEventHandle()).thenReturn(eventHandle2); acknowledgementSetManager = createObjectUnderTest(); - AcknowledgementSet acknowledgementSet1 = acknowledgementSetManager.create((flag) -> { result = flag; }, TEST_TIMEOUT_MS); + AcknowledgementSet acknowledgementSet1 = acknowledgementSetManager.create((flag) -> { result = flag; }, TEST_TIMEOUT); acknowledgementSet1.add(event1); acknowledgementSet1.add(event2); acknowledgementSet1.complete(); } DefaultAcknowledgementSetManager createObjectUnderTest() { - return new DefaultAcknowledgementSetManager(callbackExecutor, Duration.ofMillis(TEST_TIMEOUT_MS.toMillis() * 2)); + return new DefaultAcknowledgementSetManager(callbackExecutor, Duration.ofMillis(TEST_TIMEOUT.toMillis() * 2)); } @Test - void testBasic() throws InterruptedException { + void testBasic() { acknowledgementSetManager.releaseEventReference(eventHandle2, true); acknowledgementSetManager.releaseEventReference(eventHandle1, true); - Thread.sleep(TEST_TIMEOUT_MS.toMillis() * 5); + await().atMost(TEST_TIMEOUT.multipliedBy(5)) + .untilAsserted(() -> { + assertThat(acknowledgementSetManager.getAcknowledgementSetMonitor().getSize(), equalTo(0)); + assertThat(result, equalTo(true)); + }); assertThat(acknowledgementSetManager.getAcknowledgementSetMonitor().getSize(), equalTo(0)); assertThat(result, equalTo(true)); } @@ -83,13 +89,13 @@ void testBasic() throws InterruptedException { @Test void testExpirations() throws InterruptedException { acknowledgementSetManager.releaseEventReference(eventHandle2, true); - Thread.sleep(TEST_TIMEOUT_MS.toMillis() * 5); + Thread.sleep(TEST_TIMEOUT.multipliedBy(5).toMillis()); assertThat(acknowledgementSetManager.getAcknowledgementSetMonitor().getSize(), equalTo(0)); assertThat(result, equalTo(null)); } @Test - void testMultipleAcknowledgementSets() throws InterruptedException { + void testMultipleAcknowledgementSets() { event3 = mock(JacksonEvent.class); doAnswer((i) -> { eventHandle3 = i.getArgument(0); @@ -97,13 +103,17 @@ void testMultipleAcknowledgementSets() throws InterruptedException { }).when(event3).setEventHandle(any()); lenient().when(event3.getEventHandle()).thenReturn(eventHandle3); - AcknowledgementSet acknowledgementSet2 = acknowledgementSetManager.create((flag) -> { result = flag; }, TEST_TIMEOUT_MS); + AcknowledgementSet acknowledgementSet2 = acknowledgementSetManager.create((flag) -> { result = flag; }, TEST_TIMEOUT); acknowledgementSet2.add(event3); acknowledgementSet2.complete(); acknowledgementSetManager.releaseEventReference(eventHandle2, true); acknowledgementSetManager.releaseEventReference(eventHandle3, true); - Thread.sleep(TEST_TIMEOUT_MS.toMillis() * 5); + await().atMost(TEST_TIMEOUT.multipliedBy(5)) + .untilAsserted(() -> { + assertThat(acknowledgementSetManager.getAcknowledgementSetMonitor().getSize(), equalTo(0)); + assertThat(result, equalTo(true)); + }); assertThat(acknowledgementSetManager.getAcknowledgementSetMonitor().getSize(), equalTo(0)); assertThat(result, equalTo(true)); } diff --git a/data-prepper-core/src/test/java/org/opensearch/dataprepper/peerforwarder/PeerForwarderReceiveBufferTest.java b/data-prepper-core/src/test/java/org/opensearch/dataprepper/peerforwarder/PeerForwarderReceiveBufferTest.java index 372272e416..15555776d6 100644 --- a/data-prepper-core/src/test/java/org/opensearch/dataprepper/peerforwarder/PeerForwarderReceiveBufferTest.java +++ b/data-prepper-core/src/test/java/org/opensearch/dataprepper/peerforwarder/PeerForwarderReceiveBufferTest.java @@ -32,7 +32,7 @@ class PeerForwarderReceiveBufferTest { private static final int TEST_BATCH_SIZE = 3; private static final int TEST_BUFFER_SIZE = 13; private static final int TEST_WRITE_TIMEOUT = 100; - private static final int TEST_BATCH_READ_TIMEOUT = 5_000; + private static final int TEST_BATCH_READ_TIMEOUT = 200; private static final ExecutorService EXECUTOR = Executors.newSingleThreadExecutor(); private static final String PIPELINE_NAME = UUID.randomUUID().toString(); private static final String PLUGIN_ID = UUID.randomUUID().toString(); @@ -204,7 +204,7 @@ void testNonZeroBatchDelayReturnsAllRecords() throws Exception { final Collection> testRecords2 = generateBatchRecords(1); EXECUTOR.submit(() -> { try { - Thread.sleep(1000); + Thread.sleep(TEST_BATCH_READ_TIMEOUT / 5); peerForwarderReceiveBuffer.writeAll(testRecords2, TEST_WRITE_TIMEOUT); } catch (final Exception e) { throw new RuntimeException(e); @@ -229,7 +229,7 @@ void testZeroBatchDelayReturnsAvailableRecords() throws Exception { final Collection> testRecords2 = generateBatchRecords(1); EXECUTOR.submit(() -> { try { - Thread.sleep(1000); + Thread.sleep(TEST_BATCH_READ_TIMEOUT / 5); peerForwarderReceiveBuffer.writeAll(testRecords2, TEST_WRITE_TIMEOUT); } catch (final Exception e) { throw new RuntimeException(e); diff --git a/data-prepper-core/src/test/java/org/opensearch/dataprepper/peerforwarder/PeerForwarder_ClientServerIT.java b/data-prepper-core/src/test/java/org/opensearch/dataprepper/peerforwarder/PeerForwarder_ClientServerIT.java index e132acc813..ea0c018ee4 100644 --- a/data-prepper-core/src/test/java/org/opensearch/dataprepper/peerforwarder/PeerForwarder_ClientServerIT.java +++ b/data-prepper-core/src/test/java/org/opensearch/dataprepper/peerforwarder/PeerForwarder_ClientServerIT.java @@ -151,7 +151,7 @@ private Collection> getServerSideRecords(final PeerForwarderProvid assertThat(pluginBufferMap, notNullValue()); final PeerForwarderReceiveBuffer> receiveBuffer = pluginBufferMap.get(pluginId); - final Map.Entry>, CheckpointState> bufferEntry = receiveBuffer.read(1000); + final Map.Entry>, CheckpointState> bufferEntry = receiveBuffer.read(400); return bufferEntry.getKey(); } diff --git a/data-prepper-core/src/test/java/org/opensearch/dataprepper/peerforwarder/RemotePeerForwarderTest.java b/data-prepper-core/src/test/java/org/opensearch/dataprepper/peerforwarder/RemotePeerForwarderTest.java index 12b13d24dc..64c16ff82b 100644 --- a/data-prepper-core/src/test/java/org/opensearch/dataprepper/peerforwarder/RemotePeerForwarderTest.java +++ b/data-prepper-core/src/test/java/org/opensearch/dataprepper/peerforwarder/RemotePeerForwarderTest.java @@ -66,12 +66,12 @@ class RemotePeerForwarderTest { private static final int TEST_BUFFER_CAPACITY = 20; private static final int TEST_BATCH_SIZE = 20; - private static final int TEST_BATCH_DELAY = 3_000; - private static final int TEST_LOCAL_WRITE_TIMEOUT = 500; - private static final int TEST_TIMEOUT_IN_MILLIS = 500; + private static final int TEST_BATCH_DELAY = 800; + private static final int TEST_LOCAL_WRITE_TIMEOUT = 400; + private static final int TEST_TIMEOUT_IN_MILLIS = 400; private static final int FORWARDING_BATCH_SIZE = 5; private static final int FORWARDING_BATCH_QUEUE_DEPTH = 1; - private static final Duration FORWARDING_BATCH_TIMEOUT = Duration.of(3, ChronoUnit.SECONDS); + private static final Duration FORWARDING_BATCH_TIMEOUT = Duration.of(800, ChronoUnit.MILLIS); private static final int PIPELINE_WORKER_THREADS = 3; private static final String PIPELINE_NAME = UUID.randomUUID().toString(); private static final String PLUGIN_ID = UUID.randomUUID().toString(); diff --git a/data-prepper-core/src/test/java/org/opensearch/dataprepper/pipeline/PipelineTests.java b/data-prepper-core/src/test/java/org/opensearch/dataprepper/pipeline/PipelineTests.java index 7828f88a16..57b14171ec 100644 --- a/data-prepper-core/src/test/java/org/opensearch/dataprepper/pipeline/PipelineTests.java +++ b/data-prepper-core/src/test/java/org/opensearch/dataprepper/pipeline/PipelineTests.java @@ -46,6 +46,7 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; +import static org.awaitility.Awaitility.await; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.CoreMatchers.notNullValue; @@ -67,7 +68,7 @@ import static org.mockito.Mockito.when; class PipelineTests { - private static final int TEST_READ_BATCH_TIMEOUT = 3000; + private static final int TEST_READ_BATCH_TIMEOUT = 500; private static final int TEST_PROCESSOR_THREADS = 1; private static final String TEST_PIPELINE_NAME = "test-pipeline"; @@ -168,9 +169,9 @@ void testPipelineStateWithProcessor() { @Test void testPipelineDelayedReady() throws InterruptedException { - final int delayTimeSeconds = 10; + final Duration delayTime = Duration.ofMillis(2000); final Source> testSource = new TestSource(); - final TestSink testSink = new TestSink(delayTimeSeconds); + final TestSink testSink = new TestSink(delayTime); final DataFlowComponent sinkDataFlowComponent = mock(DataFlowComponent.class); final TestProcessor testProcessor = new TestProcessor(new PluginSetting("test_processor", new HashMap<>())); when(sinkDataFlowComponent.getComponent()).thenReturn(testSink); @@ -182,11 +183,11 @@ void testPipelineDelayedReady() throws InterruptedException { Instant startTime = Instant.now(); testPipeline.execute(); assertFalse(testPipeline.isReady()); - for (int i = 0; i < delayTimeSeconds + 2; i++) { - Thread.sleep(1000); - } + await().atMost(Duration.ofSeconds(2).plus(delayTime)) + .pollInterval(Duration.ofMillis(200)) + .until(testPipeline::isReady); assertTrue(testPipeline.isReady()); - assertThat(Duration.between(startTime, Instant.now()), greaterThanOrEqualTo(Duration.ofSeconds(delayTimeSeconds))); + assertThat(Duration.between(startTime, Instant.now()), greaterThanOrEqualTo(delayTime)); assertThat("Pipeline isStopRequested is expected to be false", testPipeline.isStopRequested(), is(false)); testPipeline.shutdown(); assertThat("Pipeline isStopRequested is expected to be true", testPipeline.isStopRequested(), is(true)); @@ -196,9 +197,9 @@ void testPipelineDelayedReady() throws InterruptedException { @Test void testPipelineDelayedReadyShutdownBeforeReady() throws InterruptedException { - final int delayTimeSeconds = 10; + final Duration delayTime = Duration.ofSeconds(2); final Source> testSource = new TestSource(); - final TestSink testSink = new TestSink(delayTimeSeconds); + final TestSink testSink = new TestSink(delayTime); final DataFlowComponent sinkDataFlowComponent = mock(DataFlowComponent.class); final TestProcessor testProcessor = new TestProcessor(new PluginSetting("test_processor", new HashMap<>())); when(sinkDataFlowComponent.getComponent()).thenReturn(testSink); diff --git a/data-prepper-core/src/test/java/org/opensearch/dataprepper/plugins/TestSink.java b/data-prepper-core/src/test/java/org/opensearch/dataprepper/plugins/TestSink.java index 66854f645f..88facea689 100644 --- a/data-prepper-core/src/test/java/org/opensearch/dataprepper/plugins/TestSink.java +++ b/data-prepper-core/src/test/java/org/opensearch/dataprepper/plugins/TestSink.java @@ -9,6 +9,7 @@ import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.model.sink.Sink; +import java.time.Duration; import java.util.ArrayList; import java.util.Collection; import java.util.List; @@ -29,11 +30,11 @@ public TestSink() { this.ready = true; } - public TestSink(int readyAfterSecs) { + public TestSink(Duration readyAfter) { this.ready = false; this.failSinkForTest = false; this.collectedRecords = new ArrayList<>(); - this.readyTime = Instant.now().plusSeconds(readyAfterSecs); + this.readyTime = Instant.now().plus(readyAfter); } public TestSink(boolean failSinkForTest) { diff --git a/data-prepper-plugins/aggregate-processor/src/test/java/org/opensearch/dataprepper/plugins/processor/aggregate/AggregateProcessorIT.java b/data-prepper-plugins/aggregate-processor/src/test/java/org/opensearch/dataprepper/plugins/processor/aggregate/AggregateProcessorIT.java index e6ee923bbf..9e68187255 100644 --- a/data-prepper-plugins/aggregate-processor/src/test/java/org/opensearch/dataprepper/plugins/processor/aggregate/AggregateProcessorIT.java +++ b/data-prepper-plugins/aggregate-processor/src/test/java/org/opensearch/dataprepper/plugins/processor/aggregate/AggregateProcessorIT.java @@ -137,7 +137,7 @@ private AggregateProcessor createObjectUnderTest() { return new AggregateProcessor(aggregateProcessorConfig, pluginMetrics, pluginFactory, expressionEvaluator); } - @RepeatedTest(value = 10) + @RepeatedTest(value = 2) void aggregateWithNoConcludingGroupsReturnsExpectedResult() throws InterruptedException { aggregateAction = new RemoveDuplicatesAggregateAction(); when(pluginFactory.loadPlugin(eq(AggregateAction.class), any(PluginSetting.class))) @@ -260,7 +260,7 @@ void aggregateWithPutAllActionAndCondition() throws InterruptedException { } @ParameterizedTest - @ValueSource(doubles = {5.0, 15.0, 33.0, 55.0, 70.0, 85.0, 92.0, 99.0}) + @ValueSource(doubles = {5.0, 15.0, 55.0, 92.0, 99.0}) void aggregateWithPercentSamplerAction(double testPercent) throws InterruptedException, NoSuchFieldException, IllegalAccessException { PercentSamplerAggregateActionConfig percentSamplerAggregateActionConfig = new PercentSamplerAggregateActionConfig(); setField(PercentSamplerAggregateActionConfig.class, percentSamplerAggregateActionConfig, "percent", testPercent); diff --git a/data-prepper-plugins/csv-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/csv/CsvProcessor.java b/data-prepper-plugins/csv-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/csv/CsvProcessor.java index 04f353bc74..55bd7301cb 100644 --- a/data-prepper-plugins/csv-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/csv/CsvProcessor.java +++ b/data-prepper-plugins/csv-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/csv/CsvProcessor.java @@ -59,6 +59,11 @@ public Collection> doExecute(final Collection> recor final Event event = record.getData(); final String message = event.get(config.getSource(), String.class); + + if (Objects.isNull(message)) { + continue; + } + final boolean userDidSpecifyHeaderEventKey = Objects.nonNull(config.getColumnNamesSourceKey()); final boolean thisEventHasHeaderSource = userDidSpecifyHeaderEventKey && event.containsKey(config.getColumnNamesSourceKey()); diff --git a/data-prepper-plugins/csv-processor/src/test/java/org/opensearch/dataprepper/plugins/processor/csv/CsvProcessorTest.java b/data-prepper-plugins/csv-processor/src/test/java/org/opensearch/dataprepper/plugins/processor/csv/CsvProcessorTest.java index fa5761b937..be6da7e80f 100644 --- a/data-prepper-plugins/csv-processor/src/test/java/org/opensearch/dataprepper/plugins/processor/csv/CsvProcessorTest.java +++ b/data-prepper-plugins/csv-processor/src/test/java/org/opensearch/dataprepper/plugins/processor/csv/CsvProcessorTest.java @@ -21,6 +21,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.UUID; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.MatcherAssert.assertThat; @@ -59,6 +60,18 @@ private CsvProcessor createObjectUnderTest() { return new CsvProcessor(pluginMetrics, processorConfig); } + @Test + void do_nothing_when_source_is_null_value_or_does_not_exist_in_the_Event() { + final Record eventUnderTest = createMessageEvent(""); + when(processorConfig.getSource()).thenReturn(UUID.randomUUID().toString()); + + + final List> editedEvents = (List>) csvProcessor.doExecute(Collections.singletonList(eventUnderTest)); + final Event parsedEvent = getSingleEvent(editedEvents); + + assertThat(parsedEvent, equalTo(eventUnderTest.getData())); + } + @Test void test_when_messageIsEmpty_then_notParsed() { final Record eventUnderTest = createMessageEvent(""); diff --git a/data-prepper-plugins/kafka-plugins/README.md b/data-prepper-plugins/kafka-plugins/README.md index 30fc97f876..19153dd4be 100644 --- a/data-prepper-plugins/kafka-plugins/README.md +++ b/data-prepper-plugins/kafka-plugins/README.md @@ -166,12 +166,18 @@ Command to start kafka server bin/kafka-server-start.sh config/server.properties ``` -3. Command to run integration tests +3. Command to run multi auth type integration tests ``` -./gradlew data-prepper-plugins:kafka-plugins:integrationTest -Dtests.kafka.bootstrap_servers="localhost:9092" -Dtests.kafka.trust_store_location="/home/krishkdk/kafka/kafka-3.4.1-src/sec/client.truststore.jks" -Dtests.kafka.trust_store_password="kafkaks" -Dtests.kafka.saslssl_bootstrap_servers="localhost:9093" -Dtests.kafka.ssl_bootstrap_servers="localhost:9094" -Dtests.kafka.saslplain_bootstrap_servers="localhost:9095" -Dtests.kafka.username="admin" -Dtests.kafka.password="admin1" --tests "*KafkaSourceMultipleAuthTypeIT*" +./gradlew data-prepper-plugins:kafka-plugins:integrationTest -Dtests.kafka.bootstrap_servers= -Dtests.kafka.trust_store_location= -Dtests.kafka.trust_store_password= -Dtests.kafka.saslssl_bootstrap_servers= -Dtests.kafka.ssl_bootstrap_servers= -Dtests.kafka.saslplain_bootstrap_servers= -Dtests.kafka.username= -Dtests.kafka.password= --tests "*KafkaSourceMultipleAuthTypeIT*" ``` +4. Command to run msk glue integration tests + +``` +./gradlew data-prepper-plugins:kafka-plugins:integrationTest -Dtests.kafka.bootstrap_servers= -Dtests.kafka.glue_registry_name= -Dtests.kafka.glue_avro_schema_name= -Dtests.kafka.glue_json_schema_name= -Dtests.msk.region= -Dtests.msk.arn= --tests "*TestAvroRecordConsumer*" + +``` ## Developer Guide diff --git a/data-prepper-plugins/kafka-plugins/build.gradle b/data-prepper-plugins/kafka-plugins/build.gradle index dbdecda90c..899061b611 100644 --- a/data-prepper-plugins/kafka-plugins/build.gradle +++ b/data-prepper-plugins/kafka-plugins/build.gradle @@ -24,6 +24,8 @@ dependencies { implementation 'software.amazon.awssdk:sts:2.20.103' implementation 'software.amazon.awssdk:auth:2.20.103' implementation 'software.amazon.awssdk:kafka:2.20.103' + implementation 'software.amazon.glue:schema-registry-serde:1.1.15' + implementation 'com.amazonaws:aws-java-sdk-glue:1.12.506' implementation 'io.confluent:kafka-json-schema-serializer:7.4.0' testImplementation 'org.mockito:mockito-inline:4.1.0' testImplementation 'org.yaml:snakeyaml:2.0' @@ -40,6 +42,7 @@ dependencies { testImplementation testLibs.junit.vintage testImplementation 'org.apache.kafka:kafka-clients:3.4.0:test' testImplementation 'org.apache.kafka:connect-json:3.4.0' + testImplementation('com.kjetland:mbknor-jackson-jsonschema_2.13:1.0.39') } test { @@ -74,6 +77,11 @@ task integrationTest(type: Test) { systemProperty 'tests.kafka.saslplain_bootstrap_servers', System.getProperty('tests.kafka.saslplain_bootstrap_servers') systemProperty 'tests.kafka.username', System.getProperty('tests.kafka.username') systemProperty 'tests.kafka.password', System.getProperty('tests.kafka.password') + systemProperty 'tests.kafka.glue_registry_name', System.getProperty('tests.kafka.glue_registry_name') + systemProperty 'tests.kafka.glue_json_schema_name', System.getProperty('tests.kafka.glue_json_schema_name') + systemProperty 'tests.kafka.glue_avro_schema_name', System.getProperty('tests.kafka.glue_avro_schema_name') + systemProperty 'tests.msk.region', System.getProperty('tests.msk.region') + systemProperty 'tests.msk.arn', System.getProperty('tests.msk.arn') filter { includeTestsMatching '*IT' diff --git a/data-prepper-plugins/kafka-plugins/src/integrationTest/java/org/opensearch/dataprepper/plugins/kafka/source/KafkaSourceMultipleAuthTypeIT.java b/data-prepper-plugins/kafka-plugins/src/integrationTest/java/org/opensearch/dataprepper/plugins/kafka/source/KafkaSourceMultipleAuthTypeIT.java index 2088f43335..0376662452 100644 --- a/data-prepper-plugins/kafka-plugins/src/integrationTest/java/org/opensearch/dataprepper/plugins/kafka/source/KafkaSourceMultipleAuthTypeIT.java +++ b/data-prepper-plugins/kafka-plugins/src/integrationTest/java/org/opensearch/dataprepper/plugins/kafka/source/KafkaSourceMultipleAuthTypeIT.java @@ -25,6 +25,7 @@ import org.opensearch.dataprepper.plugins.kafka.configuration.TopicConfig; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; import org.opensearch.dataprepper.model.configuration.PipelineDescription; +import org.opensearch.dataprepper.plugins.kafka.util.MessageFormat; import static org.mockito.Mockito.when; import org.mockito.Mock; @@ -74,9 +75,6 @@ public class KafkaSourceMultipleAuthTypeIT { @Mock private AuthConfig.SaslAuthConfig saslAuthConfig; - @Mock - private AuthConfig.SslAuthConfig sslAuthConfig; - @Mock private PlainTextAuthConfig plainTextAuthConfig; @@ -128,7 +126,10 @@ public void setup() { when(plainTextTopic.getName()).thenReturn(testTopic); when(plainTextTopic.getGroupId()).thenReturn(testGroup); when(plainTextTopic.getWorkers()).thenReturn(1); + when(plainTextTopic.getSessionTimeOut()).thenReturn(15000); + when(plainTextTopic.getHeartBeatInterval()).thenReturn(Duration.ofSeconds(3)); when(plainTextTopic.getAutoCommit()).thenReturn(false); + when(plainTextTopic.getSerdeFormat()).thenReturn(MessageFormat.PLAINTEXT); when(plainTextTopic.getAutoOffsetReset()).thenReturn("earliest"); when(plainTextTopic.getThreadWaitingTime()).thenReturn(Duration.ofSeconds(1)); bootstrapServers = System.getProperty("tests.kafka.bootstrap_servers"); @@ -143,7 +144,7 @@ public void setup() { @Test public void TestPlainTextWithNoAuthKafkaNoEncryptionWithNoAuthSchemaRegistry() throws Exception { final int numRecords = 1; - when(sourceConfig.getEncryptionType()).thenReturn(EncryptionType.PLAINTEXT); + when(sourceConfig.getEncryptionType()).thenReturn(EncryptionType.NONE); when(plainTextTopic.getConsumerMaxPollRecords()).thenReturn(numRecords); when(sourceConfig.getTopics()).thenReturn(List.of(plainTextTopic)); when(sourceConfig.getAuthConfig()).thenReturn(null); @@ -193,7 +194,7 @@ public void TestPlainTextWithAuthKafkaNoEncryptionWithNoAuthSchemaRegistry() thr authConfig = mock(AuthConfig.class); saslAuthConfig = mock(AuthConfig.SaslAuthConfig.class); plainTextAuthConfig = mock(PlainTextAuthConfig.class); - when(sourceConfig.getEncryptionType()).thenReturn(EncryptionType.PLAINTEXT); + when(sourceConfig.getEncryptionType()).thenReturn(EncryptionType.NONE); when(plainTextTopic.getConsumerMaxPollRecords()).thenReturn(numRecords); when(sourceConfig.getTopics()).thenReturn(List.of(plainTextTopic)); plainTextAuthConfig = mock(PlainTextAuthConfig.class); @@ -204,7 +205,6 @@ public void TestPlainTextWithAuthKafkaNoEncryptionWithNoAuthSchemaRegistry() thr when(authConfig.getInsecure()).thenReturn(true); when(saslAuthConfig.getPlainTextAuthConfig()).thenReturn(plainTextAuthConfig); when(sourceConfig.getBootStrapServers()).thenReturn(saslplainBootstrapServers); - when(authConfig.getSslAuthConfig()).thenReturn(null); kafkaSource = createObjectUnderTest(); Properties props = new Properties(); @@ -250,14 +250,9 @@ public void TestPlainTextWithNoAuthKafkaEncryptionWithNoAuthSchemaRegistry() thr final int numRecords = 1; authConfig = mock(AuthConfig.class); saslAuthConfig = mock(AuthConfig.SaslAuthConfig.class); - sslAuthConfig = mock(AuthConfig.SslAuthConfig.class); - plainTextAuthConfig = mock(PlainTextAuthConfig.class); - when(plainTextAuthConfig.getUsername()).thenReturn(kafkaUsername); - when(plainTextAuthConfig.getPassword()).thenReturn(kafkaPassword); when(sourceConfig.getAuthConfig()).thenReturn(authConfig); when(authConfig.getSaslAuthConfig()).thenReturn(null); when(authConfig.getInsecure()).thenReturn(true); - when(authConfig.getSslAuthConfig()).thenReturn(sslAuthConfig); when(sourceConfig.getEncryptionType()).thenReturn(EncryptionType.SSL); when(plainTextTopic.getConsumerMaxPollRecords()).thenReturn(numRecords); when(sourceConfig.getBootStrapServers()).thenReturn(sslBootstrapServers); @@ -307,14 +302,12 @@ public void TestPlainTextWithAuthKafkaEncryptionWithNoAuthSchemaRegistry() throw final int numRecords = 1; authConfig = mock(AuthConfig.class); saslAuthConfig = mock(AuthConfig.SaslAuthConfig.class); - sslAuthConfig = mock(AuthConfig.SslAuthConfig.class); plainTextAuthConfig = mock(PlainTextAuthConfig.class); when(plainTextAuthConfig.getUsername()).thenReturn(kafkaUsername); when(plainTextAuthConfig.getPassword()).thenReturn(kafkaPassword); when(sourceConfig.getAuthConfig()).thenReturn(authConfig); when(authConfig.getInsecure()).thenReturn(true); when(authConfig.getSaslAuthConfig()).thenReturn(saslAuthConfig); - when(authConfig.getSslAuthConfig()).thenReturn(sslAuthConfig); when(saslAuthConfig.getPlainTextAuthConfig()).thenReturn(plainTextAuthConfig); when(sourceConfig.getEncryptionType()).thenReturn(EncryptionType.SSL); when(plainTextTopic.getConsumerMaxPollRecords()).thenReturn(numRecords); diff --git a/data-prepper-plugins/kafka-plugins/src/integrationTest/java/org/opensearch/dataprepper/plugins/kafka/source/MskGlueRegistryMultiTypeIT.java b/data-prepper-plugins/kafka-plugins/src/integrationTest/java/org/opensearch/dataprepper/plugins/kafka/source/MskGlueRegistryMultiTypeIT.java new file mode 100644 index 0000000000..f354a1fb45 --- /dev/null +++ b/data-prepper-plugins/kafka-plugins/src/integrationTest/java/org/opensearch/dataprepper/plugins/kafka/source/MskGlueRegistryMultiTypeIT.java @@ -0,0 +1,428 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.kafka.source; + +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.admin.AdminClient; +import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.admin.AdminClientConfig; +import org.apache.kafka.common.serialization.StringSerializer; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.BeforeEach; +import org.mockito.Mock; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.plugins.kafka.configuration.TopicConfig; +import org.opensearch.dataprepper.plugins.kafka.configuration.AwsIamAuthConfig; +import org.opensearch.dataprepper.plugins.kafka.configuration.AuthConfig; +import org.opensearch.dataprepper.plugins.kafka.configuration.AwsConfig; +import org.opensearch.dataprepper.plugins.kafka.configuration.KafkaSourceConfig; +import org.opensearch.dataprepper.plugins.kafka.configuration.SchemaConfig; +import org.opensearch.dataprepper.plugins.kafka.configuration.SchemaRegistryType; +import org.opensearch.dataprepper.plugins.kafka.configuration.EncryptionType; +import org.opensearch.dataprepper.plugins.kafka.configuration.MskBrokerConnectionType; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.opensearch.dataprepper.model.configuration.PipelineDescription; +import com.amazonaws.services.schemaregistry.serializers.GlueSchemaRegistryKafkaSerializer; +import com.amazonaws.services.schemaregistry.utils.AWSSchemaRegistryConstants; +import com.amazonaws.services.schemaregistry.serializers.json.JsonDataWithSchema; +import org.apache.avro.Schema; + +import static org.mockito.Mockito.when; +import org.mockito.Mock; +import static org.mockito.Mockito.mock; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.CoreMatchers.equalTo; +import org.apache.commons.lang3.RandomStringUtils; + +import io.micrometer.core.instrument.Counter; +import java.util.List; +import java.util.Map; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Properties; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.TimeUnit; +import java.io.File; +import java.io.IOException; + +import java.time.Duration; + +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.generic.GenericData; +import org.apache.kafka.common.errors.SerializationException; + +public class MskGlueRegistryMultiTypeIT { + private static final String TEST_USER = "user"; + private static final String TEST_MESSAGE = "test message "; + private static final Long TEST_TIMESTAMP = 1366154481L; + private static final Integer TEST_TIMESTAMP_INT = 12345; + @Mock + private KafkaSourceConfig sourceConfig; + + @Mock + private PluginMetrics pluginMetrics; + + @Mock + private AcknowledgementSetManager acknowledgementSetManager; + + @Mock + private PipelineDescription pipelineDescription; + + @Mock + private Buffer> buffer; + + private List topicList; + + @Mock + private TopicConfig plainTextTopic; + + @Mock + private AuthConfig authConfig; + + private AwsConfig awsConfig; + + @Mock + private AuthConfig.SaslAuthConfig saslAuthConfig; + + @Mock + private AwsConfig.AwsMskConfig awsMskConfig; + + private KafkaSource kafkaSource; + private TopicConfig jsonTopic; + private TopicConfig avroTopic; + + private Counter counter; + + private List receivedRecords; + + private String bootstrapServers; + + private String testRegistryName; + + private String testAvroSchemaName; + + private String testJsonSchemaName; + + private String testMskArn; + + private String testMskRegion; + + @Mock + SchemaConfig schemaConfig; + + + public KafkaSource createObjectUnderTest() { + return new KafkaSource(sourceConfig, pluginMetrics, acknowledgementSetManager, pipelineDescription); + } + + @BeforeEach + public void setup() { + sourceConfig = mock(KafkaSourceConfig.class); + pluginMetrics = mock(PluginMetrics.class); + counter = mock(Counter.class); + buffer = mock(Buffer.class); + awsConfig = mock(AwsConfig.class); + awsMskConfig = mock(AwsConfig.AwsMskConfig.class); + authConfig = mock(AuthConfig.class); + saslAuthConfig = mock(AuthConfig.SaslAuthConfig.class); + schemaConfig = mock(SchemaConfig.class); + receivedRecords = new ArrayList<>(); + acknowledgementSetManager = mock(AcknowledgementSetManager.class); + pipelineDescription = mock(PipelineDescription.class); + when(sourceConfig.getAcknowledgementsEnabled()).thenReturn(false); + when(sourceConfig.getAcknowledgementsTimeout()).thenReturn(KafkaSourceConfig.DEFAULT_ACKNOWLEDGEMENTS_TIMEOUT); + when(sourceConfig.getSchemaConfig()).thenReturn(schemaConfig); + when(schemaConfig.getType()).thenReturn(SchemaRegistryType.GLUE); + when(pluginMetrics.counter(anyString())).thenReturn(counter); + when(pipelineDescription.getPipelineName()).thenReturn("testPipeline"); + try { + doAnswer(args -> { + Collection> bufferedRecords = (Collection>)args.getArgument(0); + receivedRecords.addAll(bufferedRecords); + Record r = receivedRecords.get(0); + return null; + }).when(buffer).writeAll(any(Collection.class), any(Integer.class)); + } catch (Exception e){} + + final String testGroup = "TestGroup_"+RandomStringUtils.randomAlphabetic(6); + final String testTopic = "TestTopic_"+RandomStringUtils.randomAlphabetic(5); + avroTopic = mock(TopicConfig.class); + jsonTopic = mock(TopicConfig.class); + when(avroTopic.getName()).thenReturn(testTopic); + when(avroTopic.getGroupId()).thenReturn(testGroup); + when(avroTopic.getWorkers()).thenReturn(1); + when(avroTopic.getAutoCommit()).thenReturn(false); + when(avroTopic.getAutoOffsetReset()).thenReturn("earliest"); + when(avroTopic.getThreadWaitingTime()).thenReturn(Duration.ofSeconds(1)); + when(avroTopic.getSessionTimeOut()).thenReturn(15000); + when(avroTopic.getHeartBeatInterval()).thenReturn(Duration.ofSeconds(3)); + when(jsonTopic.getName()).thenReturn(testTopic); + when(jsonTopic.getGroupId()).thenReturn(testGroup); + when(jsonTopic.getWorkers()).thenReturn(1); + when(jsonTopic.getAutoCommit()).thenReturn(false); + when(jsonTopic.getAutoOffsetReset()).thenReturn("earliest"); + when(jsonTopic.getThreadWaitingTime()).thenReturn(Duration.ofSeconds(1)); + when(jsonTopic.getSessionTimeOut()).thenReturn(15000); + when(jsonTopic.getHeartBeatInterval()).thenReturn(Duration.ofSeconds(3)); + bootstrapServers = System.getProperty("tests.kafka.bootstrap_servers"); + testRegistryName = System.getProperty("tests.kafka.glue_registry_name"); + testJsonSchemaName = System.getProperty("tests.kafka.glue_json_schema_name"); + testAvroSchemaName = System.getProperty("tests.kafka.glue_avro_schema_name"); + testMskArn = System.getProperty("tests.msk.arn"); + testMskRegion = System.getProperty("tests.msk.region"); + when(sourceConfig.getBootStrapServers()).thenReturn(bootstrapServers); + System.setProperty("software.amazon.awssdk.http.service.impl", "software.amazon.awssdk.http.urlconnection.UrlConnectionSdkHttpService"); + } + + @Test + public void TestJsonRecordConsumer() throws Exception { + final int numRecords = 1; + when(sourceConfig.getEncryptionType()).thenReturn(EncryptionType.SSL); + when(jsonTopic.getConsumerMaxPollRecords()).thenReturn(numRecords); + when(sourceConfig.getTopics()).thenReturn(List.of(jsonTopic)); + when(sourceConfig.getAuthConfig()).thenReturn(authConfig); + when(authConfig.getSaslAuthConfig()).thenReturn(saslAuthConfig); + when(saslAuthConfig.getAwsIamAuthConfig()).thenReturn(AwsIamAuthConfig.DEFAULT); + when(sourceConfig.getAwsConfig()).thenReturn(awsConfig); + when(awsConfig.getRegion()).thenReturn(testMskRegion); + when(awsConfig.getAwsMskConfig()).thenReturn(awsMskConfig); + when(awsMskConfig.getArn()).thenReturn(testMskArn); + when(awsMskConfig.getBrokerConnectionType()).thenReturn(MskBrokerConnectionType.PUBLIC); + kafkaSource = createObjectUnderTest(); + + Properties props = new Properties(); + props.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers); + props.put("security.protocol","SASL_SSL"); + props.put("sasl.mechanism", "AWS_MSK_IAM"); + props.put("sasl.jaas.config", "software.amazon.msk.auth.iam.IAMLoginModule required;"); + props.put("sasl.client.callback.handler.class", "software.amazon.msk.auth.iam.IAMClientCallbackHandler"); + AtomicBoolean created = new AtomicBoolean(false); + final String topicName = jsonTopic.getName(); + try (AdminClient adminClient = AdminClient.create(props)) { + try { + adminClient.createTopics( + Collections.singleton(new NewTopic(topicName, 1, (short)1))) + .all().get(30, TimeUnit.SECONDS); + } catch (Exception e) { + throw new RuntimeException(e); + } + created.set(true); + } + while (created.get() != true) { + Thread.sleep(1000); + } + kafkaSource.start(buffer); + produceJsonRecords(bootstrapServers, topicName, numRecords); + int numRetries = 0; + while (numRetries++ < 10 && (receivedRecords.size() != numRecords)) { + Thread.sleep(1000); + } + assertThat(receivedRecords.size(), equalTo(numRecords)); + for (int i = 0; i < numRecords; i++) { + Record record = receivedRecords.get(i); + Event event = (Event)record.getData(); + Map val = event.get("message-"+i, Map.class); + assertThat(val.get("username"), equalTo(TEST_USER+i)); + assertThat(val.get("message"), equalTo(TEST_MESSAGE+i)); + assertThat(((Number)val.get("timestamp")).intValue(), equalTo(TEST_TIMESTAMP_INT+i)); + } + try (AdminClient adminClient = AdminClient.create(props)) { + try { + adminClient.deleteTopics(Collections.singleton(topicName)) + .all().get(30, TimeUnit.SECONDS); + } catch (Exception e) { + throw new RuntimeException(e); + } + created.set(false); + } + while (created.get() != false) { + Thread.sleep(1000); + } + + } + + @Test + public void TestAvroRecordConsumer() throws Exception { + final int numRecords = 1; + when(sourceConfig.getEncryptionType()).thenReturn(EncryptionType.SSL); + when(avroTopic.getConsumerMaxPollRecords()).thenReturn(numRecords); + when(sourceConfig.getTopics()).thenReturn(List.of(avroTopic)); + when(sourceConfig.getAuthConfig()).thenReturn(authConfig); + when(authConfig.getSaslAuthConfig()).thenReturn(saslAuthConfig); + when(saslAuthConfig.getAwsIamAuthConfig()).thenReturn(AwsIamAuthConfig.DEFAULT); + when(sourceConfig.getAwsConfig()).thenReturn(awsConfig); + when(awsConfig.getRegion()).thenReturn(testMskRegion); + when(awsConfig.getAwsMskConfig()).thenReturn(awsMskConfig); + when(awsMskConfig.getArn()).thenReturn(testMskArn); + when(awsMskConfig.getBrokerConnectionType()).thenReturn(MskBrokerConnectionType.PUBLIC); + kafkaSource = createObjectUnderTest(); + + Properties props = new Properties(); + props.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers); + props.put("security.protocol","SASL_SSL"); + props.put("sasl.mechanism", "AWS_MSK_IAM"); + props.put("sasl.jaas.config", "software.amazon.msk.auth.iam.IAMLoginModule required;"); + props.put("sasl.client.callback.handler.class", "software.amazon.msk.auth.iam.IAMClientCallbackHandler"); + AtomicBoolean created = new AtomicBoolean(false); + final String topicName = avroTopic.getName(); + try (AdminClient adminClient = AdminClient.create(props)) { + try { + adminClient.createTopics( + Collections.singleton(new NewTopic(topicName, 1, (short)1))) + .all().get(30, TimeUnit.SECONDS); + } catch (Exception e) { + throw new RuntimeException(e); + } + created.set(true); + } + while (created.get() != true) { + Thread.sleep(1000); + } + kafkaSource.start(buffer); + produceAvroRecords(bootstrapServers, topicName, numRecords); + int numRetries = 0; + while (numRetries++ < 10 && (receivedRecords.size() != numRecords)) { + Thread.sleep(1000); + } + assertThat(receivedRecords.size(), equalTo(numRecords)); + for (int i = 0; i < numRecords; i++) { + Record record = receivedRecords.get(i); + Event event = (Event)record.getData(); + Map val = event.get(TEST_USER+i, Map.class); + assertThat(val.get("username"), equalTo(TEST_USER+i)); + assertThat(val.get("message"), equalTo(TEST_MESSAGE+i)); + assertThat(((Number)val.get("timestamp")).longValue(), equalTo(TEST_TIMESTAMP+i)); + } + try (AdminClient adminClient = AdminClient.create(props)) { + try { + adminClient.deleteTopics(Collections.singleton(topicName)) + .all().get(30, TimeUnit.SECONDS); + } catch (Exception e) { + throw new RuntimeException(e); + } + created.set(false); + } + while (created.get() != false) { + Thread.sleep(1000); + } + + } + + public void produceJsonRecords(final String servers, final String topic, int numRecords) { + Properties properties = new Properties(); + properties.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, servers); + properties.put("security.protocol", "SASL_SSL"); + properties.put("sasl.mechanism", "AWS_MSK_IAM"); + properties.put("region", testMskRegion); + properties.put("sasl.jaas.config", "software.amazon.msk.auth.iam.IAMLoginModule required;"); + properties.put("sasl.client.callback.handler.class", "software.amazon.msk.auth.iam.IAMClientCallbackHandler"); + properties.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class.getName()); + properties.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, GlueSchemaRegistryKafkaSerializer.class.getName()); + properties.put(AWSSchemaRegistryConstants.DATA_FORMAT, "json"); + properties.put(AWSSchemaRegistryConstants.AWS_REGION, awsConfig.getRegion()); + properties.put(AWSSchemaRegistryConstants.REGISTRY_NAME, testRegistryName); + properties.put(AWSSchemaRegistryConstants.SCHEMA_NAME, testJsonSchemaName); + properties.put(AWSSchemaRegistryConstants.SCHEMA_AUTO_REGISTRATION_SETTING, true); // If not passed, defaults to false + + String jsonSchema = "{\n"+ + "\"$schema\": \"http://json-schema.org/draft-07/schema#\",\n"+ + "\"type\": \"object\",\n"+ + "\"properties\": {\n"+ + "\"username\": {\n"+ + "\"type\": \"string\",\n"+ + "\"description\": \"user name.\"\n"+ + "},\n"+ + "\"message\": {\n"+ + "\"type\": \"string\",\n"+ + "\"description\": \"message.\"\n"+ + "},\n"+ + "\"timestamp\": {\n"+ + "\"type\": \"integer\",\n"+ + "\"description\": \"timestamp\"\n"+ + "}\n"+ + "}\n"+ + "}"; + try (KafkaProducer producer = new KafkaProducer(properties)) { + for (int i = 0; i < numRecords; i++) { + String jsonPayLoad = "{\n" + + " \"username\": \""+TEST_USER+i+"\",\n" + + " \"message\": \""+TEST_MESSAGE+i+"\",\n"+ + " \"timestamp\": "+(TEST_TIMESTAMP_INT+i)+"\n"+ + "}"; + JsonDataWithSchema testRecord = JsonDataWithSchema.builder(jsonSchema, jsonPayLoad).build(); + final ProducerRecord record; + String topicKey = new String("message-"+i); + record = new ProducerRecord(topic, topicKey, testRecord); + producer.send(record); + Thread.sleep(1000L); + } + producer.flush(); + } catch (final InterruptedException | SerializationException e) { + e.printStackTrace(); + } + + } + + + public void produceAvroRecords(String servers, String topic, int numRecords) { + Properties properties = new Properties(); + properties.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, servers); + properties.put("security.protocol", "SASL_SSL"); + properties.put("sasl.mechanism", "AWS_MSK_IAM"); + properties.put("region", testMskRegion); + properties.put("sasl.jaas.config", "software.amazon.msk.auth.iam.IAMLoginModule required;"); + properties.put("sasl.client.callback.handler.class", "software.amazon.msk.auth.iam.IAMClientCallbackHandler"); + properties.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class.getName()); + properties.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, GlueSchemaRegistryKafkaSerializer.class.getName()); + properties.put(AWSSchemaRegistryConstants.DATA_FORMAT, "avro"); + properties.put(AWSSchemaRegistryConstants.AWS_REGION, awsConfig.getRegion()); + properties.put(AWSSchemaRegistryConstants.REGISTRY_NAME, testRegistryName); + properties.put(AWSSchemaRegistryConstants.SCHEMA_NAME, testAvroSchemaName); + + Schema testSchema = null; + Schema.Parser parser = new Schema.Parser(); + try { + testSchema = parser.parse(new File("src/integrationTest/resources/test.avsc")); + } catch (IOException e) { + e.printStackTrace(); + } + + List testRecords = new ArrayList<>(); + for (int i = 0; i < numRecords; i++) { + GenericRecord testRecord = new GenericData.Record(testSchema); + testRecord.put("username", TEST_USER+i); + testRecord.put("message", TEST_MESSAGE+i); + testRecord.put("timestamp", TEST_TIMESTAMP+i); + testRecords.add(testRecord); + } + + try (KafkaProducer producer = new KafkaProducer(properties)) { + for (int i = 0; i < testRecords.size(); i++) { + GenericRecord r = testRecords.get(i); + + final ProducerRecord record; + record = new ProducerRecord(topic, r.get("username").toString(), r); + + producer.send(record); + Thread.sleep(1000L); + } + producer.flush(); + + } catch (final InterruptedException | SerializationException e) { + e.printStackTrace(); + } + } + +} diff --git a/data-prepper-plugins/kafka-plugins/src/integrationTest/resources/test.avsc b/data-prepper-plugins/kafka-plugins/src/integrationTest/resources/test.avsc new file mode 100644 index 0000000000..7b031640f8 --- /dev/null +++ b/data-prepper-plugins/kafka-plugins/src/integrationTest/resources/test.avsc @@ -0,0 +1,19 @@ +{ + "type" : "record", + "name" : "test_schema", + "namespace" : "com.amazon.avro", + "fields" : [ { + "name" : "username", + "type" : "string", + "doc" : "Name of the user" + }, { + "name" : "message", + "type" : "string", + "doc" : "The content message" + }, { + "name" : "timestamp", + "type" : "long", + "doc" : "Unix epoch time in seconds" + } ], + "doc:" : "A basic schema for storing user messages" +} diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/AuthConfig.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/AuthConfig.java index ae037ba13e..883e33f630 100644 --- a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/AuthConfig.java +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/AuthConfig.java @@ -76,9 +76,13 @@ public Boolean getInsecure() { return insecure; } + /* + * Currently SSL config is not supported. Commenting this for future use + * @AssertTrue(message = "Only one of SSL or SASL auth config must be specified") public boolean hasSaslOrSslConfig() { return Stream.of(sslAuthConfig, saslAuthConfig).filter(n -> n!=null).count() == 1; } + */ } diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/EncryptionType.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/EncryptionType.java index 8142d88bc6..ea644737c7 100644 --- a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/EncryptionType.java +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/EncryptionType.java @@ -12,8 +12,8 @@ import java.util.stream.Collectors; public enum EncryptionType { - SSL("ssl"), - PLAINTEXT("plaintext"); + NONE("none"), + SSL("ssl"); private static final Map OPTIONS_MAP = Arrays.stream(EncryptionType.values()) .collect(Collectors.toMap( diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/KafkaSourceConfig.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/KafkaSourceConfig.java index 136472eff1..b5e120b7d2 100644 --- a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/KafkaSourceConfig.java +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/KafkaSourceConfig.java @@ -9,7 +9,6 @@ import jakarta.validation.Valid; import jakarta.validation.constraints.NotNull; import jakarta.validation.constraints.Size; -import org.opensearch.dataprepper.plugins.kafka.util.MessageFormat; import java.util.List; import java.util.Objects; @@ -52,9 +51,6 @@ public class KafkaSourceConfig { @JsonProperty("acknowledgments_timeout") private Duration acknowledgementsTimeout = DEFAULT_ACKNOWLEDGEMENTS_TIMEOUT; - @JsonProperty("serde_format") - private String serdeFormat= MessageFormat.PLAINTEXT.toString(); - @JsonProperty("client_dns_lookup") private String clientDnsLookup; @@ -62,9 +58,6 @@ public String getClientDnsLookup() { return clientDnsLookup; } - public String getSerdeFormat() { - return serdeFormat; - } public Boolean getAcknowledgementsEnabled() { return acknowledgementsEnabled; } diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/SchemaConfig.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/SchemaConfig.java index 26054b62a7..ee4e8c316e 100644 --- a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/SchemaConfig.java +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/SchemaConfig.java @@ -16,6 +16,9 @@ public class SchemaConfig { private static final int SESSION_TIME_OUT = 45000; + @JsonProperty("type") + private SchemaRegistryType type; + @JsonProperty("registry_url") private String registryURL; @@ -50,6 +53,10 @@ public int getVersion() { return version; } + public SchemaRegistryType getType() { + return type; + } + public String getSchemaRegistryApiKey() { return schemaRegistryApiKey; } diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/SchemaRegistryType.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/SchemaRegistryType.java new file mode 100644 index 0000000000..91cfdebbc4 --- /dev/null +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/SchemaRegistryType.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.kafka.configuration; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import java.util.Map; +import java.util.Arrays; +import java.util.stream.Collectors; + +public enum SchemaRegistryType { + GLUE("glue"), + CONFLUENT("confluent"); + + private static final Map OPTIONS_MAP = Arrays.stream(SchemaRegistryType.values()) + .collect(Collectors.toMap( + value -> value.type, + value -> value + )); + + private final String type; + + SchemaRegistryType(final String type) { + this.type = type; + } + + @JsonCreator + static SchemaRegistryType fromTypeValue(final String type) { + return OPTIONS_MAP.get(type.toLowerCase()); + } +} diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/TopicConfig.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/TopicConfig.java index 6a57d37a27..40a8b7b934 100644 --- a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/TopicConfig.java +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/TopicConfig.java @@ -10,6 +10,8 @@ import jakarta.validation.constraints.NotNull; import jakarta.validation.constraints.Size; +import org.opensearch.dataprepper.plugins.kafka.util.MessageFormat; + import java.time.Duration; /** * * A helper class that helps to read consumer configuration values from @@ -59,6 +61,9 @@ public class TopicConfig { @Size(min = 1) private Duration maxRetryDelay = MAX_RETRY_DELAY; + @JsonProperty("serde_format") + private MessageFormat serdeFormat= MessageFormat.PLAINTEXT; + @JsonProperty("auto_commit") private Boolean autoCommit = false; @@ -132,6 +137,10 @@ public void setMaxRetryAttempts(Integer maxRetryAttempts) { this.maxRetryAttempts = maxRetryAttempts; } + public MessageFormat getSerdeFormat() { + return serdeFormat; + } + public Boolean getAutoCommit() { return autoCommit; } diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaSourceCustomConsumer.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaSourceCustomConsumer.java index 2f7378a39d..4fe1e8d415 100644 --- a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaSourceCustomConsumer.java +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaSourceCustomConsumer.java @@ -14,7 +14,9 @@ import org.apache.kafka.clients.consumer.ConsumerRebalanceListener; import org.apache.kafka.clients.consumer.OffsetAndMetadata; import org.apache.kafka.clients.consumer.CommitFailedException; +import org.apache.kafka.common.errors.AuthenticationException; import org.apache.kafka.common.TopicPartition; +import org.apache.avro.generic.GenericRecord; import org.opensearch.dataprepper.model.log.JacksonLog; import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.buffer.Buffer; @@ -38,6 +40,7 @@ import java.util.Objects; import java.util.concurrent.atomic.AtomicBoolean; import org.opensearch.dataprepper.plugins.kafka.util.MessageFormat; +import com.amazonaws.services.schemaregistry.serializers.json.JsonDataWithSchema; import org.apache.commons.lang3.Range; /** @@ -108,7 +111,7 @@ public void updateOffsetsToCommit(final TopicPartition partition, final OffsetAn } private AcknowledgementSet createAcknowledgementSet(Map> offsets) { - AcknowledgementSet acknowledgementSet = + AcknowledgementSet acknowledgementSet = acknowledgementSetManager.create((result) -> { if (result == true) { positiveAcknowledgementSetCounter.increment(); @@ -117,7 +120,7 @@ private AcknowledgementSet createAcknowledgementSet(Map void consumeRecords() throws Exception { - ConsumerRecords records = - consumer.poll(topicConfig.getThreadWaitingTime().toMillis()/2); - if (!records.isEmpty() && records.count() > 0) { - Map> offsets = new HashMap<>(); - AcknowledgementSet acknowledgementSet = null; - if (acknowledgementsEnabled) { - acknowledgementSet = createAcknowledgementSet(offsets); - } - iterateRecordPartitions(records, acknowledgementSet, offsets); - if (!acknowledgementsEnabled) { - offsets.forEach((partition, offsetRange) -> - updateOffsetsToCommit(partition, new OffsetAndMetadata(offsetRange.getMaximum() + 1))); - } else { - acknowledgementSet.complete(); + try { + ConsumerRecords records = + consumer.poll(topicConfig.getThreadWaitingTime().toMillis()/2); + if (!records.isEmpty() && records.count() > 0) { + Map> offsets = new HashMap<>(); + AcknowledgementSet acknowledgementSet = null; + if (acknowledgementsEnabled) { + acknowledgementSet = createAcknowledgementSet(offsets); + } + iterateRecordPartitions(records, acknowledgementSet, offsets); + if (!acknowledgementsEnabled) { + offsets.forEach((partition, offsetRange) -> + updateOffsetsToCommit(partition, new OffsetAndMetadata(offsetRange.getMaximum() + 1))); + } else { + acknowledgementSet.complete(); + } } + } catch (AuthenticationException e) { + LOG.warn("Authentication Error while doing poll(). Will retry after 10 seconds", e); + Thread.sleep(10000); } } @@ -170,7 +178,7 @@ private void commitOffsets() { offsetsToCommit.clear(); lastCommitTime = currentTimeMillis; } catch (CommitFailedException e) { - LOG.error("Failed to commit offsets in topic "+topicName); + LOG.error("Failed to commit offsets in topic "+topicName, e); } } } @@ -195,28 +203,26 @@ public void run() { private Record getRecord(ConsumerRecord consumerRecord) { Map data = new HashMap<>(); Event event; - Object value; + Object value = consumerRecord.value(); String key = (String)consumerRecord.key(); if (Objects.isNull(key)) { key = DEFAULT_KEY; } - if (schema == MessageFormat.JSON || schema == MessageFormat.AVRO) { - value = new HashMap<>(); - try { - if(schema == MessageFormat.JSON){ - value = consumerRecord.value(); - }else if(schema == MessageFormat.AVRO) { - final JsonParser jsonParser = jsonFactory.createParser((String)consumerRecord.value().toString()); - value = objectMapper.readValue(jsonParser, Map.class); - } - } catch (Exception e){ - LOG.error("Failed to parse JSON or AVRO record"); - data.put(key, value); + try { + if (value instanceof JsonDataWithSchema) { + JsonDataWithSchema j = (JsonDataWithSchema)consumerRecord.value(); + value = objectMapper.readValue(j.getPayload(), Map.class); + } else if (schema == MessageFormat.AVRO || value instanceof GenericRecord) { + final JsonParser jsonParser = jsonFactory.createParser((String)consumerRecord.value().toString()); + value = objectMapper.readValue(jsonParser, Map.class); + } else if (schema == MessageFormat.PLAINTEXT) { + value = (String)consumerRecord.value(); } - } else { - value = (String)consumerRecord.value(); + } catch (Exception e){ + LOG.error("Failed to parse JSON or AVRO record", e); } data.put(key, value); + event = JacksonLog.builder().withData(data).build(); return new Record(event); } diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/source/KafkaSource.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/source/KafkaSource.java index f1e69d7382..8b395e5a67 100644 --- a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/source/KafkaSource.java +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/source/KafkaSource.java @@ -35,31 +35,21 @@ import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; import org.opensearch.dataprepper.plugins.kafka.configuration.AuthConfig; -import org.opensearch.dataprepper.plugins.kafka.configuration.AwsConfig; -import org.opensearch.dataprepper.plugins.kafka.configuration.AwsIamAuthConfig; -import org.opensearch.dataprepper.plugins.kafka.configuration.EncryptionType; import org.opensearch.dataprepper.plugins.kafka.configuration.KafkaSourceConfig; +import org.opensearch.dataprepper.plugins.kafka.configuration.OAuthConfig; import org.opensearch.dataprepper.plugins.kafka.configuration.PlainTextAuthConfig; import org.opensearch.dataprepper.plugins.kafka.configuration.SchemaConfig; +import org.opensearch.dataprepper.plugins.kafka.configuration.SchemaRegistryType; import org.opensearch.dataprepper.plugins.kafka.configuration.TopicConfig; -import org.opensearch.dataprepper.plugins.kafka.configuration.OAuthConfig; import org.opensearch.dataprepper.plugins.kafka.consumer.KafkaSourceCustomConsumer; import org.opensearch.dataprepper.plugins.kafka.util.ClientDNSLookupType; import org.opensearch.dataprepper.plugins.kafka.util.KafkaSourceJsonDeserializer; import org.opensearch.dataprepper.plugins.kafka.util.KafkaSourceSecurityConfigurer; import org.opensearch.dataprepper.plugins.kafka.util.MessageFormat; -import software.amazon.awssdk.services.kafka.KafkaClient; -import software.amazon.awssdk.services.kafka.model.GetBootstrapBrokersRequest; -import software.amazon.awssdk.services.kafka.model.GetBootstrapBrokersResponse; -import software.amazon.awssdk.services.kafka.model.InternalServerErrorException; -import software.amazon.awssdk.services.kafka.model.ConflictException; -import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; -import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.services.sts.model.AssumeRoleRequest; -import software.amazon.awssdk.services.sts.StsClient; -import software.amazon.awssdk.regions.Region; +import com.amazonaws.services.schemaregistry.deserializers.GlueSchemaRegistryKafkaDeserializer; +import com.amazonaws.services.schemaregistry.utils.AWSSchemaRegistryConstants; +import com.amazonaws.services.schemaregistry.utils.AvroRecordType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -76,7 +66,6 @@ import java.net.Socket; import java.util.Map; import java.util.List; -import java.util.UUID; import java.util.Objects; import java.util.Comparator; import java.util.Properties; @@ -99,7 +88,6 @@ @DataPrepperPlugin(name = "kafka", pluginType = Source.class, pluginConfigurationType = KafkaSourceConfig.class) public class KafkaSource implements Source> { private static final String KAFKA_WORKER_THREAD_PROCESSING_ERRORS = "kafkaWorkerThreadProcessingErrors"; - private static final int MAX_KAFKA_CLIENT_RETRIES = 10; private static final Logger LOG = LoggerFactory.getLogger(KafkaSource.class); private final KafkaSourceConfig sourceConfig; private AtomicBoolean shutdownInProgress; @@ -112,7 +100,6 @@ public class KafkaSource implements Source> { private String schemaType = MessageFormat.PLAINTEXT.toString(); private static final String SCHEMA_TYPE = "schemaType"; private final AcknowledgementSetManager acknowledgementSetManager; - private final EncryptionType encryptionType; private static CachedSchemaRegistryClient schemaRegistryClient; @DataPrepperPluginConstructor @@ -126,7 +113,6 @@ public KafkaSource(final KafkaSourceConfig sourceConfig, this.pipelineName = pipelineDescription.getPipelineName(); this.kafkaWorkerThreadProcessingErrors = pluginMetrics.counter(KAFKA_WORKER_THREAD_PROCESSING_ERRORS); shutdownInProgress = new AtomicBoolean(false); - this.encryptionType = sourceConfig.getEncryptionType(); } @Override @@ -162,7 +148,7 @@ public void start(Buffer> buffer) { } else { LOG.error("Failed to setup the Kafka Source Plugin.", e); } - throw new RuntimeException(); + throw new RuntimeException(e); } LOG.info("Started Kafka source for topic " + topic.getName()); }); @@ -181,7 +167,7 @@ public void stop() { } } catch (InterruptedException e) { if (e.getCause() instanceof InterruptedException) { - LOG.error("Interrupted during consumer shutdown, exiting uncleanly..."); + LOG.error("Interrupted during consumer shutdown, exiting uncleanly...", e); executorService.shutdownNow(); Thread.currentThread().interrupt(); } @@ -203,106 +189,13 @@ private long calculateLongestThreadWaitingTime() { orElse(1L); } - public String getBootStrapServersForMsk(final AwsIamAuthConfig awsIamAuthConfig, final AwsConfig awsConfig) { - AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.create(); - if (awsIamAuthConfig == AwsIamAuthConfig.ROLE) { - String sessionName = "data-prepper-kafka-session" + UUID.randomUUID(); - StsClient stsClient = StsClient.builder() - .region(Region.of(awsConfig.getRegion())) - .credentialsProvider(credentialsProvider) - .build(); - credentialsProvider = StsAssumeRoleCredentialsProvider - .builder() - .stsClient(stsClient) - .refreshRequest( - AssumeRoleRequest - .builder() - .roleArn(awsConfig.getStsRoleArn()) - .roleSessionName(sessionName) - .build() - ).build(); - } else { - throw new RuntimeException("Unknown AWS IAM auth mode"); - } - final AwsConfig.AwsMskConfig awsMskConfig = awsConfig.getAwsMskConfig(); - KafkaClient kafkaClient = KafkaClient.builder() - .credentialsProvider(credentialsProvider) - .region(Region.of(awsConfig.getRegion())) - .build(); - final GetBootstrapBrokersRequest request = - GetBootstrapBrokersRequest - .builder() - .clusterArn(awsMskConfig.getArn()) - .build(); - - int numRetries = 0; - boolean retryable; - GetBootstrapBrokersResponse result = null; - do { - retryable = false; - try { - result = kafkaClient.getBootstrapBrokers(request); - } catch (InternalServerErrorException | ConflictException e) { - retryable = true; - } catch (Exception e) { - break; - } - } while (retryable && numRetries++ < MAX_KAFKA_CLIENT_RETRIES); - if (Objects.isNull(result)) { - LOG.info("Failed to get bootstrap server information from MSK, using user configured bootstrap servers"); - return sourceConfig.getBootStrapServers(); - } - switch (awsMskConfig.getBrokerConnectionType()) { - case PUBLIC: - return result.bootstrapBrokerStringPublicSaslIam(); - case MULTI_VPC: - return result.bootstrapBrokerStringVpcConnectivitySaslIam(); - default: - case SINGLE_VPC: - return result.bootstrapBrokerStringSaslIam(); - } - } private Properties getConsumerProperties(final TopicConfig topicConfig) { Properties properties = new Properties(); - AwsIamAuthConfig awsIamAuthConfig = null; - AwsConfig awsConfig = sourceConfig.getAwsConfig(); - if (sourceConfig.getAuthConfig() != null) { - AuthConfig.SaslAuthConfig saslAuthConfig = sourceConfig.getAuthConfig().getSaslAuthConfig(); - if (saslAuthConfig != null) { - awsIamAuthConfig = saslAuthConfig.getAwsIamAuthConfig(); - PlainTextAuthConfig plainTextAuthConfig = saslAuthConfig.getPlainTextAuthConfig(); - - if (awsIamAuthConfig != null) { - if (encryptionType == EncryptionType.PLAINTEXT) { - throw new RuntimeException("Encryption Config must be SSL to use IAM authentication mechanism"); - } - setAwsIamAuthProperties(properties, awsIamAuthConfig, awsConfig); - } else if (saslAuthConfig.getOAuthConfig() != null) { - KafkaSourceSecurityConfigurer.setOauthProperties(sourceConfig, properties); - } else if (plainTextAuthConfig != null) { - setPlainTextAuthProperties(properties, plainTextAuthConfig); - } else { - throw new RuntimeException("No SASL auth config specified"); - } - } else if (encryptionType == EncryptionType.SSL) { - properties.put("security.protocol", "SSL"); - if (sourceConfig.getAuthConfig().getInsecure()) { - properties.put("ssl.engine.factory.class", InsecureSslEngineFactory.class); - } - } - } - String bootstrapServers = sourceConfig.getBootStrapServers(); - if (Objects.nonNull(awsIamAuthConfig)) { - bootstrapServers = getBootStrapServersForMsk(awsIamAuthConfig, awsConfig); - } - if (Objects.isNull(bootstrapServers) || bootstrapServers.isEmpty()) { - throw new RuntimeException("Bootstrap servers are not specified"); - } - properties.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers); - /* if (isKafkaClusterExists(sourceConfig.getBootStrapServers())) { - throw new RuntimeException("Can't be able to connect to the given Kafka brokers... "); - }*/ + KafkaSourceSecurityConfigurer.setAuthProperties(properties, sourceConfig); + /* if (isKafkaClusterExists(sourceConfig.getBootStrapServers())) { + throw new RuntimeException("Can't be able to connect to the given Kafka brokers... "); + }*/ if (StringUtils.isNotEmpty(sourceConfig.getClientDnsLookup())) { ClientDNSLookupType dnsLookupType = ClientDNSLookupType.getDnsLookupType(sourceConfig.getClientDnsLookup()); switch (dnsLookupType) { @@ -340,39 +233,6 @@ private String getSchemaRegistryUrl() { return sourceConfig.getSchemaConfig().getRegistryURL(); } - private void setAwsIamAuthProperties(Properties properties, final AwsIamAuthConfig awsIamAuthConfig, final AwsConfig awsConfig) { - if (awsConfig == null) { - throw new RuntimeException("AWS Config is not specified"); - } - properties.put("security.protocol", "SASL_SSL"); - properties.put("sasl.mechanism", "AWS_MSK_IAM"); - properties.put("sasl.client.callback.handler.class", "software.amazon.msk.auth.iam.IAMClientCallbackHandler"); - if (awsIamAuthConfig == AwsIamAuthConfig.ROLE) { - properties.put("sasl.jaas.config", - "software.amazon.msk.auth.iam.IAMLoginModule required " + - "awsRoleArn=\"" + awsConfig.getStsRoleArn() + - "\" awsStsRegion=\"" + awsConfig.getRegion() + "\";"); - } else if (awsIamAuthConfig == AwsIamAuthConfig.DEFAULT) { - properties.put("sasl.jaas.config", - "software.amazon.msk.auth.iam.IAMLoginModule required;"); - } - } - - private void setPlainTextAuthProperties(Properties properties, final PlainTextAuthConfig plainTextAuthConfig) { - String username = plainTextAuthConfig.getUsername(); - String password = plainTextAuthConfig.getPassword(); - properties.put("sasl.mechanism", "PLAIN"); - properties.put("sasl.jaas.config", "org.apache.kafka.common.security.plain.PlainLoginModule required username=\"" + username + "\" password=\"" + password + "\";"); - if (encryptionType == EncryptionType.PLAINTEXT) { - properties.put("security.protocol", "SASL_PLAINTEXT"); - } else { // EncryptionType.SSL - properties.put("security.protocol", "SASL_SSL"); - } - if (sourceConfig.getAuthConfig().getInsecure()) { - properties.put("ssl.engine.factory.class", InsecureSslEngineFactory.class); - } - } - private static String getSchemaType(final String registryUrl, final String topicName, final int schemaVersion) { StringBuilder response = new StringBuilder(); String schemaType = MessageFormat.PLAINTEXT.toString(); @@ -433,18 +293,36 @@ private static String readErrorMessage(InputStream errorStream) throws IOExcepti return errorMessage.toString(); } - private void setSchemaRegistryProperties(Properties properties, TopicConfig topic) { + private void setSchemaRegistryProperties(Properties properties, TopicConfig topicConfig) { SchemaConfig schemaConfig = sourceConfig.getSchemaConfig(); - if (schemaConfig != null && StringUtils.isNotEmpty(schemaConfig.getRegistryURL())) { + if (Objects.isNull(schemaConfig)) { + setPropertiesForPlaintextAndJsonWithoutSchemaRegistry(properties, topicConfig); + return; + } + + if (schemaConfig.getType() == SchemaRegistryType.GLUE) { + setPropertiesForGlueSchemaRegistry(properties); + return; + } + + /* else schema registry type is Confluent */ + if (StringUtils.isNotEmpty(schemaConfig.getRegistryURL())) { setPropertiesForSchemaRegistryConnectivity(properties); - setPropertiesForSchemaType(properties, topic); - } else if (schemaConfig == null) { - setPropertiesForPlaintextAndJsonWithoutSchemaRegistry(properties); + setPropertiesForSchemaType(properties, topicConfig); + } else { + throw new RuntimeException("RegistryURL must be specified for confluent schema registry"); } } - private void setPropertiesForPlaintextAndJsonWithoutSchemaRegistry(Properties properties) { - MessageFormat dataFormat = MessageFormat.getByMessageFormatByName(sourceConfig.getSerdeFormat()); + private void setPropertiesForGlueSchemaRegistry(Properties properties) { + properties.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class.getName()); + properties.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, GlueSchemaRegistryKafkaDeserializer.class.getName()); + properties.put(AWSSchemaRegistryConstants.AWS_REGION, sourceConfig.getAwsConfig().getRegion()); + properties.put(AWSSchemaRegistryConstants.AVRO_RECORD_TYPE, AvroRecordType.GENERIC_RECORD.getName()); + } + + private void setPropertiesForPlaintextAndJsonWithoutSchemaRegistry(Properties properties, final TopicConfig topicConfig) { + MessageFormat dataFormat = topicConfig.getSerdeFormat(); schemaType = dataFormat.toString(); properties.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); @@ -516,10 +394,7 @@ private void setPropertiesForSchemaRegistryConnectivity(Properties properties) { if (authConfig != null && authConfig.getSaslAuthConfig() != null) { PlainTextAuthConfig plainTextAuthConfig = authConfig.getSaslAuthConfig().getPlainTextAuthConfig(); OAuthConfig oAuthConfig = authConfig.getSaslAuthConfig().getOAuthConfig(); - if (plainTextAuthConfig != null) { - properties.put("sasl.mechanism", "PLAIN"); - properties.put("security.protocol", plainTextAuthConfig.getSecurityProtocol()); - } else if (oAuthConfig != null) { + if (oAuthConfig != null) { properties.put("sasl.mechanism", oAuthConfig.getOauthSaslMechanism()); properties.put("security.protocol", oAuthConfig.getOauthSecurityProtocol()); } diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/source/InsecureSslEngineFactory.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/util/InsecureSslEngineFactory.java similarity index 97% rename from data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/source/InsecureSslEngineFactory.java rename to data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/util/InsecureSslEngineFactory.java index 39415dc370..c02e563d62 100644 --- a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/source/InsecureSslEngineFactory.java +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/util/InsecureSslEngineFactory.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.dataprepper.plugins.kafka.source; +package org.opensearch.dataprepper.plugins.kafka.util; import org.apache.kafka.common.security.auth.SslEngineFactory; import javax.net.ssl.SSLContext; diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/util/KafkaSourceSecurityConfigurer.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/util/KafkaSourceSecurityConfigurer.java index 9ffff45d84..7e99e336dc 100644 --- a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/util/KafkaSourceSecurityConfigurer.java +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/util/KafkaSourceSecurityConfigurer.java @@ -4,13 +4,34 @@ */ package org.opensearch.dataprepper.plugins.kafka.util; +import org.opensearch.dataprepper.plugins.kafka.configuration.AuthConfig; import org.opensearch.dataprepper.plugins.kafka.configuration.AwsConfig; +import org.opensearch.dataprepper.plugins.kafka.configuration.AwsIamAuthConfig; import org.opensearch.dataprepper.plugins.kafka.configuration.KafkaSourceConfig; import org.opensearch.dataprepper.plugins.kafka.configuration.OAuthConfig; -import org.opensearch.dataprepper.plugins.kafka.configuration.AwsIamAuthConfig; +import org.opensearch.dataprepper.plugins.kafka.configuration.EncryptionType; +import org.opensearch.dataprepper.plugins.kafka.configuration.PlainTextAuthConfig; +import org.apache.kafka.clients.consumer.ConsumerConfig; + +import software.amazon.awssdk.services.kafka.KafkaClient; +import software.amazon.awssdk.services.kafka.model.GetBootstrapBrokersRequest; +import software.amazon.awssdk.services.kafka.model.GetBootstrapBrokersResponse; +import software.amazon.awssdk.services.kafka.model.InternalServerErrorException; +import software.amazon.awssdk.services.kafka.model.ConflictException; +import software.amazon.awssdk.services.kafka.model.ForbiddenException; +import software.amazon.awssdk.services.kafka.model.UnauthorizedException; +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.services.sts.model.StsException; +import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.services.sts.model.AssumeRoleRequest; +import software.amazon.awssdk.services.sts.StsClient; +import software.amazon.awssdk.regions.Region; import java.util.Base64; +import java.util.Objects; import java.util.Properties; +import java.util.UUID; /** * * This is static property configure dedicated to authentication related information given in pipeline.yml @@ -20,11 +41,12 @@ public class KafkaSourceSecurityConfigurer { private static final String SASL_MECHANISM = "sasl.mechanism"; - private static final String SASL_SECURITY_PROTOCOL = "security.protocol"; + private static final String SECURITY_PROTOCOL = "security.protocol"; - private static final String SASL_JAS_CONFIG = "sasl.jaas.config"; + private static final String SASL_JAAS_CONFIG = "sasl.jaas.config"; private static final String SASL_CALLBACK_HANDLER_CLASS = "sasl.login.callback.handler.class"; + private static final String SASL_CLIENT_CALLBACK_HANDLER_CLASS = "sasl.client.callback.handler.class"; private static final String SASL_JWKS_ENDPOINT_URL = "sasl.oauthbearer.jwks.endpoint.url"; @@ -50,6 +72,8 @@ public class KafkaSourceSecurityConfigurer { private static final String REGISTRY_BASIC_AUTH_USER_INFO = "schema.registry.basic.auth.user.info"; + private static final int MAX_KAFKA_CLIENT_RETRIES = 360; // for one hour every 10 seconds + /*public static void setSaslPlainTextProperties(final KafkaSourceConfig kafkaSourConfig, final Properties properties) { @@ -62,15 +86,27 @@ public class KafkaSourceSecurityConfigurer { if (saslAuthConfig!= null) { if (StringUtils.isNotEmpty(saslAuthConfig.getPlainTextAuthConfig().getPlaintext()) && PLAINTEXT_PROTOCOL.equalsIgnoreCase(saslAuthConfig.getAuthProtocolConfig().getPlaintext())) { - properties.put(SASL_SECURITY_PROTOCOL, PLAINTEXT_PROTOCOL); + properties.put(SECURITY_PROTOCOL, PLAINTEXT_PROTOCOL); } else if (StringUtils.isNotEmpty(saslAuthConfig.getAuthProtocolConfig().getPlaintext()) && SASL_PLAINTEXT_PROTOCOL.equalsIgnoreCase(saslAuthConfig.getAuthProtocolConfig().getPlaintext())) { - properties.put(SASL_SECURITY_PROTOCOL, SASL_PLAINTEXT_PROTOCOL); + properties.put(SECURITY_PROTOCOL, SASL_PLAINTEXT_PROTOCOL); } } - properties.put(SASL_JAS_CONFIG, String.format(PLAINTEXT_JAASCONFIG, username, password)); + properties.put(SASL_JAAS_CONFIG, String.format(PLAINTEXT_JAASCONFIG, username, password)); }*/ + private static void setPlainTextAuthProperties(Properties properties, final PlainTextAuthConfig plainTextAuthConfig, EncryptionType encryptionType) { + String username = plainTextAuthConfig.getUsername(); + String password = plainTextAuthConfig.getPassword(); + properties.put(SASL_MECHANISM, "PLAIN"); + properties.put(SASL_JAAS_CONFIG, "org.apache.kafka.common.security.plain.PlainLoginModule required username=\"" + username + "\" password=\"" + password + "\";"); + if (encryptionType == EncryptionType.NONE) { + properties.put(SECURITY_PROTOCOL, "SASL_PLAINTEXT"); + } else { // EncryptionType.SSL + properties.put(SECURITY_PROTOCOL, "SASL_SSL"); + } + } + public static void setOauthProperties(final KafkaSourceConfig kafkaSourConfig, final Properties properties) { final OAuthConfig oAuthConfig = kafkaSourConfig.getAuthConfig().getSaslAuthConfig().getOAuthConfig(); @@ -91,7 +127,7 @@ public static void setOauthProperties(final KafkaSourceConfig kafkaSourConfig, properties.put(SASL_MECHANISM, saslMechanism); - properties.put(SASL_SECURITY_PROTOCOL, securityProtocol); + properties.put(SECURITY_PROTOCOL, securityProtocol); properties.put(SASL_TOKEN_ENDPOINT_URL, tokenEndPointURL); properties.put(SASL_CALLBACK_HANDLER_CLASS, loginCallBackHandler); @@ -119,25 +155,130 @@ public static void setOauthProperties(final KafkaSourceConfig kafkaSourConfig, jass_config = jass_config.replace(";", " "); jass_config += String.format(extensionValue, extensionLogicalCluster, extensionIdentityPoolId); } - properties.put(SASL_JAS_CONFIG, jass_config); + properties.put(SASL_JAAS_CONFIG, jass_config); } - public static void setAwsIamAuthProperties(Properties properties, AwsIamAuthConfig awsIamAuthConfig, AwsConfig awsConfig) { - if (awsConfig == null) { - throw new RuntimeException("AWS Config is not specified"); - } - properties.put("security.protocol", "SASL_SSL"); - properties.put("sasl.mechanism", "AWS_MSK_IAM"); - properties.put("sasl.client.callback.handler.class", "software.amazon.msk.auth.iam.IAMClientCallbackHandler"); + public static void setAwsIamAuthProperties(Properties properties, final AwsIamAuthConfig awsIamAuthConfig, final AwsConfig awsConfig) { + properties.put(SECURITY_PROTOCOL, "SASL_SSL"); + properties.put(SASL_MECHANISM, "AWS_MSK_IAM"); + properties.put(SASL_CLIENT_CALLBACK_HANDLER_CLASS, "software.amazon.msk.auth.iam.IAMClientCallbackHandler"); if (awsIamAuthConfig == AwsIamAuthConfig.ROLE) { - properties.put("sasl.jaas.config", - "software.amazon.msk.auth.iam.IAMLoginModule required " + - "awsRoleArn=\"" + awsConfig.getStsRoleArn() + - "\" awsStsRegion=\"" + awsConfig.getRegion() + "\";"); + properties.put(SASL_JAAS_CONFIG, + "software.amazon.msk.auth.iam.IAMLoginModule required " + + "awsRoleArn=\"" + awsConfig.getStsRoleArn() + + "\" awsStsRegion=\"" + awsConfig.getRegion() + "\";"); } else if (awsIamAuthConfig == AwsIamAuthConfig.DEFAULT) { - properties.put("sasl.jaas.config", + properties.put(SASL_JAAS_CONFIG, "software.amazon.msk.auth.iam.IAMLoginModule required;"); } } + + public static String getBootStrapServersForMsk(final AwsIamAuthConfig awsIamAuthConfig, final AwsConfig awsConfig) { + AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.create(); + if (awsIamAuthConfig == AwsIamAuthConfig.ROLE) { + String sessionName = "data-prepper-kafka-session" + UUID.randomUUID(); + StsClient stsClient = StsClient.builder() + .region(Region.of(awsConfig.getRegion())) + .credentialsProvider(credentialsProvider) + .build(); + credentialsProvider = StsAssumeRoleCredentialsProvider + .builder() + .stsClient(stsClient) + .refreshRequest( + AssumeRoleRequest + .builder() + .roleArn(awsConfig.getStsRoleArn()) + .roleSessionName(sessionName) + .build() + ).build(); + } else if (awsIamAuthConfig != AwsIamAuthConfig.DEFAULT) { + throw new RuntimeException("Unknown AWS IAM auth mode"); + } + final AwsConfig.AwsMskConfig awsMskConfig = awsConfig.getAwsMskConfig(); + KafkaClient kafkaClient = KafkaClient.builder() + .credentialsProvider(credentialsProvider) + .region(Region.of(awsConfig.getRegion())) + .build(); + final GetBootstrapBrokersRequest request = + GetBootstrapBrokersRequest + .builder() + .clusterArn(awsMskConfig.getArn()) + .build(); + + int numRetries = 0; + boolean retryable; + GetBootstrapBrokersResponse result = null; + do { + retryable = false; + try { + result = kafkaClient.getBootstrapBrokers(request); + } catch (InternalServerErrorException | ConflictException | ForbiddenException | UnauthorizedException | StsException e) { + + retryable = true; + try { + Thread.sleep(10000); + } catch (InterruptedException exp) {} + } catch (Exception e) { + break; + } + } while (retryable && numRetries++ < MAX_KAFKA_CLIENT_RETRIES); + if (Objects.isNull(result)) { + throw new RuntimeException("Failed to get bootstrap server information from MSK, using user configured bootstrap servers"); + } + switch (awsMskConfig.getBrokerConnectionType()) { + case PUBLIC: + return result.bootstrapBrokerStringPublicSaslIam(); + case MULTI_VPC: + return result.bootstrapBrokerStringVpcConnectivitySaslIam(); + default: + case SINGLE_VPC: + return result.bootstrapBrokerStringSaslIam(); + } + } + + public static void setAuthProperties(Properties properties, final KafkaSourceConfig sourceConfig) { + final AwsConfig awsConfig = sourceConfig.getAwsConfig(); + final AuthConfig authConfig = sourceConfig.getAuthConfig(); + final EncryptionType encryptionType = sourceConfig.getEncryptionType(); + + String bootstrapServers = sourceConfig.getBootStrapServers(); + AwsIamAuthConfig awsIamAuthConfig = null; + if (Objects.nonNull(authConfig)) { + AuthConfig.SaslAuthConfig saslAuthConfig = authConfig.getSaslAuthConfig(); + if (Objects.nonNull(saslAuthConfig)) { + awsIamAuthConfig = saslAuthConfig.getAwsIamAuthConfig(); + PlainTextAuthConfig plainTextAuthConfig = saslAuthConfig.getPlainTextAuthConfig(); + + if (Objects.nonNull(awsIamAuthConfig)) { + if (encryptionType == EncryptionType.NONE) { + throw new RuntimeException("Encryption Config must be SSL to use IAM authentication mechanism"); + } + if (Objects.isNull(awsConfig)) { + throw new RuntimeException("AWS Config is not specified"); + } + setAwsIamAuthProperties(properties, awsIamAuthConfig, awsConfig); + bootstrapServers = getBootStrapServersForMsk(awsIamAuthConfig, awsConfig); + } else if (Objects.nonNull(saslAuthConfig.getOAuthConfig())) { + setOauthProperties(sourceConfig, properties); + } else if (Objects.nonNull(plainTextAuthConfig)) { + setPlainTextAuthProperties(properties, plainTextAuthConfig, encryptionType); + } else { + throw new RuntimeException("No SASL auth config specified"); + } + } + if (authConfig.getInsecure()) { + properties.put("ssl.engine.factory.class", InsecureSslEngineFactory.class); + } + } + if (Objects.isNull(authConfig) || Objects.isNull(authConfig.getSaslAuthConfig())) { + if (encryptionType == EncryptionType.SSL) { + properties.put(SECURITY_PROTOCOL, "SSL"); + } + } + if (Objects.isNull(bootstrapServers) || bootstrapServers.isEmpty()) { + throw new RuntimeException("Bootstrap servers are not specified"); + } + properties.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers); + } } diff --git a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/configuration/KafkaSourceConfigTest.java b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/configuration/KafkaSourceConfigTest.java index e963e46e5f..a335ebd5c0 100644 --- a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/configuration/KafkaSourceConfigTest.java +++ b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/configuration/KafkaSourceConfigTest.java @@ -85,8 +85,8 @@ void test_setters() throws NoSuchFieldException, IllegalAccessException { assertEquals(true, kafkaSourceConfig.getAcknowledgementsEnabled()); assertEquals(testTimeout, kafkaSourceConfig.getAcknowledgementsTimeout()); assertEquals(EncryptionType.SSL, kafkaSourceConfig.getEncryptionType()); - setField(KafkaSourceConfig.class, kafkaSourceConfig, "encryptionType", EncryptionType.PLAINTEXT); - assertEquals(EncryptionType.PLAINTEXT, kafkaSourceConfig.getEncryptionType()); + setField(KafkaSourceConfig.class, kafkaSourceConfig, "encryptionType", EncryptionType.NONE); + assertEquals(EncryptionType.NONE, kafkaSourceConfig.getEncryptionType()); setField(KafkaSourceConfig.class, kafkaSourceConfig, "encryptionType", EncryptionType.SSL); assertEquals(EncryptionType.SSL, kafkaSourceConfig.getEncryptionType()); } diff --git a/data-prepper-plugins/log-generator-source/src/test/java/org/opensearch/dataprepper/plugins/source/loggenerator/LogGeneratorSourceTest.java b/data-prepper-plugins/log-generator-source/src/test/java/org/opensearch/dataprepper/plugins/source/loggenerator/LogGeneratorSourceTest.java index ad9bce333e..180fcae952 100644 --- a/data-prepper-plugins/log-generator-source/src/test/java/org/opensearch/dataprepper/plugins/source/loggenerator/LogGeneratorSourceTest.java +++ b/data-prepper-plugins/log-generator-source/src/test/java/org/opensearch/dataprepper/plugins/source/loggenerator/LogGeneratorSourceTest.java @@ -84,14 +84,15 @@ void GIVEN_logGeneratorSourceAndBlockingBuffer_WHEN_noLimit_THEN_keepsWritingToB BlockingBuffer> spyBuffer = spy(new BlockingBuffer>("SamplePipeline")); - lenient().when(sourceConfig.getInterval()).thenReturn(Duration.ofSeconds(1)); // interval of 1 second + Duration interval = Duration.ofMillis(100); + + lenient().when(sourceConfig.getInterval()).thenReturn(interval); lenient().when(sourceConfig.getCount()).thenReturn(INFINITE_LOG_COUNT); // no limit to log count logGeneratorSource.start(spyBuffer); - Thread.sleep(1500); - + Thread.sleep((long) (interval.toMillis() * 1.5)); verify(spyBuffer, atLeast(1)).write(any(Record.class), anyInt()); - Thread.sleep(700); + Thread.sleep((long) (interval.toMillis() * 0.7)); verify(spyBuffer, atLeast(2)).write(any(Record.class), anyInt()); } @@ -102,16 +103,18 @@ void GIVEN_logGeneratorSourceAndBlockingBuffer_WHEN_reachedLimit_THEN_stopsWriti BlockingBuffer> spyBuffer = spy(new BlockingBuffer>("SamplePipeline")); - lenient().when(sourceConfig.getInterval()).thenReturn(Duration.ofSeconds(1)); // interval of 1 second + Duration interval = Duration.ofMillis(100); + + lenient().when(sourceConfig.getInterval()).thenReturn(interval); lenient().when(sourceConfig.getCount()).thenReturn(1); // max log count of 1 in logGeneratorSource assertEquals(spyBuffer.isEmpty(), true); logGeneratorSource.start(spyBuffer); - Thread.sleep(1100); + Thread.sleep((long) (interval.toMillis() * 1.1)); verify(spyBuffer, times(1)).write(any(Record.class), anyInt()); - Thread.sleep(1000); + Thread.sleep(interval.toMillis()); verify(spyBuffer, times(1)).write(any(Record.class), anyInt()); } } \ No newline at end of file diff --git a/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/sink/opensearch/bulk/Es6BulkApiWrapper.java b/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/sink/opensearch/bulk/Es6BulkApiWrapper.java index ac644e23a6..6215d48582 100644 --- a/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/sink/opensearch/bulk/Es6BulkApiWrapper.java +++ b/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/sink/opensearch/bulk/Es6BulkApiWrapper.java @@ -5,6 +5,7 @@ import org.opensearch.client.opensearch._types.OpenSearchException; import org.opensearch.client.opensearch.core.BulkRequest; import org.opensearch.client.opensearch.core.BulkResponse; +import org.opensearch.client.opensearch.core.bulk.BulkOperation; import org.opensearch.client.transport.JsonEndpoint; import org.opensearch.client.transport.endpoints.SimpleEndpoint; import org.opensearch.client.util.ApiTypeHelper; @@ -28,17 +29,14 @@ public BulkResponse bulk(BulkRequest request) throws IOException, OpenSearchExce return openSearchClient._transport().performRequest(request, endpoint, openSearchClient._transportOptions()); } - private JsonEndpoint es6BulkEndpoint(BulkRequest bulkRequest) { + private JsonEndpoint es6BulkEndpoint(final BulkRequest bulkRequest) { return new SimpleEndpoint<>( // Request method request -> HttpMethod.POST, // Request path request -> { - final String index = request.index(); - if (index == null) { - throw new IllegalArgumentException("Bulk request index cannot be missing"); - } + final String index = request.index() == null ? getFirstOperationIndex(request) : request.index(); StringBuilder buf = new StringBuilder(); buf.append("/"); SimpleEndpoint.pathEncode(index, buf); @@ -83,4 +81,19 @@ private JsonEndpoint es6BulkEndpoint(B }, SimpleEndpoint.emptyMap(), true, BulkResponse._DESERIALIZER); } + + private String getFirstOperationIndex(final BulkRequest bulkRequest) { + final BulkOperation firstBulkOperation = bulkRequest.operations().get(0); + if (firstBulkOperation.isIndex()) { + return firstBulkOperation.index().index(); + } else if (firstBulkOperation.isCreate()) { + return firstBulkOperation.create().index(); + } else if (firstBulkOperation.isUpdate()) { + return firstBulkOperation.update().index(); + } else if (firstBulkOperation.isDelete()) { + return firstBulkOperation.delete().index(); + } + throw new IllegalArgumentException(String.format("Unsupported bulk operation kind: %s", + firstBulkOperation._kind())); + } } diff --git a/data-prepper-plugins/opensearch/src/test/java/org/opensearch/dataprepper/plugins/sink/opensearch/bulk/Es6BulkApiWrapperTest.java b/data-prepper-plugins/opensearch/src/test/java/org/opensearch/dataprepper/plugins/sink/opensearch/bulk/Es6BulkApiWrapperTest.java index 1ca5bc06ed..9bce8a8ac1 100644 --- a/data-prepper-plugins/opensearch/src/test/java/org/opensearch/dataprepper/plugins/sink/opensearch/bulk/Es6BulkApiWrapperTest.java +++ b/data-prepper-plugins/opensearch/src/test/java/org/opensearch/dataprepper/plugins/sink/opensearch/bulk/Es6BulkApiWrapperTest.java @@ -1,8 +1,12 @@ package org.opensearch.dataprepper.plugins.sink.opensearch.bulk; +import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; @@ -11,17 +15,24 @@ import org.opensearch.client.opensearch._types.ErrorResponse; import org.opensearch.client.opensearch.core.BulkRequest; import org.opensearch.client.opensearch.core.BulkResponse; +import org.opensearch.client.opensearch.core.bulk.BulkOperation; +import org.opensearch.client.opensearch.core.bulk.CreateOperation; +import org.opensearch.client.opensearch.core.bulk.DeleteOperation; +import org.opensearch.client.opensearch.core.bulk.IndexOperation; +import org.opensearch.client.opensearch.core.bulk.UpdateOperation; import org.opensearch.client.transport.JsonEndpoint; import org.opensearch.client.transport.OpenSearchTransport; import org.opensearch.client.transport.TransportOptions; import java.io.IOException; +import java.util.List; +import java.util.stream.Stream; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.MatcherAssert.assertThat; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -41,14 +52,39 @@ class Es6BulkApiWrapperTest { @Mock private BulkRequest bulkRequest; + @Mock + private BulkOperation bulkOperation; + + @Mock + private IndexOperation indexOperation; + + @Mock + private CreateOperation createOperation; + + @Mock + private UpdateOperation updateOperation; + + @Mock + private DeleteOperation deleteOperation; + @Captor private ArgumentCaptor> jsonEndpointArgumentCaptor; private Es6BulkApiWrapper objectUnderTest; + private String testIndex; @BeforeEach void setUp() { objectUnderTest = new Es6BulkApiWrapper(openSearchClient); + testIndex = RandomStringUtils.randomAlphabetic(5); + lenient().when(bulkOperation.index()).thenReturn(indexOperation); + lenient().when(bulkOperation.create()).thenReturn(createOperation); + lenient().when(bulkOperation.update()).thenReturn(updateOperation); + lenient().when(bulkOperation.delete()).thenReturn(deleteOperation); + lenient().when(indexOperation.index()).thenReturn(testIndex); + lenient().when(createOperation.index()).thenReturn(testIndex); + lenient().when(updateOperation.index()).thenReturn(testIndex); + lenient().when(deleteOperation.index()).thenReturn(testIndex); } @Test @@ -65,16 +101,33 @@ void testBulk() throws IOException { assertThat(endpoint.requestUrl(bulkRequest), equalTo(expectedURI)); } - @Test - void testBulkThrowsException_when_request_missing_index() throws IOException { + @ParameterizedTest + @MethodSource("getTypeFlags") + void testBulk_when_request_index_missing(final boolean isIndex, final boolean isCreate, + final boolean isUpdate, final boolean isDelete) throws IOException { when(openSearchClient._transport()).thenReturn(openSearchTransport); when(openSearchClient._transportOptions()).thenReturn(transportOptions); when(bulkRequest.index()).thenReturn(null); + when(bulkRequest.operations()).thenReturn(List.of(bulkOperation)); + lenient().when(bulkOperation.isIndex()).thenReturn(isIndex); + lenient().when(bulkOperation.isCreate()).thenReturn(isCreate); + lenient().when(bulkOperation.isUpdate()).thenReturn(isUpdate); + lenient().when(bulkOperation.isDelete()).thenReturn(isDelete); objectUnderTest.bulk(bulkRequest); verify(openSearchTransport).performRequest( any(BulkRequest.class), jsonEndpointArgumentCaptor.capture(), eq(transportOptions)); final JsonEndpoint endpoint = jsonEndpointArgumentCaptor.getValue(); - assertThrows(IllegalArgumentException.class, () -> endpoint.requestUrl(bulkRequest)); + final String expectedURI = String.format(ES6_URI_PATTERN, testIndex); + assertThat(endpoint.requestUrl(bulkRequest), equalTo(expectedURI)); + } + + private static Stream getTypeFlags() { + return Stream.of( + Arguments.of(true, false, false, false), + Arguments.of(false, true, false, false), + Arguments.of(false, false, true, false), + Arguments.of(false, false, false, true) + ); } } \ No newline at end of file diff --git a/data-prepper-plugins/rss-source/src/main/java/org/opensearch/dataprepper/plugins/source/rss/RSSSource.java b/data-prepper-plugins/rss-source/src/main/java/org/opensearch/dataprepper/plugins/source/rss/RSSSource.java index a727db5bc7..a3d1f9e0bb 100644 --- a/data-prepper-plugins/rss-source/src/main/java/org/opensearch/dataprepper/plugins/source/rss/RSSSource.java +++ b/data-prepper-plugins/rss-source/src/main/java/org/opensearch/dataprepper/plugins/source/rss/RSSSource.java @@ -45,7 +45,8 @@ public void start(final Buffer> buffer) { throw new IllegalStateException("Buffer is null"); } rssReaderTask = new RssReaderTask(rssReader, rssSourceConfig.getUrl(), buffer); - scheduledExecutorService.scheduleAtFixedRate(rssReaderTask, 0, 5, TimeUnit.SECONDS); + scheduledExecutorService.scheduleAtFixedRate(rssReaderTask, 0, + rssSourceConfig.getPollingFrequency().toMillis(), TimeUnit.MILLISECONDS); } @Override diff --git a/data-prepper-plugins/rss-source/src/test/java/org/opensearch/dataprepper/plugins/source/rss/RSSSourceTest.java b/data-prepper-plugins/rss-source/src/test/java/org/opensearch/dataprepper/plugins/source/rss/RSSSourceTest.java index ac3ad5ba66..a6ae918089 100644 --- a/data-prepper-plugins/rss-source/src/test/java/org/opensearch/dataprepper/plugins/source/rss/RSSSourceTest.java +++ b/data-prepper-plugins/rss-source/src/test/java/org/opensearch/dataprepper/plugins/source/rss/RSSSourceTest.java @@ -18,8 +18,10 @@ import java.time.Duration; +import static org.awaitility.Awaitility.await; import static org.mockito.ArgumentMatchers.anyCollection; import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.verify; @@ -42,12 +44,14 @@ class RSSSourceTest { private PluginMetrics pluginMetrics; private RSSSource rssSource; + private Duration pollingFrequency; @BeforeEach void setUp() { pluginMetrics = PluginMetrics.fromNames(PLUGIN_NAME, PIPELINE_NAME); + pollingFrequency = Duration.ofMillis(1800); lenient().when(rssSourceConfig.getUrl()).thenReturn(VALID_RSS_URL); - lenient().when(rssSourceConfig.getPollingFrequency()).thenReturn(Duration.ofSeconds(5)); + lenient().when(rssSourceConfig.getPollingFrequency()).thenReturn(pollingFrequency); rssSource = new RSSSource(pluginMetrics, rssSourceConfig); } @@ -59,9 +63,15 @@ public void tearDown() { @Test void test_ExecutorService_keep_writing_Events_to_Buffer() throws Exception { rssSource.start(buffer); - Thread.sleep(5000); - verify(buffer, atLeastOnce()).writeAll(anyCollection(), anyInt()); - Thread.sleep(5000); + await().atMost(pollingFrequency.multipliedBy(2)) + .untilAsserted(() -> { + verify(buffer, atLeastOnce()).writeAll(anyCollection(), anyInt()); + }); verify(buffer, atLeastOnce()).writeAll(anyCollection(), anyInt()); + await().atMost(pollingFrequency.multipliedBy(2)) + .untilAsserted(() -> { + verify(buffer, atLeast(2)).writeAll(anyCollection(), anyInt()); + }); + verify(buffer, atLeast(2)).writeAll(anyCollection(), anyInt()); } }