Skip to content

Commit

Permalink
[fix][client] Make the whole grabCnx() progress atomic (apache#20595)
Browse files Browse the repository at this point in the history
  • Loading branch information
BewareMyPower authored Jun 30, 2023
1 parent c82825b commit 2bede01
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 53 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.pulsar.client.impl;

import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import java.time.Duration;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.reflect.FieldUtils;
import org.apache.pulsar.client.api.ProducerConsumerBase;
import org.apache.pulsar.common.util.FutureUtil;
import org.awaitility.Awaitility;
import org.awaitility.core.ConditionTimeoutException;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

@Slf4j
@Test(groups = "broker-impl")
public class ConnectionHandlerTest extends ProducerConsumerBase {

private static final Backoff BACKOFF = new BackoffBuilder().setInitialTime(1, TimeUnit.MILLISECONDS)
.setMandatoryStop(1, TimeUnit.SECONDS)
.setMax(3, TimeUnit.SECONDS).create();
private final ExecutorService executor = Executors.newFixedThreadPool(4);

@BeforeClass(alwaysRun = true)
@Override
protected void setup() throws Exception {
super.internalSetup();
super.producerBaseSetup();
}

@AfterClass
@Override
protected void cleanup() throws Exception {
super.internalCleanup();
executor.shutdown();
}

@Test(timeOut = 30000)
public void testSynchronousGrabCnx() {
for (int i = 0; i < 10; i++) {
final CompletableFuture<Integer> future = new CompletableFuture<>();
final int index = i;
final ConnectionHandler handler = new ConnectionHandler(
new MockedHandlerState((PulsarClientImpl) pulsarClient, "my-topic"), BACKOFF,
cnx -> {
future.complete(index);
return CompletableFuture.completedFuture(null);
});
handler.grabCnx();
Assert.assertEquals(future.join(), i);
}
}

@Test
public void testConcurrentGrabCnx() {
final AtomicInteger cnt = new AtomicInteger(0);
final ConnectionHandler handler = new ConnectionHandler(
new MockedHandlerState((PulsarClientImpl) pulsarClient, "my-topic"), BACKOFF,
cnx -> {
cnt.incrementAndGet();
return CompletableFuture.completedFuture(null);
});
final int numGrab = 10;
for (int i = 0; i < numGrab; i++) {
handler.grabCnx();
}
Awaitility.await().atMost(Duration.ofSeconds(3)).until(() -> cnt.get() > 0);
Assert.assertThrows(ConditionTimeoutException.class,
() -> Awaitility.await().atMost(Duration.ofMillis(500)).until(() -> cnt.get() == numGrab));
Assert.assertEquals(cnt.get(), 1);
}

@Test
public void testDuringConnectInvokeCount() throws IllegalAccessException {
// 1. connectionOpened completes with null
final AtomicBoolean duringConnect = spy(new AtomicBoolean());
final ConnectionHandler handler1 = new ConnectionHandler(
new MockedHandlerState((PulsarClientImpl) pulsarClient, "my-topic"), BACKOFF,
cnx -> CompletableFuture.completedFuture(null));
FieldUtils.writeField(handler1, "duringConnect", duringConnect, true);
handler1.grabCnx();
Awaitility.await().atMost(Duration.ofSeconds(3)).until(() -> !duringConnect.get());
verify(duringConnect, times(1)).compareAndSet(false, true);
verify(duringConnect, times(1)).set(false);

// 2. connectionFailed is called
final ConnectionHandler handler2 = new ConnectionHandler(
new MockedHandlerState((PulsarClientImpl) pulsarClient, null), new MockedBackoff(),
cnx -> CompletableFuture.completedFuture(null));
FieldUtils.writeField(handler2, "duringConnect", duringConnect, true);
handler2.grabCnx();
Awaitility.await().atMost(Duration.ofSeconds(3)).until(() -> !duringConnect.get());
verify(duringConnect, times(2)).compareAndSet(false, true);
verify(duringConnect, times(2)).set(false);

// 3. connectionOpened completes exceptionally
final ConnectionHandler handler3 = new ConnectionHandler(
new MockedHandlerState((PulsarClientImpl) pulsarClient, "my-topic"), new MockedBackoff(),
cnx -> FutureUtil.failedFuture(new RuntimeException("fail")));
FieldUtils.writeField(handler3, "duringConnect", duringConnect, true);
handler3.grabCnx();
Awaitility.await().atMost(Duration.ofSeconds(3)).until(() -> !duringConnect.get());
verify(duringConnect, times(3)).compareAndSet(false, true);
verify(duringConnect, times(3)).set(false);
}

private static class MockedHandlerState extends HandlerState {

public MockedHandlerState(PulsarClientImpl client, String topic) {
super(client, topic);
}

@Override
String getHandlerName() {
return "mocked";
}
}

private static class MockedBackoff extends Backoff {

// Set a large backoff so that reconnection won't happen in tests
public MockedBackoff() {
super(1, TimeUnit.HOURS, 2, TimeUnit.HOURS, 1, TimeUnit.HOURS);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.net.InetSocketAddress;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLongFieldUpdater;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
import org.apache.pulsar.client.api.PulsarClientException;
Expand All @@ -41,10 +42,16 @@ public class ConnectionHandler {
// Start with -1L because it gets incremented before sending on the first connection
private volatile long epoch = -1L;
protected volatile long lastConnectionClosedTimestamp = 0L;
private final AtomicBoolean duringConnect = new AtomicBoolean(false);

interface Connection {
void connectionFailed(PulsarClientException exception);
void connectionOpened(ClientCnx cnx);

/**
* @apiNote If the returned future is completed exceptionally, reconnectLater will be called.
*/
CompletableFuture<Void> connectionOpened(ClientCnx cnx);
default void connectionFailed(PulsarClientException e) {
}
}

protected Connection connection;
Expand All @@ -69,6 +76,11 @@ protected void grabCnx() {
state.topic, state.getHandlerName(), state.getState());
return;
}
if (!duringConnect.compareAndSet(false, true)) {
log.info("[{}] [{}] Skip grabbing the connection since there is a pending connection",
state.topic, state.getHandlerName());
return;
}

try {
CompletableFuture<ClientCnx> cnxFuture;
Expand All @@ -81,7 +93,8 @@ protected void grabCnx() {
} else {
cnxFuture = state.client.getConnection(state.topic); //
}
cnxFuture.thenAccept(cnx -> connection.connectionOpened(cnx)) //
cnxFuture.thenCompose(cnx -> connection.connectionOpened(cnx))
.thenAccept(__ -> duringConnect.set(false))
.exceptionally(this::handleConnectionError);
} catch (Throwable t) {
log.warn("[{}] [{}] Exception thrown while getting connection: ", state.topic, state.getHandlerName(), t);
Expand All @@ -90,25 +103,27 @@ protected void grabCnx() {
}

private Void handleConnectionError(Throwable exception) {
log.warn("[{}] [{}] Error connecting to broker: {}",
state.topic, state.getHandlerName(), exception.getMessage());
if (exception instanceof PulsarClientException) {
connection.connectionFailed((PulsarClientException) exception);
} else if (exception.getCause() instanceof PulsarClientException) {
connection.connectionFailed((PulsarClientException) exception.getCause());
} else {
connection.connectionFailed(new PulsarClientException(exception));
}

State state = this.state.getState();
if (state == State.Uninitialized || state == State.Connecting || state == State.Ready) {
reconnectLater(exception);
try {
log.warn("[{}] [{}] Error connecting to broker: {}",
state.topic, state.getHandlerName(), exception.getMessage());
if (exception instanceof PulsarClientException) {
connection.connectionFailed((PulsarClientException) exception);
} else if (exception.getCause() instanceof PulsarClientException) {
connection.connectionFailed((PulsarClientException) exception.getCause());
} else {
connection.connectionFailed(new PulsarClientException(exception));
}
} catch (Throwable throwable) {
log.error("[{}] [{}] Unexpected exception after the connection",
state.topic, state.getHandlerName(), throwable);
}

reconnectLater(exception);
return null;
}

protected void reconnectLater(Throwable exception) {
void reconnectLater(Throwable exception) {
duringConnect.set(false);
CLIENT_CNX_UPDATER.set(this, null);
if (!isValidStateForReconnection()) {
log.info("[{}] [{}] Ignoring reconnection request (state: {})",
Expand All @@ -132,6 +147,7 @@ protected void reconnectLater(Throwable exception) {

public void connectionClosed(ClientCnx cnx) {
lastConnectionClosedTimestamp = System.currentTimeMillis();
duringConnect.set(false);
state.client.getCnxPool().releaseConnection(cnx);
if (CLIENT_CNX_UPDATER.compareAndSet(this, cnx, null)) {
if (!isValidStateForReconnection()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -760,16 +760,17 @@ public void negativeAcknowledge(Message<?> message) {
}

@Override
public void connectionOpened(final ClientCnx cnx) {
public CompletableFuture<Void> connectionOpened(final ClientCnx cnx) {
previousExceptions.clear();

if (getState() == State.Closing || getState() == State.Closed) {
final State state = getState();
if (state == State.Closing || state == State.Closed) {
setState(State.Closed);
closeConsumerTasks();
deregisterFromClientCnx();
client.cleanupConsumer(this);
clearReceiverQueue();
return;
return CompletableFuture.completedFuture(null);
}

log.info("[{}][{}] Subscribing to topic on cnx {}, consumerId {}",
Expand Down Expand Up @@ -823,6 +824,7 @@ public void connectionOpened(final ClientCnx cnx) {
&& startMessageId.equals(initialStartMessageId)) ? startMessageRollbackDurationInSec : 0;

// synchronized this, because redeliverUnAckMessage eliminate the epoch inconsistency between them
final CompletableFuture<Void> future = new CompletableFuture<>();
synchronized (this) {
setClientCnx(cnx);
ByteBuf request = Commands.newSubscribe(topic, subscription, consumerId, requestId, getSubType(),
Expand All @@ -844,6 +846,7 @@ public void connectionOpened(final ClientCnx cnx) {
deregisterFromClientCnx();
client.cleanupConsumer(this);
cnx.channel().close();
future.complete(null);
return;
}
}
Expand All @@ -856,12 +859,14 @@ public void connectionOpened(final ClientCnx cnx) {
if (!(firstTimeConnect && hasParentConsumer) && getCurrentReceiverQueueSize() != 0) {
increaseAvailablePermits(cnx, getCurrentReceiverQueueSize());
}
future.complete(null);
}).exceptionally((e) -> {
deregisterFromClientCnx();
if (getState() == State.Closing || getState() == State.Closed) {
// Consumer was closed while reconnecting, close the connection to make sure the broker
// drops the consumer on its side
cnx.channel().close();
future.complete(null);
return null;
}
log.warn("[{}][{}] Failed to subscribe to topic on {}", topic,
Expand All @@ -879,7 +884,7 @@ public void connectionOpened(final ClientCnx cnx) {
if (e.getCause() instanceof PulsarClientException
&& PulsarClientException.isRetriableError(e.getCause())
&& System.currentTimeMillis() < SUBSCRIBE_DEADLINE_UPDATER.get(ConsumerImpl.this)) {
reconnectLater(e.getCause());
future.completeExceptionally(e.getCause());
} else if (!subscribeFuture.isDone()) {
// unable to create new consumer, fail operation
setState(State.Failed);
Expand All @@ -903,11 +908,16 @@ public void connectionOpened(final ClientCnx cnx) {
topic, subscription, cnx.channel().remoteAddress());
} else {
// consumer was subscribed and connected but we got some error, keep trying
reconnectLater(e.getCause());
future.completeExceptionally(e.getCause());
}

if (!future.isDone()) {
future.complete(null);
}
return null;
});
}
return future;
}

protected void consumerIsReconnectedToBroker(ClientCnx cnx, int currentQueueSize) {
Expand Down Expand Up @@ -991,7 +1001,7 @@ public void connectionFailed(PulsarClientException exception) {
setState(State.Failed);
if (nonRetriableError) {
log.info("[{}] Consumer creation failed for consumer {} with unretriableError {}",
topic, consumerId, exception);
topic, consumerId, exception.getMessage());
} else {
log.info("[{}] Consumer creation failed for consumer {} after timeout", topic, consumerId);
}
Expand Down Expand Up @@ -2590,10 +2600,6 @@ void deregisterFromClientCnx() {
setClientCnx(null);
}

void reconnectLater(Throwable exception) {
this.connectionHandler.reconnectLater(exception);
}

void grabCnx() {
this.connectionHandler.grabCnx();
}
Expand Down
Loading

0 comments on commit 2bede01

Please sign in to comment.