Skip to content

Commit

Permalink
Fix consumer synchronization. Fix consumer to use user-specified grou…
Browse files Browse the repository at this point in the history
…pId (opensearch-project#3100)

* Fix consumer synchronization. Fix consumer to use user-specified groupId

Signed-off-by: Krishna Kondaka <[email protected]>

* Fix check style error

Signed-off-by: Krishna Kondaka <[email protected]>

* Fixed to retry if consume records encounters an exception

Signed-off-by: Krishna Kondaka <[email protected]>

---------

Signed-off-by: Krishna Kondaka <[email protected]>
Co-authored-by: Krishna Kondaka <[email protected]>
  • Loading branch information
kkondaka and Krishna Kondaka authored Aug 2, 2023
1 parent a890deb commit d3a9099
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ public class KafkaSourceJsonTypeIT {
private String bootstrapServers;
private String testKey;
private String testTopic;
private String testGroup;

public KafkaSource createObjectUnderTest() {
return new KafkaSource(sourceConfig, pluginMetrics, acknowledgementSetManager, pipelineDescription);
Expand Down Expand Up @@ -112,7 +113,7 @@ public void setup() {
} catch (Exception e){}

testKey = RandomStringUtils.randomAlphabetic(5);
final String testGroup = "TestGroup_"+RandomStringUtils.randomAlphabetic(6);
testGroup = "TestGroup_"+RandomStringUtils.randomAlphabetic(6);
testTopic = "TestJsonTopic_"+RandomStringUtils.randomAlphabetic(5);
jsonTopic = mock(TopicConfig.class);
when(jsonTopic.getName()).thenReturn(testTopic);
Expand Down Expand Up @@ -337,6 +338,7 @@ public void TestJsonRecordsWithKafkaKeyModeAsMetadata() throws Exception {
Thread.sleep(1000);
}
kafkaSource.start(buffer);
assertThat(kafkaSource.getConsumer().groupMetadata().groupId(), equalTo(testGroup));
produceJsonRecords(bootstrapServers, topicName, numRecords);
int numRetries = 0;
while (numRetries++ < 10 && (receivedRecords.size() != numRecords)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,24 @@
* pipelines.yaml
*/
public class TopicConfig {
private static final String AUTO_COMMIT = "false";
private static final Duration DEFAULT_COMMIT_INTERVAL = Duration.ofSeconds(5);
private static final Duration DEFAULT_SESSION_TIMEOUT = Duration.ofSeconds(45);
private static final int MAX_RETRY_ATTEMPT = Integer.MAX_VALUE;
static final boolean DEFAULT_AUTO_COMMIT = false;
static final Duration DEFAULT_COMMIT_INTERVAL = Duration.ofSeconds(5);
static final Duration DEFAULT_SESSION_TIMEOUT = Duration.ofSeconds(45);
static final int DEFAULT_MAX_RETRY_ATTEMPT = Integer.MAX_VALUE;
static final String DEFAULT_AUTO_OFFSET_RESET = "latest";
static final Duration THREAD_WAITING_TIME = Duration.ofSeconds(5);
private static final Duration MAX_RECORD_FETCH_TIME = Duration.ofSeconds(4);
private static final Duration BUFFER_DEFAULT_TIMEOUT = Duration.ofSeconds(5);
private static final Duration MAX_RETRY_DELAY = Duration.ofSeconds(1);
private static final Integer FETCH_MAX_BYTES = 52428800;
private static final Integer FETCH_MAX_WAIT = 500;
private static final Integer FETCH_MIN_BYTES = 1;
private static final Duration RETRY_BACKOFF = Duration.ofSeconds(100);
private static final Duration MAX_POLL_INTERVAL = Duration.ofSeconds(300000);
private static final Integer CONSUMER_MAX_POLL_RECORDS = 500;
static final Duration DEFAULT_THREAD_WAITING_TIME = Duration.ofSeconds(5);
static final Duration DEFAULT_MAX_RECORD_FETCH_TIME = Duration.ofSeconds(4);
static final Duration DEFAULT_BUFFER_TIMEOUT = Duration.ofSeconds(5);
static final Duration DEFAULT_MAX_RETRY_DELAY = Duration.ofSeconds(1);
static final Integer DEFAULT_FETCH_MAX_BYTES = 52428800;
static final Integer DEFAULT_FETCH_MAX_WAIT = 500;
static final Integer DEFAULT_FETCH_MIN_BYTES = 1;
static final Duration DEFAULT_RETRY_BACKOFF = Duration.ofSeconds(10);
static final Duration DEFAULT_RECONNECT_BACKOFF = Duration.ofSeconds(10);
static final Duration DEFAULT_MAX_POLL_INTERVAL = Duration.ofSeconds(300000);
static final Integer DEFAULT_CONSUMER_MAX_POLL_RECORDS = 500;
static final Integer DEFAULT_NUM_OF_WORKERS = 2;
static final Duration HEART_BEAT_INTERVAL_DURATION = Duration.ofSeconds(5);
static final Duration DEFAULT_HEART_BEAT_INTERVAL_DURATION = Duration.ofSeconds(5);

@JsonProperty("name")
@NotNull
Expand All @@ -54,18 +55,18 @@ public class TopicConfig {
@JsonProperty("max_retry_attempts")
@Valid
@Size(min = 1, max = Integer.MAX_VALUE, message = " Max retry attempts should lies between 1 and Integer.MAX_VALUE")
private Integer maxRetryAttempts = MAX_RETRY_ATTEMPT;
private Integer maxRetryAttempts = DEFAULT_MAX_RETRY_ATTEMPT;

@JsonProperty("max_retry_delay")
@Valid
@Size(min = 1)
private Duration maxRetryDelay = MAX_RETRY_DELAY;
private Duration maxRetryDelay = DEFAULT_MAX_RETRY_DELAY;

@JsonProperty("serde_format")
private MessageFormat serdeFormat= MessageFormat.PLAINTEXT;

@JsonProperty("auto_commit")
private Boolean autoCommit = false;
private Boolean autoCommit = DEFAULT_AUTO_COMMIT;

@JsonProperty("commit_interval")
@Valid
Expand All @@ -86,47 +87,50 @@ public class TopicConfig {
private String groupName;

@JsonProperty("thread_waiting_time")
private Duration threadWaitingTime = THREAD_WAITING_TIME;
private Duration threadWaitingTime = DEFAULT_THREAD_WAITING_TIME;

@JsonProperty("max_record_fetch_time")
private Duration maxRecordFetchTime = MAX_RECORD_FETCH_TIME;
private Duration maxRecordFetchTime = DEFAULT_MAX_RECORD_FETCH_TIME;

@JsonProperty("buffer_default_timeout")
@Valid
@Size(min = 1)
private Duration bufferDefaultTimeout = BUFFER_DEFAULT_TIMEOUT;
private Duration bufferDefaultTimeout = DEFAULT_BUFFER_TIMEOUT;

@JsonProperty("fetch_max_bytes")
@Valid
@Size(min = 1, max = 52428800)
private Integer fetchMaxBytes = FETCH_MAX_BYTES;
private Integer fetchMaxBytes = DEFAULT_FETCH_MAX_BYTES;

@JsonProperty("fetch_max_wait")
@Valid
@Size(min = 1)
private Integer fetchMaxWait = FETCH_MAX_WAIT;
private Integer fetchMaxWait = DEFAULT_FETCH_MAX_WAIT;

@JsonProperty("fetch_min_bytes")
@Size(min = 1)
@Valid
private Integer fetchMinBytes = FETCH_MIN_BYTES;
private Integer fetchMinBytes = DEFAULT_FETCH_MIN_BYTES;

@JsonProperty("key_mode")
private KafkaKeyMode kafkaKeyMode = KafkaKeyMode.INCLUDE_AS_FIELD;

@JsonProperty("retry_backoff")
private Duration retryBackoff = RETRY_BACKOFF;
private Duration retryBackoff = DEFAULT_RETRY_BACKOFF;

@JsonProperty("reconnect_backoff")
private Duration reconnectBackoff = DEFAULT_RECONNECT_BACKOFF;

@JsonProperty("max_poll_interval")
private Duration maxPollInterval = MAX_POLL_INTERVAL;
private Duration maxPollInterval = DEFAULT_MAX_POLL_INTERVAL;

@JsonProperty("consumer_max_poll_records")
private Integer consumerMaxPollRecords = CONSUMER_MAX_POLL_RECORDS;
private Integer consumerMaxPollRecords = DEFAULT_CONSUMER_MAX_POLL_RECORDS;

@JsonProperty("heart_beat_interval")
@Valid
@Size(min = 1)
private Duration heartBeatInterval= HEART_BEAT_INTERVAL_DURATION;
private Duration heartBeatInterval= DEFAULT_HEART_BEAT_INTERVAL_DURATION;

public String getGroupId() {
return groupId;
Expand Down Expand Up @@ -220,6 +224,10 @@ public Duration getRetryBackoff() {
return retryBackoff;
}

public Duration getReconnectBackoff() {
return reconnectBackoff;
}

public void setRetryBackoff(Duration retryBackoff) {
this.retryBackoff = retryBackoff;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.apache.kafka.clients.consumer.OffsetAndMetadata;
import org.apache.kafka.clients.consumer.CommitFailedException;
import org.apache.kafka.common.errors.AuthenticationException;
import org.apache.kafka.common.errors.RecordDeserializationException;
import org.apache.kafka.common.TopicPartition;
import org.apache.avro.generic.GenericRecord;
import org.opensearch.dataprepper.model.log.JacksonLog;
Expand All @@ -40,6 +41,8 @@
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.HashSet;
import java.util.concurrent.atomic.AtomicBoolean;
import org.opensearch.dataprepper.plugins.kafka.util.MessageFormat;
import com.amazonaws.services.schemaregistry.serializers.json.JsonDataWithSchema;
Expand Down Expand Up @@ -68,6 +71,7 @@ public class KafkaSourceCustomConsumer implements Runnable, ConsumerRebalanceLis
private static final ObjectMapper objectMapper = new ObjectMapper();
private final JsonFactory jsonFactory = new JsonFactory();
private Map<TopicPartition, OffsetAndMetadata> offsetsToCommit;
private Set<TopicPartition> partitionsToReset;
private final AcknowledgementSetManager acknowledgementSetManager;
private final Map<Integer, TopicPartitionCommitTracker> partitionCommitTrackerMap;
private final Counter positiveAcknowledgementSetCounter;
Expand Down Expand Up @@ -95,6 +99,7 @@ public KafkaSourceCustomConsumer(final KafkaConsumer consumer,
this.acknowledgementSetManager = acknowledgementSetManager;
this.pluginMetrics = pluginMetrics;
this.partitionCommitTrackerMap = new HashMap<>();
this.partitionsToReset = new HashSet<>();
this.schema = MessageFormat.getByMessageFormatByName(schemaType);
Duration bufferTimeout = Duration.ofSeconds(1);
this.bufferAccumulator = BufferAccumulator.create(buffer, DEFAULT_NUMBER_OF_RECORDS_TO_ACCUMULATE, bufferTimeout);
Expand All @@ -121,29 +126,21 @@ private AcknowledgementSet createAcknowledgementSet(Map<TopicPartition, Range<Lo
try {
int partitionId = partition.partition();
if (!partitionCommitTrackerMap.containsKey(partitionId)) {
OffsetAndMetadata committedOffsetAndMetadata = null;
synchronized(consumer) {
committedOffsetAndMetadata = consumer.committed(partition);
}
OffsetAndMetadata committedOffsetAndMetadata = consumer.committed(partition);
Long committedOffset = Objects.nonNull(committedOffsetAndMetadata) ? committedOffsetAndMetadata.offset() : null;
partitionCommitTrackerMap.put(partitionId, new TopicPartitionCommitTracker(partition, committedOffset));
}
OffsetAndMetadata offsetAndMetadata = partitionCommitTrackerMap.get(partitionId).addCompletedOffsets(offsetRange);
updateOffsetsToCommit(partition, offsetAndMetadata);
} catch (Exception e) {
LOG.error("Failed to seek to last committed offset upon positive acknowledgement "+partition, e);
LOG.error("Failed to seek to last committed offset upon positive acknowledgement {}", partition, e);
}
});
} else {
negativeAcknowledgementSetCounter.increment();
offsets.forEach((partition, offsetRange) -> {
try {
synchronized(consumer) {
OffsetAndMetadata committedOffsetAndMetadata = consumer.committed(partition);
consumer.seek(partition, committedOffsetAndMetadata);
}
} catch (Exception e) {
LOG.error("Failed to seek to last committed offset upon negative acknowledgement "+partition, e);
synchronized(partitionsToReset) {
partitionsToReset.add(partition);
}
});
}
Expand All @@ -157,10 +154,7 @@ private AcknowledgementSet createAcknowledgementSet(Map<TopicPartition, Range<Lo

public <T> void consumeRecords() throws Exception {
try {
ConsumerRecords<String, T> records = null;
synchronized(consumer) {
records = consumer.poll(topicConfig.getThreadWaitingTime().toMillis()/2);
}
ConsumerRecords<String, T> records = consumer.poll(topicConfig.getThreadWaitingTime().toMillis()/2);
if (Objects.nonNull(records) && !records.isEmpty() && records.count() > 0) {
Map<TopicPartition, Range<Long>> offsets = new HashMap<>();
AcknowledgementSet acknowledgementSet = null;
Expand All @@ -176,12 +170,27 @@ public <T> void consumeRecords() throws Exception {
}
}
} catch (AuthenticationException e) {
LOG.warn("Authentication Error while doing poll(). Will retry after 10 seconds", e);
LOG.warn("Access Denied while doing poll(). Will retry after 10 seconds", e);
Thread.sleep(10000);
} catch (RecordDeserializationException e) {
LOG.warn("Serialization error - topic {} partition {} offset {}, seeking past the error record",
e.topicPartition().topic(), e.topicPartition().partition(), e.offset());
consumer.seek(e.topicPartition(), e.offset()+1);
}
}

private void commitOffsets() {
private void resetOrCommitOffsets() {
synchronized(partitionsToReset) {
partitionsToReset.forEach(partition -> {
try {
final OffsetAndMetadata offsetAndMetadata = consumer.committed(partition);
consumer.seek(partition, offsetAndMetadata);
} catch (Exception e) {
LOG.error("Failed to seek to last committed offset upon negative acknowledgement {}", partition, e);
}
});
partitionsToReset.clear();
}
if (topicConfig.getAutoCommit()) {
return;
}
Expand All @@ -194,13 +203,11 @@ private void commitOffsets() {
return;
}
try {
synchronized(consumer) {
consumer.commitSync();
}
consumer.commitSync();
offsetsToCommit.clear();
lastCommitTime = currentTimeMillis;
} catch (CommitFailedException e) {
LOG.error("Failed to commit offsets in topic "+topicName, e);
LOG.error("Failed to commit offsets in topic {}", topicName, e);
}
}
}
Expand All @@ -211,14 +218,14 @@ Map<TopicPartition, OffsetAndMetadata> getOffsetsToCommit() {

@Override
public void run() {
try {
consumer.subscribe(Arrays.asList(topicName));
while (!shutdownInProgress.get()) {
consumer.subscribe(Arrays.asList(topicName));
while (!shutdownInProgress.get()) {
try {
resetOrCommitOffsets();
consumeRecords();
commitOffsets();
} catch (Exception exp) {
LOG.error("Error while reading the records from the topic...", exp);
}
} catch (Exception exp) {
LOG.error("Error while reading the records from the topic...", exp);
}
}

Expand Down Expand Up @@ -306,9 +313,8 @@ public void shutdownConsumer(){
@Override
public void onPartitionsAssigned(Collection<TopicPartition> partitions) {
for (TopicPartition topicPartition : partitions) {
synchronized(consumer) {
Long committedOffset = consumer.committed(topicPartition).offset();
consumer.seek(topicPartition, committedOffset);
synchronized(partitionsToReset) {
partitionsToReset.add(topicPartition);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ public class KafkaSource implements Source<Record<Event>> {
private final Counter kafkaWorkerThreadProcessingErrors;
private final PluginMetrics pluginMetrics;
private KafkaSourceCustomConsumer consumer;
private KafkaConsumer kafkaConsumer;
private String pipelineName;
private String consumerGroupID;
private String schemaType = MessageFormat.PLAINTEXT.toString();
Expand Down Expand Up @@ -125,7 +126,6 @@ public void start(Buffer<Record<Event>> buffer) {
int numWorkers = topic.getWorkers();
executorService = Executors.newFixedThreadPool(numWorkers);
IntStream.range(0, numWorkers + 1).forEach(index -> {
KafkaConsumer kafkaConsumer;
switch (schema) {
case JSON:
kafkaConsumer = new KafkaConsumer<String, JsonNode>(consumerProperties);
Expand Down Expand Up @@ -185,6 +185,9 @@ private long calculateLongestThreadWaitingTime() {
orElse(1L);
}

KafkaConsumer getConsumer() {
return kafkaConsumer;
}

private Properties getConsumerProperties(final TopicConfig topicConfig) {
Properties properties = new Properties();
Expand Down Expand Up @@ -361,6 +364,8 @@ private void setPropertiesForSchemaType(Properties properties, TopicConfig topic

private void setConsumerTopicProperties(Properties properties, TopicConfig topicConfig) {
properties.put(ConsumerConfig.GROUP_ID_CONFIG, consumerGroupID);
properties.put(ConsumerConfig.RETRY_BACKOFF_MS_CONFIG, ((Long)topicConfig.getRetryBackoff().toMillis()).intValue());
properties.put(ConsumerConfig.RECONNECT_BACKOFF_MS_CONFIG, ((Long)topicConfig.getReconnectBackoff().toMillis()).intValue());
properties.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG,
topicConfig.getAutoCommit());
properties.put(ConsumerConfig.AUTO_COMMIT_INTERVAL_MS_CONFIG,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@
import software.amazon.awssdk.services.kafka.KafkaClient;
import software.amazon.awssdk.services.kafka.model.GetBootstrapBrokersRequest;
import software.amazon.awssdk.services.kafka.model.GetBootstrapBrokersResponse;
import software.amazon.awssdk.services.kafka.model.InternalServerErrorException;
import software.amazon.awssdk.services.kafka.model.ConflictException;
import software.amazon.awssdk.services.kafka.model.ForbiddenException;
import software.amazon.awssdk.services.kafka.model.UnauthorizedException;
import software.amazon.awssdk.services.kafka.model.KafkaException;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.services.sts.model.StsException;
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider;
Expand Down Expand Up @@ -214,17 +211,15 @@ public static String getBootStrapServersForMsk(final AwsIamAuthConfig awsIamAuth
retryable = false;
try {
result = kafkaClient.getBootstrapBrokers(request);
} catch (InternalServerErrorException | ConflictException | ForbiddenException | UnauthorizedException | StsException e) {
} catch (KafkaException | StsException e) {
LOG.debug("Failed to get bootstrap server information from MSK. Retrying...", e);

retryable = true;
try {
Thread.sleep(10000);
} catch (InterruptedException exp) {}
} catch (Exception e) {
throw new RuntimeException("Failed to get bootstrap server information from MSK.", e);
}
} while (retryable && numRetries++ < MAX_KAFKA_CLIENT_RETRIES);
} while (numRetries++ < MAX_KAFKA_CLIENT_RETRIES);
if (Objects.isNull(result)) {
throw new RuntimeException("Failed to get bootstrap server information from MSK after trying multiple times with retryable exceptions.");
}
Expand Down
Loading

0 comments on commit d3a9099

Please sign in to comment.