Skip to content

Commit

Permalink
Allow connection to multiple address simultaneously (#31)
Browse files Browse the repository at this point in the history
* add SharedSubscriptionPredicate

* add MultiConnectionMode option

* resubscribe on address updates to shared subscriptions

* add go 1.21 to test matrix

* add logging interface

* use wrapped subscriber to avoid unnecessary map writes

* add multierror messages

* handle execOneRandom case for resuming subscriptions

* handle execOneRoundRobin case for publishes

* remove go 1.19 from runners

* log error when reloading multiple clients

* use different clientID for different connections

* take mutex lock while connecting clients too

* add suffix explicitly post checking base value

* log credential fetch errors
  • Loading branch information
ajatprabha authored Oct 1, 2023
1 parent 5f1f085 commit 56a982e
Show file tree
Hide file tree
Showing 30 changed files with 1,678 additions and 279 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ on:
jobs:
test:
runs-on: ubuntu-latest
container: golang:1.20
container: golang:1.21
strategy:
matrix:
go-version: [1.19.x, 1.20.x]
go-version: [1.20.x, 1.21.x]
services:
mqtt:
image: emqx/emqx:latest
Expand All @@ -34,14 +34,14 @@ jobs:
uses: actions/cache@v3
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
key: ${{ runner.os }}-go-${{ matrix.go-version }}-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
${{ runner.os }}-go-${{ matrix.go-version }}-
- name: Tools bin cache
uses: actions/cache@v3
with:
path: .bin
key: ${{ runner.os }}-${{ hashFiles('Makefile') }}
key: ${{ runner.os }}-go-${{ matrix.go-version }}-${{ hashFiles('Makefile') }}
- name: Install jq
uses: dcarbone/[email protected]
- name: Test
Expand Down
4 changes: 3 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ ALL_GO_MOD_DIRS := $(shell find . -type f -name 'go.mod' -exec dirname {} \; | s
PROJECT_DIR := $(shell dirname $(abspath $(lastword $(MAKEFILE_LIST))))
LOCAL_GO_BIN_DIR := $(PROJECT_DIR)/.bin
BIN_DIR := $(if $(LOCAL_GO_BIN_DIR),$(LOCAL_GO_BIN_DIR),$(GOPATH)/bin)
GO_MINOR_VERSION := $(shell go version | cut -d' ' -f3 | cut -d'.' -f2)
GO_BUILD_DIRS := $(foreach dir,$(ALL_GO_MOD_DIRS),$(shell GO_MOD_VERSION=$$(grep "go 1.[0-9]*" $(dir)/go.mod | cut -d' ' -f2 | cut -d'.' -f2) && [ -n "$$GO_MOD_VERSION" ] && [ $(GO_MINOR_VERSION) -ge $$GO_MOD_VERSION ] && echo $(dir)))

fmt:
@$(call run-go-mod-dir,go vet ./...,"go fmt")
Expand Down Expand Up @@ -89,7 +91,7 @@ endef
# a go.mod file
define run-go-mod-dir
set -e; \
for dir in $(ALL_GO_MOD_DIRS); do \
for dir in $(GO_BUILD_DIRS); do \
[ -z $(2) ] || echo "$(2) $${dir}/..."; \
cd "$(PROJECT_DIR)/$${dir}" && $(1); \
done;
Expand Down
109 changes: 63 additions & 46 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"math/rand"
"os"
"sync"
"sync/atomic"
Expand All @@ -19,8 +20,11 @@ var newClientFunc = defaultNewClientFunc()

// Client allows to communicate with an MQTT broker
type Client struct {
options *clientOptions
mqttClient mqtt.Client
options *clientOptions

subscriptions map[string]*subscriptionMeta
mqttClient mqtt.Client
mqttClients map[string]mqtt.Client

publisher Publisher
subscriber Subscriber
Expand All @@ -29,7 +33,10 @@ type Client struct {
sMiddlewares []subscribeMiddleware
usMiddlewares []unsubscribeMiddleware

mu sync.RWMutex
rrCounter *atomicCounter
rndPool *sync.Pool
clientMu sync.RWMutex
subMu sync.RWMutex
}

// NewClient creates the Client struct with the clientOptions provided,
Expand All @@ -46,10 +53,17 @@ func NewClient(opts ...ClientOption) (*Client, error) {
return nil, fmt.Errorf("at least WithAddress or WithResolver ClientOption should be used")
}

c := &Client{options: co}
c := &Client{
options: co,
subscriptions: map[string]*subscriptionMeta{},
rrCounter: &atomicCounter{value: 0},
rndPool: &sync.Pool{New: func() any {
return rand.New(rand.NewSource(time.Now().UnixNano()))
}},
}

if len(co.brokerAddress) != 0 {
c.mqttClient = newClientFunc.Load().(func(*mqtt.ClientOptions) mqtt.Client)(toClientOptions(c, c.options))
c.mqttClient = newClientFunc.Load().(func(*mqtt.ClientOptions) mqtt.Client)(toClientOptions(c, c.options, ""))
}

c.publisher = publishHandler(c)
Expand All @@ -61,13 +75,15 @@ func NewClient(opts ...ClientOption) (*Client, error) {

// IsConnected checks whether the client is connected to the broker
func (c *Client) IsConnected() bool {
var online bool
val := &atomic.Bool{}

err := c.execute(func(cc mqtt.Client) {
online = cc.IsConnectionOpen()
})
return c.execute(func(cc mqtt.Client) error {
if cc.IsConnectionOpen() {
val.CompareAndSwap(false, true)
}

return err == nil && online
return nil
}, execAll) == nil && val.Load()
}

// Start will attempt to connect to the broker.
Expand Down Expand Up @@ -105,22 +121,11 @@ func (c *Client) Run(ctx context.Context) error {
}

func (c *Client) stop() error {
return c.execute(func(cc mqtt.Client) {
return c.execute(func(cc mqtt.Client) error {
cc.Disconnect(uint(c.options.gracefulShutdownPeriod / time.Millisecond))
})
}

func (c *Client) execute(f func(mqtt.Client)) error {
c.mu.RLock()
defer c.mu.RUnlock()

if c.mqttClient == nil {
return ErrClientNotInitialized
}

f(c.mqttClient)

return nil
return nil
}, execAll)
}

func (c *Client) handleToken(ctx context.Context, t mqtt.Token, timeoutErr error) error {
Expand Down Expand Up @@ -158,42 +163,45 @@ func (c *Client) runResolver() error {
case <-time.After(c.options.connectTimeout):
return ErrConnectTimeout
case addrs := <-c.options.resolver.UpdateChan():
c.attemptConnection(addrs)
if err := c.attemptConnections(addrs); err != nil {
return err
}
}

go c.watchAddressUpdates(c.options.resolver)

return nil
}

func (c *Client) runConnect() (err error) {
func (c *Client) runConnect() error {
if len(c.options.brokerAddress) == 0 {
return nil
}

if e := c.execute(func(cc mqtt.Client) {
return c.execute(func(cc mqtt.Client) error {
t := cc.Connect()
if !t.WaitTimeout(c.options.connectTimeout) {
err = ErrConnectTimeout

return
return ErrConnectTimeout
}

err = t.Error()
}); e != nil {
err = e
}
return t.Error()
}, execAll)
}

return
func (c *Client) attemptSingleConnection(addrs []TCPAddress) error {
cc := c.newClient(addrs, 0)
c.reloadClient(cc)

return c.resumeSubscriptions()
}

func toClientOptions(c *Client, o *clientOptions) *mqtt.ClientOptions {
func toClientOptions(c *Client, o *clientOptions, idSuffix string) *mqtt.ClientOptions {
opts := mqtt.NewClientOptions()

if hostname, err := os.Hostname(); o.clientID == "" && err == nil {
opts.SetClientID(hostname)
opts.SetClientID(fmt.Sprintf("%s%s", hostname, idSuffix))
} else {
opts.SetClientID(o.clientID)
opts.SetClientID(fmt.Sprintf("%s%s", o.clientID, idSuffix))
}

setCredentials(o, opts)
Expand All @@ -215,21 +223,30 @@ func toClientOptions(c *Client, o *clientOptions) *mqtt.ClientOptions {

func setCredentials(o *clientOptions, opts *mqtt.ClientOptions) {
if o.credentialFetcher != nil {
ctx, cancel := context.WithTimeout(context.Background(), o.credentialFetchTimeout)
defer cancel()

if c, err := o.credentialFetcher.Credentials(ctx); err == nil {
opts.SetUsername(c.Username)
opts.SetPassword(c.Password)
refreshCredentialsWithFetcher(o, opts)

return
}
return
}

opts.SetUsername(o.username)
opts.SetPassword(o.password)
}

func refreshCredentialsWithFetcher(o *clientOptions, opts *mqtt.ClientOptions) {
ctx, cancel := context.WithTimeout(context.Background(), o.credentialFetchTimeout)
defer cancel()

c, err := o.credentialFetcher.Credentials(ctx)
if err != nil {
o.logger.Error(ctx, err, map[string]any{"message": "failed to fetch credentials"})

return
}

opts.SetUsername(c.Username)
opts.SetPassword(c.Password)
}

func formatAddressWithProtocol(opts *clientOptions) string {
if opts.tlsConfig != nil {
return fmt.Sprintf("tls://%s", opts.brokerAddress)
Expand Down
55 changes: 40 additions & 15 deletions client_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@ package courier
import (
"crypto/tls"
"fmt"
"strings"
"time"
)

var inMemoryPersistence = NewMemoryStore()

func defaultSharedSubscriptionPredicate(topic string) bool {
return strings.HasPrefix(topic, "$share/")
}

// ClientOption allows to configure the behaviour of a Client.
type ClientOption interface{ apply(*clientOptions) }

Expand Down Expand Up @@ -194,6 +199,22 @@ func WithExponentialStartOptions(options ...StartOption) ClientOption {
})
}

// SharedSubscriptionPredicate allows to configure the predicate function that determines
// whether a topic is a shared subscription topic.
type SharedSubscriptionPredicate func(topic string) bool

func (ssp SharedSubscriptionPredicate) apply(o *clientOptions) { o.sharedSubscriptionPredicate = ssp }

// UseMultiConnectionMode allows to configure the client to use multiple connections when available.
//
// This is useful when working with shared subscriptions and multiple connections can be created
// to subscribe on the same application.
var UseMultiConnectionMode = multiConnMode{}

type multiConnMode struct{}

func (mcm multiConnMode) apply(o *clientOptions) { o.multiConnectionMode = true }

type clientOptions struct {
username, clientID, password,
brokerAddress string
Expand All @@ -202,17 +223,19 @@ type clientOptions struct {

tlsConfig *tls.Config

autoReconnect, maintainOrder, cleanSession bool
autoReconnect, maintainOrder, cleanSession, multiConnectionMode bool

connectTimeout, writeTimeout, keepAlive,
maxReconnectInterval, gracefulShutdownPeriod,
credentialFetchTimeout time.Duration

startOptions *startOptions

onConnectHandler OnConnectHandler
onConnectionLostHandler OnConnectionLostHandler
onReconnectHandler OnReconnectHandler
onConnectHandler OnConnectHandler
onConnectionLostHandler OnConnectionLostHandler
onReconnectHandler OnReconnectHandler
sharedSubscriptionPredicate SharedSubscriptionPredicate
logger Logger

newEncoder EncoderFunc
newDecoder DecoderFunc
Expand All @@ -225,16 +248,18 @@ func (f optionFunc) apply(o *clientOptions) { f(o) }

func defaultClientOptions() *clientOptions {
return &clientOptions{
autoReconnect: true,
maintainOrder: true,
connectTimeout: 15 * time.Second,
writeTimeout: 10 * time.Second,
maxReconnectInterval: 5 * time.Minute,
gracefulShutdownPeriod: 30 * time.Second,
keepAlive: 60 * time.Second,
credentialFetchTimeout: 10 * time.Second,
newEncoder: DefaultEncoderFunc,
newDecoder: DefaultDecoderFunc,
store: inMemoryPersistence,
autoReconnect: true,
maintainOrder: true,
connectTimeout: 15 * time.Second,
writeTimeout: 10 * time.Second,
maxReconnectInterval: 5 * time.Minute,
gracefulShutdownPeriod: 30 * time.Second,
keepAlive: 60 * time.Second,
credentialFetchTimeout: 10 * time.Second,
newEncoder: DefaultEncoderFunc,
newDecoder: DefaultDecoderFunc,
store: inMemoryPersistence,
sharedSubscriptionPredicate: defaultSharedSubscriptionPredicate,
logger: defaultLogger,
}
}
Loading

0 comments on commit 56a982e

Please sign in to comment.