diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java index 5b65067543..fc29b5e904 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java @@ -6,6 +6,8 @@ package org.opensearch.dataprepper.plugins.lambda.common; import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY; + +import org.checkerframework.common.reflection.qual.Invoke; import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; import org.opensearch.dataprepper.plugins.lambda.common.config.LambdaCommonConfig; import org.opensearch.dataprepper.plugins.lambda.common.accumlator.BufferFactory; @@ -36,6 +38,7 @@ import java.util.Collections; import java.util.concurrent.CompletableFuture; import java.util.function.BiConsumer; +import java.util.function.BiFunction; public class LambdaCommonHandler { private final Logger LOG; @@ -43,32 +46,19 @@ public class LambdaCommonHandler { private final String functionName; private final String invocationType; private final LambdaCommonConfig config; - private final String whenCondition; BufferFactory bufferFactory; - final InputCodec responseCodec; - final ExpressionEvaluator expressionEvaluator; JsonOutputCodecConfig jsonOutputCodecConfig; private final int maxEvents; private final ByteCount maxBytes; private final Duration maxCollectionDuration; - private final ResponseEventHandlingStrategy responseStrategy; - public LambdaCommonHandler(final Logger log, final LambdaAsyncClient lambdaAsyncClient, final JsonOutputCodecConfig jsonOutputCodecConfig, - final InputCodec responseCodec, - final String whenCondition, - final ExpressionEvaluator expressionEvaluator, - final ResponseEventHandlingStrategy responseStrategy, final LambdaCommonConfig lambdaCommonConfig) { this.LOG = log; this.lambdaAsyncClient = lambdaAsyncClient; - this.responseStrategy = responseStrategy; this.config = lambdaCommonConfig; this.jsonOutputCodecConfig = jsonOutputCodecConfig; - this.whenCondition = whenCondition; - this.responseCodec = responseCodec; - this.expressionEvaluator = expressionEvaluator; this.functionName = config.getFunctionName(); this.invocationType = config.getInvocationType().getAwsLambdaValue(); maxEvents = lambdaCommonConfig.getBatchOptions().getThresholdOptions().getEventCount(); @@ -77,13 +67,6 @@ public LambdaCommonHandler(final Logger log, bufferFactory = new InMemoryBufferFactory(); } - public LambdaCommonHandler(final Logger log, - final LambdaAsyncClient lambdaAsyncClient, - final JsonOutputCodecConfig jsonOutputCodecConfig, - final LambdaCommonConfig lambdaCommonConfig) { - this(log, lambdaAsyncClient, jsonOutputCodecConfig, null, null, null, null, lambdaCommonConfig); - } - public Buffer createBuffer(BufferFactory bufferFactory) { try { LOG.debug("Resetting buffer"); @@ -116,23 +99,15 @@ public void waitForFutures(List> futureList) { } public List> sendRecords(Collection> records, - BiConsumer>> successHandler, BiConsumer>> failureHandler) { + BiFunction>> successHandler, + BiConsumer>> failureHandler) { List> resultRecords = Collections.synchronizedList(new ArrayList()); boolean createNewBuffer = true; Buffer currentBufferPerBatch = null; OutputCodec requestCodec = null; List futureList = new ArrayList<>(); for (Record record : records) { - final Event event = record.getData(); - - // If the condition is false, add the event to resultRecords as-is - if (whenCondition != null && !expressionEvaluator.evaluateConditional(whenCondition, event)) { - synchronized(resultRecords) { - resultRecords.add(record); - } - continue; - } - + Event event = record.getData(); try { if (createNewBuffer) { currentBufferPerBatch = createBuffer(bufferFactory); @@ -157,8 +132,9 @@ public List> sendRecords(Collection> records, } boolean flushToLambdaIfNeeded(List> resultRecords, Buffer currentBufferPerBatch, - OutputCodec requestCodec, List futureList, BiConsumer>> successHandler, - BiConsumer>> failureHandler, boolean forceFlush) { + OutputCodec requestCodec, List futureList, + BiFunction>> successHandler, + BiConsumer>> failureHandler, boolean forceFlush) { LOG.debug("currentBufferPerBatchEventCount:{}, maxEvents:{}, maxBytes:{}, " + "maxCollectionDuration:{}, forceFlush:{} ", currentBufferPerBatch.getEventCount(), @@ -201,8 +177,9 @@ boolean flushToLambdaIfNeeded(List> resultRecords, Buffer currentB } private void handleLambdaResponse(List> resultRecords, Buffer flushedBuffer, - int eventCount, InvokeResponse response, BiConsumer>> successHandler, - BiConsumer>> failureHandler) { + int eventCount, InvokeResponse response, + BiFunction>> successHandler, + BiConsumer>> failureHandler) { boolean success = checkStatusCode(response); if (success) { LOG.info("Successfully flushed {} events", eventCount); @@ -212,70 +189,16 @@ private void handleLambdaResponse(List> resultRecords, Buffer flus Duration latency = flushedBuffer.stopLatencyWatch(); //lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS); //totalFlushedEvents += eventCount; - - convertLambdaResponseToEvent(resultRecords, response, flushedBuffer, successHandler); + synchronized(resultRecords) { + resultRecords.addAll(successHandler.apply(flushedBuffer, response)); + } + //convertLambdaResponseToEvent(resultRecords, response, flushedBuffer, successHandler); } else { // Non-2xx status code treated as failure handleFailure(new RuntimeException("Non-success Lambda status code: " + response.statusCode()), flushedBuffer, resultRecords, failureHandler); } } - /* - * Assumption: Lambda always returns json array. - * 1. If response has an array, we assume that we split the individual events. - * 2. If it is not an array, then create one event per response. - */ - void convertLambdaResponseToEvent(final List> resultRecords, final InvokeResponse lambdaResponse, - Buffer flushedBuffer, BiConsumer>> successHandler) { - try { - List parsedEvents = new ArrayList<>(); - List> originalRecords = flushedBuffer.getRecords(); - - SdkBytes payload = lambdaResponse.payload(); - // Handle null or empty payload - if (payload == null || payload.asByteArray() == null || payload.asByteArray().length == 0) { - LOG.warn(NOISY, "Lambda response payload is null or empty, dropping the original events"); - // Set metrics - //requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize()); - //responsePayloadMetric.set(0); - } else { - // Set metrics - //requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize()); - //responsePayloadMetric.set(payload.asByteArray().length); - - LOG.debug("Response payload:{}", payload.asUtf8String()); - InputStream inputStream = new ByteArrayInputStream(payload.asByteArray()); - //Convert to response codec - try { - responseCodec.parse(inputStream, record -> { - Event event = record.getData(); - parsedEvents.add(event); - }); - } catch (IOException ex) { - throw new RuntimeException(ex); - } - - LOG.debug("Parsed Event Size:{}, FlushedBuffer eventCount:{}, " + - "FlushedBuffer size:{}", parsedEvents.size(), flushedBuffer.getEventCount(), - flushedBuffer.getSize()); - synchronized(resultRecords) { - responseStrategy.handleEvents(parsedEvents, originalRecords, resultRecords, flushedBuffer); - successHandler.accept(flushedBuffer, originalRecords); - } - } - } catch (Exception e) { - LOG.error(NOISY, "Error converting Lambda response to Event"); - // Metrics update - //requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize()); - //responsePayloadMetric.set(0); - //????? handleFailure(e, flushedBuffer, resultRecords, failureHandler); - } - } - - /* - * If one event in the Buffer fails, we consider that the entire - * Batch fails and tag each event in that Batch. - */ void handleFailure(Throwable e, Buffer flushedBuffer, List> resultRecords, BiConsumer>> failureHandler) { try { if (flushedBuffer.getEventCount() > 0) { @@ -291,4 +214,5 @@ void handleFailure(Throwable e, Buffer flushedBuffer, List> result } } + } diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java index 275373dc1c..9bf3be3390 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java @@ -31,12 +31,21 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.services.lambda.LambdaAsyncClient; +import software.amazon.awssdk.services.lambda.model.InvokeResponse; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.concurrent.atomic.AtomicLong; +import java.util.function.BiConsumer; + +import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY; @DataPrepperPlugin(name = "aws_lambda", pluginType = Processor.class, pluginConfigurationType = LambdaProcessorConfig.class) public class LambdaProcessor extends AbstractProcessor, Record> { @@ -64,6 +73,7 @@ public class LambdaProcessor extends AbstractProcessor, Record responseCodec; @DataPrepperPluginConstructor public LambdaProcessor(final PluginFactory pluginFactory, final PluginMetrics pluginMetrics, final LambdaProcessorConfig lambdaProcessorConfig, final AwsCredentialsSupplier awsCredentialsSupplier, final ExpressionEvaluator expressionEvaluator) { @@ -80,6 +90,7 @@ public LambdaProcessor(final PluginFactory pluginFactory, final PluginMetrics pl tagsOnMatchFailure = lambdaProcessorConfig.getTagsOnMatchFailure(); + PluginModel responseCodecConfig = lambdaProcessorConfig.getResponseCodecConfig(); if (responseCodecConfig == null) { @@ -89,12 +100,15 @@ public LambdaProcessor(final PluginFactory pluginFactory, final PluginMetrics pl codecPluginSetting = new PluginSetting(responseCodecConfig.getPluginName(), responseCodecConfig.getPluginSettings()); } + responseCodec = ThreadLocal.withInitial(()->pluginFactory.loadPlugin(InputCodec.class, codecPluginSetting)); jsonOutputCodecConfig = new JsonOutputCodecConfig(); jsonOutputCodecConfig.setKeyName(lambdaProcessorConfig.getBatchOptions().getKeyName()); lambdaAsyncClient = LambdaClientFactory.createAsyncLambdaClient(lambdaProcessorConfig.getAwsAuthenticationOptions(), lambdaProcessorConfig.getMaxConnectionRetries(), awsCredentialsSupplier, lambdaProcessorConfig.getConnectionTimeout()); + lambdaCommonHandler = new LambdaCommonHandler(LOG, lambdaAsyncClient, jsonOutputCodecConfig, lambdaProcessorConfig); + // Select the correct strategy based on the configuration if (lambdaProcessorConfig.getResponseEventsMatch()) { this.responseStrategy = new StrictResponseEventHandlingStrategy(); @@ -111,12 +125,83 @@ public Collection> doExecute(Collection> records) { } BufferFactory bufferFactory = new InMemoryBufferFactory(); // Setup request codec + List> resultRecords = new ArrayList<>(); + List> recordsToLambda = new ArrayList<>(); + for (Record record : records) { + final Event event = record.getData(); + // If the condition is false, add the event to resultRecords as-is + if (whenCondition != null && !expressionEvaluator.evaluateConditional(whenCondition, event)) { + resultRecords.add(record); + continue; + } + recordsToLambda.add(record); + } + resultRecords.addAll(lambdaCommonHandler.sendRecords(recordsToLambda, + (inputBuffer, response)-> { + List> outputRecords = convertLambdaResponseToEvent(response, inputBuffer); + return outputRecords; + }, + (inputBuffer, outputRecords)-> { + addFailureTags(inputBuffer, outputRecords); + }) + ); + return resultRecords; + } + + List> convertLambdaResponseToEvent(final InvokeResponse lambdaResponse, Buffer flushedBuffer) { + List> originalRecords = flushedBuffer.getRecords(); + try { + List parsedEvents = new ArrayList<>(); - InputCodec responseCodec = pluginFactory.loadPlugin(InputCodec.class, codecPluginSetting); - lambdaCommonHandler = new LambdaCommonHandler(LOG, lambdaAsyncClient, jsonOutputCodecConfig, responseCodec, whenCondition, expressionEvaluator, responseStrategy, lambdaProcessorConfig); - return lambdaCommonHandler.sendRecords(records, (inputBuffer, resultRecords)->{}, (inputBuffer, resultRecords)->{ addFailureTags(inputBuffer, resultRecords);}); + + List> resultRecords = new ArrayList<>(); + SdkBytes payload = lambdaResponse.payload(); + // Handle null or empty payload + if (payload == null || payload.asByteArray() == null || payload.asByteArray().length == 0) { + LOG.warn(NOISY, "Lambda response payload is null or empty, dropping the original events"); + // Set metrics + //requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize()); + //responsePayloadMetric.set(0); + } else { + // Set metrics + //requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize()); + //responsePayloadMetric.set(payload.asByteArray().length); + + LOG.debug("Response payload:{}", payload.asUtf8String()); + InputStream inputStream = new ByteArrayInputStream(payload.asByteArray()); + //Convert to response codec + try { + responseCodec.get().parse(inputStream, record -> { + Event event = record.getData(); + parsedEvents.add(event); + }); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + + LOG.debug("Parsed Event Size:{}, FlushedBuffer eventCount:{}, " + + "FlushedBuffer size:{}", parsedEvents.size(), flushedBuffer.getEventCount(), + flushedBuffer.getSize()); + responseStrategy.handleEvents(parsedEvents, originalRecords, resultRecords, flushedBuffer); + + } + return resultRecords; + } catch (Exception e) { + LOG.error(NOISY, "Error converting Lambda response to Event"); + // Metrics update + //requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize()); + //responsePayloadMetric.set(0); + addFailureTags(flushedBuffer, originalRecords); + return originalRecords; + //????? handleFailure(e, flushedBuffer, resultRecords, failureHandler); + } } + /* + * If one event in the Buffer fails, we consider that the entire + * Batch fails and tag each event in that Batch. + */ + private void addFailureTags(Buffer flushedBuffer, List> resultRecords) { // Add failure tags to each event in the batch for (Record record : flushedBuffer.getRecords()) { diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkService.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkService.java index 076b0139e7..282c389a09 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkService.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkService.java @@ -117,8 +117,9 @@ public void output(Collection> records) { BufferFactory bufferFactory = new InMemoryBufferFactory(); lambdaCommonHandler = new LambdaCommonHandler(LOG, lambdaAsyncClient, jsonOutputCodecConfig, lambdaSinkConfig); lambdaCommonHandler.sendRecords(records, - (inputBuffer, resultRecords)->{ + (inputBuffer, response)->{ releaseEventHandlesPerBatch(true, inputBuffer); + return null; }, (inputBuffer, resultRecords)->{ handleFailure(new RuntimeException("failed"), inputBuffer);