From 17504eec9715d5ef4f33a29707e07701853aa2ad Mon Sep 17 00:00:00 2001
From: Link Dupont
Date: Wed, 6 Dec 2023 11:25:58 -0500
Subject: [PATCH] feat(mqtt): Use Token.WaitTimeout
When connecting to a broker and publishing a message, use the
`Token.WaitTimeout` method instead of `Token.Wait`. `Token.Wait` waits
indefinitely, which can lead to situations when the cleint never
succeeds in connecting or publishing.
The timeout for each operation can be configured independently by
setting `mqtt-connect-timeout` and `mqtt-publish-timeout`. Both values
default to 30 seconds. The flags are hidden, as they should not commonly
be required to be changed by users.
Signed-off-by: Link Dupont
---
cmd/yggd/main.go | 62 +++++++++++++++++++++++++++++++++++++++---------
cmd/yggd/mqtt.go | 40 +++++++++++++++++++++++++------
2 files changed, 84 insertions(+), 18 deletions(-)
diff --git a/cmd/yggd/main.go b/cmd/yggd/main.go
index 904a78e4..f0f192f0 100644
--- a/cmd/yggd/main.go
+++ b/cmd/yggd/main.go
@@ -121,6 +121,18 @@ func main() {
Value: 0 * time.Second,
Hidden: true,
}),
+ altsrc.NewDurationFlag(&cli.DurationFlag{
+ Name: "mqtt-connect-timeout",
+ Usage: "Sets the time to wait before giving up to `DURATION` when connecting to an MQTT broker",
+ Value: 30 * time.Second,
+ Hidden: true,
+ }),
+ altsrc.NewDurationFlag(&cli.DurationFlag{
+ Name: "mqtt-publish-timeout",
+ Usage: "Sets the time to wait before giving up to `DURATION` when publishing a message to an MQTT broker",
+ Value: 30 * time.Second,
+ Hidden: true,
+ }),
}
// This BeforeFunc will load flag values from a config file only if the
@@ -272,12 +284,21 @@ func main() {
log.Tracef("subscribed to topic: %v", topic)
topic = fmt.Sprintf("%v/%v/control/in", yggdrasil.TopicPrefix, ClientID)
- client.Subscribe(topic, 1, func(c mqtt.Client, m mqtt.Message) {
- go handleControlMessage(c, m)
+ client.Subscribe(topic, 1, func(client mqtt.Client, message mqtt.Message) {
+ go handleControlMessage(
+ client,
+ message,
+ c.Duration("mqtt-publish-timeout"),
+ c.Duration("mqtt-publish-timeout"),
+ )
})
log.Tracef("subscribed to topic: %v", topic)
- go publishConnectionStatus(client, d.makeDispatchersMap())
+ go publishConnectionStatus(
+ client,
+ d.makeDispatchersMap(),
+ c.Duration("mqtt-publish-timeout"),
+ )
})
mqttClientOpts.SetDefaultPublishHandler(func(c mqtt.Client, m mqtt.Message) {
log.Errorf("unhandled message: %v", string(m.Payload()))
@@ -324,7 +345,18 @@ func main() {
)
mqttClient := mqtt.NewClient(mqttClientOpts)
- if token := mqttClient.Connect(); token.Wait() && token.Error() != nil {
+ log.Infof("connecting to broker: %v", c.StringSlice("broker"))
+ token := mqttClient.Connect()
+ if !token.WaitTimeout(c.Duration("mqtt-connect-timeout")) {
+ return cli.Exit(
+ fmt.Errorf(
+ "cannot connect to broker: connection timeout: %v elapsed",
+ c.Duration("mqtt-connect-timeout"),
+ ),
+ 1,
+ )
+ }
+ if token.Error() != nil {
return cli.Exit(fmt.Errorf("cannot connect to broker: %w", token.Error()), 1)
}
@@ -350,7 +382,11 @@ func main() {
}
}
prevDispatchersHash.Store(sum)
- go publishConnectionStatus(mqttClient, dispatchers)
+ go publishConnectionStatus(
+ mqttClient,
+ dispatchers,
+ c.Duration("mqtt-publish-timeout"),
+ )
}
}()
@@ -360,7 +396,7 @@ func main() {
// Start a goroutine that receives yggdrasil.Data values on a 'recv'
// channel and publish them to MQTT.
- go publishReceivedData(mqttClient, d.recvQ)
+ go publishReceivedData(mqttClient, d.recvQ, c.Duration("mqtt-publish-timeout"))
// Locate and start worker child processes.
workerPath := filepath.Join(yggdrasil.LibexecDir, yggdrasil.LongName)
@@ -402,21 +438,25 @@ func main() {
// Start a goroutine that watches the tags file for write events and
// publishes connection status messages when the file changes.
go func() {
- c := make(chan notify.EventInfo, 1)
+ events := make(chan notify.EventInfo, 1)
fp := filepath.Join(yggdrasil.SysconfDir, yggdrasil.LongName, "tags.toml")
- if err := notify.Watch(fp, c, notify.InCloseWrite, notify.InDelete); err != nil {
+ if err := notify.Watch(fp, events, notify.InCloseWrite, notify.InDelete); err != nil {
log.Infof("cannot start watching '%v': %v", fp, err)
return
}
- defer notify.Stop(c)
+ defer notify.Stop(events)
- for e := range c {
+ for e := range events {
log.Debugf("received inotify event %v", e.Event())
switch e.Event() {
case notify.InCloseWrite, notify.InDelete:
- go publishConnectionStatus(mqttClient, d.makeDispatchersMap())
+ go publishConnectionStatus(
+ mqttClient,
+ d.makeDispatchersMap(),
+ c.Duration("mqtt-publish-timeout"),
+ )
}
}
}()
diff --git a/cmd/yggd/mqtt.go b/cmd/yggd/mqtt.go
index 34204c52..3b0153b3 100644
--- a/cmd/yggd/mqtt.go
+++ b/cmd/yggd/mqtt.go
@@ -27,7 +27,12 @@ func handleDataMessage(client mqtt.Client, msg mqtt.Message, sendQ chan<- yggdra
sendQ <- data
}
-func handleControlMessage(client mqtt.Client, msg mqtt.Message) {
+func handleControlMessage(
+ client mqtt.Client,
+ msg mqtt.Message,
+ publishTimeout time.Duration,
+ connectTimeout time.Duration,
+) {
log.Debugf("received a message on topic %v", msg.Topic())
var cmd yggdrasil.Command
@@ -56,7 +61,11 @@ func handleControlMessage(client mqtt.Client, msg mqtt.Message) {
}
topic := fmt.Sprintf("%v/%v/control/out", yggdrasil.TopicPrefix, ClientID)
- if token := client.Publish(topic, 1, false, data); token.Wait() && token.Error() != nil {
+ token := client.Publish(topic, 1, false, data)
+ if !token.WaitTimeout(publishTimeout) {
+ log.Errorf("cannot publish message: connection timeout: %v elapsed", publishTimeout)
+ }
+ if token.Error() != nil {
log.Errorf("failed to publish message: %v", token.Error())
}
case yggdrasil.CommandNameDisconnect:
@@ -72,7 +81,12 @@ func handleControlMessage(client mqtt.Client, msg mqtt.Message) {
}
time.Sleep(time.Duration(delay) * time.Second)
- if token := client.Connect(); token.Wait() && token.Error() != nil {
+ token := client.Connect()
+ if !token.WaitTimeout(connectTimeout) {
+ log.Errorf("cannot reconnect to broker: connection timeout: %v elapsed", connectTimeout)
+ return
+ }
+ if token.Error() != nil {
log.Errorf("cannot reconnect to broker: %v", token.Error())
return
}
@@ -81,7 +95,11 @@ func handleControlMessage(client mqtt.Client, msg mqtt.Message) {
}
}
-func publishConnectionStatus(c mqtt.Client, dispatchers map[string]map[string]string) {
+func publishConnectionStatus(
+ c mqtt.Client,
+ dispatchers map[string]map[string]string,
+ timeout time.Duration,
+) {
facts, err := yggdrasil.GetCanonicalFacts()
if err != nil {
log.Errorf("cannot get canonical facts: %v", err)
@@ -125,13 +143,17 @@ func publishConnectionStatus(c mqtt.Client, dispatchers map[string]map[string]st
topic := fmt.Sprintf("%v/%v/control/out", yggdrasil.TopicPrefix, ClientID)
- if token := c.Publish(topic, 1, false, data); token.Wait() && token.Error() != nil {
+ token := c.Publish(topic, 1, false, data)
+ if !token.WaitTimeout(timeout) {
+ log.Errorf("cannot publish message: connection timeout: %v elapsed", timeout)
+ }
+ if token.Error() != nil {
log.Errorf("failed to publish message: %v", token.Error())
}
log.Debugf("published message %v to topic %v", msg.MessageID, topic)
}
-func publishReceivedData(client mqtt.Client, c <-chan yggdrasil.Data) {
+func publishReceivedData(client mqtt.Client, c <-chan yggdrasil.Data, timeout time.Duration) {
for d := range c {
topic := fmt.Sprintf("%v/%v/data/out", yggdrasil.TopicPrefix, ClientID)
@@ -141,7 +163,11 @@ func publishReceivedData(client mqtt.Client, c <-chan yggdrasil.Data) {
continue
}
- if token := client.Publish(topic, 1, false, data); token.Wait() && token.Error() != nil {
+ token := client.Publish(topic, 1, false, data)
+ if !token.WaitTimeout(timeout) {
+ log.Errorf("cannot publish message: connection timeout: %v elapsed", timeout)
+ }
+ if token.Error() != nil {
log.Errorf("failed to publish message: %v", token.Error())
}
log.Debugf("published message %v to topic %v", d.MessageID, topic)