diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java index e5792a46c6..7afce256f0 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java @@ -32,11 +32,10 @@ public LambdaCommonHandler( this.bufferFactory = bufferFactory; } - public Buffer createBuffer(Buffer currentBuffer) { + public Buffer createBuffer() { try { LOG.debug("Resetting buffer"); - currentBuffer = bufferFactory.getBuffer(lambdaAsyncClient, functionName, invocationType); - return currentBuffer; + return bufferFactory.getBuffer(lambdaAsyncClient, functionName, invocationType); } catch (IOException e) { throw new RuntimeException("Failed to reset buffer", e); } diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/Buffer.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/Buffer.java index 9c99d2fa47..878d5e9033 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/Buffer.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/Buffer.java @@ -42,11 +42,8 @@ public interface Buffer { public Long getPayloadRequestSize(); - public Long getPayloadResponseSize(); - public Duration stopLatencyWatch(); - void reset(); } diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/InMemoryBuffer.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/InMemoryBuffer.java index 297482c360..109a141e09 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/InMemoryBuffer.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/InMemoryBuffer.java @@ -39,7 +39,6 @@ public class InMemoryBuffer implements Buffer { private StopWatch lambdaLatencyWatch; private long payloadRequestSize; private long payloadResponseSize; - private boolean isCodecStarted; private final List> records; @@ -53,7 +52,6 @@ public InMemoryBuffer(LambdaAsyncClient lambdaAsyncClient, String functionName, bufferWatch.start(); lambdaLatencyWatch = new StopWatch(); eventCount = 0; - isCodecStarted = false; payloadRequestSize = 0; payloadResponseSize = 0; } @@ -86,7 +84,6 @@ public void reset() { eventCount = 0; bufferWatch.reset(); lambdaLatencyWatch.reset(); - isCodecStarted = false; payloadRequestSize = 0; payloadResponseSize = 0; } @@ -160,13 +157,10 @@ public Long getPayloadRequestSize() { return payloadRequestSize; } - public Long getPayloadResponseSize() { - return payloadResponseSize; - } - public StopWatch getBufferWatch() {return bufferWatch;} public StopWatch getLambdaLatencyWatch(){return lambdaLatencyWatch;} + } diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/AggregateResponseEventHandlingStrategy.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/AggregateResponseEventHandlingStrategy.java index fc56738c21..7d32a4f380 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/AggregateResponseEventHandlingStrategy.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/AggregateResponseEventHandlingStrategy.java @@ -5,13 +5,19 @@ import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.util.List; public class AggregateResponseEventHandlingStrategy implements ResponseEventHandlingStrategy { + private static final Logger LOG = LoggerFactory.getLogger(AggregateResponseEventHandlingStrategy.class); + @Override - public void handleEvents(List parsedEvents, List> originalRecords, List> resultRecords, Buffer flushedBuffer) { + public void handleEvents(List parsedEvents, List> originalRecords, + List> resultRecords, Buffer flushedBuffer) { + Event originalEvent = originalRecords.get(0).getData(); DefaultEventHandle eventHandle = (DefaultEventHandle) originalEvent.getEventHandle(); AcknowledgementSet originalAcknowledgementSet = eventHandle.getAcknowledgementSet(); diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java index b1e74ed096..43160162c4 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java @@ -51,6 +51,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.ReentrantLock; @DataPrepperPlugin(name = "aws_lambda", pluginType = Processor.class, pluginConfigurationType = LambdaProcessorConfig.class) public class LambdaProcessor extends AbstractProcessor, Record> { @@ -76,6 +77,7 @@ public class LambdaProcessor extends AbstractProcessor, Record(); + futureList = Collections.synchronizedList(new ArrayList<>()); lambdaAsyncClient = LambdaClientFactory.createAsyncLambdaClient(lambdaProcessorConfig.getAwsAuthenticationOptions(), lambdaProcessorConfig.getMaxConnectionRetries(), awsCredentialsSupplier, lambdaProcessorConfig.getConnectionTimeout()); @@ -137,7 +141,7 @@ public LambdaProcessor(final PluginFactory pluginFactory, final PluginMetrics pl // Initialize LambdaCommonHandler lambdaCommonHandler = new LambdaCommonHandler(LOG, lambdaAsyncClient, functionName, invocationType, bufferFactory); - currentBufferPerBatch = lambdaCommonHandler.createBuffer(currentBufferPerBatch); + currentBufferPerBatch = lambdaCommonHandler.createBuffer(); LOG.info("LambdaFunctionName:{} , responseEventsMatch:{}, invocationType:{}", functionName, lambdaProcessorConfig.getResponseEventsMatch(), invocationType); @@ -150,42 +154,47 @@ public Collection> doExecute(Collection> records) { } //lambda mutates event - List> resultRecords = new ArrayList<>(); + List> resultRecords = Collections.synchronizedList(new ArrayList<>()); - for (Record record : records) { - final Event event = record.getData(); - - // If the condition is false, add the event to resultRecords as-is - if (whenCondition != null && !expressionEvaluator.evaluateConditional(whenCondition, event)) { - resultRecords.add(record); - continue; - } + reentrantLock.lock(); + try { + for (Record record : records) { + final Event event = record.getData(); - try { - if (currentBufferPerBatch.getEventCount() == 0) { - requestCodec.start(currentBufferPerBatch.getOutputStream(), event, codecContext); + // If the condition is false, add the event to resultRecords as-is + if (whenCondition != null && !expressionEvaluator.evaluateConditional(whenCondition, event)) { + resultRecords.add(record); + continue; } - requestCodec.writeEvent(event, currentBufferPerBatch.getOutputStream()); - currentBufferPerBatch.addRecord(record); - flushToLambdaIfNeeded(resultRecords, false); - } catch (Exception e) { - LOG.error(NOISY, "Exception while processing event {}", event, e); - handleFailure(e, currentBufferPerBatch, resultRecords); - currentBufferPerBatch.reset(); + try { + if (currentBufferPerBatch.getEventCount() == 0) { + requestCodec.start(currentBufferPerBatch.getOutputStream(), event, codecContext); + } + requestCodec.writeEvent(event, currentBufferPerBatch.getOutputStream()); + currentBufferPerBatch.addRecord(record); + + flushToLambdaIfNeeded(resultRecords, false); + } catch (Exception e) { + LOG.error(NOISY, "Exception while processing event {}", event, e); + handleFailure(e, currentBufferPerBatch, resultRecords); + currentBufferPerBatch.reset(); + } } - } - // Flush any remaining events in the buffer after processing all records - if (currentBufferPerBatch.getEventCount() > 0) { - LOG.info("Force Flushing the remaining {} events in the buffer", currentBufferPerBatch.getEventCount()); - try { - flushToLambdaIfNeeded(resultRecords, true); // Force flush remaining events - currentBufferPerBatch.reset(); - } catch (Exception e) { - LOG.error("Exception while flushing remaining events", e); - handleFailure(e, currentBufferPerBatch, resultRecords); + // Flush any remaining events in the buffer after processing all records + if (currentBufferPerBatch.getEventCount() > 0) { + LOG.info("Force Flushing the remaining {} events in the buffer", currentBufferPerBatch.getEventCount()); + try { + flushToLambdaIfNeeded(resultRecords, true); // Force flush remaining events + currentBufferPerBatch.reset(); + } catch (Exception e) { + LOG.error("Exception while flushing remaining events", e); + handleFailure(e, currentBufferPerBatch, resultRecords); + } } + } finally { + reentrantLock.unlock(); } lambdaCommonHandler.waitForFutures(futureList); @@ -213,7 +222,10 @@ void flushToLambdaIfNeeded(List> resultRecords, boolean forceFlush handleLambdaResponse(resultRecords, flushedBuffer, eventCount, response); }).exceptionally(throwable -> { //Failure handler - LOG.error(NOISY, "Exception occurred while invoking Lambda. Function: {}, event in batch:{} | Exception: ", functionName, currentBufferPerBatch.getRecords().get(0), throwable); + List> bufferRecords = flushedBuffer.getRecords(); + Record eventRecord = bufferRecords.isEmpty() ? null : bufferRecords.get(0); + LOG.error(NOISY, "Exception occurred while invoking Lambda. Function: {} , Event: {}", + functionName, eventRecord == null? "null":eventRecord.getData(), throwable); requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize()); responsePayloadMetric.set(0); Duration latency = flushedBuffer.stopLatencyWatch(); @@ -225,11 +237,11 @@ void flushToLambdaIfNeeded(List> resultRecords, boolean forceFlush futureList.add(processingFuture); // Create a new buffer for the next batch - currentBufferPerBatch = lambdaCommonHandler.createBuffer(currentBufferPerBatch); + currentBufferPerBatch = lambdaCommonHandler.createBuffer(); } catch (IOException e) { LOG.error(NOISY, "Exception while flushing to lambda", e); handleFailure(e, currentBufferPerBatch, resultRecords); - currentBufferPerBatch = lambdaCommonHandler.createBuffer(currentBufferPerBatch); + currentBufferPerBatch = lambdaCommonHandler.createBuffer(); } } } @@ -260,38 +272,37 @@ private void handleLambdaResponse(List> resultRecords, Buffer flus */ void convertLambdaResponseToEvent(final List> resultRecords, final InvokeResponse lambdaResponse, Buffer flushedBuffer) { try { + List parsedEvents = new ArrayList<>(); + List> originalRecords = flushedBuffer.getRecords(); SdkBytes payload = lambdaResponse.payload(); - // Handle null or empty payload if (payload == null || payload.asByteArray() == null || payload.asByteArray().length == 0) { - LOG.error(NOISY, "Lambda response payload is null or empty"); - throw new RuntimeException("Lambda response payload is null or empty"); - } - - // Record payload sizes - requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize()); - responsePayloadMetric.set(payload.asByteArray().length); - - LOG.debug("Response payload:{}", payload.asUtf8String()); - InputStream inputStream = new ByteArrayInputStream(payload.asByteArray()); - List parsedEvents = new ArrayList<>(); + LOG.warn(NOISY, "Lambda response payload is null or empty, dropping the original events"); + // Set metrics + requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize()); + responsePayloadMetric.set(0); + } else { + // Set metrics + requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize()); + responsePayloadMetric.set(payload.asByteArray().length); + + LOG.debug("Response payload:{}", payload.asUtf8String()); + InputStream inputStream = new ByteArrayInputStream(payload.asByteArray()); + //Convert to response codec + try { + responseCodec.parse(inputStream, record -> { + Event event = record.getData(); + parsedEvents.add(event); + }); + } catch (IOException ex) { + throw new RuntimeException(ex); + } - //Convert to response codec - try { - responseCodec.parse(inputStream, record -> { - Event event = record.getData(); - parsedEvents.add(event); - }); - } catch (IOException ex) { - throw new RuntimeException(ex); + LOG.debug("Parsed Event Size:{}, FlushedBuffer eventCount:{}, FlushedBuffer size:{}", parsedEvents.size(), flushedBuffer.getEventCount(), flushedBuffer.getSize()); + responseStrategy.handleEvents(parsedEvents, originalRecords, resultRecords, flushedBuffer); } - List> originalRecords = flushedBuffer.getRecords(); - - LOG.debug("Parsed Event Size:{}, FlushedBuffer eventCount:{}, FlushedBuffer size:{}", parsedEvents.size(), flushedBuffer.getEventCount(), flushedBuffer.getSize()); - - responseStrategy.handleEvents(parsedEvents, originalRecords, resultRecords, flushedBuffer); } catch (Exception e) { LOG.error(NOISY, "Error converting Lambda response to Event"); @@ -307,12 +318,19 @@ void convertLambdaResponseToEvent(final List> resultRecords, final * Batch fails and tag each event in that Batch. */ void handleFailure(Throwable e, Buffer flushedBuffer, List> resultRecords) { - if (flushedBuffer.getEventCount() > 0) { - numberOfRecordsFailedCounter.increment(flushedBuffer.getEventCount()); - } else { - numberOfRecordsFailedCounter.increment(); + try { + if (flushedBuffer.getEventCount() > 0) { + numberOfRecordsFailedCounter.increment(flushedBuffer.getEventCount()); + } + + addFailureTags(flushedBuffer, resultRecords); + LOG.error(NOISY, "Failed to process batch due to error: ", e); + } catch(Exception ex){ + LOG.error(NOISY, "Exception in handleFailure while processing failure for buffer: ", ex); } + } + private void addFailureTags(Buffer flushedBuffer, List> resultRecords) { // Add failure tags to each event in the batch for (Record record : flushedBuffer.getRecords()) { Event event = record.getData(); @@ -324,7 +342,6 @@ void handleFailure(Throwable e, Buffer flushedBuffer, List> result } resultRecords.add(record); } - LOG.error(NOISY, "Failed to process batch due to error: ", e); } @@ -335,12 +352,19 @@ public void prepareForShutdown() { @Override public boolean isReadyForShutdown() { - return false; + //check if there are no pending futures + synchronized (futureList) { + return futureList.isEmpty() && currentBufferPerBatch.getEventCount() == 0; + } } @Override public void shutdown() { - + // Cancel any pending futures + for (CompletableFuture future : futureList) { + future.cancel(true); + } + futureList.clear(); } } \ No newline at end of file diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkService.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkService.java index 48f17d976c..981c520973 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkService.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkService.java @@ -9,6 +9,7 @@ import io.micrometer.core.instrument.Timer; import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.expression.ExpressionEvaluator; +import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY; import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.codec.OutputCodec; import org.opensearch.dataprepper.model.configuration.PluginSetting; @@ -37,6 +38,7 @@ import java.time.Duration; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; @@ -67,7 +69,6 @@ public class LambdaSinkService { private final String invocationType; private final BufferFactory bufferFactory; private final DlqPushHandler dlqPushHandler; - private final List events; private final BatchOptions batchOptions; private int maxEvents = 0; private ByteCount maxBytes = null; @@ -107,8 +108,7 @@ public LambdaSinkService(final LambdaAsyncClient lambdaAsyncClient, final Lambda maxBytes = batchOptions.getThresholdOptions().getMaximumSize(); maxCollectionDuration = batchOptions.getThresholdOptions().getEventCollectTimeOut(); invocationType = lambdaSinkConfig.getInvocationType().getAwsLambdaValue(); - events = new ArrayList(); - futureList = new ArrayList<>(); + futureList = Collections.synchronizedList(new ArrayList<>()); this.bufferFactory = bufferFactory; @@ -123,14 +123,16 @@ public void output(Collection> records) { return; } - List> resultRecords = new ArrayList<>(); + //Result from lambda is not currently processes. + List> resultRecords = null; + reentrantLock.lock(); try { for (Record record : records) { final Event event = record.getData(); if (whenCondition != null && !expressionEvaluator.evaluateConditional(whenCondition, event)) { - resultRecords.add(record); + releaseEventHandle(event, true); continue; } try { @@ -167,6 +169,12 @@ public void output(Collection> records) { // Wait for all futures to complete lambdaCommonHandler.waitForFutures(futureList); + // Release event handles for records not sent to Lambda + for (Record record : records) { + Event event = record.getData(); + releaseEventHandle(event, true); + } + } void flushToLambdaIfNeeded(List> resultRecords, boolean forceFlush) { @@ -182,22 +190,13 @@ void flushToLambdaIfNeeded(List> resultRecords, boolean forceFlush // Handle future CompletableFuture processingFuture = future.thenAccept(response -> { - // Success handler - boolean success = lambdaCommonHandler.checkStatusCode(response); - if(success) { - LOG.info("Successfully flushed {} events", eventCount); - numberOfRecordsSuccessCounter.increment(eventCount); - requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize()); - Duration latency = flushedBuffer.stopLatencyWatch(); - lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS); - } else { - // Non-2xx status code treated as failure - handleFailure(new RuntimeException("Non-success Lambda status code: " + response.statusCode()), - flushedBuffer); - } + handleLambdaResponse(flushedBuffer, eventCount, response); }).exceptionally(throwable -> { // Failure handler - LOG.error("Exception occurred while invoking Lambda. Function: {}, event in batch:{} | Exception: ", functionName, currentBufferPerBatch.getRecords().get(0), throwable); + List> bufferRecords = flushedBuffer.getRecords(); + Record eventRecord = bufferRecords.isEmpty() ? null : bufferRecords.get(0); + LOG.error(NOISY, "Exception occurred while invoking Lambda. Function: {} , Event: {}", + functionName, eventRecord == null? "null":eventRecord.getData(), throwable); requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize()); responsePayloadMetric.set(0); Duration latency = flushedBuffer.stopLatencyWatch(); @@ -209,28 +208,30 @@ void flushToLambdaIfNeeded(List> resultRecords, boolean forceFlush futureList.add(processingFuture); // Create a new buffer for the next batch - currentBufferPerBatch = lambdaCommonHandler.createBuffer(currentBufferPerBatch); + currentBufferPerBatch = lambdaCommonHandler.createBuffer(); } catch (IOException e) { LOG.error("Exception while flushing to lambda", e); handleFailure(e, currentBufferPerBatch); - currentBufferPerBatch = lambdaCommonHandler.createBuffer(currentBufferPerBatch); + currentBufferPerBatch = lambdaCommonHandler.createBuffer(); } } } void handleFailure(Throwable throwable, Buffer flushedBuffer) { - if (currentBufferPerBatch.getEventCount() > 0) { - numberOfRecordsFailedCounter.increment(currentBufferPerBatch.getEventCount()); - } else { - numberOfRecordsFailedCounter.increment(); - } + try { + if (flushedBuffer.getEventCount() > 0) { + numberOfRecordsFailedCounter.increment(flushedBuffer.getEventCount()); + } - SdkBytes payload = currentBufferPerBatch.getPayload(); - if (dlqPushHandler != null) { - dlqPushHandler.perform(pluginSetting, new LambdaSinkFailedDlqData(payload, throwable.getMessage(), 0)); - releaseEventHandlesPerBatch(true, flushedBuffer); - } else { - releaseEventHandlesPerBatch(false, flushedBuffer); + SdkBytes payload = flushedBuffer.getPayload(); + if (dlqPushHandler != null) { + dlqPushHandler.perform(pluginSetting, new LambdaSinkFailedDlqData(payload, throwable.getMessage(), 0)); + releaseEventHandlesPerBatch(true, flushedBuffer); + } else { + releaseEventHandlesPerBatch(false, flushedBuffer); + } + } catch (Exception ex){ + LOG.error("Exception occured during error handling"); } } @@ -241,6 +242,18 @@ private void releaseEventHandlesPerBatch(boolean success, Buffer flushedBuffer) List> records = flushedBuffer.getRecords(); for (Record record : records) { Event event = record.getData(); + releaseEventHandle(event, success); + } + } + + /** + * Releases the event handle based on processing success. + * + * @param event the event to release + * @param success indicates if processing was successful + */ + private void releaseEventHandle(Event event, boolean success) { + if (event != null) { EventHandle eventHandle = event.getEventHandle(); if (eventHandle != null) { eventHandle.release(success); @@ -248,4 +261,26 @@ private void releaseEventHandlesPerBatch(boolean success, Buffer flushedBuffer) } } + private void handleLambdaResponse(Buffer flushedBuffer, int eventCount, InvokeResponse response) { + boolean success = lambdaCommonHandler.checkStatusCode(response); + if (success) { + LOG.info("Successfully flushed {} events", eventCount); + SdkBytes payload = response.payload(); + if (payload == null || payload.asByteArray() == null || payload.asByteArray().length == 0) { + responsePayloadMetric.set(0); + } else { + responsePayloadMetric.set(payload.asByteArray().length); + } + //metrics + requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize()); + numberOfRecordsSuccessCounter.increment(eventCount); + Duration latency = flushedBuffer.stopLatencyWatch(); + lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS); + } + else { + // Non-2xx status code treated as failure + handleFailure(new RuntimeException("Non-success Lambda status code: " + response.statusCode()), flushedBuffer); + } + } + } diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandlerTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandlerTest.java index 86c273bcd2..4a80aa0b34 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandlerTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandlerTest.java @@ -61,7 +61,7 @@ public void testCreateBuffer_success() throws IOException { when(mockBufferFactory.getBuffer(any(), anyString(), any())).thenReturn(mockBuffer); // Act - Buffer result = lambdaCommonHandler.createBuffer(mockBuffer); + Buffer result = lambdaCommonHandler.createBuffer(); // Assert verify(mockBufferFactory, times(1)).getBuffer(mockLambdaAsyncClient, functionName, invocationType); @@ -76,7 +76,7 @@ public void testCreateBuffer_throwsException() throws IOException { // Act & Assert try { - lambdaCommonHandler.createBuffer(mockBuffer); + lambdaCommonHandler.createBuffer(); } catch (RuntimeException e) { assert e.getMessage().contains("Failed to reset buffer"); } diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java index d7cdc5148b..85242464b2 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java @@ -135,7 +135,7 @@ public void setUp() throws Exception { when(thresholdOptions.getEventCollectTimeOut()).thenReturn(Duration.ofSeconds(30)); // Mock lambdaCommonHandler.createBuffer() to return currentBufferPerBatch - when(lambdaCommonHandler.createBuffer(any())).thenReturn(currentBufferPerBatch); + when(lambdaCommonHandler.createBuffer()).thenReturn(currentBufferPerBatch); // Mock currentBufferPerBatch.reset() doNothing().when(currentBufferPerBatch).reset(); @@ -243,7 +243,7 @@ public void testDoExecute_WithRecords_SuccessfulProcessing() throws Exception { }).when(responseCodec).parse(any(InputStream.class), any(Consumer.class)); // Mock lambdaCommonHandler.createBuffer() to return currentBufferPerBatch - when(lambdaCommonHandler.createBuffer(any())).thenReturn(currentBufferPerBatch); + when(lambdaCommonHandler.createBuffer()).thenReturn(currentBufferPerBatch); setupTestObject(); populatePrivateFields(); @@ -287,6 +287,127 @@ public void testHandleFailure() throws Exception { verify(numberOfRecordsFailedCounter, times(1)).increment(1.0); } + @Test + public void testDoExecute_WithSingleRecord_SuccessfulProcessing() throws Exception { + // Arrange + setupTestObject(); + populatePrivateFields(); + + Event event = mock(Event.class); + Record record = new Record<>(event); + Collection> records = Collections.singletonList(record); + + when(currentBufferPerBatch.getRecords()).thenReturn((List>) records); + when(currentBufferPerBatch.getEventCount()).thenReturn(0).thenReturn(1); + when(currentBufferPerBatch.flushToLambda(any())).thenReturn(CompletableFuture.completedFuture(invokeResponse)); + when(lambdaCommonHandler.checkStatusCode(any())).thenReturn(true); + when(invokeResponse.payload()).thenReturn(SdkBytes.fromUtf8String("[{\"key\":\"value\"}]")); + + // Act + Collection> result = lambdaProcessor.doExecute(records); + + // Assert + assertEquals(1, result.size()); + verify(requestCodec, times(1)).writeEvent(eq(event), any()); + verify(currentBufferPerBatch, times(1)).flushToLambda(any()); + } + + @Test + public void testDoExecute_WithMultipleRecords_BatchProcessing() throws Exception { + // Arrange + setupTestObject(); + populatePrivateFields(); + + List> records = new ArrayList<>(); + for (int i = 0; i < 5; i++) { + records.add(new Record<>(mock(Event.class))); + } + + when(currentBufferPerBatch.getRecords()).thenReturn((List>) records); + when(currentBufferPerBatch.getEventCount()).thenReturn(0).thenReturn(1).thenReturn(2).thenReturn(3).thenReturn(4).thenReturn(5); + when(currentBufferPerBatch.flushToLambda(any())).thenReturn(CompletableFuture.completedFuture(invokeResponse)); + when(lambdaCommonHandler.checkStatusCode(any())).thenReturn(true); + when(invokeResponse.payload()).thenReturn(SdkBytes.fromUtf8String("[{\"key\":\"value\"}]")); + + // Act + Collection> result = lambdaProcessor.doExecute(records); + + // Assert + assertEquals(5, result.size()); + verify(requestCodec, times(5)).writeEvent(any(), any()); + verify(currentBufferPerBatch, times(1)).flushToLambda(any()); + } + + @Test + public void testDoExecute_WithExceptionDuringProcessing() throws Exception { + // Arrange + setupTestObject(); + populatePrivateFields(); + + Event event = mock(Event.class); + Record record = new Record<>(event); + Collection> records = Collections.singletonList(record); + + when(currentBufferPerBatch.getRecords()).thenReturn((List>) records); + when(currentBufferPerBatch.getEventCount()).thenReturn(0).thenReturn(1); + when(currentBufferPerBatch.flushToLambda(any())).thenThrow(new RuntimeException("Test exception")); + + // Act + Collection> result = lambdaProcessor.doExecute(records); + + // Assert + assertEquals(1, result.size()); + verify(numberOfRecordsFailedCounter, times(1)).increment(1.0); + } + + @Test + public void testDoExecute_WithEmptyResponse() throws Exception { + // Arrange + setupTestObject(); + populatePrivateFields(); + + Event event = mock(Event.class); + Record record = new Record<>(event); + Collection> records = Collections.singletonList(record); + + when(currentBufferPerBatch.getRecords()).thenReturn((List>) records); + when(currentBufferPerBatch.getEventCount()).thenReturn(0).thenReturn(1); + when(currentBufferPerBatch.flushToLambda(any())).thenReturn(CompletableFuture.completedFuture(invokeResponse)); + when(lambdaCommonHandler.checkStatusCode(any())).thenReturn(true); + when(invokeResponse.payload()).thenReturn(SdkBytes.fromUtf8String("")); + + // Act + Collection> result = lambdaProcessor.doExecute(records); + + // Assert + assertEquals(0, result.size()); + verify(numberOfRecordsSuccessCounter, times(1)).increment(1.0); + } + + @Test + public void testDoExecute_WithNullResponse() throws Exception { + // Arrange + setupTestObject(); + populatePrivateFields(); + + Event event = mock(Event.class); + Record record = new Record<>(event); + Collection> records = Collections.singletonList(record); + + when(currentBufferPerBatch.getRecords()).thenReturn((List>) records); + when(currentBufferPerBatch.getEventCount()).thenReturn(0).thenReturn(1); + when(currentBufferPerBatch.flushToLambda(any())).thenReturn(CompletableFuture.completedFuture(invokeResponse)); + when(lambdaCommonHandler.checkStatusCode(any())).thenReturn(true); + when(invokeResponse.payload()).thenReturn(null); + + // Act + Collection> result = lambdaProcessor.doExecute(records); + + // Assert + assertEquals(0, result.size()); + verify(numberOfRecordsSuccessCounter, times(1)).increment(1.0); + } + @Test public void testConvertLambdaResponseToEvent_WithEqualEventCounts_SuccessfulProcessing() throws Exception { // Arrange diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkServiceTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkServiceTest.java index 06f05e4414..34cb5f4771 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkServiceTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkServiceTest.java @@ -148,7 +148,7 @@ public void setUp() { // Mock LambdaCommonHandler lambdaCommonHandler = mock(LambdaCommonHandler.class); - when(lambdaCommonHandler.createBuffer(any())).thenReturn(currentBufferPerBatch); + when(lambdaCommonHandler.createBuffer()).thenReturn(currentBufferPerBatch); doNothing().when(currentBufferPerBatch).reset(); lambdaSinkService = new LambdaSinkService( @@ -246,7 +246,8 @@ public void testOutput_ExceptionDuringProcessing() throws Exception { when(lambdaSinkConfig.getWhenCondition()).thenReturn(null); // Mock event handling to throw exception when writeEvent is called - when(currentBufferPerBatch.getEventCount()).thenReturn(0); + when(currentBufferPerBatch.getEventCount()).thenReturn(0).thenReturn(1); + when(lambdaCommonHandler.checkStatusCode(any())).thenReturn(true); doNothing().when(requestCodec).start(any(), eq(event), any()); doThrow(new IOException("Test IOException")).when(requestCodec).writeEvent(eq(event), any()); @@ -263,7 +264,7 @@ public void testOutput_ExceptionDuringProcessing() throws Exception { // Assert verify(requestCodec, times(1)).start(any(), eq(event), any()); verify(requestCodec, times(1)).writeEvent(eq(event), any()); - verify(numberOfRecordsFailedCounter, times(1)).increment(); + verify(numberOfRecordsFailedCounter, times(1)).increment(1); }