Skip to content

Commit

Permalink
Address thread safety for lambda processor and additional fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Srikanth Govindarajan <[email protected]>
  • Loading branch information
srikanthjg committed Nov 9, 2024
1 parent 60990cd commit 12aa4aa
Show file tree
Hide file tree
Showing 9 changed files with 296 additions and 119 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,10 @@ public LambdaCommonHandler(
this.bufferFactory = bufferFactory;
}

public Buffer createBuffer(Buffer currentBuffer) {
public Buffer createBuffer() {
try {
LOG.debug("Resetting buffer");
currentBuffer = bufferFactory.getBuffer(lambdaAsyncClient, functionName, invocationType);
return currentBuffer;
return bufferFactory.getBuffer(lambdaAsyncClient, functionName, invocationType);
} catch (IOException e) {
throw new RuntimeException("Failed to reset buffer", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,8 @@ public interface Buffer {

public Long getPayloadRequestSize();

public Long getPayloadResponseSize();

public Duration stopLatencyWatch();


void reset();

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ public class InMemoryBuffer implements Buffer {
private StopWatch lambdaLatencyWatch;
private long payloadRequestSize;
private long payloadResponseSize;
private boolean isCodecStarted;
private final List<Record<Event>> records;


Expand All @@ -53,7 +52,6 @@ public InMemoryBuffer(LambdaAsyncClient lambdaAsyncClient, String functionName,
bufferWatch.start();
lambdaLatencyWatch = new StopWatch();
eventCount = 0;
isCodecStarted = false;
payloadRequestSize = 0;
payloadResponseSize = 0;
}
Expand Down Expand Up @@ -86,7 +84,6 @@ public void reset() {
eventCount = 0;
bufferWatch.reset();
lambdaLatencyWatch.reset();
isCodecStarted = false;
payloadRequestSize = 0;
payloadResponseSize = 0;
}
Expand Down Expand Up @@ -160,13 +157,10 @@ public Long getPayloadRequestSize() {
return payloadRequestSize;
}

public Long getPayloadResponseSize() {
return payloadResponseSize;
}

public StopWatch getBufferWatch() {return bufferWatch;}

public StopWatch getLambdaLatencyWatch(){return lambdaLatencyWatch;}


}

Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.record.Record;
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;

public class AggregateResponseEventHandlingStrategy implements ResponseEventHandlingStrategy {

private static final Logger LOG = LoggerFactory.getLogger(AggregateResponseEventHandlingStrategy.class);

@Override
public void handleEvents(List<Event> parsedEvents, List<Record<Event>> originalRecords, List<Record<Event>> resultRecords, Buffer flushedBuffer) {
public void handleEvents(List<Event> parsedEvents, List<Record<Event>> originalRecords,
List<Record<Event>> resultRecords, Buffer flushedBuffer) {

Event originalEvent = originalRecords.get(0).getData();
DefaultEventHandle eventHandle = (DefaultEventHandle) originalEvent.getEventHandle();
AcknowledgementSet originalAcknowledgementSet = eventHandle.getAcknowledgementSet();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantLock;

@DataPrepperPlugin(name = "aws_lambda", pluginType = Processor.class, pluginConfigurationType = LambdaProcessorConfig.class)
public class LambdaProcessor extends AbstractProcessor<Record<Event>, Record<Event>> {
Expand All @@ -76,6 +77,7 @@ public class LambdaProcessor extends AbstractProcessor<Record<Event>, Record<Eve
private final LambdaAsyncClient lambdaAsyncClient;
private final AtomicLong requestPayloadMetric;
private final AtomicLong responsePayloadMetric;
private final ReentrantLock reentrantLock;
OutputCodecContext codecContext = null;
LambdaCommonHandler lambdaCommonHandler;
InputCodec responseCodec = null;
Expand All @@ -97,6 +99,8 @@ public LambdaProcessor(final PluginFactory pluginFactory, final PluginMetrics pl
this.lambdaLatencyMetric = pluginMetrics.timer(LAMBDA_LATENCY_METRIC);
this.requestPayloadMetric = pluginMetrics.gauge(REQUEST_PAYLOAD_SIZE, new AtomicLong());
this.responsePayloadMetric = pluginMetrics.gauge(RESPONSE_PAYLOAD_SIZE, new AtomicLong());

reentrantLock = new ReentrantLock();

functionName = lambdaProcessorConfig.getFunctionName();
whenCondition = lambdaProcessorConfig.getWhenCondition();
Expand All @@ -122,7 +126,7 @@ public LambdaProcessor(final PluginFactory pluginFactory, final PluginMetrics pl
maxBytes = batchOptions.getThresholdOptions().getMaximumSize();
maxCollectionDuration = batchOptions.getThresholdOptions().getEventCollectTimeOut();
invocationType = lambdaProcessorConfig.getInvocationType().getAwsLambdaValue();
futureList = new ArrayList<>();
futureList = Collections.synchronizedList(new ArrayList<>());

lambdaAsyncClient = LambdaClientFactory.createAsyncLambdaClient(lambdaProcessorConfig.getAwsAuthenticationOptions(), lambdaProcessorConfig.getMaxConnectionRetries(), awsCredentialsSupplier, lambdaProcessorConfig.getConnectionTimeout());

Expand All @@ -137,7 +141,7 @@ public LambdaProcessor(final PluginFactory pluginFactory, final PluginMetrics pl

// Initialize LambdaCommonHandler
lambdaCommonHandler = new LambdaCommonHandler(LOG, lambdaAsyncClient, functionName, invocationType, bufferFactory);
currentBufferPerBatch = lambdaCommonHandler.createBuffer(currentBufferPerBatch);
currentBufferPerBatch = lambdaCommonHandler.createBuffer();

LOG.info("LambdaFunctionName:{} , responseEventsMatch:{}, invocationType:{}", functionName,
lambdaProcessorConfig.getResponseEventsMatch(), invocationType);
Expand All @@ -150,42 +154,47 @@ public Collection<Record<Event>> doExecute(Collection<Record<Event>> records) {
}

//lambda mutates event
List<Record<Event>> resultRecords = new ArrayList<>();
List<Record<Event>> resultRecords = Collections.synchronizedList(new ArrayList<>());

for (Record<Event> record : records) {
final Event event = record.getData();

// If the condition is false, add the event to resultRecords as-is
if (whenCondition != null && !expressionEvaluator.evaluateConditional(whenCondition, event)) {
resultRecords.add(record);
continue;
}
reentrantLock.lock();
try {
for (Record<Event> record : records) {
final Event event = record.getData();

try {
if (currentBufferPerBatch.getEventCount() == 0) {
requestCodec.start(currentBufferPerBatch.getOutputStream(), event, codecContext);
// If the condition is false, add the event to resultRecords as-is
if (whenCondition != null && !expressionEvaluator.evaluateConditional(whenCondition, event)) {
resultRecords.add(record);
continue;
}
requestCodec.writeEvent(event, currentBufferPerBatch.getOutputStream());
currentBufferPerBatch.addRecord(record);

flushToLambdaIfNeeded(resultRecords, false);
} catch (Exception e) {
LOG.error(NOISY, "Exception while processing event {}", event, e);
handleFailure(e, currentBufferPerBatch, resultRecords);
currentBufferPerBatch.reset();
try {
if (currentBufferPerBatch.getEventCount() == 0) {
requestCodec.start(currentBufferPerBatch.getOutputStream(), event, codecContext);
}
requestCodec.writeEvent(event, currentBufferPerBatch.getOutputStream());
currentBufferPerBatch.addRecord(record);

flushToLambdaIfNeeded(resultRecords, false);
} catch (Exception e) {
LOG.error(NOISY, "Exception while processing event {}", event, e);
handleFailure(e, currentBufferPerBatch, resultRecords);
currentBufferPerBatch.reset();
}
}
}

// Flush any remaining events in the buffer after processing all records
if (currentBufferPerBatch.getEventCount() > 0) {
LOG.info("Force Flushing the remaining {} events in the buffer", currentBufferPerBatch.getEventCount());
try {
flushToLambdaIfNeeded(resultRecords, true); // Force flush remaining events
currentBufferPerBatch.reset();
} catch (Exception e) {
LOG.error("Exception while flushing remaining events", e);
handleFailure(e, currentBufferPerBatch, resultRecords);
// Flush any remaining events in the buffer after processing all records
if (currentBufferPerBatch.getEventCount() > 0) {
LOG.info("Force Flushing the remaining {} events in the buffer", currentBufferPerBatch.getEventCount());
try {
flushToLambdaIfNeeded(resultRecords, true); // Force flush remaining events
currentBufferPerBatch.reset();
} catch (Exception e) {
LOG.error("Exception while flushing remaining events", e);
handleFailure(e, currentBufferPerBatch, resultRecords);
}
}
} finally {
reentrantLock.unlock();
}

lambdaCommonHandler.waitForFutures(futureList);
Expand Down Expand Up @@ -213,7 +222,10 @@ void flushToLambdaIfNeeded(List<Record<Event>> resultRecords, boolean forceFlush
handleLambdaResponse(resultRecords, flushedBuffer, eventCount, response);
}).exceptionally(throwable -> {
//Failure handler
LOG.error(NOISY, "Exception occurred while invoking Lambda. Function: {}, event in batch:{} | Exception: ", functionName, currentBufferPerBatch.getRecords().get(0), throwable);
List<Record<Event>> bufferRecords = flushedBuffer.getRecords();
Record<Event> eventRecord = bufferRecords.isEmpty() ? null : bufferRecords.get(0);
LOG.error(NOISY, "Exception occurred while invoking Lambda. Function: {} , Event: {}",
functionName, eventRecord == null? "null":eventRecord.getData(), throwable);
requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize());
responsePayloadMetric.set(0);
Duration latency = flushedBuffer.stopLatencyWatch();
Expand All @@ -225,11 +237,11 @@ void flushToLambdaIfNeeded(List<Record<Event>> resultRecords, boolean forceFlush
futureList.add(processingFuture);

// Create a new buffer for the next batch
currentBufferPerBatch = lambdaCommonHandler.createBuffer(currentBufferPerBatch);
currentBufferPerBatch = lambdaCommonHandler.createBuffer();
} catch (IOException e) {
LOG.error(NOISY, "Exception while flushing to lambda", e);
handleFailure(e, currentBufferPerBatch, resultRecords);
currentBufferPerBatch = lambdaCommonHandler.createBuffer(currentBufferPerBatch);
currentBufferPerBatch = lambdaCommonHandler.createBuffer();
}
}
}
Expand Down Expand Up @@ -260,38 +272,37 @@ private void handleLambdaResponse(List<Record<Event>> resultRecords, Buffer flus
*/
void convertLambdaResponseToEvent(final List<Record<Event>> resultRecords, final InvokeResponse lambdaResponse, Buffer flushedBuffer) {
try {
List<Event> parsedEvents = new ArrayList<>();
List<Record<Event>> originalRecords = flushedBuffer.getRecords();

SdkBytes payload = lambdaResponse.payload();

// Handle null or empty payload
if (payload == null || payload.asByteArray() == null || payload.asByteArray().length == 0) {
LOG.error(NOISY, "Lambda response payload is null or empty");
throw new RuntimeException("Lambda response payload is null or empty");
}

// Record payload sizes
requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize());
responsePayloadMetric.set(payload.asByteArray().length);

LOG.debug("Response payload:{}", payload.asUtf8String());
InputStream inputStream = new ByteArrayInputStream(payload.asByteArray());
List<Event> parsedEvents = new ArrayList<>();
LOG.warn(NOISY, "Lambda response payload is null or empty, dropping the original events");
// Set metrics
requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize());
responsePayloadMetric.set(0);
} else {
// Set metrics
requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize());
responsePayloadMetric.set(payload.asByteArray().length);

LOG.debug("Response payload:{}", payload.asUtf8String());
InputStream inputStream = new ByteArrayInputStream(payload.asByteArray());
//Convert to response codec
try {
responseCodec.parse(inputStream, record -> {
Event event = record.getData();
parsedEvents.add(event);
});
} catch (IOException ex) {
throw new RuntimeException(ex);
}

//Convert to response codec
try {
responseCodec.parse(inputStream, record -> {
Event event = record.getData();
parsedEvents.add(event);
});
} catch (IOException ex) {
throw new RuntimeException(ex);
LOG.debug("Parsed Event Size:{}, FlushedBuffer eventCount:{}, FlushedBuffer size:{}", parsedEvents.size(), flushedBuffer.getEventCount(), flushedBuffer.getSize());
responseStrategy.handleEvents(parsedEvents, originalRecords, resultRecords, flushedBuffer);
}

List<Record<Event>> originalRecords = flushedBuffer.getRecords();

LOG.debug("Parsed Event Size:{}, FlushedBuffer eventCount:{}, FlushedBuffer size:{}", parsedEvents.size(), flushedBuffer.getEventCount(), flushedBuffer.getSize());

responseStrategy.handleEvents(parsedEvents, originalRecords, resultRecords, flushedBuffer);

} catch (Exception e) {
LOG.error(NOISY, "Error converting Lambda response to Event");
Expand All @@ -307,12 +318,19 @@ void convertLambdaResponseToEvent(final List<Record<Event>> resultRecords, final
* Batch fails and tag each event in that Batch.
*/
void handleFailure(Throwable e, Buffer flushedBuffer, List<Record<Event>> resultRecords) {
if (flushedBuffer.getEventCount() > 0) {
numberOfRecordsFailedCounter.increment(flushedBuffer.getEventCount());
} else {
numberOfRecordsFailedCounter.increment();
try {
if (flushedBuffer.getEventCount() > 0) {
numberOfRecordsFailedCounter.increment(flushedBuffer.getEventCount());
}

addFailureTags(flushedBuffer, resultRecords);
LOG.error(NOISY, "Failed to process batch due to error: ", e);
} catch(Exception ex){
LOG.error(NOISY, "Exception in handleFailure while processing failure for buffer: ", ex);
}
}

private void addFailureTags(Buffer flushedBuffer, List<Record<Event>> resultRecords) {
// Add failure tags to each event in the batch
for (Record<Event> record : flushedBuffer.getRecords()) {
Event event = record.getData();
Expand All @@ -324,7 +342,6 @@ void handleFailure(Throwable e, Buffer flushedBuffer, List<Record<Event>> result
}
resultRecords.add(record);
}
LOG.error(NOISY, "Failed to process batch due to error: ", e);
}


Expand All @@ -335,12 +352,19 @@ public void prepareForShutdown() {

@Override
public boolean isReadyForShutdown() {
return false;
//check if there are no pending futures
synchronized (futureList) {
return futureList.isEmpty() && currentBufferPerBatch.getEventCount() == 0;
}
}

@Override
public void shutdown() {

// Cancel any pending futures
for (CompletableFuture<Void> future : futureList) {
future.cancel(true);
}
futureList.clear();
}

}
Loading

0 comments on commit 12aa4aa

Please sign in to comment.