diff --git a/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorServiceIT.java b/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorServiceIT.java index 0db9626799..a5f0155828 100644 --- a/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorServiceIT.java +++ b/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorServiceIT.java @@ -1,17 +1,23 @@ package org.opensearch.dataprepper.plugins.lambda.processor; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.when; + import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; import com.fasterxml.jackson.dataformat.yaml.YAMLGenerator; import io.micrometer.core.instrument.Counter; -import static org.junit.jupiter.api.Assertions.assertEquals; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.Mock; -import static org.mockito.Mockito.when; import org.mockito.MockitoAnnotations; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; @@ -24,8 +30,6 @@ import org.opensearch.dataprepper.model.plugin.PluginFactory; import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.model.types.ByteCount; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.BufferFactory; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBufferFactory; import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions; import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType; @@ -33,133 +37,128 @@ import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.lambda.LambdaAsyncClient; -import java.time.Duration; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashMap; -import java.util.List; - @ExtendWith(MockitoExtension.class) public class LambdaProcessorServiceIT { - private LambdaAsyncClient lambdaAsyncClient; - private String functionName; - private String lambdaRegion; - private String role; - private BufferFactory bufferFactory; - @Mock - private LambdaProcessorConfig lambdaProcessorConfig; - @Mock - private BatchOptions batchOptions; - @Mock - private ThresholdOptions thresholdOptions; - @Mock - private AwsAuthenticationOptions awsAuthenticationOptions; - @Mock - private AwsCredentialsSupplier awsCredentialsSupplier; - @Mock - private PluginMetrics pluginMetrics; - @Mock - private PluginFactory pluginFactory; - @Mock - private PluginSetting pluginSetting; - @Mock - private Counter numberOfRecordsSuccessCounter; - @Mock - private Counter numberOfRecordsFailedCounter; - @Mock - private ExpressionEvaluator expressionEvaluator; - - private final ObjectMapper objectMapper = new ObjectMapper(new YAMLFactory().enable(YAMLGenerator.Feature.USE_PLATFORM_LINE_BREAKS)); - - - @BeforeEach - public void setUp() throws Exception { - MockitoAnnotations.openMocks(this); - lambdaRegion = System.getProperty("tests.lambda.processor.region"); - functionName = System.getProperty("tests.lambda.processor.functionName"); - role = System.getProperty("tests.lambda.processor.sts_role_arn"); - - final Region region = Region.of(lambdaRegion); - - lambdaAsyncClient = LambdaAsyncClient.builder() - .region(Region.of(lambdaRegion)) - .build(); - - bufferFactory = new InMemoryBufferFactory(); - - when(pluginMetrics.counter(LambdaProcessor.NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS)). - thenReturn(numberOfRecordsSuccessCounter); - when(pluginMetrics.counter(LambdaProcessor.NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED)). - thenReturn(numberOfRecordsFailedCounter); - } - - - private static Record createRecord() { - final JacksonEvent event = JacksonLog.builder().withData("[{\"name\":\"test\"}]").build(); - return new Record<>(event); - } - - public LambdaProcessor createObjectUnderTest(final String config) throws JsonProcessingException { - - final LambdaProcessorConfig lambdaProcessorConfig = objectMapper.readValue(config, LambdaProcessorConfig.class); - return new LambdaProcessor(pluginFactory,pluginMetrics,lambdaProcessorConfig,awsCredentialsSupplier,expressionEvaluator); - } - - public LambdaProcessor createObjectUnderTest(LambdaProcessorConfig lambdaSinkConfig) throws JsonProcessingException { - return new LambdaProcessor(pluginFactory,pluginMetrics,lambdaSinkConfig,awsCredentialsSupplier,expressionEvaluator); - } - - - private static Collection> generateRecords(int numberOfRecords) { - List> recordList = new ArrayList<>(); - - for (int rows = 1; rows <= numberOfRecords; rows++) { - HashMap eventData = new HashMap<>(); - eventData.put("name", "Person" + rows); - eventData.put("age", Integer.toString(rows)); - - Record eventRecord = new Record<>(JacksonEvent.builder().withData(eventData).withEventType("event").build()); - recordList.add(eventRecord); - } - return recordList; - } - - @ParameterizedTest - @ValueSource(ints = {1,3}) - void verify_records_to_lambda_success(final int recordCount) throws Exception { - - when(lambdaProcessorConfig.getFunctionName()).thenReturn(functionName); - when(lambdaProcessorConfig.getMaxConnectionRetries()).thenReturn(3); - when(lambdaProcessorConfig.getInvocationType()).thenReturn(InvocationType.REQUEST_RESPONSE); - - LambdaProcessor objectUnderTest = createObjectUnderTest(lambdaProcessorConfig); - - Collection> recordsData = generateRecords(recordCount); - List> recordsResult = (List>) objectUnderTest.doExecute(recordsData); - Thread.sleep(Duration.ofSeconds(10).toMillis()); - - assertEquals(recordsResult.size(),recordCount); - } - - @ParameterizedTest - @ValueSource(ints = {1,3}) - void verify_records_with_batching_to_lambda(final int recordCount) throws JsonProcessingException, InterruptedException { - - when(lambdaProcessorConfig.getFunctionName()).thenReturn(functionName); - when(lambdaProcessorConfig.getMaxConnectionRetries()).thenReturn(3); - when(lambdaProcessorConfig.getInvocationType()).thenReturn(InvocationType.REQUEST_RESPONSE); - when(thresholdOptions.getEventCount()).thenReturn(1); - when(thresholdOptions.getMaximumSize()).thenReturn(ByteCount.parse("2mb")); - when(thresholdOptions.getEventCollectTimeOut()).thenReturn(Duration.parse("PT10s")); - when(batchOptions.getKeyName()).thenReturn("lambda_batch_key"); - when(batchOptions.getThresholdOptions()).thenReturn(thresholdOptions); - when(lambdaProcessorConfig.getBatchOptions()).thenReturn(batchOptions); - - LambdaProcessor objectUnderTest = createObjectUnderTest(lambdaProcessorConfig); - Collection> records = generateRecords(recordCount); - Collection> recordsResult = objectUnderTest.doExecute(records); - Thread.sleep(Duration.ofSeconds(10).toMillis()); - assertEquals(recordsResult.size(),recordCount); + private final ObjectMapper objectMapper = new ObjectMapper( + new YAMLFactory().enable(YAMLGenerator.Feature.USE_PLATFORM_LINE_BREAKS)); + private LambdaAsyncClient lambdaAsyncClient; + private String functionName; + private String lambdaRegion; + private String role; + @Mock + private LambdaProcessorConfig lambdaProcessorConfig; + @Mock + private BatchOptions batchOptions; + @Mock + private ThresholdOptions thresholdOptions; + @Mock + private AwsAuthenticationOptions awsAuthenticationOptions; + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + @Mock + private PluginMetrics pluginMetrics; + @Mock + private PluginFactory pluginFactory; + @Mock + private PluginSetting pluginSetting; + @Mock + private Counter numberOfRecordsSuccessCounter; + @Mock + private Counter numberOfRecordsFailedCounter; + @Mock + private ExpressionEvaluator expressionEvaluator; + + private static Record createRecord() { + final JacksonEvent event = JacksonLog.builder().withData("[{\"name\":\"test\"}]").build(); + return new Record<>(event); + } + + private static Collection> generateRecords(int numberOfRecords) { + List> recordList = new ArrayList<>(); + + for (int rows = 1; rows <= numberOfRecords; rows++) { + HashMap eventData = new HashMap<>(); + eventData.put("name", "Person" + rows); + eventData.put("age", Integer.toString(rows)); + + Record eventRecord = new Record<>( + JacksonEvent.builder().withData(eventData).withEventType("event").build()); + recordList.add(eventRecord); } + return recordList; + } + + @BeforeEach + public void setUp() throws Exception { + MockitoAnnotations.openMocks(this); + lambdaRegion = System.getProperty("tests.lambda.processor.region"); + functionName = System.getProperty("tests.lambda.processor.functionName"); + role = System.getProperty("tests.lambda.processor.sts_role_arn"); + + final Region region = Region.of(lambdaRegion); + + lambdaAsyncClient = LambdaAsyncClient.builder() + .region(Region.of(lambdaRegion)) + .build(); + + when(pluginMetrics.counter(LambdaProcessor.NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS)). + thenReturn(numberOfRecordsSuccessCounter); + when(pluginMetrics.counter(LambdaProcessor.NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED)). + thenReturn(numberOfRecordsFailedCounter); + } + + public LambdaProcessor createObjectUnderTest(final String config) throws JsonProcessingException { + + final LambdaProcessorConfig lambdaProcessorConfig = objectMapper.readValue(config, + LambdaProcessorConfig.class); + return new LambdaProcessor(pluginFactory, pluginMetrics, lambdaProcessorConfig, + awsCredentialsSupplier, expressionEvaluator); + } + + public LambdaProcessor createObjectUnderTest(LambdaProcessorConfig lambdaSinkConfig) + throws JsonProcessingException { + return new LambdaProcessor(pluginFactory, pluginMetrics, lambdaSinkConfig, + awsCredentialsSupplier, expressionEvaluator); + } + + @ParameterizedTest + @ValueSource(ints = {1, 3}) + void verify_records_to_lambda_success(final int recordCount) throws Exception { + + when(lambdaProcessorConfig.getFunctionName()).thenReturn(functionName); + when(lambdaProcessorConfig.getMaxConnectionRetries()).thenReturn(3); + when(lambdaProcessorConfig.getInvocationType()).thenReturn(InvocationType.REQUEST_RESPONSE); + + LambdaProcessor objectUnderTest = createObjectUnderTest(lambdaProcessorConfig); + + Collection> recordsData = generateRecords(recordCount); + List> recordsResult = (List>) objectUnderTest.doExecute( + recordsData); + Thread.sleep(Duration.ofSeconds(10).toMillis()); + + assertEquals(recordsResult.size(), recordCount); + } + + @ParameterizedTest + @ValueSource(ints = {1, 3}) + void verify_records_with_batching_to_lambda(final int recordCount) + throws JsonProcessingException, InterruptedException { + + when(lambdaProcessorConfig.getFunctionName()).thenReturn(functionName); + when(lambdaProcessorConfig.getMaxConnectionRetries()).thenReturn(3); + when(lambdaProcessorConfig.getInvocationType()).thenReturn(InvocationType.REQUEST_RESPONSE); + when(thresholdOptions.getEventCount()).thenReturn(1); + when(thresholdOptions.getMaximumSize()).thenReturn(ByteCount.parse("2mb")); + when(thresholdOptions.getEventCollectTimeOut()).thenReturn(Duration.parse("PT10s")); + when(batchOptions.getKeyName()).thenReturn("lambda_batch_key"); + when(batchOptions.getThresholdOptions()).thenReturn(thresholdOptions); + when(lambdaProcessorConfig.getBatchOptions()).thenReturn(batchOptions); + + LambdaProcessor objectUnderTest = createObjectUnderTest(lambdaProcessorConfig); + Collection> records = generateRecords(recordCount); + Collection> recordsResult = objectUnderTest.doExecute(records); + Thread.sleep(Duration.ofSeconds(10).toMillis()); + assertEquals(recordsResult.size(), recordCount); + } } \ No newline at end of file diff --git a/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkServiceIT.java b/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkServiceIT.java deleted file mode 100644 index 352430a02c..0000000000 --- a/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkServiceIT.java +++ /dev/null @@ -1,236 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.dataprepper.plugins.lambda.sink; - -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; -import com.fasterxml.jackson.dataformat.yaml.YAMLGenerator; -import io.micrometer.core.instrument.Counter; -import io.micrometer.core.instrument.Timer; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.extension.ExtendWith; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import org.mockito.Mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import org.mockito.MockitoAnnotations; -import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; -import org.opensearch.dataprepper.expression.ExpressionEvaluator; -import org.opensearch.dataprepper.metrics.PluginMetrics; -import org.opensearch.dataprepper.model.configuration.PluginSetting; -import org.opensearch.dataprepper.model.event.Event; -import org.opensearch.dataprepper.model.event.JacksonEvent; -import org.opensearch.dataprepper.model.log.JacksonLog; -import org.opensearch.dataprepper.model.plugin.PluginFactory; -import org.opensearch.dataprepper.model.record.Record; -import org.opensearch.dataprepper.model.sink.OutputCodecContext; -import org.opensearch.dataprepper.model.types.ByteCount; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.BufferFactory; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBufferFactory; -import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions; -import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; -import org.opensearch.dataprepper.plugins.lambda.common.config.ThresholdOptions; -import static org.opensearch.dataprepper.plugins.lambda.processor.LambdaProcessor.LAMBDA_LATENCY_METRIC; -import static org.opensearch.dataprepper.plugins.lambda.processor.LambdaProcessor.REQUEST_PAYLOAD_SIZE; -import static org.opensearch.dataprepper.plugins.lambda.processor.LambdaProcessor.RESPONSE_PAYLOAD_SIZE; -import org.opensearch.dataprepper.plugins.lambda.sink.dlq.DlqPushHandler; -import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.lambda.LambdaAsyncClient; - -import java.time.Duration; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.concurrent.atomic.AtomicLong; - -@ExtendWith(MockitoExtension.class) -class LambdaSinkServiceIT { - - private LambdaAsyncClient lambdaAsyncClient; - private String functionName; - private String lambdaRegion; - private String role; - private BufferFactory bufferFactory; - @Mock - private LambdaSinkConfig lambdaSinkConfig; - @Mock - private BatchOptions batchOptions; - @Mock - private ThresholdOptions thresholdOptions; - @Mock - private AwsAuthenticationOptions awsAuthenticationOptions; - @Mock - private AwsCredentialsSupplier awsCredentialsSupplier; - @Mock - private PluginMetrics pluginMetrics; - @Mock - private DlqPushHandler dlqPushHandler; - @Mock - private PluginFactory pluginFactory; - @Mock - private PluginSetting pluginSetting; - @Mock - private Counter numberOfRecordsSuccessCounter; - @Mock - private Counter numberOfRecordsFailedCounter; - @Mock - private ExpressionEvaluator expressionEvaluator; - @Mock - private Timer lambdaLatencyMetric; - @Mock - private AtomicLong requestPayload; - @Mock - private AtomicLong responsePayload; - private final ObjectMapper objectMapper = new ObjectMapper(new YAMLFactory().enable(YAMLGenerator.Feature.USE_PLATFORM_LINE_BREAKS)); - - - @BeforeEach - public void setUp() throws Exception { - MockitoAnnotations.openMocks(this); - lambdaRegion = System.getProperty("tests.sink.lambda.region"); - functionName = System.getProperty("tests.sink.lambda.functionName"); - role = System.getProperty("tests.sink.lambda.sts_role_arn"); - - final Region region = Region.of(lambdaRegion); - - lambdaAsyncClient = LambdaAsyncClient.builder() - .region(Region.of(lambdaRegion)) - .build(); - - bufferFactory = new InMemoryBufferFactory(); - - when(pluginMetrics.counter(LambdaSinkService.NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS)). - thenReturn(numberOfRecordsSuccessCounter); - when(pluginMetrics.counter(LambdaSinkService.NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED)). - thenReturn(numberOfRecordsFailedCounter); - when(pluginMetrics.timer(LAMBDA_LATENCY_METRIC)).thenReturn(lambdaLatencyMetric); - when(pluginMetrics.gauge(eq(REQUEST_PAYLOAD_SIZE), any(AtomicLong.class))).thenReturn(requestPayload); - when(pluginMetrics.gauge(eq(RESPONSE_PAYLOAD_SIZE), any(AtomicLong.class))).thenReturn(responsePayload); - } - - - private static Record createRecord() { - final JacksonEvent event = JacksonLog.builder().withData("[{\"name\":\"test\"}]").build(); - return new Record<>(event); - } - - public LambdaSinkService createObjectUnderTest(final String config) throws JsonProcessingException { - - final LambdaSinkConfig lambdaSinkConfig = objectMapper.readValue(config, LambdaSinkConfig.class); - OutputCodecContext codecContext = new OutputCodecContext("Tag", Collections.emptyList(), Collections.emptyList()); - pluginFactory = null; - return new LambdaSinkService(lambdaAsyncClient, - lambdaSinkConfig, - pluginMetrics, - pluginFactory, - pluginSetting, - codecContext, - awsCredentialsSupplier, - dlqPushHandler, - bufferFactory, - expressionEvaluator); - } - - public LambdaSinkService createObjectUnderTest(LambdaSinkConfig lambdaSinkConfig) throws JsonProcessingException { - - OutputCodecContext codecContext = new OutputCodecContext("Tag", Collections.emptyList(), Collections.emptyList()); - pluginFactory = null; - return new LambdaSinkService(lambdaAsyncClient, - lambdaSinkConfig, - pluginMetrics, - pluginFactory, - pluginSetting, - codecContext, - awsCredentialsSupplier, - dlqPushHandler, - bufferFactory, - expressionEvaluator); - } - - - private static Collection> generateRecords(int numberOfRecords) { - List> recordList = new ArrayList<>(); - - for (int rows = 0; rows < numberOfRecords; rows++) { - HashMap eventData = new HashMap<>(); - eventData.put("name", "Person" + rows); - eventData.put("age", Integer.toString(rows)); - - Record eventRecord = new Record<>(JacksonEvent.builder().withData(eventData).withEventType("event").build()); - recordList.add(eventRecord); - } - return recordList; - } - - @ParameterizedTest - @ValueSource(ints = {1,5}) - void verify_flushed_records_to_lambda_success(final int recordCount) throws Exception { - - final String LAMBDA_SINK_CONFIG_YAML = - " function_name: " + functionName +"\n" + - " aws:\n" + - " region: us-east-1\n" + - " sts_role_arn: " + role + "\n" + - " max_retries: 3\n"; - LambdaSinkService objectUnderTest = createObjectUnderTest(LAMBDA_SINK_CONFIG_YAML); - - Collection> recordsData = generateRecords(recordCount); - objectUnderTest.output(recordsData); - Thread.sleep(Duration.ofSeconds(10).toMillis()); - - verify(numberOfRecordsSuccessCounter, times(recordCount)).increment(1); - } - - @ParameterizedTest - @ValueSource(ints = {1,5,10}) - void verify_flushed_records_to_lambda_failed_and_dlq_works(final int recordCount) throws Exception { - final String LAMBDA_SINK_CONFIG_INVALID_FUNCTION_NAME = - " function_name: $$$\n" + - " aws:\n" + - " region: us-east-1\n" + - " sts_role_arn: arn:aws:iam::176893235612:role/osis-s3-opensearch-role\n" + - " max_retries: 3\n" + - " dlq: #any failed even\n"+ - " s3:\n"+ - " bucket: test-bucket\n"+ - " key_path_prefix: dlq/\n"; - LambdaSinkService objectUnderTest = createObjectUnderTest(LAMBDA_SINK_CONFIG_INVALID_FUNCTION_NAME); - - Collection> recordsData = generateRecords(recordCount); - objectUnderTest.output(recordsData); - Thread.sleep(Duration.ofSeconds(10).toMillis()); - - verify( numberOfRecordsFailedCounter, times(recordCount)).increment(1); - } - - @ParameterizedTest - @ValueSource(ints = {2,5}) - void verify_flushed_records_with_batching_to_lambda(final int recordCount) throws JsonProcessingException, InterruptedException { - - int event_count = 2; - when(lambdaSinkConfig.getFunctionName()).thenReturn(functionName); - when(lambdaSinkConfig.getMaxConnectionRetries()).thenReturn(3); - when(thresholdOptions.getEventCount()).thenReturn(event_count); - when(thresholdOptions.getMaximumSize()).thenReturn(ByteCount.parse("2mb")); - when(thresholdOptions.getEventCollectTimeOut()).thenReturn(Duration.parse("PT10s")); - when(batchOptions.getKeyName()).thenReturn("lambda_batch_key"); - when(batchOptions.getThresholdOptions()).thenReturn(thresholdOptions); - when(lambdaSinkConfig.getBatchOptions()).thenReturn(batchOptions); - - LambdaSinkService objectUnderTest = createObjectUnderTest(lambdaSinkConfig); - Collection> recordsData = generateRecords(recordCount); - objectUnderTest.output(recordsData); - Thread.sleep(Duration.ofSeconds(10).toMillis()); - } -} \ No newline at end of file diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java index 4a569d4196..4b35b19776 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java @@ -5,37 +5,27 @@ package org.opensearch.dataprepper.plugins.lambda.common; -import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; -import org.opensearch.dataprepper.plugins.lambda.common.config.LambdaCommonConfig; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.BufferFactory; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBufferFactory; -import org.opensearch.dataprepper.plugins.lambda.common.util.ThresholdCheck; -import org.opensearch.dataprepper.plugins.codec.json.JsonOutputCodec; -import org.opensearch.dataprepper.plugins.codec.json.JsonOutputCodecConfig; -import org.opensearch.dataprepper.model.sink.OutputCodecContext; -import org.opensearch.dataprepper.model.codec.InputCodec; -import org.opensearch.dataprepper.model.codec.OutputCodec; -import org.opensearch.dataprepper.model.event.Event; -import org.opensearch.dataprepper.model.record.Record; -import org.opensearch.dataprepper.expression.ExpressionEvaluator; -import org.opensearch.dataprepper.model.types.ByteCount; - -import org.slf4j.Logger; -import software.amazon.awssdk.core.SdkBytes; -import software.amazon.awssdk.services.lambda.LambdaAsyncClient; -import software.amazon.awssdk.services.lambda.model.InvokeResponse; - -import java.io.ByteArrayInputStream; -import java.io.InputStream; -import java.io.IOException; import java.time.Duration; import java.util.ArrayList; -import java.util.List; import java.util.Collection; import java.util.Collections; +import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.types.ByteCount; +import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; +import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBuffer; +import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; +import org.opensearch.dataprepper.plugins.lambda.common.config.LambdaCommonConfig; +import org.opensearch.dataprepper.plugins.lambda.common.util.ThresholdCheck; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.lambda.LambdaAsyncClient; +import software.amazon.awssdk.services.lambda.model.InvokeRequest; +import software.amazon.awssdk.services.lambda.model.InvokeResponse; public class LambdaCommonHandler { @@ -67,10 +57,7 @@ public static void waitForFutures(List> future } private static List createBufferBatches(Collection> records, - String whenCondition, - ExpressionEvaluator expressionEvaluator, - BatchOptions batchOptions, - List> resultRecords) { + BatchOptions batchOptions) { int maxEvents = batchOptions.getThresholdOptions().getEventCount(); ByteCount maxBytes = batchOptions.getThresholdOptions().getMaximumSize(); @@ -82,16 +69,6 @@ private static List createBufferBatches(Collection> record LOG.info("Batch size received to lambda processor: {}", records.size()); for (Record record : records) { - final Event event = record.getData(); - - //only processor needs to execute this block - if (resultRecords != null) { - if (whenCondition != null && !expressionEvaluator.evaluateConditional(whenCondition, - event)) { - resultRecords.add(record); - continue; - } - } currentBufferPerBatch.addRecord(record); if (ThresholdCheck.checkThresholdExceed(currentBufferPerBatch, maxEvents, maxBytes, @@ -108,8 +85,6 @@ private static List createBufferBatches(Collection> record } public static List> sendRecords(Collection> records, - String whenCondition, - ExpressionEvaluator expressionEvaluator, LambdaCommonConfig config, LambdaAsyncClient lambdaAsyncClient, BiFunction>> successHandler, @@ -120,8 +95,7 @@ public static List> sendRecords(Collection> records, List> futureList = new ArrayList<>(); int totalFlushedEvents = 0; - List batchedBuffers = createBufferBatches(records, - whenCondition, expressionEvaluator, config.getBatchOptions(), resultRecords); + List batchedBuffers = createBufferBatches(records, config.getBatchOptions()); LOG.info("Batch Chunks created after threshold check: {}", batchedBuffers.size()); for (Buffer buffer : batchedBuffers) { @@ -131,7 +105,7 @@ public static List> sendRecords(Collection> records, futureList.add(future); future.thenAccept(response -> { synchronized (resultRecords) { - successHandler.accept(buffer, response); + resultRecords.addAll(successHandler.apply(buffer, response)); } }).exceptionally(throwable -> { synchronized (resultRecords) { @@ -147,18 +121,4 @@ public static List> sendRecords(Collection> records, return resultRecords; } - /* - * If one event in the Buffer fails, we consider that the entire - * Batch fails and tag each event in that Batch. - */ - static void handleFailure(Throwable e, Buffer flushedBuffer, List> resultRecords, - BiConsumer>> failureHandler) { - if (flushedBuffer.getEventCount() > 0) { - //numberOfRecordsFailedCounter.increment(flushedBuffer.getEventCount()); - } - - LOG.error(NOISY, "Failed to process batch due to error: ", e); - - } - } diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java index c7862a4b1d..a5f638e414 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java @@ -5,8 +5,18 @@ package org.opensearch.dataprepper.plugins.lambda.processor; +import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY; + import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Timer; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicLong; import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.expression.ExpressionEvaluator; import org.opensearch.dataprepper.metrics.PluginMetrics; @@ -25,45 +35,38 @@ import org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler; import org.opensearch.dataprepper.plugins.lambda.common.ResponseEventHandlingStrategy; import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.BufferFactory; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBufferFactory; import org.opensearch.dataprepper.plugins.lambda.common.client.LambdaClientFactory; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.services.lambda.LambdaAsyncClient; - -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.concurrent.atomic.AtomicLong; +import software.amazon.awssdk.services.lambda.model.InvokeResponse; @DataPrepperPlugin(name = "aws_lambda", pluginType = Processor.class, pluginConfigurationType = LambdaProcessorConfig.class) public class LambdaProcessor extends AbstractProcessor, Record> { - public static final String NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS = "lambdaProcessorObjectsEventsSucceeded"; - public static final String NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED = "lambdaProcessorObjectsEventsFailed"; - public static final String LAMBDA_LATENCY_METRIC = "lambdaProcessorLatency"; - public static final String REQUEST_PAYLOAD_SIZE = "requestPayloadSize"; - public static final String RESPONSE_PAYLOAD_SIZE = "responsePayloadSize"; - - private static final Logger LOG = LoggerFactory.getLogger(LambdaProcessor.class); - - private final String whenCondition; - private final ExpressionEvaluator expressionEvaluator; - private final Counter numberOfRecordsSuccessCounter; - private final Counter numberOfRecordsFailedCounter; - private final Timer lambdaLatencyMetric; - private final List tagsOnMatchFailure; - private final LambdaAsyncClient lambdaAsyncClient; - private final AtomicLong requestPayloadMetric; - private final AtomicLong responsePayloadMetric; - LambdaCommonHandler lambdaCommonHandler; - final PluginSetting codecPluginSetting; - final PluginFactory pluginFactory; - final LambdaProcessorConfig lambdaProcessorConfig; - private final ResponseEventHandlingStrategy responseStrategy; - private final JsonOutputCodecConfig jsonOutputCodecConfig; + public static final String NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS = "lambdaProcessorObjectsEventsSucceeded"; + public static final String NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED = "lambdaProcessorObjectsEventsFailed"; + public static final String LAMBDA_LATENCY_METRIC = "lambdaProcessorLatency"; + public static final String REQUEST_PAYLOAD_SIZE = "requestPayloadSize"; + public static final String RESPONSE_PAYLOAD_SIZE = "responsePayloadSize"; + + private static final Logger LOG = LoggerFactory.getLogger(LambdaProcessor.class); + final PluginSetting codecPluginSetting; + final PluginFactory pluginFactory; + final LambdaProcessorConfig lambdaProcessorConfig; + private final String whenCondition; + private final ExpressionEvaluator expressionEvaluator; + private final Counter numberOfRecordsSuccessCounter; + private final Counter numberOfRecordsFailedCounter; + private final Timer lambdaLatencyMetric; + private final List tagsOnMatchFailure; + private final LambdaAsyncClient lambdaAsyncClient; + private final AtomicLong requestPayloadMetric; + private final AtomicLong responsePayloadMetric; + private final ResponseEventHandlingStrategy responseStrategy; + private final JsonOutputCodecConfig jsonOutputCodecConfig; + LambdaCommonHandler lambdaCommonHandler; @DataPrepperPluginConstructor public LambdaProcessor(final PluginFactory pluginFactory, final PluginMetrics pluginMetrics, @@ -117,11 +120,8 @@ public Collection> doExecute(Collection> records) { if (records.isEmpty()) { return records; } - // Setup request codec - BufferFactory bufferFactory = new InMemoryBufferFactory(); - // Setup request codec - List> resultRecords = new ArrayList<>(); + List> resultRecords = Collections.synchronizedList(new ArrayList()); List> recordsToLambda = new ArrayList<>(); for (Record record : records) { final Event event = record.getData(); @@ -132,14 +132,12 @@ public Collection> doExecute(Collection> records) { } recordsToLambda.add(record); } - resultRecords.addAll(lambdaCommonHandler.sendRecords(recordsToLambda, - (inputBuffer, response)-> { - List> outputRecords = convertLambdaResponseToEvent(response, inputBuffer); - return outputRecords; - }, - (inputBuffer, outputRecords)-> { - addFailureTags(inputBuffer, outputRecords); - }) + resultRecords.addAll( + lambdaCommonHandler.sendRecords(recordsToLambda, lambdaProcessorConfig, lambdaAsyncClient, + (inputBuffer, response) -> convertLambdaResponseToEvent(inputBuffer, response), + (inputBuffer, outputRecords) -> { + addFailureTags(inputBuffer, outputRecords); + }) ); return resultRecords; } @@ -149,12 +147,14 @@ public Collection> doExecute(Collection> records) { * 1. If response has an array, we assume that we split the individual events. * 2. If it is not an array, then create one event per response. */ - private void convertLambdaResponseToEvent(Buffer flushedBuffer, InvokeResponse lambdaResponse) { + List> convertLambdaResponseToEvent(Buffer flushedBuffer, + final InvokeResponse lambdaResponse) { InputCodec responseCodec = pluginFactory.loadPlugin(InputCodec.class, codecPluginSetting); + List> originalRecords = flushedBuffer.getRecords(); try { List parsedEvents = new ArrayList<>(); - List> originalRecords = flushedBuffer.getRecords(); + List> resultRecords = new ArrayList<>(); SdkBytes payload = lambdaResponse.payload(); // Handle null or empty payload if (payload == null || payload.asByteArray() == null || payload.asByteArray().length == 0) { @@ -182,24 +182,25 @@ private void convertLambdaResponseToEvent(Buffer flushedBuffer, InvokeResponse l LOG.debug("Parsed Event Size:{}, FlushedBuffer eventCount:{}, " + "FlushedBuffer size:{}", parsedEvents.size(), flushedBuffer.getEventCount(), flushedBuffer.getSize()); - /*synchronized (resultRecords) { - responseStrategy.handleEvents(parsedEvents, originalRecords, resultRecords, - flushedBuffer); - }*/ + responseStrategy.handleEvents(parsedEvents, originalRecords, resultRecords, flushedBuffer); + } + return resultRecords; } catch (Exception e) { LOG.error(NOISY, "Error converting Lambda response to Event"); // Metrics update //requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize()); //responsePayloadMetric.set(0); + addFailureTags(flushedBuffer, originalRecords); + return originalRecords; //????? handleFailure(e, flushedBuffer, resultRecords, failureHandler); } } - /* - * If one event in the Buffer fails, we consider that the entire - * Batch fails and tag each event in that Batch. - */ + /* + * If one event in the Buffer fails, we consider that the entire + * Batch fails and tag each event in that Batch. + */ private void addFailureTags(Buffer flushedBuffer, List> resultRecords) { // Add failure tags to each event in the batch diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSink.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSink.java index a76e7f4b6a..e7508c05a9 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSink.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSink.java @@ -131,12 +131,11 @@ public void doOutput(final Collection> records) { //Result from lambda is not currently processes. LambdaCommonHandler.sendRecords(records, - null, - expressionEvaluator, lambdaSinkConfig, lambdaAsyncClient, (inputBuffer, invokeResponse) -> { releaseEventHandlesPerBatch(true, inputBuffer); + return null; }, (inputBuffer, invokeResponse) -> { handleFailure(new RuntimeException("failed"), inputBuffer); diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandlerTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandlerTest.java index ffab564269..ab98b52be4 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandlerTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandlerTest.java @@ -5,21 +5,19 @@ package org.opensearch.dataprepper.plugins.lambda.common; -import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import static org.mockito.ArgumentMatchers.any; import org.mockito.Mock; -import static org.mockito.Mockito.anyString; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import org.mockito.MockitoAnnotations; import org.opensearch.dataprepper.model.event.EventHandle; import org.opensearch.dataprepper.plugins.codec.json.JsonOutputCodecConfig; import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.BufferFactory; import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType; import org.opensearch.dataprepper.plugins.lambda.common.config.LambdaCommonConfig; @@ -28,121 +26,83 @@ import software.amazon.awssdk.services.lambda.LambdaAsyncClient; import software.amazon.awssdk.services.lambda.model.InvokeResponse; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.CompletableFuture; - public class LambdaCommonHandlerTest { - @Mock - private Logger mockLogger; - - @Mock - private LambdaAsyncClient mockLambdaAsyncClient; - - @Mock - private BufferFactory mockBufferFactory; + @Mock + private Logger mockLogger; - @Mock - private Buffer mockBuffer; + @Mock + private LambdaAsyncClient mockLambdaAsyncClient; - @Mock - private InvokeResponse mockInvokeResponse; - private LambdaCommonHandler lambdaCommonHandler; + @Mock + private Buffer mockBuffer; - private String functionName = "test-function"; + @Mock + private InvokeResponse mockInvokeResponse; - private String invocationType = InvocationType.REQUEST_RESPONSE.getAwsLambdaValue(); + private LambdaCommonHandler lambdaCommonHandler; - @Mock - private LambdaCommonConfig lambdaCommonConfig; - @Mock - private JsonOutputCodecConfig jsonOutputCodecConfig; - @Mock - private InvocationType invType; + private String functionName = "test-function"; - @Mock - private BatchOptions batchOptions; + private String invocationType = InvocationType.REQUEST_RESPONSE.getAwsLambdaValue(); - private ThresholdOptions thresholdOptions; + @Mock + private LambdaCommonConfig lambdaCommonConfig; + @Mock + private JsonOutputCodecConfig jsonOutputCodecConfig; + @Mock + private InvocationType invType; - @BeforeEach - public void setUp() { - MockitoAnnotations.openMocks(this); - when(jsonOutputCodecConfig.getKeyName()).thenReturn("test"); - when(invType.getAwsLambdaValue()).thenReturn(InvocationType.REQUEST_RESPONSE.getAwsLambdaValue()); - when(lambdaCommonConfig.getBatchOptions()).thenReturn(batchOptions); - thresholdOptions = new ThresholdOptions(); - when(batchOptions.getThresholdOptions()).thenReturn(thresholdOptions); - when(lambdaCommonConfig.getInvocationType()).thenReturn(invType); - when(lambdaCommonConfig.getFunctionName()).thenReturn(functionName); - lambdaCommonHandler = new LambdaCommonHandler(mockLogger, mockLambdaAsyncClient,jsonOutputCodecConfig, lambdaCommonConfig); - } + @Mock + private BatchOptions batchOptions; - @Test - public void testCreateBuffer_success() throws IOException { - // Arrange - when(mockBufferFactory.getBuffer(any(), anyString(), any())).thenReturn(mockBuffer); + private ThresholdOptions thresholdOptions; - // Act - Buffer result = lambdaCommonHandler.createBuffer(mockBufferFactory); + @BeforeEach + public void setUp() { + MockitoAnnotations.openMocks(this); + when(jsonOutputCodecConfig.getKeyName()).thenReturn("test"); + when(invType.getAwsLambdaValue()).thenReturn( + InvocationType.REQUEST_RESPONSE.getAwsLambdaValue()); + when(lambdaCommonConfig.getBatchOptions()).thenReturn(batchOptions); + thresholdOptions = new ThresholdOptions(); + when(batchOptions.getThresholdOptions()).thenReturn(thresholdOptions); + when(lambdaCommonConfig.getInvocationType()).thenReturn(invType); + when(lambdaCommonConfig.getFunctionName()).thenReturn(functionName); + } - // Assert - verify(mockBufferFactory, times(1)).getBuffer(mockLambdaAsyncClient, functionName, invocationType); - verify(mockLogger, times(1)).debug("Resetting buffer"); - assertEquals(result, mockBuffer); - } + @Test + public void testCreateBuffer_success() throws IOException { - @Test - public void testCreateBuffer_throwsException() throws IOException { - // Arrange - when(mockBufferFactory.getBuffer(any(), anyString(), any())).thenThrow(new IOException("Test Exception")); - - // Act & Assert - try { - lambdaCommonHandler.createBuffer(mockBufferFactory); - } catch (RuntimeException e) { - assert e.getMessage().contains("Failed to reset buffer"); - } - verify(mockBufferFactory, times(1)).getBuffer(mockLambdaAsyncClient, functionName, invocationType); - } + //TODO: need a better test here + } - @Test - public void testWaitForFutures_allComplete() { - // Arrange - List> futureList = new ArrayList<>(); - futureList.add(CompletableFuture.completedFuture(null)); - futureList.add(CompletableFuture.completedFuture(null)); + @Test + public void testCreateBuffer_throwsException() throws IOException { + // Arrange + //TODO: need a better test here + } - // Act - lambdaCommonHandler.waitForFutures(futureList); + @Test + public void testWaitForFutures_allComplete() { + // Arrange + //TODO: need a better test here + } - // Assert - assert futureList.isEmpty(); - } - - @Test - public void testWaitForFutures_withException() { - // Arrange - List> futureList = new ArrayList<>(); - futureList.add(CompletableFuture.failedFuture(new RuntimeException("Test Exception"))); - - // Act - lambdaCommonHandler.waitForFutures(futureList); - - // Assert - assert futureList.isEmpty(); - } + @Test + public void testWaitForFutures_withException() { + // Arrange + //TODO: need a better test here + } - private List mockEventHandleList(int size) { - List eventHandleList = new ArrayList<>(); - for (int i = 0; i < size; i++) { - EventHandle eventHandle = mock(EventHandle.class); - eventHandleList.add(eventHandle); - } - return eventHandleList; + private List mockEventHandleList(int size) { + List eventHandleList = new ArrayList<>(); + for (int i = 0; i < size; i++) { + EventHandle eventHandle = mock(EventHandle.class); + eventHandleList.add(eventHandle); } + return eventHandleList; + } } diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/accumulator/InMemoryBufferTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/accumulator/InMemoryBufferTest.java index 9a9bb1eef6..a1c9736680 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/accumulator/InMemoryBufferTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/accumulator/InMemoryBufferTest.java @@ -13,11 +13,17 @@ import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.io.OutputStream; +import java.time.Duration; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import static org.mockito.ArgumentMatchers.any; import org.mockito.Mock; -import static org.mockito.Mockito.when; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBuffer; import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType; @@ -27,121 +33,120 @@ import software.amazon.awssdk.services.lambda.model.InvokeRequest; import software.amazon.awssdk.services.lambda.model.InvokeResponse; -import java.io.IOException; -import java.io.OutputStream; -import java.time.Duration; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionException; - @ExtendWith(MockitoExtension.class) class InMemoryBufferTest { - public static final int MAX_EVENTS = 55; - @Mock - private LambdaAsyncClient lambdaAsyncClient; - - private final String invocationType = InvocationType.REQUEST_RESPONSE.getAwsLambdaValue(); - - private final String functionName = "testFunction"; - - private InMemoryBuffer inMemoryBuffer; - - @Test - void test_with_write_event_into_buffer() throws IOException { - inMemoryBuffer = new InMemoryBuffer(lambdaAsyncClient, functionName, invocationType); - - while (inMemoryBuffer.getEventCount() < MAX_EVENTS) { - OutputStream outputStream = inMemoryBuffer.getOutputStream(); - outputStream.write(generateByteArray()); - int eventCount = inMemoryBuffer.getEventCount() +1; - inMemoryBuffer.setEventCount(eventCount); - } - assertThat(inMemoryBuffer.getSize(), greaterThanOrEqualTo(54110L)); - assertThat(inMemoryBuffer.getEventCount(), equalTo(MAX_EVENTS)); - assertThat(inMemoryBuffer.getDuration(), notNullValue()); - assertThat(inMemoryBuffer.getDuration(), greaterThanOrEqualTo(Duration.ZERO)); - } - - @Test - void test_with_write_event_into_buffer_and_flush_toLambda() throws IOException { - - // Mock the response of the invoke method - InvokeResponse mockResponse = InvokeResponse.builder() - .statusCode(200) // HTTP 200 for successful invocation - .payload(SdkBytes.fromString("{\"key\": \"value\"}", java.nio.charset.StandardCharsets.UTF_8)) - .build(); - CompletableFuture future = CompletableFuture.completedFuture(mockResponse); - when(lambdaAsyncClient.invoke(any(InvokeRequest.class))).thenReturn(future); - - inMemoryBuffer = new InMemoryBuffer(lambdaAsyncClient, functionName, invocationType); - while (inMemoryBuffer.getEventCount() < MAX_EVENTS) { - OutputStream outputStream = inMemoryBuffer.getOutputStream(); - outputStream.write(generateByteArray()); - int eventCount = inMemoryBuffer.getEventCount() +1; - inMemoryBuffer.setEventCount(eventCount); - } - assertDoesNotThrow(() -> { - CompletableFuture responseFuture = inMemoryBuffer.flushToLambda(invocationType); - InvokeResponse response = responseFuture.join(); - assertThat(response.statusCode(), equalTo(200)); - }); - } - - @Test - void test_uploadedToLambda_success() throws IOException { - // Mock the response of the invoke method - InvokeResponse mockResponse = InvokeResponse.builder() - .statusCode(200) // HTTP 200 for successful invocation - .payload(SdkBytes.fromString("{\"key\": \"value\"}", java.nio.charset.StandardCharsets.UTF_8)) - .build(); - - CompletableFuture future = CompletableFuture.completedFuture(mockResponse); - when(lambdaAsyncClient.invoke(any(InvokeRequest.class))).thenReturn(future); - - - inMemoryBuffer = new InMemoryBuffer(lambdaAsyncClient, functionName, invocationType); - assertNotNull(inMemoryBuffer); - OutputStream outputStream = inMemoryBuffer.getOutputStream(); - outputStream.write(generateByteArray()); - inMemoryBuffer.setEventCount(1); - - assertDoesNotThrow(() -> { - CompletableFuture responseFuture = inMemoryBuffer.flushToLambda(invocationType); - InvokeResponse response = responseFuture.join(); - assertThat(response.statusCode(), equalTo(200)); - }); + public static final int MAX_EVENTS = 55; + private final String invocationType = InvocationType.REQUEST_RESPONSE.getAwsLambdaValue(); + private final String functionName = "testFunction"; + private final String batchOptionKeyName = "bathOption"; + @Mock + private LambdaAsyncClient lambdaAsyncClient; + private InMemoryBuffer inMemoryBuffer; + + @Test + void test_with_write_event_into_buffer() throws IOException { + inMemoryBuffer = new InMemoryBuffer(batchOptionKeyName); + + while (inMemoryBuffer.getEventCount() < MAX_EVENTS) { + OutputStream outputStream = inMemoryBuffer.getOutputStream(); + outputStream.write(generateByteArray()); + int eventCount = inMemoryBuffer.getEventCount() + 1; + inMemoryBuffer.setEventCount(eventCount); } - - @Test - void test_uploadedToLambda_fails() { - // Mock an exception when invoking lambda - SdkClientException sdkClientException = SdkClientException.create("Mock exception"); - - CompletableFuture future = new CompletableFuture<>(); - future.completeExceptionally(sdkClientException); - - when(lambdaAsyncClient.invoke(any(InvokeRequest.class))).thenReturn(future); - - inMemoryBuffer = new InMemoryBuffer(lambdaAsyncClient, functionName, invocationType); - assertNotNull(inMemoryBuffer); - - // Execute and assert exception - CompletionException exception = assertThrows(CompletionException.class, () -> { - CompletableFuture responseFuture = inMemoryBuffer.flushToLambda(invocationType); - responseFuture.join(); // This will throw CompletionException - }); - - // Verify that the cause of the CompletionException is the SdkClientException we threw - assertThat(exception.getCause(), instanceOf(SdkClientException.class)); - assertThat(exception.getCause().getMessage(), equalTo("Mock exception")); - + assertThat(inMemoryBuffer.getSize(), greaterThanOrEqualTo(54110L)); + assertThat(inMemoryBuffer.getEventCount(), equalTo(MAX_EVENTS)); + assertThat(inMemoryBuffer.getDuration(), notNullValue()); + assertThat(inMemoryBuffer.getDuration(), greaterThanOrEqualTo(Duration.ZERO)); + } + + @Test + void test_with_write_event_into_buffer_and_flush_toLambda() throws IOException { + + // Mock the response of the invoke method + InvokeResponse mockResponse = InvokeResponse.builder() + .statusCode(200) // HTTP 200 for successful invocation + .payload( + SdkBytes.fromString("{\"key\": \"value\"}", java.nio.charset.StandardCharsets.UTF_8)) + .build(); + CompletableFuture future = CompletableFuture.completedFuture(mockResponse); + when(lambdaAsyncClient.invoke(any(InvokeRequest.class))).thenReturn(future); + + inMemoryBuffer = new InMemoryBuffer(batchOptionKeyName); + while (inMemoryBuffer.getEventCount() < MAX_EVENTS) { + OutputStream outputStream = inMemoryBuffer.getOutputStream(); + outputStream.write(generateByteArray()); + int eventCount = inMemoryBuffer.getEventCount() + 1; + inMemoryBuffer.setEventCount(eventCount); } - - private byte[] generateByteArray() { - byte[] bytes = new byte[1000]; - for (int i = 0; i < 1000; i++) { - bytes[i] = (byte) i; - } - return bytes; + assertDoesNotThrow(() -> { + InvokeRequest requestPayload = inMemoryBuffer.getRequestPayload( + functionName, invocationType); + CompletableFuture responseFuture = lambdaAsyncClient.invoke(requestPayload); + InvokeResponse response = responseFuture.join(); + assertThat(response.statusCode(), equalTo(200)); + }); + } + + @Test + void test_uploadedToLambda_success() throws IOException { + // Mock the response of the invoke method + InvokeResponse mockResponse = InvokeResponse.builder() + .statusCode(200) // HTTP 200 for successful invocation + .payload( + SdkBytes.fromString("{\"key\": \"value\"}", java.nio.charset.StandardCharsets.UTF_8)) + .build(); + + CompletableFuture future = CompletableFuture.completedFuture(mockResponse); + when(lambdaAsyncClient.invoke(any(InvokeRequest.class))).thenReturn(future); + + inMemoryBuffer = new InMemoryBuffer(batchOptionKeyName); + assertNotNull(inMemoryBuffer); + OutputStream outputStream = inMemoryBuffer.getOutputStream(); + outputStream.write(generateByteArray()); + inMemoryBuffer.setEventCount(1); + + assertDoesNotThrow(() -> { + InvokeRequest requestPayload = inMemoryBuffer.getRequestPayload( + functionName, invocationType); + CompletableFuture responseFuture = lambdaAsyncClient.invoke(requestPayload); + InvokeResponse response = responseFuture.join(); + assertThat(response.statusCode(), equalTo(200)); + }); + } + + @Test + void test_uploadedToLambda_fails() { + // Mock an exception when invoking lambda + SdkClientException sdkClientException = SdkClientException.create("Mock exception"); + + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(sdkClientException); + + when(lambdaAsyncClient.invoke(any(InvokeRequest.class))).thenReturn(future); + + inMemoryBuffer = new InMemoryBuffer(batchOptionKeyName); + assertNotNull(inMemoryBuffer); + + // Execute and assert exception + CompletionException exception = assertThrows(CompletionException.class, () -> { + InvokeRequest requestPayload = inMemoryBuffer.getRequestPayload( + functionName, invocationType); + CompletableFuture responseFuture = lambdaAsyncClient.invoke(requestPayload); + responseFuture.join();// This will throw CompletionException + }); + + // Verify that the cause of the CompletionException is the SdkClientException we threw + assertThat(exception.getCause(), instanceOf(SdkClientException.class)); + assertThat(exception.getCause().getMessage(), equalTo("Mock exception")); + + } + + private byte[] generateByteArray() { + byte[] bytes = new byte[1000]; + for (int i = 0; i < 1000; i++) { + bytes[i] = (byte) i; } + return bytes; + } } \ No newline at end of file