Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow connection to multiple address simultaneously #31

Merged
merged 24 commits into from
Oct 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
6c8110d
add SharedSubscriptionPredicate
ajatprabha Aug 18, 2023
a7ce755
add MultiConnectionMode option
ajatprabha Aug 21, 2023
cb3c0cf
resubscribe on address updates to shared subscriptions
ajatprabha Aug 21, 2023
55bbc85
update cache keys
ajatprabha Aug 21, 2023
5a31047
refactor code
ajatprabha Aug 21, 2023
594dee5
add go 1.21 to test matrix
ajatprabha Aug 22, 2023
a6da169
add logging interface
ajatprabha Aug 22, 2023
c7ccbf7
filter newer go version support required packages
ajatprabha Aug 22, 2023
bc81f0c
lint code
ajatprabha Aug 22, 2023
daabf62
use wrapped subscriber to avoid unnecessary map writes
ajatprabha Aug 22, 2023
8a507fd
add multierror messages
ajatprabha Aug 22, 2023
78c5de6
handle execOneRandom case for resuming subscriptions
ajatprabha Aug 24, 2023
8df50bf
[WIP] handle execOneRoundRobin case for publishes
ajatprabha Aug 24, 2023
4468cfe
remove go 1.19 from runners
ajatprabha Aug 24, 2023
d74ba43
bump github.com/gojekfarm/xtools/generic version
ajatprabha Aug 29, 2023
4e59b69
add spec for resubscribing to subscriptions in single connection mode
ajatprabha Sep 1, 2023
b811977
log error when reloading multiple clients
ajatprabha Sep 1, 2023
9f89411
refactor code
ajatprabha Sep 1, 2023
1143dcf
use different clientID for different connections
ajatprabha Sep 4, 2023
c03d223
remove MultiConnectionMode option
ajatprabha Sep 5, 2023
165515a
add specs for error scenarios
ajatprabha Sep 5, 2023
490f3c0
take mutex lock while connecting clients too
ajatprabha Sep 11, 2023
df5dca4
add suffix explicitly post checking base value
ajatprabha Sep 12, 2023
962ecbb
log credential fetch errors
ajatprabha Sep 26, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ on:
jobs:
test:
runs-on: ubuntu-latest
container: golang:1.20
container: golang:1.21
strategy:
matrix:
go-version: [1.19.x, 1.20.x]
go-version: [1.20.x, 1.21.x]
services:
mqtt:
image: emqx/emqx:latest
Expand All @@ -34,14 +34,14 @@ jobs:
uses: actions/cache@v3
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
key: ${{ runner.os }}-go-${{ matrix.go-version }}-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
${{ runner.os }}-go-${{ matrix.go-version }}-
- name: Tools bin cache
uses: actions/cache@v3
with:
path: .bin
key: ${{ runner.os }}-${{ hashFiles('Makefile') }}
key: ${{ runner.os }}-go-${{ matrix.go-version }}-${{ hashFiles('Makefile') }}
- name: Install jq
uses: dcarbone/[email protected]
- name: Test
Expand Down
4 changes: 3 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ ALL_GO_MOD_DIRS := $(shell find . -type f -name 'go.mod' -exec dirname {} \; | s
PROJECT_DIR := $(shell dirname $(abspath $(lastword $(MAKEFILE_LIST))))
LOCAL_GO_BIN_DIR := $(PROJECT_DIR)/.bin
BIN_DIR := $(if $(LOCAL_GO_BIN_DIR),$(LOCAL_GO_BIN_DIR),$(GOPATH)/bin)
GO_MINOR_VERSION := $(shell go version | cut -d' ' -f3 | cut -d'.' -f2)
GO_BUILD_DIRS := $(foreach dir,$(ALL_GO_MOD_DIRS),$(shell GO_MOD_VERSION=$$(grep "go 1.[0-9]*" $(dir)/go.mod | cut -d' ' -f2 | cut -d'.' -f2) && [ -n "$$GO_MOD_VERSION" ] && [ $(GO_MINOR_VERSION) -ge $$GO_MOD_VERSION ] && echo $(dir)))

fmt:
@$(call run-go-mod-dir,go vet ./...,"go fmt")
Expand Down Expand Up @@ -89,7 +91,7 @@ endef
# a go.mod file
define run-go-mod-dir
set -e; \
for dir in $(ALL_GO_MOD_DIRS); do \
for dir in $(GO_BUILD_DIRS); do \
[ -z $(2) ] || echo "$(2) $${dir}/..."; \
cd "$(PROJECT_DIR)/$${dir}" && $(1); \
done;
Expand Down
109 changes: 63 additions & 46 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"math/rand"
"os"
"sync"
"sync/atomic"
Expand All @@ -19,8 +20,11 @@ var newClientFunc = defaultNewClientFunc()

// Client allows to communicate with an MQTT broker
type Client struct {
options *clientOptions
mqttClient mqtt.Client
options *clientOptions

subscriptions map[string]*subscriptionMeta
mqttClient mqtt.Client
mqttClients map[string]mqtt.Client

publisher Publisher
subscriber Subscriber
Expand All @@ -29,7 +33,10 @@ type Client struct {
sMiddlewares []subscribeMiddleware
usMiddlewares []unsubscribeMiddleware

mu sync.RWMutex
rrCounter *atomicCounter
rndPool *sync.Pool
clientMu sync.RWMutex
subMu sync.RWMutex
}

// NewClient creates the Client struct with the clientOptions provided,
Expand All @@ -46,10 +53,17 @@ func NewClient(opts ...ClientOption) (*Client, error) {
return nil, fmt.Errorf("at least WithAddress or WithResolver ClientOption should be used")
}

c := &Client{options: co}
c := &Client{
options: co,
subscriptions: map[string]*subscriptionMeta{},
rrCounter: &atomicCounter{value: 0},
rndPool: &sync.Pool{New: func() any {
return rand.New(rand.NewSource(time.Now().UnixNano()))
}},
}

if len(co.brokerAddress) != 0 {
c.mqttClient = newClientFunc.Load().(func(*mqtt.ClientOptions) mqtt.Client)(toClientOptions(c, c.options))
c.mqttClient = newClientFunc.Load().(func(*mqtt.ClientOptions) mqtt.Client)(toClientOptions(c, c.options, ""))
}

c.publisher = publishHandler(c)
Expand All @@ -61,13 +75,15 @@ func NewClient(opts ...ClientOption) (*Client, error) {

// IsConnected checks whether the client is connected to the broker
func (c *Client) IsConnected() bool {
var online bool
val := &atomic.Bool{}

err := c.execute(func(cc mqtt.Client) {
online = cc.IsConnectionOpen()
})
return c.execute(func(cc mqtt.Client) error {
if cc.IsConnectionOpen() {
val.CompareAndSwap(false, true)
}

return err == nil && online
return nil
}, execAll) == nil && val.Load()
}

// Start will attempt to connect to the broker.
Expand Down Expand Up @@ -105,22 +121,11 @@ func (c *Client) Run(ctx context.Context) error {
}

func (c *Client) stop() error {
return c.execute(func(cc mqtt.Client) {
return c.execute(func(cc mqtt.Client) error {
cc.Disconnect(uint(c.options.gracefulShutdownPeriod / time.Millisecond))
})
}

func (c *Client) execute(f func(mqtt.Client)) error {
c.mu.RLock()
defer c.mu.RUnlock()

if c.mqttClient == nil {
return ErrClientNotInitialized
}

f(c.mqttClient)

return nil
return nil
}, execAll)
}

func (c *Client) handleToken(ctx context.Context, t mqtt.Token, timeoutErr error) error {
Expand Down Expand Up @@ -158,42 +163,45 @@ func (c *Client) runResolver() error {
case <-time.After(c.options.connectTimeout):
return ErrConnectTimeout
case addrs := <-c.options.resolver.UpdateChan():
c.attemptConnection(addrs)
if err := c.attemptConnections(addrs); err != nil {
return err
}
}

go c.watchAddressUpdates(c.options.resolver)

return nil
}

func (c *Client) runConnect() (err error) {
func (c *Client) runConnect() error {
if len(c.options.brokerAddress) == 0 {
return nil
}

if e := c.execute(func(cc mqtt.Client) {
return c.execute(func(cc mqtt.Client) error {
t := cc.Connect()
if !t.WaitTimeout(c.options.connectTimeout) {
err = ErrConnectTimeout

return
return ErrConnectTimeout
}

err = t.Error()
}); e != nil {
err = e
}
return t.Error()
}, execAll)
}

return
func (c *Client) attemptSingleConnection(addrs []TCPAddress) error {
cc := c.newClient(addrs, 0)
c.reloadClient(cc)

return c.resumeSubscriptions()
}

func toClientOptions(c *Client, o *clientOptions) *mqtt.ClientOptions {
func toClientOptions(c *Client, o *clientOptions, idSuffix string) *mqtt.ClientOptions {
opts := mqtt.NewClientOptions()

if hostname, err := os.Hostname(); o.clientID == "" && err == nil {
opts.SetClientID(hostname)
opts.SetClientID(fmt.Sprintf("%s%s", hostname, idSuffix))
} else {
opts.SetClientID(o.clientID)
opts.SetClientID(fmt.Sprintf("%s%s", o.clientID, idSuffix))
}

setCredentials(o, opts)
Expand All @@ -215,21 +223,30 @@ func toClientOptions(c *Client, o *clientOptions) *mqtt.ClientOptions {

func setCredentials(o *clientOptions, opts *mqtt.ClientOptions) {
if o.credentialFetcher != nil {
ctx, cancel := context.WithTimeout(context.Background(), o.credentialFetchTimeout)
defer cancel()

if c, err := o.credentialFetcher.Credentials(ctx); err == nil {
opts.SetUsername(c.Username)
opts.SetPassword(c.Password)
refreshCredentialsWithFetcher(o, opts)

return
}
return
}

opts.SetUsername(o.username)
opts.SetPassword(o.password)
}

func refreshCredentialsWithFetcher(o *clientOptions, opts *mqtt.ClientOptions) {
ctx, cancel := context.WithTimeout(context.Background(), o.credentialFetchTimeout)
defer cancel()

c, err := o.credentialFetcher.Credentials(ctx)
if err != nil {
o.logger.Error(ctx, err, map[string]any{"message": "failed to fetch credentials"})

return
}

opts.SetUsername(c.Username)
opts.SetPassword(c.Password)
}

func formatAddressWithProtocol(opts *clientOptions) string {
if opts.tlsConfig != nil {
return fmt.Sprintf("tls://%s", opts.brokerAddress)
Expand Down
55 changes: 40 additions & 15 deletions client_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@ package courier
import (
"crypto/tls"
"fmt"
"strings"
"time"
)

var inMemoryPersistence = NewMemoryStore()

func defaultSharedSubscriptionPredicate(topic string) bool {
return strings.HasPrefix(topic, "$share/")
}

// ClientOption allows to configure the behaviour of a Client.
type ClientOption interface{ apply(*clientOptions) }

Expand Down Expand Up @@ -194,6 +199,22 @@ func WithExponentialStartOptions(options ...StartOption) ClientOption {
})
}

// SharedSubscriptionPredicate allows to configure the predicate function that determines
// whether a topic is a shared subscription topic.
type SharedSubscriptionPredicate func(topic string) bool

func (ssp SharedSubscriptionPredicate) apply(o *clientOptions) { o.sharedSubscriptionPredicate = ssp }

// UseMultiConnectionMode allows to configure the client to use multiple connections when available.
//
// This is useful when working with shared subscriptions and multiple connections can be created
// to subscribe on the same application.
var UseMultiConnectionMode = multiConnMode{}

type multiConnMode struct{}
ajatprabha marked this conversation as resolved.
Show resolved Hide resolved

func (mcm multiConnMode) apply(o *clientOptions) { o.multiConnectionMode = true }

type clientOptions struct {
username, clientID, password,
brokerAddress string
Expand All @@ -202,17 +223,19 @@ type clientOptions struct {

tlsConfig *tls.Config

autoReconnect, maintainOrder, cleanSession bool
autoReconnect, maintainOrder, cleanSession, multiConnectionMode bool

connectTimeout, writeTimeout, keepAlive,
maxReconnectInterval, gracefulShutdownPeriod,
credentialFetchTimeout time.Duration

startOptions *startOptions

onConnectHandler OnConnectHandler
onConnectionLostHandler OnConnectionLostHandler
onReconnectHandler OnReconnectHandler
onConnectHandler OnConnectHandler
onConnectionLostHandler OnConnectionLostHandler
onReconnectHandler OnReconnectHandler
sharedSubscriptionPredicate SharedSubscriptionPredicate
logger Logger

newEncoder EncoderFunc
newDecoder DecoderFunc
Expand All @@ -225,16 +248,18 @@ func (f optionFunc) apply(o *clientOptions) { f(o) }

func defaultClientOptions() *clientOptions {
return &clientOptions{
autoReconnect: true,
maintainOrder: true,
connectTimeout: 15 * time.Second,
writeTimeout: 10 * time.Second,
maxReconnectInterval: 5 * time.Minute,
gracefulShutdownPeriod: 30 * time.Second,
keepAlive: 60 * time.Second,
credentialFetchTimeout: 10 * time.Second,
newEncoder: DefaultEncoderFunc,
newDecoder: DefaultDecoderFunc,
store: inMemoryPersistence,
autoReconnect: true,
maintainOrder: true,
connectTimeout: 15 * time.Second,
writeTimeout: 10 * time.Second,
maxReconnectInterval: 5 * time.Minute,
gracefulShutdownPeriod: 30 * time.Second,
keepAlive: 60 * time.Second,
credentialFetchTimeout: 10 * time.Second,
newEncoder: DefaultEncoderFunc,
newDecoder: DefaultDecoderFunc,
store: inMemoryPersistence,
sharedSubscriptionPredicate: defaultSharedSubscriptionPredicate,
logger: defaultLogger,
}
}
Loading