Skip to content

Commit

Permalink
Resolve race conditions in MQTT connection logic (#1060)
Browse files Browse the repository at this point in the history
  • Loading branch information
MertCingoz authored Jan 8, 2025
1 parent c2c051e commit ee7c0cb
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 81 deletions.
2 changes: 1 addition & 1 deletion bin/test_redirect
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ log bin/reset_config $site_path $project_spec $device_id shutdown_config.json
bin/reset_config $site_path $project_spec $device_id shutdown_config.json

log And let it settle for last start termination...
sleep 120
sleep 125

tail out/pubber.log.2

Expand Down
41 changes: 12 additions & 29 deletions pubber/src/main/java/daq/pubber/Pubber.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,9 @@
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import java.util.function.Function;
import org.apache.http.ConnectionClosedException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import udmi.lib.base.MqttDevice;
import udmi.lib.base.MqttPublisher.PublisherException;
import udmi.lib.client.DeviceManager;
import udmi.lib.client.SystemManager;
import udmi.lib.intf.FamilyProvider;
Expand Down Expand Up @@ -103,7 +101,6 @@ public class Pubber extends PubberManager implements PubberUdmiPublisher {
private SchemaVersion targetSchema;
private int deviceUpdateCount = -1;
private PubberDeviceManager deviceManager;
private boolean isConnected;
private boolean isGatewayDevice;

/**
Expand Down Expand Up @@ -393,7 +390,7 @@ private void processDeviceMetadata(Metadata metadata) {
}

@Override
public void periodicUpdate() {
public synchronized void periodicUpdate() {
try {
deviceUpdateCount++;
checkSmokyFailure();
Expand All @@ -403,6 +400,7 @@ public void periodicUpdate() {
flushDirtyState();
} catch (Exception e) {
error("Fatal error during execution", e);
resetConnection(getWorkingEndpoint());
}
}

Expand All @@ -427,18 +425,13 @@ public void startConnection(Function<String, Boolean> connectionDone) {

private boolean attemptConnection() {
try {
isConnected = false;
deviceManager.stop();
super.stop();
if (deviceTarget == null || !deviceTarget.isActive()) {
error("Mqtt publisher not active");
disconnectMqtt();
initializeMqtt();
}
disconnectMqtt();
initializeMqtt();
registerMessageHandlers();
connect();
configLatchWait();
isConnected = true;
deviceManager.activate();
return true;
} catch (Exception e) {
Expand Down Expand Up @@ -515,22 +508,14 @@ public byte[] ensureKeyBytes() {
}

@Override
public void publisherException(Exception toReport) {
if (toReport instanceof PublisherException report) {
publisherHandler(report.getType(), report.getPhase(), report.getCause(),
report.getDeviceId());
} else if (toReport instanceof ConnectionClosedException) {
error("Connection closed, attempting reconnect...");
while (retriesRemaining.getAndDecrement() > 0) {
if (attemptConnection()) {
return;
}
public synchronized void reconnect() {
while (retriesRemaining.getAndDecrement() > 0) {
if (attemptConnection()) {
return;
}
error("Connection retry failed, giving up.");
deviceManager.systemLifecycle(SystemMode.TERMINATE);
} else {
error("Unknown exception type " + toReport.getClass(), toReport);
}
error("Connection retry failed, giving up.");
deviceManager.systemLifecycle(SystemMode.TERMINATE);
}

@Override
Expand All @@ -541,12 +526,10 @@ public void persistEndpoint(EndpointConfiguration endpoint) {
}

@Override
public void resetConnection(String targetEndpoint) {
public synchronized void resetConnection(String targetEndpoint) {
try {
config.endpoint = fromJsonString(targetEndpoint,
EndpointConfiguration.class);
disconnectMqtt();
initializeMqtt();
retriesRemaining.set(CONNECT_RETRIES);
startConnection(connectionDone);
} catch (Exception e) {
Expand Down Expand Up @@ -700,7 +683,7 @@ public void setConfigLatch(CountDownLatch countDownLatch) {

@Override
public boolean isConnected() {
return isConnected;
return deviceTarget != null && deviceTarget.isActive();
}

@Override
Expand Down
6 changes: 4 additions & 2 deletions pubber/src/main/java/daq/pubber/PubberGatewayManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import udmi.lib.ProtocolFamily;
import udmi.lib.client.GatewayManager;
import udmi.lib.client.ProxyDeviceHost;
Expand Down Expand Up @@ -47,8 +48,9 @@ public void setMetadata(Metadata metadata) {

@Override
public void activate() {
ifNotNullThen(proxyDevices, p -> p.values()
.parallelStream().forEach(ProxyDeviceHost::activate));
ifNotNullThen(proxyDevices, p -> CompletableFuture.runAsync(() -> p.values()
.parallelStream()
.forEach(ProxyDeviceHost::activate)));
}

@Override
Expand Down
39 changes: 27 additions & 12 deletions pubber/src/main/java/daq/pubber/PubberUdmiPublisher.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,13 @@
import java.util.concurrent.locks.Lock;
import java.util.function.Consumer;
import java.util.function.Function;
import org.apache.http.ConnectionClosedException;
import udmi.lib.base.GatewayError;
import udmi.lib.base.MqttDevice;
import udmi.lib.base.MqttPublisher.FakeTopic;
import udmi.lib.base.MqttPublisher.InjectedMessage;
import udmi.lib.base.MqttPublisher.InjectedState;
import udmi.lib.base.MqttPublisher.PublisherException;
import udmi.lib.client.DeviceManager;
import udmi.lib.client.PointsetManager;
import udmi.lib.client.PointsetManager.ExtraPointsetEvent;
Expand Down Expand Up @@ -242,9 +244,9 @@ default void captureExceptions(String action, Runnable runnable) {
*/
default void disconnectMqtt() {
if (getDeviceTarget() != null) {
captureExceptions("closing mqtt publisher", () -> getDeviceTarget().close());
captureExceptions("shutting down mqtt publisher executor",
captureExceptions("Shutting down MQTT publisher executor",
() -> getDeviceTarget().shutdown());
captureExceptions("Closing MQTT publisher", () -> getDeviceTarget().close());
setDeviceTarget(null);
}
}
Expand Down Expand Up @@ -805,16 +807,12 @@ default void publishSynchronousState() {
}
}

default boolean publisherActive() {
return getDeviceTarget() != null && getDeviceTarget().isActive();
}

/**
* Publishes the current device state as a message to the publisher if the publisher is active. If
* the publisher is not active, it marks the state as dirty and returns without publishing.
*/
default void publishStateMessage() {
if (!publisherActive()) {
if (!isConnected()) {
markStateDirty(-1);
return;
}
Expand Down Expand Up @@ -898,8 +896,8 @@ private void publishDeviceMessage(String targetId, Object message) {
* configured.
*/
default void publishDeviceMessage(String targetId, Object message, Runnable callback) {
if (getDeviceTarget() == null) {
error("publisher not active");
if (!isConnected()) {
error(format("Publisher not active (%s)", targetId));
return;
}
String topicSuffix = MESSAGE_TOPIC_SUFFIX_MAP.get(message.getClass());
Expand Down Expand Up @@ -989,6 +987,10 @@ default void debug(String message, String detail) {

void startConnection(Function<String, Boolean> connectionDone);

void reconnect();

void resetConnection(String targetEndpoint);

/**
* Flushes the dirty state by publishing an asynchronous state change.
*/
Expand All @@ -1000,12 +1002,25 @@ default void flushDirtyState() {

byte[] ensureKeyBytes();

void publisherException(Exception toReport);
/**
* Handles exceptions related to the publisher and
* takes appropriate actions based on the exception type.
*
* @param toReport the exception to be handled;
*/
default void publisherException(Exception toReport) {
if (toReport instanceof PublisherException r) {
publisherHandler(r.getType(), r.getPhase(), r.getCause(), r.getDeviceId());
} else if (toReport instanceof ConnectionClosedException) {
warn("Connection closed, attempting reconnect...");
reconnect();
} else {
error("Unknown exception type " + toReport.getClass(), toReport);
}
}

void persistEndpoint(EndpointConfiguration endpoint);

void resetConnection(String targetEndpoint);

String traceTimestamp(String messageBase);

/**
Expand Down
77 changes: 40 additions & 37 deletions pubber/src/main/java/udmi/lib/base/MqttPublisher.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import javax.net.SocketFactory;
import javax.net.ssl.SSLSocketFactory;
Expand Down Expand Up @@ -100,6 +101,7 @@ public class MqttPublisher implements Publisher {

private final Map<String, MqttClient> mqttClients = new ConcurrentHashMap<>();
private final Map<String, Instant> reauthTimes = new ConcurrentHashMap<>();
ReentrantLock reconnectLock = new ReentrantLock();

private final ExecutorService publisherExecutor =
Executors.newFixedThreadPool(PUBLISH_THREAD_COUNT);
Expand Down Expand Up @@ -215,22 +217,32 @@ private void publishCore(String deviceId, String topicSuffix, Object data, Runna
callback.run();
}
} catch (Exception e) {
if (!isActive()) {
return;
}
errorCounter.incrementAndGet();
warn(format("Publish %s failed for %s: %s", topicSuffix, deviceId, e));
if (getGatewayId() == null) {
closeMqttClient(deviceId);
if (mqttClients.isEmpty()) {
warn("Last client closed, shutting down connection.");
close();
shutdown();
reconnect();
}
} else if (getGatewayId().equals(deviceId)) {
reconnect();
}
}
}

private synchronized void reconnect() {
if (isActive()) {
if (reconnectLock.tryLock()) {
try {
// Force reconnect to address potential bad states
onError.accept(new ConnectionClosedException());
} finally {
reconnectLock.unlock();
}
} else if (getGatewayId().equals(deviceId)) {
close();
shutdown();
// Force reconnect to address potential bad states
onError.accept(new ConnectionClosedException());
}
}
}
Expand Down Expand Up @@ -268,7 +280,7 @@ private void closeMqttClient(String deviceId) {
if (removed != null) {
try {
if (removed.isConnected()) {
removed.disconnect();
removed.disconnectForcibly();
}
removed.close();
} catch (Exception e) {
Expand Down Expand Up @@ -298,7 +310,7 @@ public void close() {
@Override
public void shutdown() {
if (isActive()) {
publisherExecutor.shutdown();
publisherExecutor.shutdownNow();
}
}

Expand Down Expand Up @@ -532,7 +544,7 @@ private String getDeviceId(String topic) {
return topic.split("/")[splitIndex];
}

public void connect(String targetId, boolean clean) {
public synchronized void connect(String targetId, boolean clean) {
ifTrueThen(clean, () -> closeMqttClient(targetId));
getConnectedClient(targetId);
}
Expand Down Expand Up @@ -569,8 +581,10 @@ private boolean sendMessage(String deviceId, String mqttTopic,
return true;
}

private MqttClient getActiveClient(String targetId) {
checkAuthentication(targetId);
private synchronized MqttClient getActiveClient(String targetId) {
if (!checkAuthentication(targetId)) {
return null;
}
MqttClient client = getConnectedClient(targetId);
if (client.isConnected()) {
return client;
Expand All @@ -586,24 +600,16 @@ private void safeSleep(long timeoutMs) {
}
}

private void checkAuthentication(String targetId) {
private boolean checkAuthentication(String targetId) {
String authId = ofNullable(getGatewayId()).orElse(targetId);
Instant reAuthTime = reauthTimes.get(authId);
if (reAuthTime == null || Instant.now().isBefore(reAuthTime)) {
return;
return true;
}
warn("Authentication retry time reached for " + authId);
reauthTimes.remove(authId);
synchronized (mqttClients) {
try {
close();
shutdown();
// Force reconnect to address potential bad states
onError.accept(new ConnectionClosedException());
} catch (Exception e) {
throw new RuntimeException("While trying to reconnect mqtt client", e);
}
}
reconnect();
return false;
}

private MqttClient getConnectedClient(String deviceId) {
Expand Down Expand Up @@ -721,26 +727,23 @@ private class MqttCallbackHandler implements MqttCallback {

@Override
public void connectionLost(Throwable cause) {
boolean connected = cleanClients(deviceId).isConnected();
warn("MQTT Connection Lost: " + connected + cause);
close();
shutdown();
// Force reconnect to address potential bad states
onError.accept(new ConnectionClosedException());
if (isActive()) {
boolean connected = cleanClients(deviceId).isConnected();
warn(format("MQTT Connection Lost: %s %s", connected, cause));
reconnect();
}
}

@Override
public void deliveryComplete(IMqttDeliveryToken token) {
}

@Override
public void messageArrived(String topic, MqttMessage message) {
synchronized (MqttPublisher.this) {
try {
messageArrivedCore(topic, message);
} catch (Exception e) {
error("While processing message", deviceId, null, "handle", e);
}
public synchronized void messageArrived(String topic, MqttMessage message) {
try {
messageArrivedCore(topic, message);
} catch (Exception e) {
error("While processing message", deviceId, null, "handle", e);
}
}

Expand Down
Loading

0 comments on commit ee7c0cb

Please sign in to comment.