diff --git a/cmd/yggd/main.go b/cmd/yggd/main.go index 904a78e4..f0f192f0 100644 --- a/cmd/yggd/main.go +++ b/cmd/yggd/main.go @@ -121,6 +121,18 @@ func main() { Value: 0 * time.Second, Hidden: true, }), + altsrc.NewDurationFlag(&cli.DurationFlag{ + Name: "mqtt-connect-timeout", + Usage: "Sets the time to wait before giving up to `DURATION` when connecting to an MQTT broker", + Value: 30 * time.Second, + Hidden: true, + }), + altsrc.NewDurationFlag(&cli.DurationFlag{ + Name: "mqtt-publish-timeout", + Usage: "Sets the time to wait before giving up to `DURATION` when publishing a message to an MQTT broker", + Value: 30 * time.Second, + Hidden: true, + }), } // This BeforeFunc will load flag values from a config file only if the @@ -272,12 +284,21 @@ func main() { log.Tracef("subscribed to topic: %v", topic) topic = fmt.Sprintf("%v/%v/control/in", yggdrasil.TopicPrefix, ClientID) - client.Subscribe(topic, 1, func(c mqtt.Client, m mqtt.Message) { - go handleControlMessage(c, m) + client.Subscribe(topic, 1, func(client mqtt.Client, message mqtt.Message) { + go handleControlMessage( + client, + message, + c.Duration("mqtt-publish-timeout"), + c.Duration("mqtt-publish-timeout"), + ) }) log.Tracef("subscribed to topic: %v", topic) - go publishConnectionStatus(client, d.makeDispatchersMap()) + go publishConnectionStatus( + client, + d.makeDispatchersMap(), + c.Duration("mqtt-publish-timeout"), + ) }) mqttClientOpts.SetDefaultPublishHandler(func(c mqtt.Client, m mqtt.Message) { log.Errorf("unhandled message: %v", string(m.Payload())) @@ -324,7 +345,18 @@ func main() { ) mqttClient := mqtt.NewClient(mqttClientOpts) - if token := mqttClient.Connect(); token.Wait() && token.Error() != nil { + log.Infof("connecting to broker: %v", c.StringSlice("broker")) + token := mqttClient.Connect() + if !token.WaitTimeout(c.Duration("mqtt-connect-timeout")) { + return cli.Exit( + fmt.Errorf( + "cannot connect to broker: connection timeout: %v elapsed", + c.Duration("mqtt-connect-timeout"), + ), + 1, + ) + } + if token.Error() != nil { return cli.Exit(fmt.Errorf("cannot connect to broker: %w", token.Error()), 1) } @@ -350,7 +382,11 @@ func main() { } } prevDispatchersHash.Store(sum) - go publishConnectionStatus(mqttClient, dispatchers) + go publishConnectionStatus( + mqttClient, + dispatchers, + c.Duration("mqtt-publish-timeout"), + ) } }() @@ -360,7 +396,7 @@ func main() { // Start a goroutine that receives yggdrasil.Data values on a 'recv' // channel and publish them to MQTT. - go publishReceivedData(mqttClient, d.recvQ) + go publishReceivedData(mqttClient, d.recvQ, c.Duration("mqtt-publish-timeout")) // Locate and start worker child processes. workerPath := filepath.Join(yggdrasil.LibexecDir, yggdrasil.LongName) @@ -402,21 +438,25 @@ func main() { // Start a goroutine that watches the tags file for write events and // publishes connection status messages when the file changes. go func() { - c := make(chan notify.EventInfo, 1) + events := make(chan notify.EventInfo, 1) fp := filepath.Join(yggdrasil.SysconfDir, yggdrasil.LongName, "tags.toml") - if err := notify.Watch(fp, c, notify.InCloseWrite, notify.InDelete); err != nil { + if err := notify.Watch(fp, events, notify.InCloseWrite, notify.InDelete); err != nil { log.Infof("cannot start watching '%v': %v", fp, err) return } - defer notify.Stop(c) + defer notify.Stop(events) - for e := range c { + for e := range events { log.Debugf("received inotify event %v", e.Event()) switch e.Event() { case notify.InCloseWrite, notify.InDelete: - go publishConnectionStatus(mqttClient, d.makeDispatchersMap()) + go publishConnectionStatus( + mqttClient, + d.makeDispatchersMap(), + c.Duration("mqtt-publish-timeout"), + ) } } }() diff --git a/cmd/yggd/mqtt.go b/cmd/yggd/mqtt.go index 34204c52..3b0153b3 100644 --- a/cmd/yggd/mqtt.go +++ b/cmd/yggd/mqtt.go @@ -27,7 +27,12 @@ func handleDataMessage(client mqtt.Client, msg mqtt.Message, sendQ chan<- yggdra sendQ <- data } -func handleControlMessage(client mqtt.Client, msg mqtt.Message) { +func handleControlMessage( + client mqtt.Client, + msg mqtt.Message, + publishTimeout time.Duration, + connectTimeout time.Duration, +) { log.Debugf("received a message on topic %v", msg.Topic()) var cmd yggdrasil.Command @@ -56,7 +61,11 @@ func handleControlMessage(client mqtt.Client, msg mqtt.Message) { } topic := fmt.Sprintf("%v/%v/control/out", yggdrasil.TopicPrefix, ClientID) - if token := client.Publish(topic, 1, false, data); token.Wait() && token.Error() != nil { + token := client.Publish(topic, 1, false, data) + if !token.WaitTimeout(publishTimeout) { + log.Errorf("cannot publish message: connection timeout: %v elapsed", publishTimeout) + } + if token.Error() != nil { log.Errorf("failed to publish message: %v", token.Error()) } case yggdrasil.CommandNameDisconnect: @@ -72,7 +81,12 @@ func handleControlMessage(client mqtt.Client, msg mqtt.Message) { } time.Sleep(time.Duration(delay) * time.Second) - if token := client.Connect(); token.Wait() && token.Error() != nil { + token := client.Connect() + if !token.WaitTimeout(connectTimeout) { + log.Errorf("cannot reconnect to broker: connection timeout: %v elapsed", connectTimeout) + return + } + if token.Error() != nil { log.Errorf("cannot reconnect to broker: %v", token.Error()) return } @@ -81,7 +95,11 @@ func handleControlMessage(client mqtt.Client, msg mqtt.Message) { } } -func publishConnectionStatus(c mqtt.Client, dispatchers map[string]map[string]string) { +func publishConnectionStatus( + c mqtt.Client, + dispatchers map[string]map[string]string, + timeout time.Duration, +) { facts, err := yggdrasil.GetCanonicalFacts() if err != nil { log.Errorf("cannot get canonical facts: %v", err) @@ -125,13 +143,17 @@ func publishConnectionStatus(c mqtt.Client, dispatchers map[string]map[string]st topic := fmt.Sprintf("%v/%v/control/out", yggdrasil.TopicPrefix, ClientID) - if token := c.Publish(topic, 1, false, data); token.Wait() && token.Error() != nil { + token := c.Publish(topic, 1, false, data) + if !token.WaitTimeout(timeout) { + log.Errorf("cannot publish message: connection timeout: %v elapsed", timeout) + } + if token.Error() != nil { log.Errorf("failed to publish message: %v", token.Error()) } log.Debugf("published message %v to topic %v", msg.MessageID, topic) } -func publishReceivedData(client mqtt.Client, c <-chan yggdrasil.Data) { +func publishReceivedData(client mqtt.Client, c <-chan yggdrasil.Data, timeout time.Duration) { for d := range c { topic := fmt.Sprintf("%v/%v/data/out", yggdrasil.TopicPrefix, ClientID) @@ -141,7 +163,11 @@ func publishReceivedData(client mqtt.Client, c <-chan yggdrasil.Data) { continue } - if token := client.Publish(topic, 1, false, data); token.Wait() && token.Error() != nil { + token := client.Publish(topic, 1, false, data) + if !token.WaitTimeout(timeout) { + log.Errorf("cannot publish message: connection timeout: %v elapsed", timeout) + } + if token.Error() != nil { log.Errorf("failed to publish message: %v", token.Error()) } log.Debugf("published message %v to topic %v", d.MessageID, topic)