diff --git a/data-prepper-plugins/aws-plugin-api/src/main/java/org/opensearch/dataprepper/aws/api/AwsCredentialsOptions.java b/data-prepper-plugins/aws-plugin-api/src/main/java/org/opensearch/dataprepper/aws/api/AwsCredentialsOptions.java index f66e7cbc10..f0be710968 100644 --- a/data-prepper-plugins/aws-plugin-api/src/main/java/org/opensearch/dataprepper/aws/api/AwsCredentialsOptions.java +++ b/data-prepper-plugins/aws-plugin-api/src/main/java/org/opensearch/dataprepper/aws/api/AwsCredentialsOptions.java @@ -15,6 +15,7 @@ * Provides a standard model for requesting AWS credentials. */ public class AwsCredentialsOptions { + private static final AwsCredentialsOptions DEFAULT_OPTIONS = new AwsCredentialsOptions(); private final String stsRoleArn; private final String stsExternalId; private final Region region; @@ -27,6 +28,13 @@ private AwsCredentialsOptions(final Builder builder) { this.stsHeaderOverrides = builder.stsHeaderOverrides != null ? new HashMap<>(builder.stsHeaderOverrides) : Collections.emptyMap(); } + private AwsCredentialsOptions() { + this.stsRoleArn = null; + this.stsExternalId = null; + this.region = null; + this.stsHeaderOverrides = Collections.emptyMap(); + } + /** * Constructs a new {@link Builder} to build the credentials * options. @@ -37,6 +45,10 @@ public static Builder builder() { return new Builder(); } + public static AwsCredentialsOptions defaultOptions() { + return DEFAULT_OPTIONS; + } + public String getStsRoleArn() { return stsRoleArn; } diff --git a/data-prepper-plugins/aws-plugin-api/src/test/java/org/opensearch/dataprepper/aws/api/AwsCredentialsOptionsTest.java b/data-prepper-plugins/aws-plugin-api/src/test/java/org/opensearch/dataprepper/aws/api/AwsCredentialsOptionsTest.java index d4894d0f8a..5f4200069e 100644 --- a/data-prepper-plugins/aws-plugin-api/src/test/java/org/opensearch/dataprepper/aws/api/AwsCredentialsOptionsTest.java +++ b/data-prepper-plugins/aws-plugin-api/src/test/java/org/opensearch/dataprepper/aws/api/AwsCredentialsOptionsTest.java @@ -10,12 +10,14 @@ import org.junit.jupiter.params.provider.ValueSource; import software.amazon.awssdk.regions.Region; +import java.util.Collections; import java.util.Map; import java.util.UUID; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.notNullValue; import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.CoreMatchers.sameInstance; import static org.hamcrest.MatcherAssert.assertThat; class AwsCredentialsOptionsTest { @@ -131,4 +133,21 @@ void with_StsHeaderOverrides() { assertThat(awsCredentialsOptions.getStsHeaderOverrides().size(), equalTo(stsHeaderOverrides.size())); assertThat(awsCredentialsOptions.getStsHeaderOverrides(), equalTo(stsHeaderOverrides)); } + + @Test + void defaultOptions_returns_with_null_or_empty_values() { + AwsCredentialsOptions defaultOptions = AwsCredentialsOptions.defaultOptions(); + + assertThat(defaultOptions, notNullValue()); + assertThat(defaultOptions.getRegion(), nullValue()); + assertThat(defaultOptions.getStsRoleArn(), nullValue()); + assertThat(defaultOptions.getStsExternalId(), nullValue()); + assertThat(defaultOptions.getStsHeaderOverrides(), equalTo(Collections.emptyMap())); + } + + @Test + void defaultOptions_returns_same_instance_on_multiple_calls() { + assertThat(AwsCredentialsOptions.defaultOptions(), + sameInstance(AwsCredentialsOptions.defaultOptions())); + } } \ No newline at end of file diff --git a/data-prepper-plugins/kafka-plugins/build.gradle b/data-prepper-plugins/kafka-plugins/build.gradle index 07ee89d36d..1a9a79e9b5 100644 --- a/data-prepper-plugins/kafka-plugins/build.gradle +++ b/data-prepper-plugins/kafka-plugins/build.gradle @@ -11,6 +11,7 @@ dependencies { implementation project(':data-prepper-api') implementation project(':data-prepper-plugins:buffer-common') implementation project(':data-prepper-plugins:blocking-buffer') + implementation project(':data-prepper-plugins:aws-plugin-api') implementation 'org.apache.kafka:kafka-clients:3.4.0' implementation libs.avro.core implementation 'com.fasterxml.jackson.core:jackson-databind' @@ -21,10 +22,11 @@ dependencies { implementation 'io.confluent:kafka-avro-serializer:7.3.3' implementation 'io.confluent:kafka-schema-registry-client:7.3.3' implementation 'io.confluent:kafka-schema-registry:7.3.3:tests' + implementation 'software.amazon.awssdk:sts' + implementation 'software.amazon.awssdk:auth' + implementation 'software.amazon.awssdk:kafka' + implementation 'software.amazon.awssdk:kms' implementation 'software.amazon.msk:aws-msk-iam-auth:1.1.6' - implementation 'software.amazon.awssdk:sts:2.20.103' - implementation 'software.amazon.awssdk:auth:2.20.103' - implementation 'software.amazon.awssdk:kafka:2.20.103' implementation 'software.amazon.glue:schema-registry-serde:1.1.15' implementation 'com.amazonaws:aws-java-sdk-glue:1.12.506' implementation 'io.confluent:kafka-json-schema-serializer:7.4.0' diff --git a/data-prepper-plugins/kafka-plugins/src/integrationTest/java/org/opensearch/dataprepper/plugins/kafka/buffer/KafkaBufferIT.java b/data-prepper-plugins/kafka-plugins/src/integrationTest/java/org/opensearch/dataprepper/plugins/kafka/buffer/KafkaBufferIT.java index ab6fe38199..a9fac35ef9 100644 --- a/data-prepper-plugins/kafka-plugins/src/integrationTest/java/org/opensearch/dataprepper/plugins/kafka/buffer/KafkaBufferIT.java +++ b/data-prepper-plugins/kafka-plugins/src/integrationTest/java/org/opensearch/dataprepper/plugins/kafka/buffer/KafkaBufferIT.java @@ -21,7 +21,11 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.crypto.KeyGenerator; +import javax.crypto.SecretKey; +import java.security.NoSuchAlgorithmException; import java.time.Duration; +import java.util.Base64; import java.util.Collection; import java.util.Collections; import java.util.Map; @@ -45,8 +49,11 @@ public class KafkaBufferIT { private PluginFactory pluginFactory; @Mock private AcknowledgementSetManager acknowledgementSetManager; + @Mock + private TopicConfig topicConfig; private PluginMetrics pluginMetrics; + private String bootstrapServersCommaDelimited; @BeforeEach void setUp() { @@ -57,7 +64,6 @@ void setUp() { MessageFormat messageFormat = MessageFormat.JSON; String topicName = "buffer-" + RandomStringUtils.randomAlphabetic(5); - TopicConfig topicConfig = mock(TopicConfig.class); when(topicConfig.getName()).thenReturn(topicName); when(topicConfig.getGroupId()).thenReturn("buffergroup-" + RandomStringUtils.randomAlphabetic(6)); when(topicConfig.isCreate()).thenReturn(true); @@ -76,16 +82,16 @@ void setUp() { EncryptionConfig encryptionConfig = mock(EncryptionConfig.class); - String bootstrapServers = System.getProperty("tests.kafka.bootstrap_servers"); + bootstrapServersCommaDelimited = System.getProperty("tests.kafka.bootstrap_servers"); - LOG.info("Using Kafka bootstrap servers: {}", bootstrapServers); + LOG.info("Using Kafka bootstrap servers: {}", bootstrapServersCommaDelimited); - when(kafkaBufferConfig.getBootstrapServers()).thenReturn(Collections.singletonList(bootstrapServers)); + when(kafkaBufferConfig.getBootstrapServers()).thenReturn(Collections.singletonList(bootstrapServersCommaDelimited)); when(kafkaBufferConfig.getEncryptionConfig()).thenReturn(encryptionConfig); } private KafkaBuffer> createObjectUnderTest() { - return new KafkaBuffer<>(pluginSetting, kafkaBufferConfig, pluginFactory, acknowledgementSetManager, pluginMetrics); + return new KafkaBuffer<>(pluginSetting, kafkaBufferConfig, pluginFactory, acknowledgementSetManager, pluginMetrics, null); } @Test @@ -110,8 +116,40 @@ void write_and_read() throws TimeoutException { assertThat(onlyResult.getData().toMap(), equalTo(record.getData().toMap())); } + @Test + void write_and_read_encrypted() throws TimeoutException, NoSuchAlgorithmException { + when(topicConfig.getEncryptionKey()).thenReturn(createAesKey()); + + KafkaBuffer> objectUnderTest = createObjectUnderTest(); + + Record record = createRecord(); + objectUnderTest.write(record, 1_000); + + Map.Entry>, CheckpointState> readResult = objectUnderTest.read(10_000); + + assertThat(readResult, notNullValue()); + assertThat(readResult.getKey(), notNullValue()); + assertThat(readResult.getKey().size(), equalTo(1)); + + Record onlyResult = readResult.getKey().stream().iterator().next(); + + assertThat(onlyResult, notNullValue()); + assertThat(onlyResult.getData(), notNullValue()); + // TODO: The metadata is not included. It needs to be included in the Buffer, though not in the Sink. This may be something we make configurable in the consumer/producer - whether to serialize the metadata or not. + //assertThat(onlyResult.getData().getMetadata(), equalTo(record.getData().getMetadata())); + assertThat(onlyResult.getData().toMap(), equalTo(record.getData().toMap())); + } + private Record createRecord() { Event event = JacksonEvent.fromMessage(UUID.randomUUID().toString()); return new Record<>(event); } + + private static String createAesKey() throws NoSuchAlgorithmException { + KeyGenerator aesKeyGenerator = KeyGenerator.getInstance("AES"); + aesKeyGenerator.init(256); + SecretKey secretKey = aesKeyGenerator.generateKey(); + byte[] base64Bytes = Base64.getEncoder().encode(secretKey.getEncoded()); + return new String(base64Bytes); + } } diff --git a/data-prepper-plugins/kafka-plugins/src/integrationTest/java/org/opensearch/dataprepper/plugins/kafka/sink/KafkaSinkAvroTypeIT.java b/data-prepper-plugins/kafka-plugins/src/integrationTest/java/org/opensearch/dataprepper/plugins/kafka/sink/KafkaSinkAvroTypeIT.java index 6830168e25..3b75384581 100644 --- a/data-prepper-plugins/kafka-plugins/src/integrationTest/java/org/opensearch/dataprepper/plugins/kafka/sink/KafkaSinkAvroTypeIT.java +++ b/data-prepper-plugins/kafka-plugins/src/integrationTest/java/org/opensearch/dataprepper/plugins/kafka/sink/KafkaSinkAvroTypeIT.java @@ -24,6 +24,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mock; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.expression.ExpressionEvaluator; import org.opensearch.dataprepper.model.configuration.PluginSetting; import org.opensearch.dataprepper.model.event.Event; @@ -95,6 +96,9 @@ public class KafkaSinkAvroTypeIT { @Mock private ExpressionEvaluator evaluator; + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + private AuthConfig authConfig; private AuthConfig.SaslAuthConfig saslAuthConfig; private PlainTextAuthConfig plainTextAuthConfig; @@ -102,7 +106,7 @@ public class KafkaSinkAvroTypeIT { public KafkaSink createObjectUnderTest() { - return new KafkaSink(pluginSetting, kafkaSinkConfig, pluginFactory, evaluator, sinkContext); + return new KafkaSink(pluginSetting, kafkaSinkConfig, pluginFactory, evaluator, sinkContext, awsCredentialsSupplier); } @BeforeEach diff --git a/data-prepper-plugins/kafka-plugins/src/integrationTest/java/org/opensearch/dataprepper/plugins/kafka/sink/KafkaSinkJsonTypeIT.java b/data-prepper-plugins/kafka-plugins/src/integrationTest/java/org/opensearch/dataprepper/plugins/kafka/sink/KafkaSinkJsonTypeIT.java index 30dec4a628..3d825925ad 100644 --- a/data-prepper-plugins/kafka-plugins/src/integrationTest/java/org/opensearch/dataprepper/plugins/kafka/sink/KafkaSinkJsonTypeIT.java +++ b/data-prepper-plugins/kafka-plugins/src/integrationTest/java/org/opensearch/dataprepper/plugins/kafka/sink/KafkaSinkJsonTypeIT.java @@ -22,6 +22,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mock; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.expression.ExpressionEvaluator; import org.opensearch.dataprepper.model.configuration.PluginSetting; import org.opensearch.dataprepper.model.event.Event; @@ -85,6 +86,9 @@ public class KafkaSinkJsonTypeIT { @Mock private ExpressionEvaluator evaluator; + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + private PlainTextAuthConfig plainTextAuthConfig; private AuthConfig.SaslAuthConfig saslAuthConfig; private AuthConfig authConfig; @@ -93,7 +97,7 @@ public class KafkaSinkJsonTypeIT { public KafkaSink createObjectUnderTest() { - return new KafkaSink(pluginSetting, kafkaSinkConfig, pluginFactory, evaluator, sinkContext); + return new KafkaSink(pluginSetting, kafkaSinkConfig, pluginFactory, evaluator, sinkContext, awsCredentialsSupplier); } @BeforeEach diff --git a/data-prepper-plugins/kafka-plugins/src/integrationTest/java/org/opensearch/dataprepper/plugins/kafka/sink/KafkaSinkPlainTextTypeIT.java b/data-prepper-plugins/kafka-plugins/src/integrationTest/java/org/opensearch/dataprepper/plugins/kafka/sink/KafkaSinkPlainTextTypeIT.java index d016df3033..cd7ac9526f 100644 --- a/data-prepper-plugins/kafka-plugins/src/integrationTest/java/org/opensearch/dataprepper/plugins/kafka/sink/KafkaSinkPlainTextTypeIT.java +++ b/data-prepper-plugins/kafka-plugins/src/integrationTest/java/org/opensearch/dataprepper/plugins/kafka/sink/KafkaSinkPlainTextTypeIT.java @@ -20,6 +20,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mock; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.expression.ExpressionEvaluator; import org.opensearch.dataprepper.model.configuration.PluginSetting; import org.opensearch.dataprepper.model.event.Event; @@ -83,6 +84,9 @@ public class KafkaSinkPlainTextTypeIT { @Mock private ExpressionEvaluator evaluator; + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + private PlainTextAuthConfig plainTextAuthConfig; private AuthConfig.SaslAuthConfig saslAuthConfig; private AuthConfig authConfig; @@ -91,7 +95,7 @@ public class KafkaSinkPlainTextTypeIT { public KafkaSink createObjectUnderTest() { - return new KafkaSink(pluginSetting, kafkaSinkConfig, pluginFactory, evaluator, sinkContext); + return new KafkaSink(pluginSetting, kafkaSinkConfig, pluginFactory, evaluator, sinkContext, awsCredentialsSupplier); } @BeforeEach diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/buffer/KafkaBuffer.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/buffer/KafkaBuffer.java index 89d754aa84..0c3e89f90b 100644 --- a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/buffer/KafkaBuffer.java +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/buffer/KafkaBuffer.java @@ -1,5 +1,6 @@ package org.opensearch.dataprepper.plugins.kafka.buffer; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.CheckpointState; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; @@ -42,12 +43,13 @@ public class KafkaBuffer> extends AbstractBuffer { @DataPrepperPluginConstructor public KafkaBuffer(final PluginSetting pluginSetting, final KafkaBufferConfig kafkaBufferConfig, final PluginFactory pluginFactory, - final AcknowledgementSetManager acknowledgementSetManager, final PluginMetrics pluginMetrics){ + final AcknowledgementSetManager acknowledgementSetManager, final PluginMetrics pluginMetrics, + final AwsCredentialsSupplier awsCredentialsSupplier) { super(pluginSetting); SerializationFactory serializationFactory = new SerializationFactory(); - final KafkaCustomProducerFactory kafkaCustomProducerFactory = new KafkaCustomProducerFactory(serializationFactory); + final KafkaCustomProducerFactory kafkaCustomProducerFactory = new KafkaCustomProducerFactory(serializationFactory, awsCredentialsSupplier); producer = kafkaCustomProducerFactory.createProducer(kafkaBufferConfig, pluginFactory, pluginSetting, null, null); - final KafkaCustomConsumerFactory kafkaCustomConsumerFactory = new KafkaCustomConsumerFactory(serializationFactory); + final KafkaCustomConsumerFactory kafkaCustomConsumerFactory = new KafkaCustomConsumerFactory(serializationFactory, awsCredentialsSupplier); innerBuffer = new BlockingBuffer<>(INNER_BUFFER_CAPACITY, INNER_BUFFER_BATCH_SIZE, pluginSetting.getPipelineName()); final List consumers = kafkaCustomConsumerFactory.createConsumersForTopic(kafkaBufferConfig, kafkaBufferConfig.getTopic(), innerBuffer, pluginMetrics, acknowledgementSetManager, new AtomicBoolean(false)); diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/KafkaDataConfig.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/KafkaDataConfig.java index 0d5ed43398..7dd5dd2feb 100644 --- a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/KafkaDataConfig.java +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/KafkaDataConfig.java @@ -2,10 +2,13 @@ import org.opensearch.dataprepper.plugins.kafka.util.MessageFormat; +import java.util.function.Supplier; + /** * An interface representing important data for how the data going to or coming from * Kafka should be represented. */ public interface KafkaDataConfig { MessageFormat getSerdeFormat(); + Supplier getEncryptionKeySupplier(); } diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/KafkaDataConfigAdapter.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/KafkaDataConfigAdapter.java new file mode 100644 index 0000000000..e7d38bf8ba --- /dev/null +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/KafkaDataConfigAdapter.java @@ -0,0 +1,32 @@ +package org.opensearch.dataprepper.plugins.kafka.common; + +import org.opensearch.dataprepper.plugins.kafka.common.key.KeyFactory; +import org.opensearch.dataprepper.plugins.kafka.configuration.TopicConfig; +import org.opensearch.dataprepper.plugins.kafka.util.MessageFormat; + +import java.util.function.Supplier; + +/** + * Adapts a {@link TopicConfig} to a {@link KafkaDataConfig}. + */ +public class KafkaDataConfigAdapter implements KafkaDataConfig { + private final KeyFactory keyFactory; + private final TopicConfig topicConfig; + + public KafkaDataConfigAdapter(KeyFactory keyFactory, TopicConfig topicConfig) { + this.keyFactory = keyFactory; + this.topicConfig = topicConfig; + } + + @Override + public MessageFormat getSerdeFormat() { + return topicConfig.getSerdeFormat(); + } + + @Override + public Supplier getEncryptionKeySupplier() { + if(topicConfig.getEncryptionKey() == null) + return null; + return keyFactory.getKeySupplier(topicConfig); + } +} diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/PlaintextKafkaDataConfig.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/PlaintextKafkaDataConfig.java index 951306d748..26a6239823 100644 --- a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/PlaintextKafkaDataConfig.java +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/PlaintextKafkaDataConfig.java @@ -2,7 +2,15 @@ import org.opensearch.dataprepper.plugins.kafka.util.MessageFormat; +import java.util.function.Supplier; + public class PlaintextKafkaDataConfig implements KafkaDataConfig { + private final KafkaDataConfig dataConfig; + + private PlaintextKafkaDataConfig(KafkaDataConfig dataConfig) { + this.dataConfig = dataConfig; + } + /** * Gets similar {@link KafkaDataConfig} as the given one, but uses {@link MessageFormat#PLAINTEXT} for * the serialization/deserialization format. @@ -11,11 +19,16 @@ public class PlaintextKafkaDataConfig implements KafkaDataConfig { * @return A {@link KafkaDataConfig} with the PLAINTEXT message format. */ public static KafkaDataConfig plaintextDataConfig(final KafkaDataConfig dataConfig) { - return new PlaintextKafkaDataConfig(); + return new PlaintextKafkaDataConfig(dataConfig); } @Override public MessageFormat getSerdeFormat() { return MessageFormat.PLAINTEXT; } + + @Override + public Supplier getEncryptionKeySupplier() { + return dataConfig.getEncryptionKeySupplier(); + } } diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/aws/AwsContext.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/aws/AwsContext.java new file mode 100644 index 0000000000..ac848cc564 --- /dev/null +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/aws/AwsContext.java @@ -0,0 +1,45 @@ +package org.opensearch.dataprepper.plugins.kafka.common.aws; + +import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.plugins.kafka.configuration.AwsConfig; +import org.opensearch.dataprepper.plugins.kafka.configuration.KafkaConnectionConfig; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; + +import java.util.function.Supplier; + +/** + * Standard implementation in Kafka plugins to get an {@link AwsCredentialsProvider}. + * The key interface this implements is {@link Supplier}, supplying an {@link AwsCredentialsProvider}. + * In general, you can provide the {@link Supplier} into class; just use this class when + * constructing. + */ +public class AwsContext implements Supplier { + private final AwsConfig awsConfig; + private final AwsCredentialsSupplier awsCredentialsSupplier; + + public AwsContext(KafkaConnectionConfig connectionConfig, AwsCredentialsSupplier awsCredentialsSupplier) { + awsConfig = connectionConfig.getAwsConfig(); + this.awsCredentialsSupplier = awsCredentialsSupplier; + } + + @Override + public AwsCredentialsProvider get() { + final AwsCredentialsOptions credentialsOptions; + if(awsConfig != null) { + credentialsOptions = awsConfig.toCredentialsOptions(); + } else { + credentialsOptions = AwsCredentialsOptions.defaultOptions(); + } + + return awsCredentialsSupplier.getProvider(credentialsOptions); + } + + public Region getRegion() { + if(awsConfig != null && awsConfig.getRegion() != null) { + return Region.of(awsConfig.getRegion()); + } + return null; + } +} diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/key/InnerKeyProvider.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/key/InnerKeyProvider.java new file mode 100644 index 0000000000..88fd91ec83 --- /dev/null +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/key/InnerKeyProvider.java @@ -0,0 +1,9 @@ +package org.opensearch.dataprepper.plugins.kafka.common.key; + +import org.opensearch.dataprepper.plugins.kafka.configuration.TopicConfig; + +import java.util.function.Function; + +interface InnerKeyProvider extends Function { + boolean supportsConfiguration(TopicConfig topicConfig); +} diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/key/KeyFactory.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/key/KeyFactory.java new file mode 100644 index 0000000000..96cfa82c76 --- /dev/null +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/key/KeyFactory.java @@ -0,0 +1,35 @@ +package org.opensearch.dataprepper.plugins.kafka.common.key; + +import org.opensearch.dataprepper.plugins.kafka.common.aws.AwsContext; +import org.opensearch.dataprepper.plugins.kafka.configuration.TopicConfig; + +import java.util.List; +import java.util.function.Supplier; + +public class KeyFactory { + private final List orderedKeyProviders; + + public KeyFactory(AwsContext awsContext) { + this(List.of( + new KmsKeyProvider(awsContext), + new UnencryptedKeyProvider() + )); + } + + KeyFactory(List orderedKeyProviders) { + this.orderedKeyProviders = orderedKeyProviders; + } + + public Supplier getKeySupplier(TopicConfig topicConfig) { + if (topicConfig.getEncryptionKey() == null) + return null; + + InnerKeyProvider keyProvider = orderedKeyProviders + .stream() + .filter(innerKeyProvider -> innerKeyProvider.supportsConfiguration(topicConfig)) + .findFirst() + .orElseThrow(() -> new RuntimeException("Unable to find an inner key provider. This is a programming error - UnencryptedKeyProvider should always work.")); + + return () -> keyProvider.apply(topicConfig); + } +} diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/key/KmsKeyProvider.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/key/KmsKeyProvider.java new file mode 100644 index 0000000000..3abe62ffb2 --- /dev/null +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/key/KmsKeyProvider.java @@ -0,0 +1,43 @@ +package org.opensearch.dataprepper.plugins.kafka.common.key; + +import org.opensearch.dataprepper.plugins.kafka.common.aws.AwsContext; +import org.opensearch.dataprepper.plugins.kafka.configuration.TopicConfig; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.kms.KmsClient; +import software.amazon.awssdk.services.kms.model.DecryptResponse; + +import java.util.Base64; + +class KmsKeyProvider implements InnerKeyProvider { + private final AwsContext awsContext; + + public KmsKeyProvider(AwsContext awsContext) { + this.awsContext = awsContext; + } + + @Override + public boolean supportsConfiguration(TopicConfig topicConfig) { + return topicConfig.getKmsKeyId() != null; + } + + @Override + public byte[] apply(TopicConfig topicConfig) { + String kmsKeyId = topicConfig.getKmsKeyId(); + + AwsCredentialsProvider awsCredentialsProvider = awsContext.get(); + + KmsClient kmsClient = KmsClient.builder() + .credentialsProvider(awsCredentialsProvider) + .region(awsContext.getRegion()) + .build(); + + byte[] decodedEncryptionKey = Base64.getDecoder().decode(topicConfig.getEncryptionKey()); + DecryptResponse decryptResponse = kmsClient.decrypt(builder -> builder + .keyId(kmsKeyId) + .ciphertextBlob(SdkBytes.fromByteArray(decodedEncryptionKey)) + ); + + return decryptResponse.plaintext().asByteArray(); + } +} diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/key/UnencryptedKeyProvider.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/key/UnencryptedKeyProvider.java new file mode 100644 index 0000000000..19309697fd --- /dev/null +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/key/UnencryptedKeyProvider.java @@ -0,0 +1,17 @@ +package org.opensearch.dataprepper.plugins.kafka.common.key; + +import org.opensearch.dataprepper.plugins.kafka.configuration.TopicConfig; + +import java.util.Base64; + +class UnencryptedKeyProvider implements InnerKeyProvider { + @Override + public boolean supportsConfiguration(TopicConfig topicConfig) { + return true; + } + + @Override + public byte[] apply(TopicConfig topicConfig) { + return Base64.getDecoder().decode(topicConfig.getEncryptionKey()); + } +} diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/DecryptionDeserializer.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/DecryptionDeserializer.java new file mode 100644 index 0000000000..c66cf58dbb --- /dev/null +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/DecryptionDeserializer.java @@ -0,0 +1,48 @@ +package org.opensearch.dataprepper.plugins.kafka.common.serialization; + +import org.apache.kafka.common.serialization.Deserializer; + +import javax.crypto.BadPaddingException; +import javax.crypto.Cipher; +import javax.crypto.IllegalBlockSizeException; +import javax.crypto.NoSuchPaddingException; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; + +/** + * Implementation of Kafka's {@link Deserializer} which decrypts the message + * before deserializing it. + * + * @param - Type to be deserialized into + * @see EncryptionSerializer + */ +class DecryptionDeserializer implements Deserializer { + private final Deserializer innerDeserializer; + private final Cipher cipher; + private final EncryptionContext encryptionContext; + + DecryptionDeserializer(Deserializer innerDeserializer, EncryptionContext encryptionContext) throws InvalidKeyException, NoSuchPaddingException, NoSuchAlgorithmException { + this.innerDeserializer = innerDeserializer; + cipher = encryptionContext.createDecryptionCipher(); + this.encryptionContext = encryptionContext; + } + + @Override + public T deserialize(String topic, byte[] data) { + byte[] unencryptedBytes; + try { + unencryptedBytes = cipher.doFinal(data); + } catch (IllegalBlockSizeException | BadPaddingException e) { + throw new RuntimeException(e); + } + return innerDeserializer.deserialize(topic, unencryptedBytes); + } + + EncryptionContext getEncryptionContext() { + return encryptionContext; + } + + Deserializer getInnerDeserializer() { + return innerDeserializer; + } +} diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/EncryptionContext.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/EncryptionContext.java new file mode 100644 index 0000000000..a53aa2deb8 --- /dev/null +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/EncryptionContext.java @@ -0,0 +1,41 @@ +package org.opensearch.dataprepper.plugins.kafka.common.serialization; + +import javax.crypto.Cipher; +import javax.crypto.NoSuchPaddingException; +import javax.crypto.spec.SecretKeySpec; +import java.security.InvalidKeyException; +import java.security.Key; +import java.security.NoSuchAlgorithmException; + +class EncryptionContext { + private static final String AES_ALGORITHM = "AES"; + private final Key encryptionKey; + + EncryptionContext(Key encryptionKey) { + this.encryptionKey = encryptionKey; + } + + static EncryptionContext fromEncryptionKey(byte[] encryptionKey) { + SecretKeySpec secretKeySpec = new SecretKeySpec(encryptionKey, AES_ALGORITHM); + + return new EncryptionContext(secretKeySpec); + } + + Cipher createEncryptionCipher() throws NoSuchPaddingException, NoSuchAlgorithmException, InvalidKeyException { + return createCipher(Cipher.ENCRYPT_MODE); + } + + Cipher createDecryptionCipher() throws NoSuchPaddingException, NoSuchAlgorithmException, InvalidKeyException { + return createCipher(Cipher.DECRYPT_MODE); + } + + Key getEncryptionKey() { + return encryptionKey; + } + + private Cipher createCipher(int encryptMode) throws NoSuchAlgorithmException, NoSuchPaddingException, InvalidKeyException { + Cipher cipher = Cipher.getInstance(AES_ALGORITHM); + cipher.init(encryptMode, encryptionKey); + return cipher; + } +} diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/EncryptionSerializationFactory.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/EncryptionSerializationFactory.java new file mode 100644 index 0000000000..7d83045b51 --- /dev/null +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/EncryptionSerializationFactory.java @@ -0,0 +1,37 @@ +package org.opensearch.dataprepper.plugins.kafka.common.serialization; + +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serializer; +import org.opensearch.dataprepper.plugins.kafka.common.KafkaDataConfig; + +import javax.crypto.NoSuchPaddingException; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; + +class EncryptionSerializationFactory { + Deserializer getDeserializer(KafkaDataConfig dataConfig, Deserializer innerDeserializer) { + if(dataConfig.getEncryptionKeySupplier() == null) + return innerDeserializer; + + EncryptionContext encryptionContext = EncryptionContext.fromEncryptionKey(dataConfig.getEncryptionKeySupplier().get()); + + try { + return new DecryptionDeserializer<>(innerDeserializer, encryptionContext); + } catch (InvalidKeyException | NoSuchPaddingException | NoSuchAlgorithmException e) { + throw new RuntimeException(e); + } + } + + Serializer getSerializer(KafkaDataConfig dataConfig, Serializer innerSerializer) { + if(dataConfig.getEncryptionKeySupplier() == null) + return innerSerializer; + + EncryptionContext encryptionContext = EncryptionContext.fromEncryptionKey(dataConfig.getEncryptionKeySupplier().get()); + + try { + return new EncryptionSerializer<>(innerSerializer, encryptionContext); + } catch (InvalidKeyException | NoSuchPaddingException | NoSuchAlgorithmException e) { + throw new RuntimeException(e); + } + } +} diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/EncryptionSerializer.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/EncryptionSerializer.java new file mode 100644 index 0000000000..7710b78c8a --- /dev/null +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/EncryptionSerializer.java @@ -0,0 +1,46 @@ +package org.opensearch.dataprepper.plugins.kafka.common.serialization; + +import org.apache.kafka.common.serialization.Serializer; + +import javax.crypto.BadPaddingException; +import javax.crypto.Cipher; +import javax.crypto.IllegalBlockSizeException; +import javax.crypto.NoSuchPaddingException; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; + +/** + * Implementation of Kafka's {@link Serializer} which encrypts data after serializing it. + * + * @param - Type to be serialized from + * @see DecryptionDeserializer + */ +class EncryptionSerializer implements Serializer { + private final Serializer innerSerializer; + private final Cipher cipher; + private final EncryptionContext encryptionContext; + + EncryptionSerializer(Serializer innerSerializer, EncryptionContext encryptionContext) throws InvalidKeyException, NoSuchPaddingException, NoSuchAlgorithmException { + this.innerSerializer = innerSerializer; + cipher = encryptionContext.createEncryptionCipher(); + this.encryptionContext = encryptionContext; + } + + @Override + public byte[] serialize(String topic, T data) { + byte[] unencryptedBytes = innerSerializer.serialize(topic, data); + try { + return cipher.doFinal(unencryptedBytes); + } catch (IllegalBlockSizeException | BadPaddingException e) { + throw new RuntimeException(e); + } + } + + public EncryptionContext getEncryptionContext() { + return encryptionContext; + } + + public Serializer getInnerSerializer() { + return innerSerializer; + } +} diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/SerializationFactory.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/SerializationFactory.java index b871af5ee6..b7495bf4cc 100644 --- a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/SerializationFactory.java +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/SerializationFactory.java @@ -6,25 +6,30 @@ public class SerializationFactory { private final MessageFormatSerializationFactory messageFormatSerializationFactory; + private final EncryptionSerializationFactory encryptionSerializationFactory; public SerializationFactory() { - this(new MessageFormatSerializationFactory()); + this(new MessageFormatSerializationFactory(), new EncryptionSerializationFactory()); } /** * Testing constructor. * * @param messageFormatSerializationFactory The {@link MessageFormatSerializationFactory} + * @param encryptionSerializationFactory */ - SerializationFactory(MessageFormatSerializationFactory messageFormatSerializationFactory) { + SerializationFactory(MessageFormatSerializationFactory messageFormatSerializationFactory, EncryptionSerializationFactory encryptionSerializationFactory) { this.messageFormatSerializationFactory = messageFormatSerializationFactory; + this.encryptionSerializationFactory = encryptionSerializationFactory; } public Deserializer getDeserializer(KafkaDataConfig dataConfig) { - return messageFormatSerializationFactory.getDeserializer(dataConfig.getSerdeFormat()); + Deserializer deserializer = messageFormatSerializationFactory.getDeserializer(dataConfig.getSerdeFormat()); + return encryptionSerializationFactory.getDeserializer(dataConfig, deserializer); } public Serializer getSerializer(KafkaDataConfig dataConfig) { - return messageFormatSerializationFactory.getSerializer(dataConfig.getSerdeFormat()); + Serializer serializer = messageFormatSerializationFactory.getSerializer(dataConfig.getSerdeFormat()); + return encryptionSerializationFactory.getSerializer(dataConfig, serializer); } } diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/AwsConfig.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/AwsConfig.java index af055ab89a..6257ecb12d 100644 --- a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/AwsConfig.java +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/AwsConfig.java @@ -8,6 +8,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import jakarta.validation.constraints.Size; import jakarta.validation.Valid; +import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; public class AwsConfig { @@ -60,4 +61,11 @@ public String getStsRoleArn() { public String getStsRoleSessionName() { return stsRoleSessionName; } + + public AwsCredentialsOptions toCredentialsOptions() { + return AwsCredentialsOptions.builder() + .withRegion(region) + .withStsRoleArn(stsRoleArn) + .build(); + } } diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/TopicConfig.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/TopicConfig.java index 659d4b490a..f0441d7114 100644 --- a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/TopicConfig.java +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/TopicConfig.java @@ -10,16 +10,15 @@ import jakarta.validation.constraints.NotNull; import jakarta.validation.constraints.Size; import org.opensearch.dataprepper.model.types.ByteCount; - -import org.opensearch.dataprepper.plugins.kafka.common.KafkaDataConfig; import org.opensearch.dataprepper.plugins.kafka.util.MessageFormat; import java.time.Duration; + /** * * A helper class that helps to read consumer configuration values from * pipelines.yaml */ -public class TopicConfig implements KafkaDataConfig { +public class TopicConfig { static final boolean DEFAULT_AUTO_COMMIT = false; static final Duration DEFAULT_COMMIT_INTERVAL = Duration.ofSeconds(5); static final Duration DEFAULT_SESSION_TIMEOUT = Duration.ofSeconds(45); @@ -126,6 +125,12 @@ public class TopicConfig implements KafkaDataConfig { @JsonProperty("retention_period") private Long retentionPeriod=DEFAULT_RETENTION_PERIOD; + @JsonProperty("encryption_key") + private String encryptionKey; + + @JsonProperty("kms_key_id") + private String kmsKeyId; + public Long getRetentionPeriod() { return retentionPeriod; } @@ -142,6 +147,14 @@ public MessageFormat getSerdeFormat() { return serdeFormat; } + public String getEncryptionKey() { + return encryptionKey; + } + + public String getKmsKeyId() { + return kmsKeyId; + } + public Boolean getAutoCommit() { return autoCommit; } @@ -269,5 +282,4 @@ public Integer getNumberOfPartions() { public Short getReplicationFactor() { return replicationFactor; } - } diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumerFactory.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumerFactory.java index 5497dc35da..a3094011c1 100644 --- a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumerFactory.java +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumerFactory.java @@ -12,12 +12,17 @@ import org.apache.kafka.common.errors.BrokerNotAvailableException; import org.apache.kafka.common.serialization.Deserializer; import org.apache.kafka.common.serialization.StringDeserializer; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; import org.opensearch.dataprepper.model.buffer.Buffer; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.kafka.common.KafkaDataConfig; +import org.opensearch.dataprepper.plugins.kafka.common.KafkaDataConfigAdapter; import org.opensearch.dataprepper.plugins.kafka.common.PlaintextKafkaDataConfig; +import org.opensearch.dataprepper.plugins.kafka.common.aws.AwsContext; +import org.opensearch.dataprepper.plugins.kafka.common.key.KeyFactory; import org.opensearch.dataprepper.plugins.kafka.common.serialization.SerializationFactory; import org.opensearch.dataprepper.plugins.kafka.configuration.AuthConfig; import org.opensearch.dataprepper.plugins.kafka.configuration.KafkaConsumerConfig; @@ -48,11 +53,12 @@ public class KafkaCustomConsumerFactory { private final StringDeserializer stringDeserializer = new StringDeserializer(); private final SerializationFactory serializationFactory; + private final AwsCredentialsSupplier awsCredentialsSupplier; private String schemaType = MessageFormat.PLAINTEXT.toString(); - public KafkaCustomConsumerFactory(SerializationFactory serializationFactory) { - + public KafkaCustomConsumerFactory(SerializationFactory serializationFactory, AwsCredentialsSupplier awsCredentialsSupplier) { this.serializationFactory = serializationFactory; + this.awsCredentialsSupplier = awsCredentialsSupplier; } public List createConsumersForTopic(final KafkaConsumerConfig kafkaConsumerConfig, final TopicConfig topic, @@ -67,16 +73,20 @@ public List createConsumersForTopic(final KafkaConsumerConf final List consumers = new ArrayList<>(); + AwsContext awsContext = new AwsContext(kafkaConsumerConfig, awsCredentialsSupplier); + KeyFactory keyFactory = new KeyFactory(awsContext); + try { final int numWorkers = topic.getWorkers(); IntStream.range(0, numWorkers).forEach(index -> { - Deserializer keyDeserializer = (Deserializer) serializationFactory.getDeserializer(PlaintextKafkaDataConfig.plaintextDataConfig(topic)); + KafkaDataConfig dataConfig = new KafkaDataConfigAdapter(keyFactory, topic); + Deserializer keyDeserializer = (Deserializer) serializationFactory.getDeserializer(PlaintextKafkaDataConfig.plaintextDataConfig(dataConfig)); Deserializer valueDeserializer = null; if(schema == MessageFormat.PLAINTEXT) { valueDeserializer = KafkaSecurityConfigurer.getGlueSerializer(kafkaConsumerConfig); } if(valueDeserializer == null) { - valueDeserializer = (Deserializer) serializationFactory.getDeserializer(topic); + valueDeserializer = (Deserializer) serializationFactory.getDeserializer(dataConfig); } final KafkaConsumer kafkaConsumer = new KafkaConsumer<>(consumerProperties, keyDeserializer, valueDeserializer); diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/producer/KafkaCustomProducerFactory.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/producer/KafkaCustomProducerFactory.java index 27593fddf0..97c1794658 100644 --- a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/producer/KafkaCustomProducerFactory.java +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/producer/KafkaCustomProducerFactory.java @@ -2,11 +2,16 @@ import org.apache.kafka.clients.producer.KafkaProducer; import org.apache.kafka.common.serialization.Serializer; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.expression.ExpressionEvaluator; import org.opensearch.dataprepper.model.configuration.PluginSetting; import org.opensearch.dataprepper.model.plugin.PluginFactory; import org.opensearch.dataprepper.model.sink.SinkContext; +import org.opensearch.dataprepper.plugins.kafka.common.KafkaDataConfig; +import org.opensearch.dataprepper.plugins.kafka.common.KafkaDataConfigAdapter; import org.opensearch.dataprepper.plugins.kafka.common.PlaintextKafkaDataConfig; +import org.opensearch.dataprepper.plugins.kafka.common.aws.AwsContext; +import org.opensearch.dataprepper.plugins.kafka.common.key.KeyFactory; import org.opensearch.dataprepper.plugins.kafka.common.serialization.SerializationFactory; import org.opensearch.dataprepper.plugins.kafka.configuration.KafkaProducerConfig; import org.opensearch.dataprepper.plugins.kafka.configuration.SchemaConfig; @@ -27,21 +32,25 @@ public class KafkaCustomProducerFactory { private static final Logger LOG = LoggerFactory.getLogger(KafkaCustomConsumerFactory.class); private final SerializationFactory serializationFactory; + private final AwsCredentialsSupplier awsCredentialsSupplier; - public KafkaCustomProducerFactory(final SerializationFactory serializationFactory) { - + public KafkaCustomProducerFactory(final SerializationFactory serializationFactory, AwsCredentialsSupplier awsCredentialsSupplier) { this.serializationFactory = serializationFactory; + this.awsCredentialsSupplier = awsCredentialsSupplier; } public KafkaCustomProducer createProducer(final KafkaProducerConfig kafkaProducerConfig, final PluginFactory pluginFactory, final PluginSetting pluginSetting, final ExpressionEvaluator expressionEvaluator, final SinkContext sinkContext) { + AwsContext awsContext = new AwsContext(kafkaProducerConfig, awsCredentialsSupplier); + KeyFactory keyFactory = new KeyFactory(awsContext); prepareTopicAndSchema(kafkaProducerConfig); Properties properties = SinkPropertyConfigurer.getProducerProperties(kafkaProducerConfig); KafkaSecurityConfigurer.setAuthProperties(properties, kafkaProducerConfig, LOG); properties = Objects.requireNonNull(properties); TopicConfig topic = kafkaProducerConfig.getTopic(); - Serializer keyDeserializer = (Serializer) serializationFactory.getSerializer(PlaintextKafkaDataConfig.plaintextDataConfig(topic)); - Serializer valueSerializer = (Serializer) serializationFactory.getSerializer(topic); + KafkaDataConfig dataConfig = new KafkaDataConfigAdapter(keyFactory, topic); + Serializer keyDeserializer = (Serializer) serializationFactory.getSerializer(PlaintextKafkaDataConfig.plaintextDataConfig(dataConfig)); + Serializer valueSerializer = (Serializer) serializationFactory.getSerializer(dataConfig); final KafkaProducer producer = new KafkaProducer<>(properties, keyDeserializer, valueSerializer); final DLQSink dlqSink = new DLQSink(pluginFactory, kafkaProducerConfig, pluginSetting); return new KafkaCustomProducer(producer, diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/sink/KafkaSink.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/sink/KafkaSink.java index 8d855693bf..ea273c370d 100644 --- a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/sink/KafkaSink.java +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/sink/KafkaSink.java @@ -5,6 +5,7 @@ package org.opensearch.dataprepper.plugins.kafka.sink; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.expression.ExpressionEvaluator; import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin; import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor; @@ -69,7 +70,8 @@ public class KafkaSink extends AbstractSink> { @DataPrepperPluginConstructor public KafkaSink(final PluginSetting pluginSetting, final KafkaSinkConfig kafkaSinkConfig, final PluginFactory pluginFactory, - final ExpressionEvaluator expressionEvaluator, final SinkContext sinkContext) { + final ExpressionEvaluator expressionEvaluator, final SinkContext sinkContext, + AwsCredentialsSupplier awsCredentialsSupplier) { super(pluginSetting); this.pluginSetting = pluginSetting; this.kafkaSinkConfig = kafkaSinkConfig; @@ -79,7 +81,7 @@ public KafkaSink(final PluginSetting pluginSetting, final KafkaSinkConfig kafkaS this.sinkContext = sinkContext; SerializationFactory serializationFactory = new SerializationFactory(); - kafkaCustomProducerFactory = new KafkaCustomProducerFactory(serializationFactory); + kafkaCustomProducerFactory = new KafkaCustomProducerFactory(serializationFactory, awsCredentialsSupplier); } diff --git a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/buffer/KafkaBufferTest.java b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/buffer/KafkaBufferTest.java index 37a9c6e8bf..ba923c77f3 100644 --- a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/buffer/KafkaBufferTest.java +++ b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/buffer/KafkaBufferTest.java @@ -9,6 +9,7 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoSettings; import org.mockito.quality.Strictness; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.CheckpointState; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; @@ -101,6 +102,9 @@ class KafkaBufferTest { @Mock BlockingBuffer> blockingBuffer; + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + public KafkaBuffer> createObjectUnderTest() { try ( @@ -116,7 +120,7 @@ public KafkaBuffer> createObjectUnderTest() { })) { executorsMockedStatic.when(() -> Executors.newFixedThreadPool(anyInt())).thenReturn(executorService); - return new KafkaBuffer>(pluginSetting, bufferConfig, pluginFactory, acknowledgementSetManager, pluginMetrics); + return new KafkaBuffer>(pluginSetting, bufferConfig, pluginFactory, acknowledgementSetManager, pluginMetrics, awsCredentialsSupplier); } } diff --git a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/KafkaDataConfigAdapterTest.java b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/KafkaDataConfigAdapterTest.java new file mode 100644 index 0000000000..7e2aa9ef49 --- /dev/null +++ b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/KafkaDataConfigAdapterTest.java @@ -0,0 +1,57 @@ +package org.opensearch.dataprepper.plugins.kafka.common; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.plugins.kafka.common.key.KeyFactory; +import org.opensearch.dataprepper.plugins.kafka.configuration.TopicConfig; +import org.opensearch.dataprepper.plugins.kafka.util.MessageFormat; + +import java.util.UUID; +import java.util.function.Supplier; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class KafkaDataConfigAdapterTest { + @Mock + private KeyFactory keyFactory; + @Mock + private TopicConfig topicConfig; + + private KafkaDataConfigAdapter createObjectUnderTest() { + return new KafkaDataConfigAdapter(keyFactory, topicConfig); + } + + @ParameterizedTest + @EnumSource(MessageFormat.class) + void getSerdeFormat_returns_TopicConfig_getSerdeFormat(MessageFormat serdeFormat) { + when(topicConfig.getSerdeFormat()).thenReturn(serdeFormat); + assertThat(createObjectUnderTest().getSerdeFormat(), + equalTo(serdeFormat)); + } + + @Test + void getEncryptionKeySupplier_returns_null_if_encryptionKey_is_null() { + assertThat(createObjectUnderTest().getEncryptionKeySupplier(), + nullValue()); + } + + @Test + void getEncryptionKeySupplier_returns_keyFactory_getKeySupplier_if_encryptionKey_is_present() { + String encryptionKey = UUID.randomUUID().toString(); + when(topicConfig.getEncryptionKey()).thenReturn(encryptionKey); + Supplier keySupplier = mock(Supplier.class); + when(keyFactory.getKeySupplier(topicConfig)).thenReturn(keySupplier); + + assertThat(createObjectUnderTest().getEncryptionKeySupplier(), + equalTo(keySupplier)); + } +} \ No newline at end of file diff --git a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/PlaintextKafkaDataConfigTest.java b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/PlaintextKafkaDataConfigTest.java index 5c7c6059bc..ef8e3bd4c6 100644 --- a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/PlaintextKafkaDataConfigTest.java +++ b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/PlaintextKafkaDataConfigTest.java @@ -1,9 +1,12 @@ package org.opensearch.dataprepper.plugins.kafka.common; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.EnumSource; import org.opensearch.dataprepper.plugins.kafka.util.MessageFormat; +import java.util.function.Supplier; + import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.notNullValue; import static org.hamcrest.MatcherAssert.assertThat; @@ -23,4 +26,16 @@ void plaintextDataConfig_returns_KafkaDataConfig_with_getSerdeFormat_returning_P assertThat(outputDataConfig.getSerdeFormat(), equalTo(MessageFormat.PLAINTEXT)); } + + @Test + void plaintextDataConfig_returns_KafkaDataConfig_with_getEncryptionKeySupplier_returning_from_inner_dataConfig() { + KafkaDataConfig inputDataConfig = mock(KafkaDataConfig.class); + Supplier keySupplier = mock(Supplier.class); + when(inputDataConfig.getEncryptionKeySupplier()).thenReturn(keySupplier); + + KafkaDataConfig outputDataConfig = PlaintextKafkaDataConfig.plaintextDataConfig(inputDataConfig); + + assertThat(outputDataConfig, notNullValue()); + assertThat(outputDataConfig.getEncryptionKeySupplier(), equalTo(keySupplier)); + } } \ No newline at end of file diff --git a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/aws/AwsContextTest.java b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/aws/AwsContextTest.java new file mode 100644 index 0000000000..0c35434be0 --- /dev/null +++ b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/aws/AwsContextTest.java @@ -0,0 +1,86 @@ +package org.opensearch.dataprepper.plugins.kafka.common.aws; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.plugins.kafka.configuration.AwsConfig; +import org.opensearch.dataprepper.plugins.kafka.configuration.KafkaConnectionConfig; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class AwsContextTest { + @Mock + private KafkaConnectionConfig connectionConfig; + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + + private AwsContext createObjectUnderTest() { + return new AwsContext(connectionConfig, awsCredentialsSupplier); + } + + @Test + void get_uses_defaultOptions_when_awsConfig_is_null() { + AwsCredentialsProvider awsCredentialsProvider = mock(AwsCredentialsProvider.class); + when(awsCredentialsSupplier.getProvider(AwsCredentialsOptions.defaultOptions())) + .thenReturn(awsCredentialsProvider); + + assertThat(createObjectUnderTest().get(), equalTo(awsCredentialsProvider)); + } + + @Nested + class WithAwsConfig { + + @Mock + private AwsConfig awsConfig; + @Mock + private AwsCredentialsOptions awsCredentialsOptions; + + @BeforeEach + void setUp() { + when(connectionConfig.getAwsConfig()).thenReturn(awsConfig); + } + + @Test + void get_uses_uses_awsConfig_to() { + when(awsConfig.toCredentialsOptions()).thenReturn(awsCredentialsOptions); + AwsCredentialsProvider awsCredentialsProvider = mock(AwsCredentialsProvider.class); + when(awsCredentialsSupplier.getProvider(awsCredentialsOptions)) + .thenReturn(awsCredentialsProvider); + + assertThat(createObjectUnderTest().get(), equalTo(awsCredentialsProvider)); + } + + @Test + void getRegion_returns_null_if_AwsConfig_region_is_null() { + assertThat(createObjectUnderTest().getRegion(), nullValue()); + } + + @ParameterizedTest + @ValueSource(strings = {"us-east-2", "eu-west-3", "ap-northeast-1"}) + void getRegion_returns_Region_of(String regionString) { + when(awsConfig.getRegion()).thenReturn(regionString); + + Region region = Region.of(regionString); + assertThat(createObjectUnderTest().getRegion(), equalTo(region)); + } + } + + @Test + void getRegion_returns_null_if_no_AwsConfig() { + assertThat(createObjectUnderTest().getRegion(), nullValue()); + } +} \ No newline at end of file diff --git a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/key/KeyFactoryTest.java b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/key/KeyFactoryTest.java new file mode 100644 index 0000000000..d4d2f506b5 --- /dev/null +++ b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/key/KeyFactoryTest.java @@ -0,0 +1,65 @@ +package org.opensearch.dataprepper.plugins.kafka.common.key; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.plugins.kafka.configuration.TopicConfig; + +import java.util.List; +import java.util.UUID; +import java.util.function.Supplier; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class KeyFactoryTest { + private List innerKeyProviders; + @Mock + private TopicConfig topicConfig; + + @BeforeEach + void setUp() { + innerKeyProviders = List.of( + mock(InnerKeyProvider.class), + mock(InnerKeyProvider.class), + mock(InnerKeyProvider.class) + ); + } + + private KeyFactory createObjectUnderTest() { + return new KeyFactory(innerKeyProviders); + } + + @Test + void getKeySupplier_returns_null_if_encryptionKey_is_null() { + assertThat(createObjectUnderTest().getKeySupplier(topicConfig), + nullValue()); + } + + @Test + void getKeySupplier_returns_using_first_InnerKeyFactory_that_supports_the_TopicConfig() { + when(topicConfig.getEncryptionKey()).thenReturn(UUID.randomUUID().toString()); + + InnerKeyProvider middleKeyProvider = innerKeyProviders.get(1); + when(middleKeyProvider.supportsConfiguration(topicConfig)).thenReturn(true); + + Supplier keySupplier = createObjectUnderTest().getKeySupplier(topicConfig); + + assertThat(keySupplier, notNullValue()); + + byte[] expectedBytes = UUID.randomUUID().toString().getBytes(); + when(middleKeyProvider.apply(topicConfig)).thenReturn(expectedBytes); + assertThat(keySupplier.get(), equalTo(expectedBytes)); + + InnerKeyProvider lastKeyProvider = innerKeyProviders.get(2); + verifyNoInteractions(lastKeyProvider); + } +} \ No newline at end of file diff --git a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/key/KmsKeyProviderTest.java b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/key/KmsKeyProviderTest.java new file mode 100644 index 0000000000..4fd241f65f --- /dev/null +++ b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/key/KmsKeyProviderTest.java @@ -0,0 +1,132 @@ +package org.opensearch.dataprepper.plugins.kafka.common.key; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.plugins.kafka.common.aws.AwsContext; +import org.opensearch.dataprepper.plugins.kafka.configuration.TopicConfig; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.kms.KmsClient; +import software.amazon.awssdk.services.kms.KmsClientBuilder; +import software.amazon.awssdk.services.kms.model.DecryptRequest; +import software.amazon.awssdk.services.kms.model.DecryptResponse; + +import java.util.Base64; +import java.util.UUID; +import java.util.function.Consumer; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class KmsKeyProviderTest { + @Mock + private AwsContext awsContext; + @Mock + private AwsCredentialsProvider awsCredentialsProvider; + @Mock + private TopicConfig topicConfig; + + private KmsKeyProvider createObjectUnderTest() { + return new KmsKeyProvider(awsContext); + } + + @Test + void supportsConfiguration_returns_false_if_kmsKeyId_is_null() { + assertThat(createObjectUnderTest().supportsConfiguration(topicConfig), equalTo(false)); + } + + @Test + void supportsConfiguration_returns_true_if_kmsKeyId_is_present() { + when(topicConfig.getKmsKeyId()).thenReturn(UUID.randomUUID().toString()); + assertThat(createObjectUnderTest().supportsConfiguration(topicConfig), equalTo(true)); + } + + @Nested + class WithKmsKey { + private String kmsKeyId; + + private Region region; + + private String encryptionKey; + private KmsClientBuilder kmsClientBuilder; + private KmsClient kmsClient; + private byte[] decryptedBytes; + + @BeforeEach + void setUp() { + kmsKeyId = UUID.randomUUID().toString(); + region = mock(Region.class); + + when(awsContext.get()).thenReturn(awsCredentialsProvider); + when(awsContext.getRegion()).thenReturn(region); + + encryptionKey = UUID.randomUUID().toString(); + String base64EncryptionKey = Base64.getEncoder().encodeToString(encryptionKey.getBytes()); + when(topicConfig.getEncryptionKey()).thenReturn(base64EncryptionKey); + when(topicConfig.getKmsKeyId()).thenReturn(kmsKeyId); + + kmsClient = mock(KmsClient.class); + DecryptResponse decryptResponse = mock(DecryptResponse.class); + decryptedBytes = UUID.randomUUID().toString().getBytes(); + when(decryptResponse.plaintext()).thenReturn(SdkBytes.fromByteArray(decryptedBytes)); + when(kmsClient.decrypt(any(Consumer.class))).thenReturn(decryptResponse); + + kmsClientBuilder = mock(KmsClientBuilder.class); + when(kmsClientBuilder.credentialsProvider(awsCredentialsProvider)).thenReturn(kmsClientBuilder); + when(kmsClientBuilder.region(region)).thenReturn(kmsClientBuilder); + when(kmsClientBuilder.build()).thenReturn(kmsClient); + } + + @Test + void apply_returns_plaintext_from_decrypt_request() { + byte[] actualBytes; + KmsKeyProvider objectUnderTest = createObjectUnderTest(); + try (MockedStatic kmsClientMockedStatic = mockStatic(KmsClient.class)) { + kmsClientMockedStatic.when(() -> KmsClient.builder()).thenReturn(kmsClientBuilder); + actualBytes = objectUnderTest.apply(topicConfig); + } + + assertThat(actualBytes, equalTo(decryptedBytes)); + } + + @Test + void apply_calls_decrypt_with_correct_values() { + KmsKeyProvider objectUnderTest = createObjectUnderTest(); + try (MockedStatic kmsClientMockedStatic = mockStatic(KmsClient.class)) { + kmsClientMockedStatic.when(() -> KmsClient.builder()).thenReturn(kmsClientBuilder); + objectUnderTest.apply(topicConfig); + } + + ArgumentCaptor> consumerArgumentCaptor = ArgumentCaptor.forClass(Consumer.class); + verify(kmsClient).decrypt(consumerArgumentCaptor.capture()); + + Consumer actualConsumer = consumerArgumentCaptor.getValue(); + + DecryptRequest.Builder builder = mock(DecryptRequest.Builder.class); + when(builder.keyId(anyString())).thenReturn(builder); + when(builder.ciphertextBlob(any())).thenReturn(builder); + actualConsumer.accept(builder); + + verify(builder).keyId(kmsKeyId); + ArgumentCaptor actualBytesCaptor = ArgumentCaptor.forClass(SdkBytes.class); + verify(builder).ciphertextBlob(actualBytesCaptor.capture()); + + SdkBytes actualSdkBytes = actualBytesCaptor.getValue(); + assertThat(actualSdkBytes.asByteArray(), equalTo(encryptionKey.getBytes())); + } + } +} \ No newline at end of file diff --git a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/key/UnencryptedKeyProviderTest.java b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/key/UnencryptedKeyProviderTest.java new file mode 100644 index 0000000000..4ad3f7b779 --- /dev/null +++ b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/key/UnencryptedKeyProviderTest.java @@ -0,0 +1,42 @@ +package org.opensearch.dataprepper.plugins.kafka.common.key; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.plugins.kafka.configuration.TopicConfig; + +import java.util.Base64; +import java.util.UUID; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class UnencryptedKeyProviderTest { + @Mock + private TopicConfig topicConfig; + + private UnencryptedKeyProvider createObjectUnderTest() { + return new UnencryptedKeyProvider(); + } + + @Test + void supportsConfiguration_returns_true() { + assertThat(createObjectUnderTest().supportsConfiguration(topicConfig), equalTo(true)); + } + + @Test + void apply_returns_base64_decoded_encryptionKey() { + String unencodedInput = UUID.randomUUID().toString(); + String base64InputString = Base64.getEncoder().encodeToString(unencodedInput.getBytes()); + + when(topicConfig.getEncryptionKey()).thenReturn(base64InputString); + + byte[] actualBytes = createObjectUnderTest().apply(topicConfig); + assertThat(actualBytes, notNullValue()); + assertThat(actualBytes, equalTo(unencodedInput.getBytes())); + } +} \ No newline at end of file diff --git a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/DecryptionDeserializerTest.java b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/DecryptionDeserializerTest.java new file mode 100644 index 0000000000..8a33332af8 --- /dev/null +++ b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/DecryptionDeserializerTest.java @@ -0,0 +1,57 @@ +package org.opensearch.dataprepper.plugins.kafka.common.serialization; + +import org.apache.kafka.common.serialization.Deserializer; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import javax.crypto.BadPaddingException; +import javax.crypto.Cipher; +import javax.crypto.IllegalBlockSizeException; +import javax.crypto.NoSuchPaddingException; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.util.UUID; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class DecryptionDeserializerTest { + @Mock + private Deserializer innerDeserializer; + @Mock + private EncryptionContext encryptionContext; + @Mock + private Cipher cipher; + private String topicName; + + @BeforeEach + void setUp() throws NoSuchPaddingException, NoSuchAlgorithmException, InvalidKeyException { + topicName = UUID.randomUUID().toString(); + + when(encryptionContext.createDecryptionCipher()).thenReturn(cipher); + } + + private DecryptionDeserializer createObjectUnderTest() throws NoSuchPaddingException, NoSuchAlgorithmException, InvalidKeyException { + return new DecryptionDeserializer<>(innerDeserializer, encryptionContext); + } + + @Test + void deserialize_calls_innerDeserializer_on_cipher_doFinal() throws IllegalBlockSizeException, BadPaddingException, NoSuchPaddingException, NoSuchAlgorithmException, InvalidKeyException { + byte[] decryptedData = UUID.randomUUID().toString().getBytes(); + Object deserializedContent = UUID.randomUUID().toString(); + + when(innerDeserializer.deserialize(topicName, decryptedData)) + .thenReturn(deserializedContent); + + byte[] inputData = UUID.randomUUID().toString().getBytes(); + when(cipher.doFinal(inputData)).thenReturn(decryptedData); + + assertThat(createObjectUnderTest().deserialize(topicName, inputData), + equalTo(deserializedContent)); + } +} \ No newline at end of file diff --git a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/EncryptionContextTest.java b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/EncryptionContextTest.java new file mode 100644 index 0000000000..0f8572e6b9 --- /dev/null +++ b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/EncryptionContextTest.java @@ -0,0 +1,65 @@ +package org.opensearch.dataprepper.plugins.kafka.common.serialization; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import javax.crypto.BadPaddingException; +import javax.crypto.Cipher; +import javax.crypto.IllegalBlockSizeException; +import javax.crypto.KeyGenerator; +import javax.crypto.NoSuchPaddingException; +import javax.crypto.SecretKey; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.util.UUID; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.MatcherAssert.assertThat; + +class EncryptionContextTest { + private SecretKey aesKey; + + @BeforeEach + void setUp() throws NoSuchAlgorithmException { + aesKey = createAesKey(); + } + + private EncryptionContext createObjectUnderTest() { + return new EncryptionContext(aesKey); + } + + @Test + void encryption_and_decryption_are_symmetric() throws NoSuchPaddingException, NoSuchAlgorithmException, InvalidKeyException, IllegalBlockSizeException, BadPaddingException { + EncryptionContext objectUnderTest = createObjectUnderTest(); + + Cipher encryptionCipher = objectUnderTest.createEncryptionCipher(); + + Cipher decryptionCipher = objectUnderTest.createDecryptionCipher(); + + byte[] inputBytes = UUID.randomUUID().toString().getBytes(); + + byte[] encryptedBytes = encryptionCipher.doFinal(inputBytes); + + byte[] decryptedBytes = decryptionCipher.doFinal(encryptedBytes); + + assertThat(decryptedBytes, equalTo(inputBytes)); + } + + @Test + void fromEncryptionKey_includes_correct_Key() { + byte[] key = aesKey.getEncoded(); + + EncryptionContext encryptionContext = EncryptionContext.fromEncryptionKey(key); + + assertThat(encryptionContext.getEncryptionKey(), notNullValue()); + assertThat(encryptionContext.getEncryptionKey().getEncoded(), equalTo(aesKey.getEncoded())); + assertThat(encryptionContext.getEncryptionKey().getAlgorithm(), equalTo("AES")); + } + + private static SecretKey createAesKey() throws NoSuchAlgorithmException { + KeyGenerator aesKeyGenerator = KeyGenerator.getInstance("AES"); + aesKeyGenerator.init(256); + return aesKeyGenerator.generateKey(); + } +} \ No newline at end of file diff --git a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/EncryptionSerializationFactoryTest.java b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/EncryptionSerializationFactoryTest.java new file mode 100644 index 0000000000..2b98a7087c --- /dev/null +++ b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/EncryptionSerializationFactoryTest.java @@ -0,0 +1,103 @@ +package org.opensearch.dataprepper.plugins.kafka.common.serialization; + +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serializer; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.plugins.kafka.common.KafkaDataConfig; + +import java.security.Key; +import java.util.UUID; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class EncryptionSerializationFactoryTest { + @Mock + private KafkaDataConfig dataConfig; + + private EncryptionSerializationFactory createObjectUnderTest() { + return new EncryptionSerializationFactory(); + } + + @Test + void getDeserializer_returns_innerDeserializer_if_encryptionKey_is_null() { + Deserializer innerDeserializer = mock(Deserializer.class); + + assertThat(createObjectUnderTest().getDeserializer(dataConfig, innerDeserializer), + equalTo(innerDeserializer)); + } + + @Test + void getDeserializer_returns_DecryptionDeserializer_if_encryptionKey_is_present() { + Deserializer innerDeserializer = mock(Deserializer.class); + + byte[] encryptionKey = UUID.randomUUID().toString().getBytes(); + when(dataConfig.getEncryptionKeySupplier()).thenReturn(() -> encryptionKey); + EncryptionContext encryptionContext = mock(EncryptionContext.class); + Key key = mock(Key.class); + when(encryptionContext.getEncryptionKey()).thenReturn(key); + + EncryptionSerializationFactory objectUnderTest = createObjectUnderTest(); + + Deserializer actualDeserializer; + + try(MockedStatic encryptionContextMockedStatic = mockStatic(EncryptionContext.class)) { + encryptionContextMockedStatic.when(() -> EncryptionContext.fromEncryptionKey(encryptionKey)) + .thenReturn(encryptionContext); + actualDeserializer = objectUnderTest.getDeserializer(dataConfig, innerDeserializer); + } + + assertThat(actualDeserializer, instanceOf(DecryptionDeserializer.class)); + DecryptionDeserializer decryptionDeserializer = (DecryptionDeserializer) actualDeserializer; + + assertThat(decryptionDeserializer.getEncryptionContext(), notNullValue()); + assertThat(decryptionDeserializer.getEncryptionContext().getEncryptionKey(), equalTo(key)); + assertThat(decryptionDeserializer.getInnerDeserializer(), equalTo(innerDeserializer)); + } + + @Test + void getSerializer_returns_innerSerializer_if_encryptionKey_is_null() { + Serializer innerSerializer = mock(Serializer.class); + + assertThat(createObjectUnderTest().getSerializer(dataConfig, innerSerializer), + equalTo(innerSerializer)); + } + + @Test + void getSerializer_returns_EncryptionSerializer_if_encryptionKey_is_present() { + Serializer innerSerializer = mock(Serializer.class); + + byte[] encryptionKey = UUID.randomUUID().toString().getBytes(); + when(dataConfig.getEncryptionKeySupplier()).thenReturn(() -> encryptionKey); + EncryptionContext encryptionContext = mock(EncryptionContext.class); + Key key = mock(Key.class); + when(encryptionContext.getEncryptionKey()).thenReturn(key); + + EncryptionSerializationFactory objectUnderTest = createObjectUnderTest(); + + Serializer actualSerializer; + + try(MockedStatic encryptionContextMockedStatic = mockStatic(EncryptionContext.class)) { + encryptionContextMockedStatic.when(() -> EncryptionContext.fromEncryptionKey(encryptionKey)) + .thenReturn(encryptionContext); + actualSerializer = objectUnderTest.getSerializer(dataConfig, innerSerializer); + } + + assertThat(actualSerializer, instanceOf(EncryptionSerializer.class)); + EncryptionSerializer encryptionSerializer = (EncryptionSerializer) actualSerializer; + + assertThat(encryptionSerializer.getEncryptionContext(), notNullValue()); + assertThat(encryptionSerializer.getEncryptionContext().getEncryptionKey(), equalTo(key)); + assertThat(encryptionSerializer.getInnerSerializer(), equalTo(innerSerializer)); + } +} \ No newline at end of file diff --git a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/EncryptionSerializerTest.java b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/EncryptionSerializerTest.java new file mode 100644 index 0000000000..941b27e1ea --- /dev/null +++ b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/EncryptionSerializerTest.java @@ -0,0 +1,57 @@ +package org.opensearch.dataprepper.plugins.kafka.common.serialization; + +import org.apache.kafka.common.serialization.Serializer; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import javax.crypto.BadPaddingException; +import javax.crypto.Cipher; +import javax.crypto.IllegalBlockSizeException; +import javax.crypto.NoSuchPaddingException; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.util.UUID; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class EncryptionSerializerTest { + @Mock + private Serializer innerSerializer; + @Mock + private EncryptionContext encryptionContext; + @Mock + private Cipher cipher; + private String topicName; + + @BeforeEach + void setUp() throws NoSuchPaddingException, NoSuchAlgorithmException, InvalidKeyException { + topicName = UUID.randomUUID().toString(); + + when(encryptionContext.createEncryptionCipher()).thenReturn(cipher); + } + + private EncryptionSerializer createObjectUnderTest() throws NoSuchPaddingException, NoSuchAlgorithmException, InvalidKeyException { + return new EncryptionSerializer<>(innerSerializer, encryptionContext); + } + + @Test + void serialize_performs_cipher_encryption_on_serialized_data() throws IllegalBlockSizeException, BadPaddingException, NoSuchPaddingException, NoSuchAlgorithmException, InvalidKeyException { + byte[] plaintextData = UUID.randomUUID().toString().getBytes(); + Object input = UUID.randomUUID().toString(); + + when(innerSerializer.serialize(topicName, input)) + .thenReturn(plaintextData); + + byte[] encryptedData = UUID.randomUUID().toString().getBytes(); + when(cipher.doFinal(plaintextData)).thenReturn(encryptedData); + + assertThat(createObjectUnderTest().serialize(topicName, input), + equalTo(encryptedData)); + } +} \ No newline at end of file diff --git a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/SerializationFactoryTest.java b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/SerializationFactoryTest.java index a2ea18030b..6a691efa38 100644 --- a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/SerializationFactoryTest.java +++ b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/serialization/SerializationFactoryTest.java @@ -19,21 +19,26 @@ class SerializationFactoryTest { @Mock private MessageFormatSerializationFactory messageFormatSerializationFactory; + @Mock + private EncryptionSerializationFactory encryptionSerializationFactory; @Mock private KafkaDataConfig kafkaDataConfig; private SerializationFactory createObjectUnderTest() { - return new SerializationFactory(messageFormatSerializationFactory); + return new SerializationFactory(messageFormatSerializationFactory, encryptionSerializationFactory); } @ParameterizedTest @EnumSource(MessageFormat.class) - void getDeserializer_returns_result_of_getDeserializer(MessageFormat serdeFormat) { + void getDeserializer_returns_result_of_EncryptionSerializationFactory_from_MessageFormatSerializationFactory(MessageFormat serdeFormat) { + Deserializer innerDeserializer = mock(Deserializer.class); Deserializer deserializer = mock(Deserializer.class); when(kafkaDataConfig.getSerdeFormat()).thenReturn(serdeFormat); when(messageFormatSerializationFactory.getDeserializer(serdeFormat)) + .thenReturn(innerDeserializer); + when(encryptionSerializationFactory.getDeserializer(kafkaDataConfig, innerDeserializer)) .thenReturn(deserializer); assertThat(createObjectUnderTest().getDeserializer(kafkaDataConfig), @@ -42,11 +47,14 @@ void getDeserializer_returns_result_of_getDeserializer(MessageFormat serdeFormat @ParameterizedTest @EnumSource(MessageFormat.class) - void getSerializer_returns_result_of_getSerializer(MessageFormat serdeFormat) { + void getSerializer_returns_result_of_EncryptionSerializationFactory_from_MessageFormatSerializationFactory(MessageFormat serdeFormat) { + Serializer innerSerializer = mock(Serializer.class); Serializer serializer = mock(Serializer.class); when(kafkaDataConfig.getSerdeFormat()).thenReturn(serdeFormat); when(messageFormatSerializationFactory.getSerializer(serdeFormat)) + .thenReturn(innerSerializer); + when(encryptionSerializationFactory.getSerializer(kafkaDataConfig, innerSerializer)) .thenReturn(serializer); assertThat(createObjectUnderTest().getSerializer(kafkaDataConfig), diff --git a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/sink/KafkaSinkTest.java b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/sink/KafkaSinkTest.java index 06690f081a..8758b13881 100644 --- a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/sink/KafkaSinkTest.java +++ b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/sink/KafkaSinkTest.java @@ -18,6 +18,7 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoSettings; import org.mockito.quality.Strictness; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.expression.ExpressionEvaluator; import org.opensearch.dataprepper.model.configuration.PluginSetting; import org.opensearch.dataprepper.model.event.Event; @@ -91,6 +92,9 @@ public class KafkaSinkTest { @Mock SinkContext sinkContext; + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + @BeforeEach void setUp() throws Exception { @@ -112,7 +116,7 @@ void setUp() throws Exception { when(pluginSetting.getPipelineName()).thenReturn("Kafka-sink"); event = JacksonEvent.fromMessage(UUID.randomUUID().toString()); when(sinkContext.getTagsTargetKey()).thenReturn("tag"); - kafkaSink = new KafkaSink(pluginSetting, kafkaSinkConfig, pluginFactoryMock, mock(ExpressionEvaluator.class), sinkContext); + kafkaSink = new KafkaSink(pluginSetting, kafkaSinkConfig, pluginFactoryMock, mock(ExpressionEvaluator.class), sinkContext, awsCredentialsSupplier); spySink = spy(kafkaSink); executorsMockedStatic = mockStatic(Executors.class); props = new Properties();