Skip to content

Commit

Permalink
Use async client to delete scroll and pit for OpenSearch as workaroun…
Browse files Browse the repository at this point in the history
…d for bug in client

Signed-off-by: Taylor Gray <[email protected]>
  • Loading branch information
graytaylor0 committed Sep 14, 2023
1 parent 90575b1 commit af2e4b1
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 23 deletions.
1 change: 1 addition & 0 deletions data-prepper-plugins/opensearch-source/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ dependencies {
implementation project(':data-prepper-plugins:buffer-common')
implementation project(':data-prepper-plugins:aws-plugin-api')
implementation 'software.amazon.awssdk:apache-client'
implementation 'software.amazon.awssdk:netty-nio-client'
implementation 'io.micrometer:micrometer-core'
implementation 'com.fasterxml.jackson.core:jackson-databind'
implementation 'com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.15.2'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,15 @@ public class OpenSearchAccessor implements SearchAccessor, ClusterClientFactory
static final String SCROLL_RESOURCE_LIMIT_EXCEPTION_MESSAGE = "Trying to create too many scroll contexts";

private final OpenSearchClient openSearchClient;
private final OpenSearchClient openSearchAsyncClient;
private final SearchContextType searchContextType;

public OpenSearchAccessor(final OpenSearchClient openSearchClient, final SearchContextType searchContextType) {
public OpenSearchAccessor(final OpenSearchClient openSearchClient,
final OpenSearchClient asyncOpenSearchClient,
final SearchContextType searchContextType) {
this.openSearchClient = openSearchClient;
this.searchContextType = searchContextType;
this.openSearchAsyncClient = asyncOpenSearchClient;
}

@Override
Expand Down Expand Up @@ -126,7 +130,7 @@ public SearchWithSearchAfterResults searchWithPit(final SearchPointInTimeRequest
@Override
public void deletePit(final DeletePointInTimeRequest deletePointInTimeRequest) {
try {
final DeletePitResponse deletePitResponse = openSearchClient.deletePit(DeletePitRequest.of(builder -> builder.pitId(Collections.singletonList(deletePointInTimeRequest.getPitId()))));
final DeletePitResponse deletePitResponse = openSearchAsyncClient.deletePit(DeletePitRequest.of(builder -> builder.pitId(Collections.singletonList(deletePointInTimeRequest.getPitId()))));
if (isPitDeletedSuccessfully(deletePitResponse)) {
LOG.debug("Successfully deleted point in time id {}", deletePointInTimeRequest.getPitId());
} else {
Expand Down Expand Up @@ -193,7 +197,7 @@ public SearchScrollResponse searchWithScroll(final SearchScrollRequest searchScr
@Override
public void deleteScroll(final DeleteScrollRequest deleteScrollRequest) {
try {
final ClearScrollResponse clearScrollResponse = openSearchClient.clearScroll(ClearScrollRequest.of(request -> request.scrollId(deleteScrollRequest.getScrollId())));
final ClearScrollResponse clearScrollResponse = openSearchAsyncClient.clearScroll(ClearScrollRequest.of(request -> request.scrollId(deleteScrollRequest.getScrollId())));
if (clearScrollResponse.succeeded()) {
LOG.debug("Successfully deleted scroll context with id {}", deleteScrollRequest.getScrollId());
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
import software.amazon.awssdk.auth.signer.Aws4Signer;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.http.apache.ApacheHttpClient;
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;

import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
Expand Down Expand Up @@ -74,7 +76,18 @@ private OpenSearchClientFactory(final AwsCredentialsSupplier awsCredentialsSuppl
public OpenSearchClient provideOpenSearchClient(final OpenSearchSourceConfiguration openSearchSourceConfiguration) {
OpenSearchTransport transport;
if (Objects.nonNull(openSearchSourceConfiguration.getAwsAuthenticationOptions())) {
transport = createOpenSearchTransportForAws(openSearchSourceConfiguration);
transport = createOpenSearchTransportForAws(openSearchSourceConfiguration, false);
} else {
final RestClient restClient = createOpenSearchRestClient(openSearchSourceConfiguration);
transport = createOpenSearchTransport(restClient);
}
return new OpenSearchClient(transport);
}

public OpenSearchClient provideOpenSearchAsyncClient(final OpenSearchSourceConfiguration openSearchSourceConfiguration) {
OpenSearchTransport transport;
if (Objects.nonNull(openSearchSourceConfiguration.getAwsAuthenticationOptions())) {
transport = createOpenSearchTransportForAws(openSearchSourceConfiguration, true);
} else {
final RestClient restClient = createOpenSearchRestClient(openSearchSourceConfiguration);
transport = createOpenSearchTransport(restClient);
Expand All @@ -92,7 +105,7 @@ private OpenSearchTransport createOpenSearchTransport(final RestClient restClien
return new RestClientTransport(restClient, new JacksonJsonpMapper());
}

private OpenSearchTransport createOpenSearchTransportForAws(final OpenSearchSourceConfiguration openSearchSourceConfiguration) {
private OpenSearchTransport createOpenSearchTransportForAws(final OpenSearchSourceConfiguration openSearchSourceConfiguration, final boolean async) {
final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsSupplier.getProvider(AwsCredentialsOptions.builder()
.withRegion(openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsRegion())
.withStsRoleArn(openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsStsRoleArn())
Expand All @@ -103,14 +116,26 @@ private OpenSearchTransport createOpenSearchTransportForAws(final OpenSearchSour
final boolean isServerlessCollection = Objects.nonNull(openSearchSourceConfiguration.getAwsAuthenticationOptions()) &&
openSearchSourceConfiguration.getAwsAuthenticationOptions().isServerlessCollection();

return new AwsSdk2Transport(createSdkHttpClient(openSearchSourceConfiguration),
HttpHost.create(openSearchSourceConfiguration.getHosts().get(0)).getHostName(),
isServerlessCollection ? AOSS_SERVICE_NAME : AOS_SERVICE_NAME,
openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsRegion(),
AwsSdk2TransportOptions.builder()
.setCredentials(awsCredentialsProvider)
.setMapper(new JacksonJsonpMapper())
.build());
if (!async) {
return new AwsSdk2Transport(createSdkHttpClient(openSearchSourceConfiguration),
HttpHost.create(openSearchSourceConfiguration.getHosts().get(0)).getHostName(),
isServerlessCollection ? AOSS_SERVICE_NAME : AOS_SERVICE_NAME,
openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsRegion(),
AwsSdk2TransportOptions.builder()
.setCredentials(awsCredentialsProvider)
.setMapper(new JacksonJsonpMapper())
.build());
} else {
return new AwsSdk2Transport(createSdkAsyncHttpClient(openSearchSourceConfiguration),
HttpHost.create(openSearchSourceConfiguration.getHosts().get(0)).getHostName(),
isServerlessCollection ? AOSS_SERVICE_NAME : AOS_SERVICE_NAME,
openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsRegion(),
AwsSdk2TransportOptions.builder()
.setCredentials(awsCredentialsProvider)
.setMapper(new JacksonJsonpMapper())
.build());
}

}

private SdkHttpClient createSdkHttpClient(final OpenSearchSourceConfiguration openSearchSourceConfiguration) {
Expand All @@ -129,6 +154,18 @@ private SdkHttpClient createSdkHttpClient(final OpenSearchSourceConfiguration op
return apacheHttpClientBuilder.build();
}

public SdkAsyncHttpClient createSdkAsyncHttpClient(final OpenSearchSourceConfiguration openSearchSourceConfiguration) {
final NettyNioAsyncHttpClient.Builder builder = NettyNioAsyncHttpClient.builder();

if (Objects.nonNull(openSearchSourceConfiguration.getConnectionConfiguration().getConnectTimeout())) {
builder.connectionTimeout(openSearchSourceConfiguration.getConnectionConfiguration().getConnectTimeout());
}

attachSSLContext(builder, openSearchSourceConfiguration);

return builder.build();
}

private RestClient createOpenSearchRestClient(final OpenSearchSourceConfiguration openSearchSourceConfiguration) {
final List<String> hosts = openSearchSourceConfiguration.getHosts();
final HttpHost[] httpHosts = new HttpHost[hosts.size()];
Expand Down Expand Up @@ -274,6 +311,11 @@ private void attachSSLContext(final ApacheHttpClient.Builder apacheHttpClientBui
apacheHttpClientBuilder.tlsTrustManagersProvider(() -> trustManagers);
}

private void attachSSLContext(final NettyNioAsyncHttpClient.Builder asyncClientBuilder, final OpenSearchSourceConfiguration openSearchSourceConfiguration) {
TrustManager[] trustManagers = createTrustManagers(openSearchSourceConfiguration.getConnectionConfiguration().getCertPath());
asyncClientBuilder.tlsTrustManagersProvider(() -> trustManagers);
}

private void attachSSLContext(final HttpAsyncClientBuilder httpClientBuilder, final OpenSearchSourceConfiguration openSearchSourceConfiguration) {

final ConnectionConfiguration connectionConfiguration = openSearchSourceConfiguration.getConnectionConfiguration();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ public SearchAccessor getSearchAccessor() {
}

if (Objects.isNull(elasticsearchClient)) {
return new OpenSearchAccessor(openSearchClient, searchContextType);
return new OpenSearchAccessor(openSearchClient,
openSearchClientFactory.provideOpenSearchAsyncClient(openSearchSourceConfiguration),
searchContextType);
}

return new ElasticsearchAccessor(elasticsearchClient, searchContextType);
Expand All @@ -110,14 +112,18 @@ public SearchAccessor getSearchAccessor() {
private SearchAccessor createSearchAccessorForServerlessCollection(final OpenSearchClient openSearchClient) {
if (Objects.isNull(openSearchSourceConfiguration.getSearchConfiguration().getSearchContextType())) {
LOG.info("Configured with AOS serverless flag as true, defaulting to search_context_type as 'none', which uses search_after");
return new OpenSearchAccessor(openSearchClient, SearchContextType.NONE);
return new OpenSearchAccessor(openSearchClient,
openSearchClientFactory.provideOpenSearchAsyncClient(openSearchSourceConfiguration),
SearchContextType.NONE);
} else {
if (SearchContextType.POINT_IN_TIME.equals(openSearchSourceConfiguration.getSearchConfiguration().getSearchContextType())) {
throw new InvalidPluginConfigurationException("A search_context_type of point_in_time is not supported for serverless collections");
}

LOG.info("Using search_context_type set in the config: '{}'", openSearchSourceConfiguration.getSearchConfiguration().getSearchContextType().toString().toLowerCase());
return new OpenSearchAccessor(openSearchClient, openSearchSourceConfiguration.getSearchConfiguration().getSearchContextType());
return new OpenSearchAccessor(openSearchClient,
openSearchClientFactory.provideOpenSearchAsyncClient(openSearchSourceConfiguration),
openSearchSourceConfiguration.getSearchConfiguration().getSearchContextType());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,11 @@ public class OpenSearchAccessorTest {
@Mock
private OpenSearchClient openSearchClient;

@Mock
private OpenSearchClient asyncOpenSearchClient;

private SearchAccessor createObjectUnderTest() {
return new OpenSearchAccessor(openSearchClient, SearchContextType.POINT_IN_TIME);
return new OpenSearchAccessor(openSearchClient, asyncOpenSearchClient, SearchContextType.POINT_IN_TIME);
}

@Test
Expand Down Expand Up @@ -349,7 +352,7 @@ void delete_pit_with_no_exception_does_not_throw(final boolean successful) throw
when(deletePitRecord.successful()).thenReturn(successful);
when(deletePitResponse.pits()).thenReturn(Collections.singletonList(deletePitRecord));

when(openSearchClient.deletePit(any(DeletePitRequest.class))).thenReturn(deletePitResponse);
when(asyncOpenSearchClient.deletePit(any(DeletePitRequest.class))).thenReturn(deletePitResponse);

createObjectUnderTest().deletePit(deletePointInTimeRequest);
}
Expand All @@ -366,7 +369,7 @@ void delete_scroll_with_no_exception_does_not_throw(final boolean successful) th
when(clearScrollResponse.succeeded()).thenReturn(successful);


when(openSearchClient.clearScroll(any(ClearScrollRequest.class))).thenReturn(clearScrollResponse);
when(asyncOpenSearchClient.clearScroll(any(ClearScrollRequest.class))).thenReturn(clearScrollResponse);

createObjectUnderTest().deleteScroll(deleteScrollRequest);
}
Expand All @@ -378,7 +381,7 @@ void delete_pit_does_not_throw_during_opensearch_exception() throws IOException
final DeletePointInTimeRequest deletePointInTimeRequest = mock(DeletePointInTimeRequest.class);
when(deletePointInTimeRequest.getPitId()).thenReturn(pitId);

when(openSearchClient.deletePit(any(DeletePitRequest.class))).thenThrow(OpenSearchException.class);
when(asyncOpenSearchClient.deletePit(any(DeletePitRequest.class))).thenThrow(OpenSearchException.class);

createObjectUnderTest().deletePit(deletePointInTimeRequest);
}
Expand All @@ -391,7 +394,7 @@ void delete_scroll_does_not_throw_during_opensearch_exception() throws IOExcepti
when(deleteScrollRequest.getScrollId()).thenReturn(scrollId);


when(openSearchClient.clearScroll(any(ClearScrollRequest.class))).thenThrow(OpenSearchException.class);
when(asyncOpenSearchClient.clearScroll(any(ClearScrollRequest.class))).thenThrow(OpenSearchException.class);

createObjectUnderTest().deleteScroll(deleteScrollRequest);
}
Expand All @@ -403,7 +406,7 @@ void delete_pit_does_not_throw_exception_when_client_throws_IOException() throws
final DeletePointInTimeRequest deletePointInTimeRequest = mock(DeletePointInTimeRequest.class);
when(deletePointInTimeRequest.getPitId()).thenReturn(pitId);

when(openSearchClient.deletePit(any(DeletePitRequest.class))).thenThrow(IOException.class);
when(asyncOpenSearchClient.deletePit(any(DeletePitRequest.class))).thenThrow(IOException.class);

createObjectUnderTest().deletePit(deletePointInTimeRequest);
}
Expand All @@ -416,7 +419,7 @@ void delete_scroll_does_not_throw_during_IO_exception() throws IOException {
when(deleteScrollRequest.getScrollId()).thenReturn(scrollId);


when(openSearchClient.clearScroll(any(ClearScrollRequest.class))).thenThrow(IOException.class);
when(asyncOpenSearchClient.clearScroll(any(ClearScrollRequest.class))).thenThrow(IOException.class);

createObjectUnderTest().deleteScroll(deleteScrollRequest);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,27 @@ void provideOpenSearchClient_with_username_and_password() {

}

@Test
void provideAsyncOpenSearchClient_with_username_and_password() {
final String username = UUID.randomUUID().toString();
final String password = UUID.randomUUID().toString();
when(openSearchSourceConfiguration.getUsername()).thenReturn(username);
when(openSearchSourceConfiguration.getPassword()).thenReturn(password);

when(connectionConfiguration.getCertPath()).thenReturn(null);
when(connectionConfiguration.getSocketTimeout()).thenReturn(null);
when(connectionConfiguration.getConnectTimeout()).thenReturn(null);
when(connectionConfiguration.isInsecure()).thenReturn(true);

when(openSearchSourceConfiguration.getAwsAuthenticationOptions()).thenReturn(null);

final OpenSearchClient openSearchClient = createObjectUnderTest().provideOpenSearchAsyncClient(openSearchSourceConfiguration);
assertThat(openSearchClient, notNullValue());

verifyNoInteractions(awsCredentialsSupplier);

}

@Test
void provideElasticSearchClient_with_username_and_password() {
final String username = UUID.randomUUID().toString();
Expand Down Expand Up @@ -150,6 +171,33 @@ void provideOpenSearchClient_with_aws_auth() {
assertThat(awsCredentialsOptions.getStsRoleArn(), equalTo(stsRoleArn));
}

@Test
void provideAsyncOpenSearchClient_with_aws_auth() {
when(connectionConfiguration.getCertPath()).thenReturn(null);
when(connectionConfiguration.getConnectTimeout()).thenReturn(null);

final AwsAuthenticationConfiguration awsAuthenticationConfiguration = mock(AwsAuthenticationConfiguration.class);
when(awsAuthenticationConfiguration.getAwsRegion()).thenReturn(Region.US_EAST_1);
final String stsRoleArn = "arn:aws:iam::123456789012:role/my-role";
when(awsAuthenticationConfiguration.getAwsStsRoleArn()).thenReturn(stsRoleArn);
when(awsAuthenticationConfiguration.getAwsStsHeaderOverrides()).thenReturn(Collections.emptyMap());
when(awsAuthenticationConfiguration.isServerlessCollection()).thenReturn(false);
when(openSearchSourceConfiguration.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationConfiguration);

final ArgumentCaptor<AwsCredentialsOptions> awsCredentialsOptionsArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsOptions.class);
final AwsCredentialsProvider awsCredentialsProvider = mock(AwsCredentialsProvider.class);
when(awsCredentialsSupplier.getProvider(awsCredentialsOptionsArgumentCaptor.capture())).thenReturn(awsCredentialsProvider);

final OpenSearchClient openSearchClient = createObjectUnderTest().provideOpenSearchAsyncClient(openSearchSourceConfiguration);
assertThat(openSearchClient, notNullValue());

final AwsCredentialsOptions awsCredentialsOptions = awsCredentialsOptionsArgumentCaptor.getValue();
assertThat(awsCredentialsOptions, notNullValue());
assertThat(awsCredentialsOptions.getRegion(), equalTo(Region.US_EAST_1));
assertThat(awsCredentialsOptions.getStsHeaderOverrides(), equalTo(Collections.emptyMap()));
assertThat(awsCredentialsOptions.getStsRoleArn(), equalTo(stsRoleArn));
}

@Test
void provideElasticSearchClient_with_auth_disabled() {
when(openSearchSourceConfiguration.isAuthenticationDisabled()).thenReturn(true);
Expand Down
Loading

0 comments on commit af2e4b1

Please sign in to comment.