Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PersonalizeSink: add client and configuration classes #4803

Merged
merged 5 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions data-prepper-plugins/personalize-sink/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

dependencies {
implementation project(':data-prepper-api')
implementation project(path: ':data-prepper-plugins:common')
implementation project(':data-prepper-plugins:aws-plugin-api')
implementation 'io.micrometer:micrometer-core'
implementation 'com.fasterxml.jackson.core:jackson-core'
implementation 'com.fasterxml.jackson.core:jackson-databind'
implementation 'org.hibernate.validator:hibernate-validator:8.0.1.Final'
implementation 'software.amazon.awssdk:personalizeevents'
implementation 'software.amazon.awssdk:sts'
implementation 'software.amazon.awssdk:arns'
testImplementation project(':data-prepper-test-common')
testImplementation testLibs.slf4j.simple
}

sourceSets {
integrationTest {
java {
compileClasspath += main.output + test.output
runtimeClasspath += main.output + test.output
srcDir file('src/integrationTest/java')
}
resources.srcDir file('src/integrationTest/resources')
}
}

configurations {
integrationTestImplementation.extendsFrom testImplementation
integrationTestRuntime.extendsFrom testRuntime
}

task integrationTest(type: Test) {
group = 'verification'
testClassesDirs = sourceSets.integrationTest.output.classesDirs

useJUnitPlatform()

classpath = sourceSets.integrationTest.runtimeClasspath

filter {
includeTestsMatching '*IT'
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.sink.personalize;

import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.plugins.sink.personalize.configuration.AwsAuthenticationOptions;
import org.opensearch.dataprepper.plugins.sink.personalize.configuration.PersonalizeSinkConfiguration;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.core.retry.RetryPolicy;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.personalizeevents.PersonalizeEventsClient;

final class ClientFactory {
private ClientFactory() { }

static PersonalizeEventsClient createPersonalizeEventsClient(final PersonalizeSinkConfiguration personalizeSinkConfig, final AwsCredentialsSupplier awsCredentialsSupplier) {
final AwsCredentialsOptions awsCredentialsOptions = convertToCredentialsOptions(personalizeSinkConfig.getAwsAuthenticationOptions());
final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsSupplier.getProvider(awsCredentialsOptions);

return PersonalizeEventsClient.builder()
.region(getRegion(personalizeSinkConfig, awsCredentialsSupplier))
.credentialsProvider(awsCredentialsProvider)
.overrideConfiguration(createOverrideConfiguration(personalizeSinkConfig)).build();
}

private static ClientOverrideConfiguration createOverrideConfiguration(final PersonalizeSinkConfiguration personalizeSinkConfig) {
final RetryPolicy retryPolicy = RetryPolicy.builder().numRetries(personalizeSinkConfig.getMaxRetries()).build();
return ClientOverrideConfiguration.builder()
.retryPolicy(retryPolicy)
.build();
}

private static AwsCredentialsOptions convertToCredentialsOptions(final AwsAuthenticationOptions awsAuthenticationOptions) {
if (awsAuthenticationOptions == null) {
return AwsCredentialsOptions.builder().build();
}
return AwsCredentialsOptions.builder()
.withRegion(awsAuthenticationOptions.getAwsRegion().orElse(null))
.withStsRoleArn(awsAuthenticationOptions.getAwsStsRoleArn())
.withStsExternalId(awsAuthenticationOptions.getAwsStsExternalId())
.withStsHeaderOverrides(awsAuthenticationOptions.getAwsStsHeaderOverrides())
.build();
}

private static Region getRegion(final PersonalizeSinkConfiguration personalizeSinkConfig, final AwsCredentialsSupplier awsCredentialsSupplier) {
Region defaultRegion = awsCredentialsSupplier.getDefaultRegion().orElse(null);
if (personalizeSinkConfig.getAwsAuthenticationOptions() == null) {
return defaultRegion;
} else {
return personalizeSinkConfig.getAwsAuthenticationOptions().getAwsRegion().orElse(defaultRegion);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.sink.personalize;

import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin;
import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor;
import org.opensearch.dataprepper.model.configuration.PluginSetting;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.plugin.PluginFactory;
import org.opensearch.dataprepper.model.record.Record;
import org.opensearch.dataprepper.model.sink.AbstractSink;
import org.opensearch.dataprepper.model.sink.Sink;
import org.opensearch.dataprepper.model.sink.SinkContext;
import org.opensearch.dataprepper.plugins.sink.personalize.configuration.PersonalizeSinkConfiguration;
import software.amazon.awssdk.services.personalizeevents.PersonalizeEventsClient;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Collection;

/**
* Implementation class of personalize-sink plugin. It is responsible for receiving the collection of
* {@link Event} and uploading to amazon personalize.
*/
@DataPrepperPlugin(name = "aws_personalize", pluginType = Sink.class, pluginConfigurationType = PersonalizeSinkConfiguration.class)
public class PersonalizeSink extends AbstractSink<Record<Event>> {

private static final Logger LOG = LoggerFactory.getLogger(PersonalizeSink.class);

private final PersonalizeSinkConfiguration personalizeSinkConfig;
private volatile boolean sinkInitialized;
private final PersonalizeSinkService personalizeSinkService;
private final SinkContext sinkContext;

/**
* @param pluginSetting dp plugin settings.
* @param personalizeSinkConfig personalize sink configurations.
* @param sinkContext sink context
* @param awsCredentialsSupplier aws credentials
* @param pluginFactory dp plugin factory.
*/
@DataPrepperPluginConstructor
public PersonalizeSink(final PluginSetting pluginSetting,
final PersonalizeSinkConfiguration personalizeSinkConfig,
final PluginFactory pluginFactory,
final SinkContext sinkContext,
final AwsCredentialsSupplier awsCredentialsSupplier) {
super(pluginSetting);
this.personalizeSinkConfig = personalizeSinkConfig;
this.sinkContext = sinkContext;

sinkInitialized = false;

final PersonalizeEventsClient personalizeEventsClient = ClientFactory.createPersonalizeEventsClient(personalizeSinkConfig, awsCredentialsSupplier);

personalizeSinkService = new PersonalizeSinkService(personalizeSinkConfig, pluginMetrics);
}

@Override
public boolean isReady() {
return sinkInitialized;
}

@Override
public void doInitialize() {
sinkInitialized = true;
}

/**
* @param records Records to be output
*/
@Override
public void doOutput(final Collection<Record<Event>> records) {
personalizeSinkService.output(records);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.sink.personalize;

import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Timer;
import org.opensearch.dataprepper.plugins.sink.personalize.configuration.PersonalizeSinkConfiguration;
import org.opensearch.dataprepper.metrics.PluginMetrics;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.record.Record;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Collection;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

/**
* Class responsible for creating PersonalizeEventsClient object, check thresholds,
* get new buffer and write records into buffer.
*/
class PersonalizeSinkService {

private static final Logger LOG = LoggerFactory.getLogger(PersonalizeSinkService.class);
public static final String RECORDS_SUCCEEDED = "personalizeRecordsSucceeded";
public static final String RECORDS_FAILED = "personalizeRecordsFailed";
public static final String RECORDS_INVALID = "personalizeRecordsInvalid";
public static final String REQUESTS_THROTTLED = "personalizeRequestsThrottled";
public static final String REQUEST_LATENCY = "personalizeRequestLatency";

private final PersonalizeSinkConfiguration personalizeSinkConfig;
private final Lock reentrantLock;
private final int maxRetries;
private final Counter recordsSucceededCounter;
private final Counter recordsFailedCounter;
private final Counter recordsInvalidCounter;
private final Counter requestsThrottledCounter;
private final Timer requestLatencyTimer;

/**
* @param personalizeSinkConfig personalize sink related configuration.
* @param pluginMetrics metrics.
*/
public PersonalizeSinkService(final PersonalizeSinkConfiguration personalizeSinkConfig,
final PluginMetrics pluginMetrics) {
this.personalizeSinkConfig = personalizeSinkConfig;
reentrantLock = new ReentrantLock();

maxRetries = personalizeSinkConfig.getMaxRetries();

recordsSucceededCounter = pluginMetrics.counter(RECORDS_SUCCEEDED);
recordsFailedCounter = pluginMetrics.counter(RECORDS_FAILED);
recordsInvalidCounter = pluginMetrics.counter(RECORDS_INVALID);
requestsThrottledCounter = pluginMetrics.counter(REQUESTS_THROTTLED);
requestLatencyTimer = pluginMetrics.timer(REQUEST_LATENCY);
}

/**
* @param records received records and add into buffer.
*/
void output(Collection<Record<Event>> records) {
LOG.trace("{} records received", records.size());
return;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.sink.personalize.configuration;

import com.fasterxml.jackson.annotation.JsonProperty;
import jakarta.validation.GroupSequence;
import jakarta.validation.constraints.AssertTrue;
import jakarta.validation.constraints.Size;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.arns.Arn;

import java.util.Map;
import java.util.Optional;

@GroupSequence({AwsAuthenticationOptions.class, PersonalizeAdvancedValidation.class})
public class AwsAuthenticationOptions {
private static final String AWS_IAM_ROLE = "role";
private static final String AWS_IAM = "iam";

@JsonProperty("region")
@Size(min = 1, message = "Region cannot be empty string")
private String awsRegion;

@JsonProperty("sts_role_arn")
@Size(min = 20, max = 2048, message = "awsStsRoleArn length should be between 1 and 2048 characters")
private String awsStsRoleArn;

@JsonProperty("sts_external_id")
@Size(min = 2, max = 1224, message = "awsStsExternalId length should be between 2 and 1224 characters")
private String awsStsExternalId;

@JsonProperty("sts_header_overrides")
@Size(max = 5, message = "sts_header_overrides supports a maximum of 5 headers to override")
private Map<String, String> awsStsHeaderOverrides;

@AssertTrue(message = "sts_role_arn must be an IAM Role", groups = PersonalizeAdvancedValidation.class)
boolean isValidStsRoleArn() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't be necessary. I don't think we need to add it to configuration validation

if (awsStsRoleArn == null) {
return true;
}
final Arn arn = getArn();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you will want to return false here if this method throws an IllegalArgumentException. Have you tested with invalid arn format and observed the exception?

Also this role should be optional in this configuration, so adding a null check to return true in the method should be enough. The reason it is optional here is because users can configure a default role in the data-prepper-config.yaml (#4559)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I have tested with invalid arn format.

I didn't know about the default in data-prepper-config.yaml. I'll make this optional

boolean status = true;
if (!AWS_IAM.equals(arn.service())) {
status = false;
}
final Optional<String> resourceType = arn.resource().resourceType();
if (resourceType.isEmpty() || !resourceType.get().equals(AWS_IAM_ROLE)) {
status = false;
}
return status;
}

private Arn getArn() {
try {
return Arn.fromString(awsStsRoleArn);
} catch (final Exception e) {
throw new IllegalArgumentException(String.format("Invalid ARN format for awsStsRoleArn. Check the format of %s", awsStsRoleArn));
}
}

public Optional<Region> getAwsRegion() {
Region region = awsRegion != null ? Region.of(awsRegion) : null;
return Optional.ofNullable(region);
}

public String getAwsStsRoleArn() {
return awsStsRoleArn;
}

public String getAwsStsExternalId() {
return awsStsExternalId;
}

public Map<String, String> getAwsStsHeaderOverrides() {
return awsStsHeaderOverrides;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package org.opensearch.dataprepper.plugins.sink.personalize.configuration;

interface PersonalizeAdvancedValidation {
}
Loading
Loading