From 38fe2afb91670d5d8ba5a278dd96de0274ad8f58 Mon Sep 17 00:00:00 2001 From: Ivan Tse <115105835+ivan-tse@users.noreply.github.com> Date: Wed, 14 Aug 2024 09:36:12 -0700 Subject: [PATCH] PersonalizeSink: add client and configuration classes (#4803) PersonalizeSink: add client and configuration classes Signed-off-by: Ivan Tse --- .../personalize-sink/build.gradle | 48 ++++ .../sink/personalize/ClientFactory.java | 58 +++++ .../sink/personalize/PersonalizeSink.java | 80 +++++++ .../personalize/PersonalizeSinkService.java | 68 ++++++ .../AwsAuthenticationOptions.java | 80 +++++++ .../PersonalizeAdvancedValidation.java | 4 + .../PersonalizeSinkConfiguration.java | 137 ++++++++++++ .../dataset/DatasetTypeOptions.java | 33 +++ .../sink/personalize/ClientFactoryTest.java | 135 ++++++++++++ .../sink/personalize/PersonalizeSinkTest.java | 85 ++++++++ .../AwsAuthenticationOptionsTest.java | 129 +++++++++++ .../PersonalizeSinkConfigurationTest.java | 205 ++++++++++++++++++ .../dataset/DatasetTypeOptionsTest.java | 38 ++++ settings.gradle | 1 + 14 files changed, 1101 insertions(+) create mode 100644 data-prepper-plugins/personalize-sink/build.gradle create mode 100644 data-prepper-plugins/personalize-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/personalize/ClientFactory.java create mode 100644 data-prepper-plugins/personalize-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/personalize/PersonalizeSink.java create mode 100644 data-prepper-plugins/personalize-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/personalize/PersonalizeSinkService.java create mode 100644 data-prepper-plugins/personalize-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/personalize/configuration/AwsAuthenticationOptions.java create mode 100644 data-prepper-plugins/personalize-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/personalize/configuration/PersonalizeAdvancedValidation.java create mode 100644 data-prepper-plugins/personalize-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/personalize/configuration/PersonalizeSinkConfiguration.java create mode 100644 data-prepper-plugins/personalize-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/personalize/dataset/DatasetTypeOptions.java create mode 100644 data-prepper-plugins/personalize-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/personalize/ClientFactoryTest.java create mode 100644 data-prepper-plugins/personalize-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/personalize/PersonalizeSinkTest.java create mode 100644 data-prepper-plugins/personalize-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/personalize/configuration/AwsAuthenticationOptionsTest.java create mode 100644 data-prepper-plugins/personalize-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/personalize/configuration/PersonalizeSinkConfigurationTest.java create mode 100644 data-prepper-plugins/personalize-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/personalize/dataset/DatasetTypeOptionsTest.java diff --git a/data-prepper-plugins/personalize-sink/build.gradle b/data-prepper-plugins/personalize-sink/build.gradle new file mode 100644 index 0000000000..bf408a04b8 --- /dev/null +++ b/data-prepper-plugins/personalize-sink/build.gradle @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +dependencies { + implementation project(':data-prepper-api') + implementation project(path: ':data-prepper-plugins:common') + implementation project(':data-prepper-plugins:aws-plugin-api') + implementation 'io.micrometer:micrometer-core' + implementation 'com.fasterxml.jackson.core:jackson-core' + implementation 'com.fasterxml.jackson.core:jackson-databind' + implementation 'org.hibernate.validator:hibernate-validator:8.0.1.Final' + implementation 'software.amazon.awssdk:personalizeevents' + implementation 'software.amazon.awssdk:sts' + implementation 'software.amazon.awssdk:arns' + testImplementation project(':data-prepper-test-common') + testImplementation testLibs.slf4j.simple +} + +sourceSets { + integrationTest { + java { + compileClasspath += main.output + test.output + runtimeClasspath += main.output + test.output + srcDir file('src/integrationTest/java') + } + resources.srcDir file('src/integrationTest/resources') + } +} + +configurations { + integrationTestImplementation.extendsFrom testImplementation + integrationTestRuntime.extendsFrom testRuntime +} + +task integrationTest(type: Test) { + group = 'verification' + testClassesDirs = sourceSets.integrationTest.output.classesDirs + + useJUnitPlatform() + + classpath = sourceSets.integrationTest.runtimeClasspath + + filter { + includeTestsMatching '*IT' + } +} diff --git a/data-prepper-plugins/personalize-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/personalize/ClientFactory.java b/data-prepper-plugins/personalize-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/personalize/ClientFactory.java new file mode 100644 index 0000000000..2c93fc991b --- /dev/null +++ b/data-prepper-plugins/personalize-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/personalize/ClientFactory.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.sink.personalize; + +import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.plugins.sink.personalize.configuration.AwsAuthenticationOptions; +import org.opensearch.dataprepper.plugins.sink.personalize.configuration.PersonalizeSinkConfiguration; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; +import software.amazon.awssdk.core.retry.RetryPolicy; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.personalizeevents.PersonalizeEventsClient; + +final class ClientFactory { + private ClientFactory() { } + + static PersonalizeEventsClient createPersonalizeEventsClient(final PersonalizeSinkConfiguration personalizeSinkConfig, final AwsCredentialsSupplier awsCredentialsSupplier) { + final AwsCredentialsOptions awsCredentialsOptions = convertToCredentialsOptions(personalizeSinkConfig.getAwsAuthenticationOptions()); + final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsSupplier.getProvider(awsCredentialsOptions); + + return PersonalizeEventsClient.builder() + .region(getRegion(personalizeSinkConfig, awsCredentialsSupplier)) + .credentialsProvider(awsCredentialsProvider) + .overrideConfiguration(createOverrideConfiguration(personalizeSinkConfig)).build(); + } + + private static ClientOverrideConfiguration createOverrideConfiguration(final PersonalizeSinkConfiguration personalizeSinkConfig) { + final RetryPolicy retryPolicy = RetryPolicy.builder().numRetries(personalizeSinkConfig.getMaxRetries()).build(); + return ClientOverrideConfiguration.builder() + .retryPolicy(retryPolicy) + .build(); + } + + private static AwsCredentialsOptions convertToCredentialsOptions(final AwsAuthenticationOptions awsAuthenticationOptions) { + if (awsAuthenticationOptions == null) { + return AwsCredentialsOptions.builder().build(); + } + return AwsCredentialsOptions.builder() + .withRegion(awsAuthenticationOptions.getAwsRegion().orElse(null)) + .withStsRoleArn(awsAuthenticationOptions.getAwsStsRoleArn()) + .withStsExternalId(awsAuthenticationOptions.getAwsStsExternalId()) + .withStsHeaderOverrides(awsAuthenticationOptions.getAwsStsHeaderOverrides()) + .build(); + } + + private static Region getRegion(final PersonalizeSinkConfiguration personalizeSinkConfig, final AwsCredentialsSupplier awsCredentialsSupplier) { + Region defaultRegion = awsCredentialsSupplier.getDefaultRegion().orElse(null); + if (personalizeSinkConfig.getAwsAuthenticationOptions() == null) { + return defaultRegion; + } else { + return personalizeSinkConfig.getAwsAuthenticationOptions().getAwsRegion().orElse(defaultRegion); + } + } +} \ No newline at end of file diff --git a/data-prepper-plugins/personalize-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/personalize/PersonalizeSink.java b/data-prepper-plugins/personalize-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/personalize/PersonalizeSink.java new file mode 100644 index 0000000000..a93e58875c --- /dev/null +++ b/data-prepper-plugins/personalize-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/personalize/PersonalizeSink.java @@ -0,0 +1,80 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.sink.personalize; + +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin; +import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor; +import org.opensearch.dataprepper.model.configuration.PluginSetting; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.sink.AbstractSink; +import org.opensearch.dataprepper.model.sink.Sink; +import org.opensearch.dataprepper.model.sink.SinkContext; +import org.opensearch.dataprepper.plugins.sink.personalize.configuration.PersonalizeSinkConfiguration; +import software.amazon.awssdk.services.personalizeevents.PersonalizeEventsClient; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; + +/** + * Implementation class of personalize-sink plugin. It is responsible for receiving the collection of + * {@link Event} and uploading to amazon personalize. + */ +@DataPrepperPlugin(name = "aws_personalize", pluginType = Sink.class, pluginConfigurationType = PersonalizeSinkConfiguration.class) +public class PersonalizeSink extends AbstractSink> { + + private static final Logger LOG = LoggerFactory.getLogger(PersonalizeSink.class); + + private final PersonalizeSinkConfiguration personalizeSinkConfig; + private volatile boolean sinkInitialized; + private final PersonalizeSinkService personalizeSinkService; + private final SinkContext sinkContext; + + /** + * @param pluginSetting dp plugin settings. + * @param personalizeSinkConfig personalize sink configurations. + * @param sinkContext sink context + * @param awsCredentialsSupplier aws credentials + * @param pluginFactory dp plugin factory. + */ + @DataPrepperPluginConstructor + public PersonalizeSink(final PluginSetting pluginSetting, + final PersonalizeSinkConfiguration personalizeSinkConfig, + final PluginFactory pluginFactory, + final SinkContext sinkContext, + final AwsCredentialsSupplier awsCredentialsSupplier) { + super(pluginSetting); + this.personalizeSinkConfig = personalizeSinkConfig; + this.sinkContext = sinkContext; + + sinkInitialized = false; + + final PersonalizeEventsClient personalizeEventsClient = ClientFactory.createPersonalizeEventsClient(personalizeSinkConfig, awsCredentialsSupplier); + + personalizeSinkService = new PersonalizeSinkService(personalizeSinkConfig, pluginMetrics); + } + + @Override + public boolean isReady() { + return sinkInitialized; + } + + @Override + public void doInitialize() { + sinkInitialized = true; + } + + /** + * @param records Records to be output + */ + @Override + public void doOutput(final Collection> records) { + personalizeSinkService.output(records); + } +} \ No newline at end of file diff --git a/data-prepper-plugins/personalize-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/personalize/PersonalizeSinkService.java b/data-prepper-plugins/personalize-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/personalize/PersonalizeSinkService.java new file mode 100644 index 0000000000..80ea94bcf1 --- /dev/null +++ b/data-prepper-plugins/personalize-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/personalize/PersonalizeSinkService.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.sink.personalize; + +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Timer; +import org.opensearch.dataprepper.plugins.sink.personalize.configuration.PersonalizeSinkConfiguration; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.record.Record; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; + +/** + * Class responsible for creating PersonalizeEventsClient object, check thresholds, + * get new buffer and write records into buffer. + */ +class PersonalizeSinkService { + + private static final Logger LOG = LoggerFactory.getLogger(PersonalizeSinkService.class); + public static final String RECORDS_SUCCEEDED = "personalizeRecordsSucceeded"; + public static final String RECORDS_FAILED = "personalizeRecordsFailed"; + public static final String RECORDS_INVALID = "personalizeRecordsInvalid"; + public static final String REQUESTS_THROTTLED = "personalizeRequestsThrottled"; + public static final String REQUEST_LATENCY = "personalizeRequestLatency"; + + private final PersonalizeSinkConfiguration personalizeSinkConfig; + private final Lock reentrantLock; + private final int maxRetries; + private final Counter recordsSucceededCounter; + private final Counter recordsFailedCounter; + private final Counter recordsInvalidCounter; + private final Counter requestsThrottledCounter; + private final Timer requestLatencyTimer; + + /** + * @param personalizeSinkConfig personalize sink related configuration. + * @param pluginMetrics metrics. + */ + public PersonalizeSinkService(final PersonalizeSinkConfiguration personalizeSinkConfig, + final PluginMetrics pluginMetrics) { + this.personalizeSinkConfig = personalizeSinkConfig; + reentrantLock = new ReentrantLock(); + + maxRetries = personalizeSinkConfig.getMaxRetries(); + + recordsSucceededCounter = pluginMetrics.counter(RECORDS_SUCCEEDED); + recordsFailedCounter = pluginMetrics.counter(RECORDS_FAILED); + recordsInvalidCounter = pluginMetrics.counter(RECORDS_INVALID); + requestsThrottledCounter = pluginMetrics.counter(REQUESTS_THROTTLED); + requestLatencyTimer = pluginMetrics.timer(REQUEST_LATENCY); + } + + /** + * @param records received records and add into buffer. + */ + void output(Collection> records) { + LOG.trace("{} records received", records.size()); + return; + } +} \ No newline at end of file diff --git a/data-prepper-plugins/personalize-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/personalize/configuration/AwsAuthenticationOptions.java b/data-prepper-plugins/personalize-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/personalize/configuration/AwsAuthenticationOptions.java new file mode 100644 index 0000000000..ba7e96d43d --- /dev/null +++ b/data-prepper-plugins/personalize-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/personalize/configuration/AwsAuthenticationOptions.java @@ -0,0 +1,80 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.sink.personalize.configuration; + +import com.fasterxml.jackson.annotation.JsonProperty; +import jakarta.validation.GroupSequence; +import jakarta.validation.constraints.AssertTrue; +import jakarta.validation.constraints.Size; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.arns.Arn; + +import java.util.Map; +import java.util.Optional; + +@GroupSequence({AwsAuthenticationOptions.class, PersonalizeAdvancedValidation.class}) +public class AwsAuthenticationOptions { + private static final String AWS_IAM_ROLE = "role"; + private static final String AWS_IAM = "iam"; + + @JsonProperty("region") + @Size(min = 1, message = "Region cannot be empty string") + private String awsRegion; + + @JsonProperty("sts_role_arn") + @Size(min = 20, max = 2048, message = "awsStsRoleArn length should be between 1 and 2048 characters") + private String awsStsRoleArn; + + @JsonProperty("sts_external_id") + @Size(min = 2, max = 1224, message = "awsStsExternalId length should be between 2 and 1224 characters") + private String awsStsExternalId; + + @JsonProperty("sts_header_overrides") + @Size(max = 5, message = "sts_header_overrides supports a maximum of 5 headers to override") + private Map awsStsHeaderOverrides; + + @AssertTrue(message = "sts_role_arn must be an IAM Role", groups = PersonalizeAdvancedValidation.class) + boolean isValidStsRoleArn() { + if (awsStsRoleArn == null) { + return true; + } + final Arn arn = getArn(); + boolean status = true; + if (!AWS_IAM.equals(arn.service())) { + status = false; + } + final Optional resourceType = arn.resource().resourceType(); + if (resourceType.isEmpty() || !resourceType.get().equals(AWS_IAM_ROLE)) { + status = false; + } + return status; + } + + private Arn getArn() { + try { + return Arn.fromString(awsStsRoleArn); + } catch (final Exception e) { + throw new IllegalArgumentException(String.format("Invalid ARN format for awsStsRoleArn. Check the format of %s", awsStsRoleArn)); + } + } + + public Optional getAwsRegion() { + Region region = awsRegion != null ? Region.of(awsRegion) : null; + return Optional.ofNullable(region); + } + + public String getAwsStsRoleArn() { + return awsStsRoleArn; + } + + public String getAwsStsExternalId() { + return awsStsExternalId; + } + + public Map getAwsStsHeaderOverrides() { + return awsStsHeaderOverrides; + } +} \ No newline at end of file diff --git a/data-prepper-plugins/personalize-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/personalize/configuration/PersonalizeAdvancedValidation.java b/data-prepper-plugins/personalize-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/personalize/configuration/PersonalizeAdvancedValidation.java new file mode 100644 index 0000000000..f48c1d9466 --- /dev/null +++ b/data-prepper-plugins/personalize-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/personalize/configuration/PersonalizeAdvancedValidation.java @@ -0,0 +1,4 @@ +package org.opensearch.dataprepper.plugins.sink.personalize.configuration; + +interface PersonalizeAdvancedValidation { +} diff --git a/data-prepper-plugins/personalize-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/personalize/configuration/PersonalizeSinkConfiguration.java b/data-prepper-plugins/personalize-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/personalize/configuration/PersonalizeSinkConfiguration.java new file mode 100644 index 0000000000..95c9f1d5c9 --- /dev/null +++ b/data-prepper-plugins/personalize-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/personalize/configuration/PersonalizeSinkConfiguration.java @@ -0,0 +1,137 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.dataprepper.plugins.sink.personalize.configuration; + +import jakarta.validation.GroupSequence; +import org.opensearch.dataprepper.plugins.sink.personalize.dataset.DatasetTypeOptions; +import com.fasterxml.jackson.annotation.JsonProperty; +import jakarta.validation.Valid; +import jakarta.validation.constraints.NotNull; +import jakarta.validation.constraints.AssertTrue; +import software.amazon.awssdk.arns.Arn; + +import java.util.List; +import java.util.Optional; + +/** + * personalize sink configuration class contains properties, used to read yaml configuration. + */ +@GroupSequence({PersonalizeSinkConfiguration.class, PersonalizeAdvancedValidation.class}) +public class PersonalizeSinkConfiguration { + private static final int DEFAULT_RETRIES = 10; + private static final String AWS_PERSONALIZE = "personalize"; + private static final String AWS_PERSONALIZE_DATASET = "dataset"; + private static final List DATASET_ARN_REQUIRED_LIST = List.of(DatasetTypeOptions.USERS, DatasetTypeOptions.ITEMS); + + @JsonProperty("aws") + @Valid + private AwsAuthenticationOptions awsAuthenticationOptions; + + @JsonProperty("dataset_type") + @NotNull + @Valid + private DatasetTypeOptions datasetType; + + @JsonProperty("dataset_arn") + private String datasetArn; + + @JsonProperty("tracking_id") + private String trackingId; + + @JsonProperty("document_root_key") + private String documentRootKey; + + @JsonProperty("max_retries") + private int maxRetries = DEFAULT_RETRIES; + + @AssertTrue(message = "A dataset arn is required for items and users datasets.", groups = PersonalizeAdvancedValidation.class) + boolean isDatasetArnProvidedWhenNeeded() { + if (DATASET_ARN_REQUIRED_LIST.contains(datasetType)) { + return datasetArn != null; + } + return true; + } + + @AssertTrue(message = "dataset_arn must be a Personalize Dataset arn", groups = PersonalizeAdvancedValidation.class) + boolean isValidDatasetArn() { + if (datasetArn == null) { + return true; + } + final Arn arn = getArn(); + boolean status = true; + if (!AWS_PERSONALIZE.equals(arn.service())) { + status = false; + } + final Optional resourceType = arn.resource().resourceType(); + if (resourceType.isEmpty() || !resourceType.get().equals(AWS_PERSONALIZE_DATASET)) { + status = false; + } + return status; + } + + private Arn getArn() { + try { + return Arn.fromString(datasetArn); + } catch (final Exception e) { + throw new IllegalArgumentException(String.format("Invalid ARN format for datasetArn. Check the format of %s", datasetArn), e); + } + } + + @AssertTrue(message = "A tracking id is required for interactions dataset.", groups = PersonalizeAdvancedValidation.class) + boolean isTrackingIdProvidedWhenNeeded() { + if (DatasetTypeOptions.INTERACTIONS.equals(datasetType)) { + return trackingId != null; + } + return true; + } + + /** + * Aws Authentication configuration Options. + * @return aws authentication options. + */ + public AwsAuthenticationOptions getAwsAuthenticationOptions() { + return awsAuthenticationOptions; + } + + /** + * Dataset type configuration Options. + * @return dataset type option object. + */ + public DatasetTypeOptions getDatasetType() { + return datasetType; + } + + /** + * Dataset arn for Personalize Dataset. + * @return dataset arn string. + */ + public String getDatasetArn() { + return datasetArn; + } + + /** + * Tracking id for Personalize Event Tracker. + * @return tracking id string. + */ + public String getTrackingId() { + return trackingId; + } + + /** + * Tracking id for Personalize Event Tracker. + * @return document root key string. + */ + public String getDocumentRootKey() { + return documentRootKey; + } + + /** + * Personalize client retries configuration Options. + * @return maximum retries value. + */ + public int getMaxRetries() { + return maxRetries; + } +} \ No newline at end of file diff --git a/data-prepper-plugins/personalize-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/personalize/dataset/DatasetTypeOptions.java b/data-prepper-plugins/personalize-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/personalize/dataset/DatasetTypeOptions.java new file mode 100644 index 0000000000..cc6791f0a6 --- /dev/null +++ b/data-prepper-plugins/personalize-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/personalize/dataset/DatasetTypeOptions.java @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.sink.personalize.dataset; + +import com.fasterxml.jackson.annotation.JsonCreator; +import java.util.Arrays; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Defines all the dataset types enumerations. + */ +public enum DatasetTypeOptions { + USERS("users"), + ITEMS("items"), + INTERACTIONS("interactions"); + + private final String option; + private static final Map OPTIONS_MAP = Arrays.stream(DatasetTypeOptions.values()) + .collect(Collectors.toMap(value -> value.option, value -> value)); + + DatasetTypeOptions(final String option) { + this.option = option.toLowerCase(); + } + + @JsonCreator + static DatasetTypeOptions fromOptionValue(final String option) { + return OPTIONS_MAP.get(option); + } +} \ No newline at end of file diff --git a/data-prepper-plugins/personalize-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/personalize/ClientFactoryTest.java b/data-prepper-plugins/personalize-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/personalize/ClientFactoryTest.java new file mode 100644 index 0000000000..6b1ad7f80a --- /dev/null +++ b/data-prepper-plugins/personalize-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/personalize/ClientFactoryTest.java @@ -0,0 +1,135 @@ +package org.opensearch.dataprepper.plugins.sink.personalize; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.plugins.sink.personalize.configuration.PersonalizeSinkConfiguration; +import org.opensearch.dataprepper.plugins.sink.personalize.configuration.AwsAuthenticationOptions; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.personalizeevents.PersonalizeEventsClient; +import software.amazon.awssdk.services.personalizeevents.PersonalizeEventsClientBuilder; + +import java.util.Collections; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class ClientFactoryTest { + @Mock + private PersonalizeSinkConfiguration personalizeSinkConfig; + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + + @Mock + private AwsAuthenticationOptions awsAuthenticationOptions; + + @BeforeEach + void setUp() { + when(personalizeSinkConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions); + } + + @Test + void createPersonalizeEventsClient_with_real_PersonalizeEventsClient() { + when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Optional.of(Region.US_EAST_1)); + final PersonalizeEventsClient personalizeEventsClient = ClientFactory.createPersonalizeEventsClient(personalizeSinkConfig, awsCredentialsSupplier); + + assertThat(personalizeEventsClient, notNullValue()); + } + + @Test + void createPersonalizeEventsClient_provides_correct_inputs_for_null_awsAuthenticationOptions() { + when(personalizeSinkConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions); + final AwsCredentialsProvider expectedCredentialsProvider = mock(AwsCredentialsProvider.class); + when(awsCredentialsSupplier.getProvider(any())).thenReturn(expectedCredentialsProvider); + + final PersonalizeEventsClientBuilder personalizeEventsClientBuilder = mock(PersonalizeEventsClientBuilder.class); + when(personalizeEventsClientBuilder.region(any())).thenReturn(personalizeEventsClientBuilder); + when(personalizeEventsClientBuilder.credentialsProvider(any())).thenReturn(personalizeEventsClientBuilder); + when(personalizeEventsClientBuilder.overrideConfiguration(any(ClientOverrideConfiguration.class))).thenReturn(personalizeEventsClientBuilder); + try(final MockedStatic personalizeEventsClientMockedStatic = mockStatic(PersonalizeEventsClient.class)) { + personalizeEventsClientMockedStatic.when(PersonalizeEventsClient::builder) + .thenReturn(personalizeEventsClientBuilder); + ClientFactory.createPersonalizeEventsClient(personalizeSinkConfig, awsCredentialsSupplier); + } + + final ArgumentCaptor credentialsProviderArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsProvider.class); + verify(personalizeEventsClientBuilder).credentialsProvider(credentialsProviderArgumentCaptor.capture()); + + final AwsCredentialsProvider actualCredentialsProvider = credentialsProviderArgumentCaptor.getValue(); + + assertThat(actualCredentialsProvider, equalTo(expectedCredentialsProvider)); + + final ArgumentCaptor optionsArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsOptions.class); + verify(awsCredentialsSupplier).getProvider(optionsArgumentCaptor.capture()); + + final AwsCredentialsOptions actualCredentialsOptions = optionsArgumentCaptor.getValue(); + assertThat(actualCredentialsOptions, is(notNullValue())); + assertThat(actualCredentialsOptions.getRegion(), equalTo(null)); + assertThat(actualCredentialsOptions.getStsRoleArn(), equalTo(null)); + assertThat(actualCredentialsOptions.getStsExternalId(), equalTo(null)); + assertThat(actualCredentialsOptions.getStsHeaderOverrides(), equalTo(Collections.emptyMap())); + } + + @ParameterizedTest + @ValueSource(strings = {"us-east-1", "us-west-2", "eu-central-1"}) + void createPersonalizeEventsClient_provides_correct_inputs(final String regionString) { + final Region region = Region.of(regionString); + final String stsRoleArn = UUID.randomUUID().toString(); + final String externalId = UUID.randomUUID().toString(); + final Map stsHeaderOverrides = Map.of(UUID.randomUUID().toString(), UUID.randomUUID().toString()); + when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Optional.of(region)); + when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn(stsRoleArn); + when(awsAuthenticationOptions.getAwsStsExternalId()).thenReturn(externalId); + when(awsAuthenticationOptions.getAwsStsHeaderOverrides()).thenReturn(stsHeaderOverrides); + + final AwsCredentialsProvider expectedCredentialsProvider = mock(AwsCredentialsProvider.class); + when(awsCredentialsSupplier.getProvider(any())).thenReturn(expectedCredentialsProvider); + + final PersonalizeEventsClientBuilder personalizeEventsClientBuilder = mock(PersonalizeEventsClientBuilder.class); + when(personalizeEventsClientBuilder.region(region)).thenReturn(personalizeEventsClientBuilder); + when(personalizeEventsClientBuilder.credentialsProvider(any())).thenReturn(personalizeEventsClientBuilder); + when(personalizeEventsClientBuilder.overrideConfiguration(any(ClientOverrideConfiguration.class))).thenReturn(personalizeEventsClientBuilder); + try(final MockedStatic personalizeEventsClientMockedStatic = mockStatic(PersonalizeEventsClient.class)) { + personalizeEventsClientMockedStatic.when(PersonalizeEventsClient::builder) + .thenReturn(personalizeEventsClientBuilder); + ClientFactory.createPersonalizeEventsClient(personalizeSinkConfig, awsCredentialsSupplier); + } + + final ArgumentCaptor credentialsProviderArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsProvider.class); + verify(personalizeEventsClientBuilder).credentialsProvider(credentialsProviderArgumentCaptor.capture()); + + final AwsCredentialsProvider actualCredentialsProvider = credentialsProviderArgumentCaptor.getValue(); + + assertThat(actualCredentialsProvider, equalTo(expectedCredentialsProvider)); + + final ArgumentCaptor optionsArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsOptions.class); + verify(awsCredentialsSupplier).getProvider(optionsArgumentCaptor.capture()); + + final AwsCredentialsOptions actualCredentialsOptions = optionsArgumentCaptor.getValue(); + assertThat(actualCredentialsOptions.getRegion(), equalTo(region)); + assertThat(actualCredentialsOptions.getStsRoleArn(), equalTo(stsRoleArn)); + assertThat(actualCredentialsOptions.getStsExternalId(), equalTo(externalId)); + assertThat(actualCredentialsOptions.getStsHeaderOverrides(), equalTo(stsHeaderOverrides)); + } +} \ No newline at end of file diff --git a/data-prepper-plugins/personalize-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/personalize/PersonalizeSinkTest.java b/data-prepper-plugins/personalize-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/personalize/PersonalizeSinkTest.java new file mode 100644 index 0000000000..852e75630f --- /dev/null +++ b/data-prepper-plugins/personalize-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/personalize/PersonalizeSinkTest.java @@ -0,0 +1,85 @@ +package org.opensearch.dataprepper.plugins.sink.personalize; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.model.configuration.PluginSetting; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.sink.SinkContext; +import org.opensearch.dataprepper.plugins.sink.personalize.configuration.AwsAuthenticationOptions; +import org.opensearch.dataprepper.plugins.sink.personalize.configuration.PersonalizeSinkConfiguration; +import org.opensearch.dataprepper.plugins.sink.personalize.dataset.DatasetTypeOptions; +import software.amazon.awssdk.regions.Region; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +class PersonalizeSinkTest { + public static final int MAX_RETRIES = 10; + public static final String REGION = "us-east-1"; + public static final String SINK_PLUGIN_NAME = "personalize"; + public static final String SINK_PIPELINE_NAME = "personalize-sink-pipeline"; + public static final String DATASET_ARN = "arn:aws:iam::123456789012:dataset/test"; + public static final String TRACKING_ID = "1233513241"; + private PersonalizeSinkConfiguration personalizeSinkConfig; + private PersonalizeSink personalizeSink; + private PluginSetting pluginSetting; + private PluginFactory pluginFactory; + private AwsCredentialsSupplier awsCredentialsSupplier; + private SinkContext sinkContext; + + @BeforeEach + void setup() { + personalizeSinkConfig = mock(PersonalizeSinkConfiguration.class); + sinkContext = mock(SinkContext.class); + AwsAuthenticationOptions awsAuthenticationOptions = mock(AwsAuthenticationOptions.class); + pluginSetting = mock(PluginSetting.class); + pluginFactory = mock(PluginFactory.class); + awsCredentialsSupplier = mock(AwsCredentialsSupplier.class); + + when(personalizeSinkConfig.getMaxRetries()).thenReturn(MAX_RETRIES); + when(personalizeSinkConfig.getDatasetArn()).thenReturn(DATASET_ARN); + when(personalizeSinkConfig.getDatasetType()).thenReturn(DatasetTypeOptions.USERS); + when(personalizeSinkConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions); + when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Optional.of(Region.of(REGION))); + when(pluginSetting.getName()).thenReturn(SINK_PLUGIN_NAME); + when(pluginSetting.getPipelineName()).thenReturn(SINK_PIPELINE_NAME); + } + + private PersonalizeSink createObjectUnderTest() { + return new PersonalizeSink(pluginSetting, personalizeSinkConfig, pluginFactory, sinkContext, awsCredentialsSupplier); + } + + @Test + void test_personalize_sink_plugin_isReady_positive() { + personalizeSink = createObjectUnderTest(); + Assertions.assertNotNull(personalizeSink); + personalizeSink.doInitialize(); + assertTrue(personalizeSink.isReady(), "Expected the personalize sink to be ready, but it is reporting it is not ready."); + } + + @Test + void test_personalize_Sink_plugin_isReady_negative() { + personalizeSink = createObjectUnderTest(); + Assertions.assertNotNull(personalizeSink); + assertFalse(personalizeSink.isReady(), "Expected the personalize sink to report that it is not ready, but it is reporting it is ready."); + } + + @Test + void test_doOutput_with_empty_records() { + personalizeSink = createObjectUnderTest(); + Assertions.assertNotNull(personalizeSink); + personalizeSink.doInitialize(); + Collection> records = new ArrayList<>(); + personalizeSink.doOutput(records); + } +} diff --git a/data-prepper-plugins/personalize-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/personalize/configuration/AwsAuthenticationOptionsTest.java b/data-prepper-plugins/personalize-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/personalize/configuration/AwsAuthenticationOptionsTest.java new file mode 100644 index 0000000000..29be309622 --- /dev/null +++ b/data-prepper-plugins/personalize-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/personalize/configuration/AwsAuthenticationOptionsTest.java @@ -0,0 +1,129 @@ +package org.opensearch.dataprepper.plugins.sink.personalize.configuration; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import software.amazon.awssdk.regions.Region; + +import java.util.Collections; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; + +class AwsAuthenticationOptionsTest { + private ObjectMapper objectMapper; + + @BeforeEach + void setUp() { + objectMapper = new ObjectMapper(); + } + + @ParameterizedTest + @ValueSource(strings = {"us-east-1", "us-west-2", "eu-central-1"}) + void getAwsRegion_returns_Region_of(final String regionString) { + final Optional expectedRegionObject = Optional.of(Region.of(regionString)); + final Map jsonMap = Map.of("region", regionString); + final AwsAuthenticationOptions objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationOptions.class); + assertThat(objectUnderTest.getAwsRegion(), equalTo(expectedRegionObject)); + } + + @Test + void getAwsRegion_returns_null_when_region_is_null() { + final Map jsonMap = Collections.emptyMap(); + final AwsAuthenticationOptions objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationOptions.class); + assertThat(objectUnderTest.getAwsRegion(), equalTo(Optional.empty())); + } + + @Test + void getAwsStsRoleArn_returns_value_from_deserialized_JSON() { + final String stsRoleArn = UUID.randomUUID().toString(); + final Map jsonMap = Map.of("sts_role_arn", stsRoleArn); + final AwsAuthenticationOptions objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationOptions.class); + assertThat(objectUnderTest.getAwsStsRoleArn(), equalTo(stsRoleArn)); + } + + @Test + void getAwsStsRoleArn_returns_null_if_not_in_JSON() { + final Map jsonMap = Collections.emptyMap(); + final AwsAuthenticationOptions objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationOptions.class); + assertThat(objectUnderTest.getAwsStsRoleArn(), nullValue()); + } + + @Test + void isValidStsRoleArn_returns_true_for_valid_IAM_role() { + final String stsRoleArn = "arn:aws:iam::123456789012:role/test"; + final Map jsonMap = Map.of("sts_role_arn", stsRoleArn); + final AwsAuthenticationOptions objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationOptions.class); + assertTrue(objectUnderTest.isValidStsRoleArn()); + } + + @Test + void isValidStsRoleArn_returns_true_for_null() { + final Map jsonMap = Collections.emptyMap(); + final AwsAuthenticationOptions objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationOptions.class); + assertTrue(objectUnderTest.isValidStsRoleArn()); + } + + @Test + void isValidStsRoleArn_returns_false_when_arn_service_is_not_IAM() { + final String stsRoleArn = "arn:aws:personalize::123456789012:role/test"; + final Map jsonMap = Map.of("sts_role_arn", stsRoleArn); + final AwsAuthenticationOptions objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationOptions.class); + assertFalse(objectUnderTest.isValidStsRoleArn()); + } + + @Test + void isValidStsRoleArn_returns_false_when_arn_resource_is_not_role() { + final String stsRoleArn = "arn:aws:iam::123456789012:dataset/test"; + final Map jsonMap = Map.of("sts_role_arn", stsRoleArn); + final AwsAuthenticationOptions objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationOptions.class); + assertFalse(objectUnderTest.isValidStsRoleArn()); + } + + @Test + void isValidStsRoleArn_invalid_arn_throws_IllegalArgumentException() { + final String stsRoleArn = UUID.randomUUID().toString(); + final Map jsonMap = Map.of("sts_role_arn", stsRoleArn); + final AwsAuthenticationOptions objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationOptions.class); + assertThrows(IllegalArgumentException.class, () -> objectUnderTest.isValidStsRoleArn()); + } + + @Test + void getAwsStsExternalId_returns_value_from_deserialized_JSON() { + final String stsExternalId = UUID.randomUUID().toString(); + final Map jsonMap = Map.of("sts_external_id", stsExternalId); + final AwsAuthenticationOptions objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationOptions.class); + assertThat(objectUnderTest.getAwsStsExternalId(), equalTo(stsExternalId)); + } + + @Test + void getAwsStsExternalId_returns_null_if_not_in_JSON() { + final Map jsonMap = Collections.emptyMap(); + final AwsAuthenticationOptions objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationOptions.class); + assertThat(objectUnderTest.getAwsStsExternalId(), nullValue()); + } + + @Test + void getAwsStsHeaderOverrides_returns_value_from_deserialized_JSON() { + final Map stsHeaderOverrides = Map.of(UUID.randomUUID().toString(), UUID.randomUUID().toString()); + final Map jsonMap = Map.of("sts_header_overrides", stsHeaderOverrides); + final AwsAuthenticationOptions objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationOptions.class); + assertThat(objectUnderTest.getAwsStsHeaderOverrides(), equalTo(stsHeaderOverrides)); + } + + @Test + void getAwsStsHeaderOverrides_returns_null_if_not_in_JSON() { + final Map jsonMap = Collections.emptyMap(); + final AwsAuthenticationOptions objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationOptions.class); + assertThat(objectUnderTest.getAwsStsHeaderOverrides(), nullValue()); + } +} diff --git a/data-prepper-plugins/personalize-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/personalize/configuration/PersonalizeSinkConfigurationTest.java b/data-prepper-plugins/personalize-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/personalize/configuration/PersonalizeSinkConfigurationTest.java new file mode 100644 index 0000000000..67bc690623 --- /dev/null +++ b/data-prepper-plugins/personalize-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/personalize/configuration/PersonalizeSinkConfigurationTest.java @@ -0,0 +1,205 @@ +package org.opensearch.dataprepper.plugins.sink.personalize.configuration; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.opensearch.dataprepper.plugins.sink.personalize.dataset.DatasetTypeOptions; + +import java.util.Collections; +import java.util.Map; +import java.util.UUID; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class PersonalizeSinkConfigurationTest { + private static final int DEFAULT_RETRIES = 10; + private ObjectMapper objectMapper; + + @BeforeEach + void setUp() { + objectMapper = new ObjectMapper(); + } + + @Test + void getDatasetType_returns_value_from_deserialized_JSON() { + final String datasetType = "users"; + final Map jsonMap = Map.of("dataset_type", datasetType); + final PersonalizeSinkConfiguration objectUnderTest = objectMapper.convertValue(jsonMap, PersonalizeSinkConfiguration.class); + assertThat(objectUnderTest.getDatasetType(), equalTo(DatasetTypeOptions.USERS)); + } + + @Test + void getDatasetArn_returns_null_when_datasetArn_is_null() { + final Map jsonMap = Collections.emptyMap(); + final PersonalizeSinkConfiguration objectUnderTest = objectMapper.convertValue(jsonMap, PersonalizeSinkConfiguration.class); + assertThat(objectUnderTest.getDatasetArn(), nullValue()); + } + + @Test + void getDatasetArn_returns_value_from_deserialized_JSON() { + final String datasetArn = UUID.randomUUID().toString(); + final Map jsonMap = Map.of("dataset_arn", datasetArn); + final PersonalizeSinkConfiguration objectUnderTest = objectMapper.convertValue(jsonMap, PersonalizeSinkConfiguration.class); + assertThat(objectUnderTest.getDatasetArn(), equalTo(datasetArn)); + } + + @Test + void isDatasetArnProvidedWhenNeeded_returns_true_when_datasetType_is_interactions_and_datasetArn_is_null() { + final String datasetType = "interactions"; + final Map jsonMap = Map.of("dataset_type", datasetType); + final PersonalizeSinkConfiguration objectUnderTest = objectMapper.convertValue(jsonMap, PersonalizeSinkConfiguration.class); + assertTrue(objectUnderTest.isDatasetArnProvidedWhenNeeded()); + } + + @Test + void isDatasetArnProvidedWhenNeeded_returns_true_when_datasetType_is_users_and_datasetArn_is_provided() { + final String datasetType = "users"; + final String datasetArn = UUID.randomUUID().toString(); + final Map jsonMap = Map.of("dataset_type", datasetType, "dataset_arn", datasetArn); + final PersonalizeSinkConfiguration objectUnderTest = objectMapper.convertValue(jsonMap, PersonalizeSinkConfiguration.class); + assertTrue(objectUnderTest.isDatasetArnProvidedWhenNeeded()); + } + + @Test + void isDatasetArnProvidedWhenNeeded_returns_false_when_datasetType_is_users_and_datasetArn_is_not_provided() { + final String datasetType = "users"; + final Map jsonMap = Map.of("dataset_type", datasetType); + final PersonalizeSinkConfiguration objectUnderTest = objectMapper.convertValue(jsonMap, PersonalizeSinkConfiguration.class); + assertFalse(objectUnderTest.isDatasetArnProvidedWhenNeeded()); + } + + @Test + void isDatasetArnProvidedWhenNeeded_returns_true_when_datasetType_is_items_and_datasetArn_is_provided() { + final String datasetType = "items"; + final String datasetArn = UUID.randomUUID().toString(); + final Map jsonMap = Map.of("dataset_type", datasetType, "dataset_arn", datasetArn); + final PersonalizeSinkConfiguration objectUnderTest = objectMapper.convertValue(jsonMap, PersonalizeSinkConfiguration.class); + assertTrue(objectUnderTest.isDatasetArnProvidedWhenNeeded()); + } + + @Test + void isDatasetArnProvidedWhenNeeded_returns_false_when_datasetType_is_items_and_datasetArn_is_not_provided() { + final String datasetType = "items"; + final Map jsonMap = Map.of("dataset_type", datasetType); + final PersonalizeSinkConfiguration objectUnderTest = objectMapper.convertValue(jsonMap, PersonalizeSinkConfiguration.class); + assertFalse(objectUnderTest.isDatasetArnProvidedWhenNeeded()); + } + + @Test + void isValidDatasetArn_returns_true_for_valid_dataset_arn() { + final String datasetArn = "arn:aws:personalize::123456789012:dataset/test"; + final Map jsonMap = Map.of("dataset_arn", datasetArn); + final PersonalizeSinkConfiguration objectUnderTest = objectMapper.convertValue(jsonMap, PersonalizeSinkConfiguration.class); + assertTrue(objectUnderTest.isValidDatasetArn()); + } + + @Test + void isValidDatasetArn_returns_false_when_arn_service_is_not_personalize() { + final String datasetArn = "arn:aws:iam::123456789012:dataset/test"; + final Map jsonMap = Map.of("dataset_arn", datasetArn); + final PersonalizeSinkConfiguration objectUnderTest = objectMapper.convertValue(jsonMap, PersonalizeSinkConfiguration.class); + assertFalse(objectUnderTest.isValidDatasetArn()); + } + + @Test + void isValidDatasetArn_returns_false_when_arn_resource_is_not_dataset() { + final String datasetArn = "arn:aws:personalize::123456789012:role/test"; + final Map jsonMap = Map.of("dataset_arn", datasetArn); + final PersonalizeSinkConfiguration objectUnderTest = objectMapper.convertValue(jsonMap, PersonalizeSinkConfiguration.class); + assertFalse(objectUnderTest.isValidDatasetArn()); + } + + @Test + void isValidStsRoleArn_invalid_arn_throws_IllegalArgumentException() { + final String datasetArn = UUID.randomUUID().toString(); + final Map jsonMap = Map.of("dataset_arn", datasetArn); + final PersonalizeSinkConfiguration objectUnderTest = objectMapper.convertValue(jsonMap, PersonalizeSinkConfiguration.class); + assertThrows(IllegalArgumentException.class, () -> objectUnderTest.isValidDatasetArn()); + } + + + + @Test + void getTrackingId_returns_null_when_trackingId_is_null() { + final Map jsonMap = Collections.emptyMap(); + final PersonalizeSinkConfiguration objectUnderTest = objectMapper.convertValue(jsonMap, PersonalizeSinkConfiguration.class); + assertThat(objectUnderTest.getTrackingId(), nullValue()); + } + + @Test + void getTrackingId_returns_value_from_deserialized_JSON() { + final String trackingId = UUID.randomUUID().toString();; + final Map jsonMap = Map.of("tracking_id", trackingId); + final PersonalizeSinkConfiguration objectUnderTest = objectMapper.convertValue(jsonMap, PersonalizeSinkConfiguration.class); + assertThat(objectUnderTest.getTrackingId(), equalTo(trackingId)); + } + + @Test + void isTrackingIdProvidedWhenNeeded_returns_false_when_datasetType_is_interactions_and_trackingId_is_not_provided() { + final String datasetType = "interactions"; + final Map jsonMap = Map.of("dataset_type", datasetType); + final PersonalizeSinkConfiguration objectUnderTest = objectMapper.convertValue(jsonMap, PersonalizeSinkConfiguration.class); + assertFalse(objectUnderTest.isTrackingIdProvidedWhenNeeded()); + } + + @Test + void isTrackingIdProvidedWhenNeeded_returns_true_when_datasetType_is_interactions_and_trackingId_is_provided() { + final String datasetType = "interactions"; + final String trackingId = UUID.randomUUID().toString(); + final Map jsonMap = Map.of("dataset_type", datasetType, "tracking_id", trackingId); + final PersonalizeSinkConfiguration objectUnderTest = objectMapper.convertValue(jsonMap, PersonalizeSinkConfiguration.class); + assertTrue(objectUnderTest.isTrackingIdProvidedWhenNeeded()); + } + + @Test + void isTrackingIdProvidedWhenNeeded_returns_true_when_datasetType_is_users_and_trackingId_is_not_provided() { + final String datasetType = "users"; + final Map jsonMap = Map.of("dataset_type", datasetType); + final PersonalizeSinkConfiguration objectUnderTest = objectMapper.convertValue(jsonMap, PersonalizeSinkConfiguration.class); + assertTrue(objectUnderTest.isTrackingIdProvidedWhenNeeded()); + } + + @Test + void isTrackingIdProvidedWhenNeeded_returns_true_when_datasetType_is_items_and_trackingId_is_not_provided() { + final String datasetType = "items"; + final Map jsonMap = Map.of("dataset_type", datasetType); + final PersonalizeSinkConfiguration objectUnderTest = objectMapper.convertValue(jsonMap, PersonalizeSinkConfiguration.class); + assertTrue(objectUnderTest.isTrackingIdProvidedWhenNeeded()); + } + + + @Test + void getDocumentRootKey_returns_null_when_documentRootKey_is_null() { + final Map jsonMap = Collections.emptyMap(); + final PersonalizeSinkConfiguration objectUnderTest = objectMapper.convertValue(jsonMap, PersonalizeSinkConfiguration.class); + assertThat(objectUnderTest.getDocumentRootKey(), nullValue()); + } + + @Test + void getDocumentRootKey_returns_value_from_deserialized_JSON() { + final String documentRootKey = UUID.randomUUID().toString();; + final Map jsonMap = Map.of("document_root_key", documentRootKey); + final PersonalizeSinkConfiguration objectUnderTest = objectMapper.convertValue(jsonMap, PersonalizeSinkConfiguration.class); + assertThat(objectUnderTest.getDocumentRootKey(), equalTo(documentRootKey)); + } + + @Test + void getMaxRetries_returns_default_when_maxRetries_is_null() { + final Map jsonMap = Collections.emptyMap(); + final PersonalizeSinkConfiguration objectUnderTest = objectMapper.convertValue(jsonMap, PersonalizeSinkConfiguration.class); + assertThat(objectUnderTest.getMaxRetries(), equalTo(DEFAULT_RETRIES)); + } + + @Test + void getMaxRetries_returns_value_from_deserialized_JSON() { + final int maxRetries = 3; + final Map jsonMap = Map.of("max_retries", maxRetries); + final PersonalizeSinkConfiguration objectUnderTest = objectMapper.convertValue(jsonMap, PersonalizeSinkConfiguration.class); + assertThat(objectUnderTest.getMaxRetries(), equalTo(maxRetries)); + } +} diff --git a/data-prepper-plugins/personalize-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/personalize/dataset/DatasetTypeOptionsTest.java b/data-prepper-plugins/personalize-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/personalize/dataset/DatasetTypeOptionsTest.java new file mode 100644 index 0000000000..40b1821d0a --- /dev/null +++ b/data-prepper-plugins/personalize-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/personalize/dataset/DatasetTypeOptionsTest.java @@ -0,0 +1,38 @@ +package org.opensearch.dataprepper.plugins.sink.personalize.dataset; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +@ExtendWith(MockitoExtension.class) +class DatasetTypeOptionsTest { + @Test + void notNull_test() { + assertNotNull(DatasetTypeOptions.ITEMS); + } + + @Test + void fromOptionValue_users_test() { + DatasetTypeOptions datasetTypeOptions = DatasetTypeOptions.fromOptionValue("users"); + assertNotNull(datasetTypeOptions); + assertThat(datasetTypeOptions.toString(), equalTo("USERS")); + } + + @Test + void fromOptionValue_items_test() { + DatasetTypeOptions datasetTypeOptions = DatasetTypeOptions.fromOptionValue("items"); + assertNotNull(datasetTypeOptions); + assertThat(datasetTypeOptions.toString(), equalTo("ITEMS")); + } + + @Test + void fromOptionValue_interactions_test() { + DatasetTypeOptions datasetTypeOptions = DatasetTypeOptions.fromOptionValue("interactions"); + assertNotNull(datasetTypeOptions); + assertThat(datasetTypeOptions.toString(), equalTo("INTERACTIONS")); + } +} diff --git a/settings.gradle b/settings.gradle index 18ccd4dc7b..63e7ad2a9f 100644 --- a/settings.gradle +++ b/settings.gradle @@ -170,6 +170,7 @@ include 'data-prepper-plugins:buffer-common' //include 'data-prepper-plugins:http-sink' //include 'data-prepper-plugins:sns-sink' //include 'data-prepper-plugins:prometheus-sink' +include 'data-prepper-plugins:personalize-sink' include 'data-prepper-plugins:dissect-processor' include 'data-prepper-plugins:dynamodb-source' include 'data-prepper-plugins:decompress-processor'