Skip to content

Commit

Permalink
Change AwsAuthenticationOptions to be optional
Browse files Browse the repository at this point in the history
Signed-off-by: Ivan Tse <[email protected]>
  • Loading branch information
ivan-tse committed Aug 9, 2024
1 parent 3656651 commit 09892e0
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.core.retry.RetryPolicy;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.personalizeevents.PersonalizeEventsClient;

final class ClientFactory {
Expand All @@ -22,7 +23,7 @@ static PersonalizeEventsClient createPersonalizeEventsClient(final PersonalizeSi
final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsSupplier.getProvider(awsCredentialsOptions);

return PersonalizeEventsClient.builder()
.region(personalizeSinkConfig.getAwsAuthenticationOptions().getAwsRegion())
.region(getRegion(personalizeSinkConfig, awsCredentialsSupplier))
.credentialsProvider(awsCredentialsProvider)
.overrideConfiguration(createOverrideConfiguration(personalizeSinkConfig)).build();
}
Expand All @@ -35,11 +36,23 @@ private static ClientOverrideConfiguration createOverrideConfiguration(final Per
}

private static AwsCredentialsOptions convertToCredentialsOptions(final AwsAuthenticationOptions awsAuthenticationOptions) {
if (awsAuthenticationOptions == null) {
return AwsCredentialsOptions.builder().build();
}
return AwsCredentialsOptions.builder()
.withRegion(awsAuthenticationOptions.getAwsRegion())
.withRegion(awsAuthenticationOptions.getAwsRegion().orElse(null))
.withStsRoleArn(awsAuthenticationOptions.getAwsStsRoleArn())
.withStsExternalId(awsAuthenticationOptions.getAwsStsExternalId())
.withStsHeaderOverrides(awsAuthenticationOptions.getAwsStsHeaderOverrides())
.build();
}

private static Region getRegion(final PersonalizeSinkConfiguration personalizeSinkConfig, final AwsCredentialsSupplier awsCredentialsSupplier) {
Region defaultRegion = awsCredentialsSupplier.getDefaultRegion().orElse(null);
if (personalizeSinkConfig.getAwsAuthenticationOptions() == null) {
return defaultRegion;
} else {
return personalizeSinkConfig.getAwsAuthenticationOptions().getAwsRegion().orElse(defaultRegion);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ public class AwsAuthenticationOptions {

@AssertTrue(message = "sts_role_arn must be an IAM Role", groups = PersonalizeAdvancedValidation.class)
boolean isValidStsRoleArn() {
if (awsStsRoleArn == null) {
return true;
}
final Arn arn = getArn();
boolean status = true;
if (!AWS_IAM.equals(arn.service())) {
Expand All @@ -58,8 +61,9 @@ private Arn getArn() {
}
}

public Region getAwsRegion() {
return awsRegion != null ? Region.of(awsRegion) : null;
public Optional<Region> getAwsRegion() {
Region region = awsRegion != null ? Region.of(awsRegion) : null;
return Optional.ofNullable(region);
}

public String getAwsStsRoleArn() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ public class PersonalizeSinkConfiguration {
private static final List<DatasetTypeOptions> DATASET_ARN_REQUIRED_LIST = List.of(DatasetTypeOptions.USERS, DatasetTypeOptions.ITEMS);

@JsonProperty("aws")
@NotNull
@Valid
private AwsAuthenticationOptions awsAuthenticationOptions;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@
import software.amazon.awssdk.services.personalizeevents.PersonalizeEventsClient;
import software.amazon.awssdk.services.personalizeevents.PersonalizeEventsClientBuilder;

import java.util.Collections;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
Expand All @@ -48,20 +51,54 @@ void setUp() {

@Test
void createPersonalizeEventsClient_with_real_PersonalizeEventsClient() {
when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.US_EAST_1);
when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Optional.of(Region.US_EAST_1));
final PersonalizeEventsClient personalizeEventsClient = ClientFactory.createPersonalizeEventsClient(personalizeSinkConfig, awsCredentialsSupplier);

assertThat(personalizeEventsClient, notNullValue());
}

@Test
void createPersonalizeEventsClient_provides_correct_inputs_for_null_awsAuthenticationOptions() {
when(personalizeSinkConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions);
final AwsCredentialsProvider expectedCredentialsProvider = mock(AwsCredentialsProvider.class);
when(awsCredentialsSupplier.getProvider(any())).thenReturn(expectedCredentialsProvider);

final PersonalizeEventsClientBuilder personalizeEventsClientBuilder = mock(PersonalizeEventsClientBuilder.class);
when(personalizeEventsClientBuilder.region(any())).thenReturn(personalizeEventsClientBuilder);
when(personalizeEventsClientBuilder.credentialsProvider(any())).thenReturn(personalizeEventsClientBuilder);
when(personalizeEventsClientBuilder.overrideConfiguration(any(ClientOverrideConfiguration.class))).thenReturn(personalizeEventsClientBuilder);
try(final MockedStatic<PersonalizeEventsClient> personalizeEventsClientMockedStatic = mockStatic(PersonalizeEventsClient.class)) {
personalizeEventsClientMockedStatic.when(PersonalizeEventsClient::builder)
.thenReturn(personalizeEventsClientBuilder);
ClientFactory.createPersonalizeEventsClient(personalizeSinkConfig, awsCredentialsSupplier);
}

final ArgumentCaptor<AwsCredentialsProvider> credentialsProviderArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsProvider.class);
verify(personalizeEventsClientBuilder).credentialsProvider(credentialsProviderArgumentCaptor.capture());

final AwsCredentialsProvider actualCredentialsProvider = credentialsProviderArgumentCaptor.getValue();

assertThat(actualCredentialsProvider, equalTo(expectedCredentialsProvider));

final ArgumentCaptor<AwsCredentialsOptions> optionsArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsOptions.class);
verify(awsCredentialsSupplier).getProvider(optionsArgumentCaptor.capture());

final AwsCredentialsOptions actualCredentialsOptions = optionsArgumentCaptor.getValue();
assertThat(actualCredentialsOptions, is(notNullValue()));
assertThat(actualCredentialsOptions.getRegion(), equalTo(null));
assertThat(actualCredentialsOptions.getStsRoleArn(), equalTo(null));
assertThat(actualCredentialsOptions.getStsExternalId(), equalTo(null));
assertThat(actualCredentialsOptions.getStsHeaderOverrides(), equalTo(Collections.emptyMap()));
}

@ParameterizedTest
@ValueSource(strings = {"us-east-1", "us-west-2", "eu-central-1"})
void createPersonalizeEventsClient_provides_correct_inputs(final String regionString) {
final Region region = Region.of(regionString);
final String stsRoleArn = UUID.randomUUID().toString();
final String externalId = UUID.randomUUID().toString();
final Map<String, String> stsHeaderOverrides = Map.of(UUID.randomUUID().toString(), UUID.randomUUID().toString());
when(awsAuthenticationOptions.getAwsRegion()).thenReturn(region);
when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Optional.of(region));
when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn(stsRoleArn);
when(awsAuthenticationOptions.getAwsStsExternalId()).thenReturn(externalId);
when(awsAuthenticationOptions.getAwsStsHeaderOverrides()).thenReturn(stsHeaderOverrides);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import java.util.ArrayList;
import java.util.Collection;
import java.util.Optional;

import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
Expand Down Expand Up @@ -49,7 +50,7 @@ void setup() {
when(personalizeSinkConfig.getDatasetArn()).thenReturn(DATASET_ARN);
when(personalizeSinkConfig.getDatasetType()).thenReturn(DatasetTypeOptions.USERS);
when(personalizeSinkConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions);
when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.of(REGION));
when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Optional.of(Region.of(REGION)));
when(pluginSetting.getName()).thenReturn(SINK_PLUGIN_NAME);
when(pluginSetting.getPipelineName()).thenReturn(SINK_PIPELINE_NAME);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import java.util.Collections;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;

import static org.hamcrest.CoreMatchers.equalTo;
Expand All @@ -29,7 +30,7 @@ void setUp() {
@ParameterizedTest
@ValueSource(strings = {"us-east-1", "us-west-2", "eu-central-1"})
void getAwsRegion_returns_Region_of(final String regionString) {
final Region expectedRegionObject = Region.of(regionString);
final Optional<Region> expectedRegionObject = Optional.of(Region.of(regionString));
final Map<String, Object> jsonMap = Map.of("region", regionString);
final AwsAuthenticationOptions objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationOptions.class);
assertThat(objectUnderTest.getAwsRegion(), equalTo(expectedRegionObject));
Expand All @@ -39,7 +40,7 @@ void getAwsRegion_returns_Region_of(final String regionString) {
void getAwsRegion_returns_null_when_region_is_null() {
final Map<String, Object> jsonMap = Collections.emptyMap();
final AwsAuthenticationOptions objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationOptions.class);
assertThat(objectUnderTest.getAwsRegion(), nullValue());
assertThat(objectUnderTest.getAwsRegion(), equalTo(Optional.empty()));
}

@Test
Expand All @@ -65,6 +66,13 @@ void isValidStsRoleArn_returns_true_for_valid_IAM_role() {
assertTrue(objectUnderTest.isValidStsRoleArn());
}

@Test
void isValidStsRoleArn_returns_true_for_null() {
final Map<String, Object> jsonMap = Collections.emptyMap();
final AwsAuthenticationOptions objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationOptions.class);
assertTrue(objectUnderTest.isValidStsRoleArn());
}

@Test
void isValidStsRoleArn_returns_false_when_arn_service_is_not_IAM() {
final String stsRoleArn = "arn:aws:personalize::123456789012:role/test";
Expand Down

0 comments on commit 09892e0

Please sign in to comment.