Skip to content

Commit

Permalink
Fix for kafka source issue opensearch-project#3247 (offset commit sto…
Browse files Browse the repository at this point in the history
…ps on deserialization error)

Signed-off-by: Hardeep Singh <[email protected]>
  • Loading branch information
hshardeesi committed Aug 25, 2023
1 parent 1bcf9f6 commit d41b5b0
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.opensearch.dataprepper.plugins.kafka.configuration.KafkaSourceConfig;
import org.opensearch.dataprepper.plugins.kafka.configuration.TopicConfig;
import org.opensearch.dataprepper.plugins.kafka.util.KafkaTopicMetrics;
import org.opensearch.dataprepper.plugins.kafka.util.LogRateLimiter;
import org.opensearch.dataprepper.plugins.kafka.util.MessageFormat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -84,6 +85,7 @@ public class KafkaSourceCustomConsumer implements Runnable, ConsumerRebalanceLis
private long metricsUpdatedTime;
private final AtomicInteger numberOfAcksPending;
private long numRecordsCommitted = 0;
private final LogRateLimiter errLogRateLimiter;

public KafkaSourceCustomConsumer(final KafkaConsumer consumer,
final AtomicBoolean shutdownInProgress,
Expand Down Expand Up @@ -114,6 +116,7 @@ public KafkaSourceCustomConsumer(final KafkaConsumer consumer,
this.bufferAccumulator = BufferAccumulator.create(buffer, DEFAULT_NUMBER_OF_RECORDS_TO_ACCUMULATE, bufferTimeout);
this.lastCommitTime = System.currentTimeMillis();
this.numberOfAcksPending = new AtomicInteger(0);
this.errLogRateLimiter = new LogRateLimiter(2, System.currentTimeMillis());
}

private long getCurrentTimeNanos() {
Expand Down Expand Up @@ -153,7 +156,8 @@ private AcknowledgementSet createAcknowledgementSet(Map<TopicPartition, CommitOf

public <T> void consumeRecords() throws Exception {
try {
ConsumerRecords<String, T> records = consumer.poll(topicConfig.getThreadWaitingTime().toMillis()/2);
ConsumerRecords<String, T> records =
consumer.poll(Duration.ofMillis(topicConfig.getThreadWaitingTime().toMillis()/2));
if (Objects.nonNull(records) && !records.isEmpty() && records.count() > 0) {
Map<TopicPartition, CommitOffsetRange> offsets = new HashMap<>();
AcknowledgementSet acknowledgementSet = null;
Expand All @@ -176,19 +180,42 @@ public <T> void consumeRecords() throws Exception {
topicMetrics.getNumberOfPollAuthErrors().increment();
Thread.sleep(10000);
} catch (RecordDeserializationException e) {
LOG.warn("Deserialization error - topic {} partition {} offset {}",
e.topicPartition().topic(), e.topicPartition().partition(), e.offset());

LOG.warn("Deserialization error - topic {} partition {} offset {}. Error message: {}",
e.topicPartition().topic(), e.topicPartition().partition(), e.offset(), e.getMessage());
if (e.getCause() instanceof AWSSchemaRegistryException) {
LOG.warn("AWSSchemaRegistryException: {}. Retrying after 30 seconds", e.getMessage());
LOG.warn("Retrying after 30 seconds");
Thread.sleep(30000);
} else {
LOG.warn("Seeking past the error record", e);
consumer.seek(e.topicPartition(), e.offset()+1);
LOG.warn("Seeking past the error record");
consumer.seek(e.topicPartition(), e.offset() + 1);

// Update failed record offset in commitTracker because we are not
// processing it and seeking past the error record.
// Note: partitionCommitTrackerMap may not have the partition if this is
// ths first set of records that hit deserialization error
if (acknowledgementsEnabled && partitionCommitTrackerMap.containsKey(e.topicPartition().partition())) {
addAcknowledgedOffsets(e.topicPartition(), Range.of(e.offset(), e.offset()));
}
}

topicMetrics.getNumberOfDeserializationErrors().increment();
}
}

private void addAcknowledgedOffsets(final TopicPartition topicPartition, final Range<Long> offsetRange) {
final int partitionId = topicPartition.partition();
final TopicPartitionCommitTracker commitTracker = partitionCommitTrackerMap.get(partitionId);

if (Objects.isNull(commitTracker) && errLogRateLimiter.isAllowed(System.currentTimeMillis())) {
LOG.error("Commit tracker not found for TopicPartition: {}", topicPartition);
}

final OffsetAndMetadata offsetAndMetadata =
partitionCommitTrackerMap.get(partitionId).addCompletedOffsets(offsetRange);
updateOffsetsToCommit(topicPartition, offsetAndMetadata);
}

private void resetOffsets() {
if (partitionsToReset.size() > 0) {
partitionsToReset.forEach(partition -> {
Expand All @@ -211,19 +238,11 @@ private void resetOffsets() {
}

void processAcknowledgedOffsets() {

acknowledgedOffsets.forEach(offsets -> {
offsets.forEach((partition, offsetRange) -> {
if (getPartitionEpoch(partition) == offsetRange.getEpoch()) {
try {
int partitionId = partition.partition();
if (partitionCommitTrackerMap.containsKey(partitionId)) {
final OffsetAndMetadata offsetAndMetadata =
partitionCommitTrackerMap.get(partitionId).addCompletedOffsets(offsetRange.getOffsets());
updateOffsetsToCommit(partition, offsetAndMetadata);
} else {
LOG.error("Commit tracker not found for topic: {} partition: {}", partition.topic(), partitionId);
}
addAcknowledgedOffsets(partition, offsetRange.getOffsets());
} catch (Exception e) {
LOG.error("Failed committed offsets upon positive acknowledgement {}", partition, e);
}
Expand All @@ -236,9 +255,8 @@ void processAcknowledgedOffsets() {
private void updateCommitCountMetric(final TopicPartition topicPartition, final OffsetAndMetadata offsetAndMetadata) {
if (acknowledgementsEnabled) {
final TopicPartitionCommitTracker commitTracker = partitionCommitTrackerMap.get(topicPartition.partition());
if (Objects.isNull(commitTracker)) {
LOG.error("Commit tracker not found for topic: {} partition: {}",
topicPartition.topic(), topicPartition.partition());
if (Objects.isNull(commitTracker) && errLogRateLimiter.isAllowed(System.currentTimeMillis())) {
LOG.error("Commit tracker not found for TopicPartition: {}", topicPartition);
return;
}
topicMetrics.getNumberOfRecordsCommitted().increment(commitTracker.getCommittedRecordCount());
Expand Down Expand Up @@ -374,7 +392,9 @@ private <T> void iterateRecordPartitions(ConsumerRecords<String, T> records, fin
for (TopicPartition topicPartition : records.partitions()) {
final long partitionEpoch = getPartitionEpoch(topicPartition);
if (acknowledgementsEnabled && partitionEpoch == 0) {
//ToDo: Add metric
if (errLogRateLimiter.isAllowed(System.currentTimeMillis())) {
LOG.error("Lost ownership of partition {}", topicPartition);
}
continue;
}

Expand Down Expand Up @@ -480,4 +500,4 @@ final String getTopicPartitionOffset(final Map<TopicPartition, Long> offsetMap,
final Long offset = offsetMap.get(topicPartition);
return Objects.isNull(offset) ? "-" : offset.toString();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package org.opensearch.dataprepper.plugins.kafka.util;

// Poor-man implementation of rate-limiter for logging error messages.
// Todo: Use token-bucket as a generic rate-limiter.
public class LogRateLimiter {
public static int MILLIS_PER_SECOND = 1000;
public static int MAX_LOGS_PER_SECOND = 1000;
private int tokens;
private long lastMs;
private long replenishInterval;

public LogRateLimiter(final int ratePerSecond, final long currentMs) {
if (ratePerSecond < 0 || ratePerSecond > MAX_LOGS_PER_SECOND) {
throw new IllegalArgumentException(
String.format("Invalid arguments. ratePerSecond should be >0 and less than %s", MAX_LOGS_PER_SECOND)
);
}
replenishInterval = MILLIS_PER_SECOND / ratePerSecond;
lastMs = currentMs;
tokens = 1;
}

public boolean isAllowed(final long currentMs) {
if ((currentMs- lastMs) >= replenishInterval) {
tokens = 1;
lastMs = currentMs;
}

if (tokens == 0) {
return false;
}

tokens--;
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
import static org.hamcrest.Matchers.hasEntry;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -138,8 +137,8 @@ private BlockingBuffer<Record<Event>> getBuffer() {
@Test
public void testPlainTextConsumeRecords() throws InterruptedException {
String topic = topicConfig.getName();
consumerRecords = createPlainTextRecords(topic);
when(kafkaConsumer.poll(anyLong())).thenReturn(consumerRecords);
consumerRecords = createPlainTextRecords(topic, 0L);
when(kafkaConsumer.poll(any(Duration.class))).thenReturn(consumerRecords);
consumer = createObjectUnderTest("plaintext", false);

try {
Expand Down Expand Up @@ -176,8 +175,8 @@ public void testPlainTextConsumeRecords() throws InterruptedException {
@Test
public void testPlainTextConsumeRecordsWithAcknowledgements() throws InterruptedException {
String topic = topicConfig.getName();
consumerRecords = createPlainTextRecords(topic);
when(kafkaConsumer.poll(anyLong())).thenReturn(consumerRecords);
consumerRecords = createPlainTextRecords(topic, 0L);
when(kafkaConsumer.poll(any(Duration.class))).thenReturn(consumerRecords);
consumer = createObjectUnderTest("plaintext", true);

try {
Expand Down Expand Up @@ -223,8 +222,8 @@ public void testPlainTextConsumeRecordsWithAcknowledgements() throws Interrupted
@Test
public void testPlainTextConsumeRecordsWithNegativeAcknowledgements() throws InterruptedException {
String topic = topicConfig.getName();
consumerRecords = createPlainTextRecords(topic);
when(kafkaConsumer.poll(anyLong())).thenReturn(consumerRecords);
consumerRecords = createPlainTextRecords(topic, 0L);
when(kafkaConsumer.poll(any(Duration.class))).thenReturn(consumerRecords);
consumer = createObjectUnderTest("plaintext", true);

try {
Expand Down Expand Up @@ -266,7 +265,7 @@ public void testJsonConsumeRecords() throws InterruptedException, Exception {
when(topicConfig.getSerdeFormat()).thenReturn(MessageFormat.JSON);
when(topicConfig.getKafkaKeyMode()).thenReturn(KafkaKeyMode.INCLUDE_AS_FIELD);
consumerRecords = createJsonRecords(topic);
when(kafkaConsumer.poll(anyLong())).thenReturn(consumerRecords);
when(kafkaConsumer.poll(any(Duration.class))).thenReturn(consumerRecords);
consumer = createObjectUnderTest("json", false);

consumer.onPartitionsAssigned(List.of(new TopicPartition(topic, testJsonPartition)));
Expand Down Expand Up @@ -296,10 +295,68 @@ public void testJsonConsumeRecords() throws InterruptedException, Exception {
}
}

private ConsumerRecords createPlainTextRecords(String topic) {
@Test
public void testJsonDeserializationErrorWithAcknowledgements() throws Exception {
String topic = topicConfig.getName();
when(topicConfig.getSerdeFormat()).thenReturn(MessageFormat.JSON);
when(topicConfig.getKafkaKeyMode()).thenReturn(KafkaKeyMode.INCLUDE_AS_FIELD);
consumer = createObjectUnderTest("json", true);
consumer.onPartitionsAssigned(List.of(new TopicPartition(topic, testJsonPartition)));

consumerRecords = createPlainTextRecords(topic, 98L);
when(kafkaConsumer.poll(any(Duration.class))).thenReturn(consumerRecords);
consumer.consumeRecords();

Map.Entry<Collection<Record<Event>>, CheckpointState> bufferRecords = buffer.read(1000);
ArrayList<Record<Event>> bufferedRecords = new ArrayList<>(bufferRecords.getKey());
Assertions.assertEquals(0, bufferedRecords.size());
Map<TopicPartition, OffsetAndMetadata> offsetsToCommit = consumer.getOffsetsToCommit();
Assertions.assertEquals(offsetsToCommit.size(), 0);

consumerRecords = createJsonRecords(topic);
when(kafkaConsumer.poll(any(Duration.class))).thenReturn(consumerRecords);
consumer.consumeRecords();

bufferRecords = buffer.read(1000);
bufferedRecords = new ArrayList<>(bufferRecords.getKey());
Assertions.assertEquals(2, bufferedRecords.size());
offsetsToCommit = consumer.getOffsetsToCommit();
Assertions.assertEquals(offsetsToCommit.size(), 0);

for (Record<Event> record: bufferedRecords) {
Event event = record.getData();
Map<String, Object> eventMap = event.toMap();
String kafkaKey = event.get("kafka_key", String.class);
assertTrue(kafkaKey.equals(testKey1) || kafkaKey.equals(testKey2));
if (kafkaKey.equals(testKey1)) {
testMap1.forEach((k, v) -> assertThat(eventMap, hasEntry(k,v)));
}
if (kafkaKey.equals(testKey2)) {
testMap2.forEach((k, v) -> assertThat(eventMap, hasEntry(k,v)));
}
event.getEventHandle().release(true);
}
// Wait for acknowledgement callback function to run
try {
Thread.sleep(10000);
} catch (Exception e){}

consumer.processAcknowledgedOffsets();
offsetsToCommit = consumer.getOffsetsToCommit();
Assertions.assertEquals(offsetsToCommit.size(), 1);
offsetsToCommit.forEach((topicPartition, offsetAndMetadata) -> {
Assertions.assertEquals(topicPartition.partition(), testJsonPartition);
Assertions.assertEquals(topicPartition.topic(), topic);
Assertions.assertEquals(offsetAndMetadata.offset(), 102L);
});

}


private ConsumerRecords createPlainTextRecords(String topic, final long startOffset) {
Map<TopicPartition, List<ConsumerRecord>> records = new HashMap<>();
ConsumerRecord<String, String> record1 = new ConsumerRecord<>(topic, testPartition, 0L, testKey1, testValue1);
ConsumerRecord<String, String> record2 = new ConsumerRecord<>(topic, testPartition, 1L, testKey2, testValue2);
ConsumerRecord<String, String> record1 = new ConsumerRecord<>(topic, testPartition, startOffset, testKey1, testValue1);
ConsumerRecord<String, String> record2 = new ConsumerRecord<>(topic, testPartition, startOffset+1, testKey2, testValue2);
records.put(new TopicPartition(topic, testPartition), Arrays.asList(record1, record2));
return new ConsumerRecords(records);
}
Expand All @@ -312,6 +369,5 @@ private ConsumerRecords createJsonRecords(String topic) throws Exception {
records.put(new TopicPartition(topic, testJsonPartition), Arrays.asList(record1, record2));
return new ConsumerRecords(records);
}

}

Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package org.opensearch.dataprepper.plugins.kafka.util;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.junit.jupiter.MockitoExtension;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;

@ExtendWith(MockitoExtension.class)
public class LogRateLimiterTest {

@Test
public void testRateLimiter() {
long currentMs = System.currentTimeMillis();
LogRateLimiter objUnderTest = new LogRateLimiter(10, currentMs);
assertThat(objUnderTest.isAllowed(currentMs), equalTo(true));
assertThat(objUnderTest.isAllowed(currentMs), equalTo(false));
currentMs += 50;
assertThat(objUnderTest.isAllowed(currentMs), equalTo(false));
currentMs += 50;
assertThat(objUnderTest.isAllowed(currentMs), equalTo(true));
assertThat(objUnderTest.isAllowed(currentMs), equalTo(false));
currentMs += 876;
assertThat(objUnderTest.isAllowed(currentMs), equalTo(true));
assertThat(objUnderTest.isAllowed(currentMs), equalTo(false));

currentMs = System.currentTimeMillis();
objUnderTest = new LogRateLimiter(2, currentMs);
assertThat(objUnderTest.isAllowed(currentMs), equalTo(true));
assertThat(objUnderTest.isAllowed(currentMs), equalTo(false));
currentMs += 100;
assertThat(objUnderTest.isAllowed(currentMs), equalTo(false));
currentMs += 200;
assertThat(objUnderTest.isAllowed(currentMs), equalTo(false));
currentMs += 500;
assertThat(objUnderTest.isAllowed(currentMs), equalTo(true));

currentMs = System.nanoTime();
objUnderTest = new LogRateLimiter(1000, currentMs);
assertThat(objUnderTest.isAllowed(currentMs), equalTo(true));
assertThat(objUnderTest.isAllowed(currentMs), equalTo(false));
currentMs += 1;
assertThat(objUnderTest.isAllowed(currentMs), equalTo(true));
assertThat(objUnderTest.isAllowed(currentMs), equalTo(false));
currentMs += 2;
assertThat(objUnderTest.isAllowed(currentMs), equalTo(true));
}

@Test
public void testRateLimiterInvalidArgs() {
assertThrows(
IllegalArgumentException.class,
() -> new LogRateLimiter(1345, System.currentTimeMillis())
);
}
}

0 comments on commit d41b5b0

Please sign in to comment.