Skip to content

Commit

Permalink
PersonalizeSink: add client and configuration classes (#4803)
Browse files Browse the repository at this point in the history
PersonalizeSink: add client and configuration classes

Signed-off-by: Ivan Tse <[email protected]>
  • Loading branch information
ivan-tse authored Aug 14, 2024
1 parent 1487973 commit 38fe2af
Show file tree
Hide file tree
Showing 14 changed files with 1,101 additions and 0 deletions.
48 changes: 48 additions & 0 deletions data-prepper-plugins/personalize-sink/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

dependencies {
implementation project(':data-prepper-api')
implementation project(path: ':data-prepper-plugins:common')
implementation project(':data-prepper-plugins:aws-plugin-api')
implementation 'io.micrometer:micrometer-core'
implementation 'com.fasterxml.jackson.core:jackson-core'
implementation 'com.fasterxml.jackson.core:jackson-databind'
implementation 'org.hibernate.validator:hibernate-validator:8.0.1.Final'
implementation 'software.amazon.awssdk:personalizeevents'
implementation 'software.amazon.awssdk:sts'
implementation 'software.amazon.awssdk:arns'
testImplementation project(':data-prepper-test-common')
testImplementation testLibs.slf4j.simple
}

sourceSets {
integrationTest {
java {
compileClasspath += main.output + test.output
runtimeClasspath += main.output + test.output
srcDir file('src/integrationTest/java')
}
resources.srcDir file('src/integrationTest/resources')
}
}

configurations {
integrationTestImplementation.extendsFrom testImplementation
integrationTestRuntime.extendsFrom testRuntime
}

task integrationTest(type: Test) {
group = 'verification'
testClassesDirs = sourceSets.integrationTest.output.classesDirs

useJUnitPlatform()

classpath = sourceSets.integrationTest.runtimeClasspath

filter {
includeTestsMatching '*IT'
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.sink.personalize;

import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.plugins.sink.personalize.configuration.AwsAuthenticationOptions;
import org.opensearch.dataprepper.plugins.sink.personalize.configuration.PersonalizeSinkConfiguration;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.core.retry.RetryPolicy;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.personalizeevents.PersonalizeEventsClient;

final class ClientFactory {
private ClientFactory() { }

static PersonalizeEventsClient createPersonalizeEventsClient(final PersonalizeSinkConfiguration personalizeSinkConfig, final AwsCredentialsSupplier awsCredentialsSupplier) {
final AwsCredentialsOptions awsCredentialsOptions = convertToCredentialsOptions(personalizeSinkConfig.getAwsAuthenticationOptions());
final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsSupplier.getProvider(awsCredentialsOptions);

return PersonalizeEventsClient.builder()
.region(getRegion(personalizeSinkConfig, awsCredentialsSupplier))
.credentialsProvider(awsCredentialsProvider)
.overrideConfiguration(createOverrideConfiguration(personalizeSinkConfig)).build();
}

private static ClientOverrideConfiguration createOverrideConfiguration(final PersonalizeSinkConfiguration personalizeSinkConfig) {
final RetryPolicy retryPolicy = RetryPolicy.builder().numRetries(personalizeSinkConfig.getMaxRetries()).build();
return ClientOverrideConfiguration.builder()
.retryPolicy(retryPolicy)
.build();
}

private static AwsCredentialsOptions convertToCredentialsOptions(final AwsAuthenticationOptions awsAuthenticationOptions) {
if (awsAuthenticationOptions == null) {
return AwsCredentialsOptions.builder().build();
}
return AwsCredentialsOptions.builder()
.withRegion(awsAuthenticationOptions.getAwsRegion().orElse(null))
.withStsRoleArn(awsAuthenticationOptions.getAwsStsRoleArn())
.withStsExternalId(awsAuthenticationOptions.getAwsStsExternalId())
.withStsHeaderOverrides(awsAuthenticationOptions.getAwsStsHeaderOverrides())
.build();
}

private static Region getRegion(final PersonalizeSinkConfiguration personalizeSinkConfig, final AwsCredentialsSupplier awsCredentialsSupplier) {
Region defaultRegion = awsCredentialsSupplier.getDefaultRegion().orElse(null);
if (personalizeSinkConfig.getAwsAuthenticationOptions() == null) {
return defaultRegion;
} else {
return personalizeSinkConfig.getAwsAuthenticationOptions().getAwsRegion().orElse(defaultRegion);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.sink.personalize;

import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin;
import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor;
import org.opensearch.dataprepper.model.configuration.PluginSetting;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.plugin.PluginFactory;
import org.opensearch.dataprepper.model.record.Record;
import org.opensearch.dataprepper.model.sink.AbstractSink;
import org.opensearch.dataprepper.model.sink.Sink;
import org.opensearch.dataprepper.model.sink.SinkContext;
import org.opensearch.dataprepper.plugins.sink.personalize.configuration.PersonalizeSinkConfiguration;
import software.amazon.awssdk.services.personalizeevents.PersonalizeEventsClient;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Collection;

/**
* Implementation class of personalize-sink plugin. It is responsible for receiving the collection of
* {@link Event} and uploading to amazon personalize.
*/
@DataPrepperPlugin(name = "aws_personalize", pluginType = Sink.class, pluginConfigurationType = PersonalizeSinkConfiguration.class)
public class PersonalizeSink extends AbstractSink<Record<Event>> {

private static final Logger LOG = LoggerFactory.getLogger(PersonalizeSink.class);

private final PersonalizeSinkConfiguration personalizeSinkConfig;
private volatile boolean sinkInitialized;
private final PersonalizeSinkService personalizeSinkService;
private final SinkContext sinkContext;

/**
* @param pluginSetting dp plugin settings.
* @param personalizeSinkConfig personalize sink configurations.
* @param sinkContext sink context
* @param awsCredentialsSupplier aws credentials
* @param pluginFactory dp plugin factory.
*/
@DataPrepperPluginConstructor
public PersonalizeSink(final PluginSetting pluginSetting,
final PersonalizeSinkConfiguration personalizeSinkConfig,
final PluginFactory pluginFactory,
final SinkContext sinkContext,
final AwsCredentialsSupplier awsCredentialsSupplier) {
super(pluginSetting);
this.personalizeSinkConfig = personalizeSinkConfig;
this.sinkContext = sinkContext;

sinkInitialized = false;

final PersonalizeEventsClient personalizeEventsClient = ClientFactory.createPersonalizeEventsClient(personalizeSinkConfig, awsCredentialsSupplier);

personalizeSinkService = new PersonalizeSinkService(personalizeSinkConfig, pluginMetrics);
}

@Override
public boolean isReady() {
return sinkInitialized;
}

@Override
public void doInitialize() {
sinkInitialized = true;
}

/**
* @param records Records to be output
*/
@Override
public void doOutput(final Collection<Record<Event>> records) {
personalizeSinkService.output(records);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.sink.personalize;

import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Timer;
import org.opensearch.dataprepper.plugins.sink.personalize.configuration.PersonalizeSinkConfiguration;
import org.opensearch.dataprepper.metrics.PluginMetrics;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.record.Record;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Collection;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

/**
* Class responsible for creating PersonalizeEventsClient object, check thresholds,
* get new buffer and write records into buffer.
*/
class PersonalizeSinkService {

private static final Logger LOG = LoggerFactory.getLogger(PersonalizeSinkService.class);
public static final String RECORDS_SUCCEEDED = "personalizeRecordsSucceeded";
public static final String RECORDS_FAILED = "personalizeRecordsFailed";
public static final String RECORDS_INVALID = "personalizeRecordsInvalid";
public static final String REQUESTS_THROTTLED = "personalizeRequestsThrottled";
public static final String REQUEST_LATENCY = "personalizeRequestLatency";

private final PersonalizeSinkConfiguration personalizeSinkConfig;
private final Lock reentrantLock;
private final int maxRetries;
private final Counter recordsSucceededCounter;
private final Counter recordsFailedCounter;
private final Counter recordsInvalidCounter;
private final Counter requestsThrottledCounter;
private final Timer requestLatencyTimer;

/**
* @param personalizeSinkConfig personalize sink related configuration.
* @param pluginMetrics metrics.
*/
public PersonalizeSinkService(final PersonalizeSinkConfiguration personalizeSinkConfig,
final PluginMetrics pluginMetrics) {
this.personalizeSinkConfig = personalizeSinkConfig;
reentrantLock = new ReentrantLock();

maxRetries = personalizeSinkConfig.getMaxRetries();

recordsSucceededCounter = pluginMetrics.counter(RECORDS_SUCCEEDED);
recordsFailedCounter = pluginMetrics.counter(RECORDS_FAILED);
recordsInvalidCounter = pluginMetrics.counter(RECORDS_INVALID);
requestsThrottledCounter = pluginMetrics.counter(REQUESTS_THROTTLED);
requestLatencyTimer = pluginMetrics.timer(REQUEST_LATENCY);
}

/**
* @param records received records and add into buffer.
*/
void output(Collection<Record<Event>> records) {
LOG.trace("{} records received", records.size());
return;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.sink.personalize.configuration;

import com.fasterxml.jackson.annotation.JsonProperty;
import jakarta.validation.GroupSequence;
import jakarta.validation.constraints.AssertTrue;
import jakarta.validation.constraints.Size;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.arns.Arn;

import java.util.Map;
import java.util.Optional;

@GroupSequence({AwsAuthenticationOptions.class, PersonalizeAdvancedValidation.class})
public class AwsAuthenticationOptions {
private static final String AWS_IAM_ROLE = "role";
private static final String AWS_IAM = "iam";

@JsonProperty("region")
@Size(min = 1, message = "Region cannot be empty string")
private String awsRegion;

@JsonProperty("sts_role_arn")
@Size(min = 20, max = 2048, message = "awsStsRoleArn length should be between 1 and 2048 characters")
private String awsStsRoleArn;

@JsonProperty("sts_external_id")
@Size(min = 2, max = 1224, message = "awsStsExternalId length should be between 2 and 1224 characters")
private String awsStsExternalId;

@JsonProperty("sts_header_overrides")
@Size(max = 5, message = "sts_header_overrides supports a maximum of 5 headers to override")
private Map<String, String> awsStsHeaderOverrides;

@AssertTrue(message = "sts_role_arn must be an IAM Role", groups = PersonalizeAdvancedValidation.class)
boolean isValidStsRoleArn() {
if (awsStsRoleArn == null) {
return true;
}
final Arn arn = getArn();
boolean status = true;
if (!AWS_IAM.equals(arn.service())) {
status = false;
}
final Optional<String> resourceType = arn.resource().resourceType();
if (resourceType.isEmpty() || !resourceType.get().equals(AWS_IAM_ROLE)) {
status = false;
}
return status;
}

private Arn getArn() {
try {
return Arn.fromString(awsStsRoleArn);
} catch (final Exception e) {
throw new IllegalArgumentException(String.format("Invalid ARN format for awsStsRoleArn. Check the format of %s", awsStsRoleArn));
}
}

public Optional<Region> getAwsRegion() {
Region region = awsRegion != null ? Region.of(awsRegion) : null;
return Optional.ofNullable(region);
}

public String getAwsStsRoleArn() {
return awsStsRoleArn;
}

public String getAwsStsExternalId() {
return awsStsExternalId;
}

public Map<String, String> getAwsStsHeaderOverrides() {
return awsStsHeaderOverrides;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package org.opensearch.dataprepper.plugins.sink.personalize.configuration;

interface PersonalizeAdvancedValidation {
}
Loading

0 comments on commit 38fe2af

Please sign in to comment.