Skip to content

Commit

Permalink
Merge pull request #57 from fracasula/manual-acks
Browse files Browse the repository at this point in the history
Manual acknowledgments
  • Loading branch information
Al S-M authored Jul 2, 2021
2 parents 59f0c54 + 8e1297f commit 1aa48ad
Show file tree
Hide file tree
Showing 5 changed files with 358 additions and 34 deletions.
2 changes: 1 addition & 1 deletion packets/publish.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func (p *Publish) Buffers() net.Buffers {
var b bytes.Buffer
writeString(p.Topic, &b)
if p.QoS > 0 {
writeUint16(p.PacketID, &b)
_ = writeUint16(p.PacketID, &b)
}
idvp := p.Properties.Pack(PUBLISH)
encodeVBIdirect(len(idvp), &b)
Expand Down
79 changes: 79 additions & 0 deletions paho/acks_tracker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package paho

import (
"errors"
"sync"

"github.com/eclipse/paho.golang/packets"
)

var (
ErrPacketNotFound = errors.New("packet not found")
)

type acksTracker struct {
mx sync.Mutex
order []packet
}

func (t *acksTracker) add(pb *packets.Publish) {
t.mx.Lock()
defer t.mx.Unlock()

for _, v := range t.order {
if v.pb.PacketID == pb.PacketID {
return // already added
}
}

t.order = append(t.order, packet{pb: pb})
}

func (t *acksTracker) markAsAcked(pb *packets.Publish) error {
t.mx.Lock()
defer t.mx.Unlock()

for k, v := range t.order {
if pb.PacketID == v.pb.PacketID {
t.order[k].acknowledged = true
return nil
}
}

return ErrPacketNotFound
}

func (t *acksTracker) flush(do func([]*packets.Publish)) {
t.mx.Lock()
defer t.mx.Unlock()

var (
buf []*packets.Publish
)
for _, v := range t.order {
if v.acknowledged {
buf = append(buf, v.pb)
} else {
break
}
}

if len(buf) == 0 {
return
}

do(buf)
t.order = t.order[len(buf):]
}

// reset should be used upon disconnections
func (t *acksTracker) reset() {
t.mx.Lock()
defer t.mx.Unlock()
t.order = nil
}

type packet struct {
pb *packets.Publish
acknowledged bool
}
85 changes: 85 additions & 0 deletions paho/acks_tracker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package paho

import (
"testing"

"github.com/stretchr/testify/require"

"github.com/eclipse/paho.golang/packets"
)

func TestAcksTracker(t *testing.T) {
var (
at acksTracker
p1 = &packets.Publish{PacketID: 1}
p2 = &packets.Publish{PacketID: 2}
p3 = &packets.Publish{PacketID: 3}
p4 = &packets.Publish{PacketID: 4} // to test not found
)

t.Run("flush-empty", func(t *testing.T) {
at.flush(func(_ []*packets.Publish) {
t.Fatal("flush should not call 'do' since no packets have been added nor acknowledged")
})
})

t.Run("flush-without-acking", func(t *testing.T) {
at.add(p1)
at.add(p2)
at.add(p3)
require.Equal(t, ErrPacketNotFound, at.markAsAcked(p4))
at.flush(func(_ []*packets.Publish) {
t.Fatal("flush should not call 'do' since no packets have been acknowledged so far")
})
})

t.Run("ack-in-the-middle", func(t *testing.T) {
require.NoError(t, at.markAsAcked(p3))
at.flush(func(_ []*packets.Publish) {
t.Fatal("flush should not call 'do' since p1 and p2 have not been acknowledged yet")
})
})

t.Run("idempotent-acking", func(t *testing.T) {
require.NoError(t, at.markAsAcked(p3))
require.NoError(t, at.markAsAcked(p3))
require.NoError(t, at.markAsAcked(p3))
})

t.Run("ack-first", func(t *testing.T) {
var flushCalled bool
require.NoError(t, at.markAsAcked(p1))
at.flush(func(pbs []*packets.Publish) {
require.Equal(t, []*packets.Publish{p1}, pbs, "Only p1 expected even though p3 was acked, p2 is still missing")
flushCalled = true
})
require.True(t, flushCalled)
})

t.Run("ack-after-flush", func(t *testing.T) {
var flushCalled bool
require.NoError(t, at.markAsAcked(p2))
at.add(p4) // this should just be appended and not flushed (yet)
at.flush(func(pbs []*packets.Publish) {
require.Equal(t, []*packets.Publish{p2, p3}, pbs, "Only p2 and p3 expected, p1 was flushed in the previous call")
flushCalled = true
})
require.True(t, flushCalled)
})

t.Run("ack-last", func(t *testing.T) {
var flushCalled bool
require.NoError(t, at.markAsAcked(p4))
at.flush(func(pbs []*packets.Publish) {
require.Equal(t, []*packets.Publish{p4}, pbs, "Only p4 expected, the rest was flushed in previous calls")
flushCalled = true
})
require.True(t, flushCalled)
})

t.Run("flush-after-acking-everything", func(t *testing.T) {
at.flush(func(_ []*packets.Publish) {
t.Fatal("no call to 'do' expected, we flushed all packets already")
})
})
}
139 changes: 106 additions & 33 deletions paho/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package paho

import (
"context"
"errors"
"fmt"
"math"
"net"
Expand All @@ -20,30 +21,46 @@ const (
MQTTv5 MQTTVersion = 5
)

const defaultSendAckInterval = 50 * time.Millisecond

var (
ErrManualAcknowledgmentDisabled = errors.New("manual acknowledgments disabled")
)

type (
// ClientConfig are the user configurable options for the client, an
// instance of this struct is passed into NewClient(), not all options
// are required to be set, defaults are provided for Persistence, MIDs,
// PingHandler, PacketTimeout and Router.
ClientConfig struct {
ClientID string
Conn net.Conn
MIDs MIDService
AuthHandler Auther
PingHandler Pinger
Router Router
Persistence Persistence
PacketTimeout time.Duration
ClientID string
Conn net.Conn
MIDs MIDService
AuthHandler Auther
PingHandler Pinger
Router Router
Persistence Persistence
PacketTimeout time.Duration
// OnServerDisconnect is called only when a packets.DISCONNECT is received from server
OnServerDisconnect func(*Disconnect)
// Only called when receiving packets.DISCONNECT from server
// OnClientError is for example called on net.Error
OnClientError func(error)
// Client error call, For example: net.Error
PublishHook func(*Publish)
// PublishHook allows a user provided function to be called before
// a Publish packet is sent allowing it to inspect or modify the
// Publish, an example of the utility of this is provided in the
// Topic Alias Handler extension which will automatically assign
// and use topic alias values rather than topic strings.
PublishHook func(*Publish)
// EnableManualAcknowledgment is used to control the acknowledgment of packets manually.
// BEWARE that the MQTT specs require clients to send acknowledgments in the order in which the corresponding
// PUBLISH packets were received.
// Consider the following scenario: the client receives packets 1,2,3,4
// If you acknowledge 3 first, no ack is actually sent to the server but it's buffered until also 1 and 2
// are acknowledged.
EnableManualAcknowledgment bool
// SendAcksInterval is used only when EnableManualAcknowledgment is true
// it determines how often the client tries to send a batch of acknowledgments in the right order to the server.
SendAcksInterval time.Duration
}
// Client is the struct representing an MQTT client
Client struct {
Expand All @@ -53,6 +70,7 @@ type (
raCtx *CPContext
stop chan struct{}
publishPackets chan *packets.Publish
acksTracker acksTracker
workers sync.WaitGroup
serverProps CommsProperties
clientProps CommsProperties
Expand Down Expand Up @@ -283,9 +301,73 @@ func (c *Client) Connect(ctx context.Context, cp *Connect) (*Connack, error) {
c.incoming()
}()

if c.EnableManualAcknowledgment {
c.debug.Println("starting acking routine")

c.acksTracker.reset()
sendAcksInterval := defaultSendAckInterval
if c.SendAcksInterval > 0 {
sendAcksInterval = c.SendAcksInterval
}

c.workers.Add(1)
go func() {
defer c.workers.Done()
defer c.debug.Println("returning from ack tracker routine")
t := time.NewTicker(sendAcksInterval)
for {
select {
case <-c.stop:
return
case <-t.C:
c.acksTracker.flush(func(pbs []*packets.Publish) {
for _, pb := range pbs {
c.ack(pb)
}
})
}
}
}()
}

return ca, nil
}

func (c *Client) Ack(pb *Publish) error {
if !c.EnableManualAcknowledgment {
return ErrManualAcknowledgmentDisabled
}
if pb.QoS == 0 {
return nil
}
return c.acksTracker.markAsAcked(pb.Packet())
}

func (c *Client) ack(pb *packets.Publish) {
switch pb.QoS {
case 1:
pa := packets.Puback{
Properties: &packets.Properties{},
PacketID: pb.PacketID,
}
c.debug.Println("sending PUBACK")
_, err := pa.WriteTo(c.Conn)
if err != nil {
c.errors.Printf("failed to send PUBACK for %d: %s", pb.PacketID, err)
}
case 2:
pr := packets.Pubrec{
Properties: &packets.Properties{},
PacketID: pb.PacketID,
}
c.debug.Printf("sending PUBREC")
_, err := pr.WriteTo(c.Conn)
if err != nil {
c.errors.Printf("failed to send PUBREC for %d: %s", pb.PacketID, err)
}
}
}

func (c *Client) routePublishPackets() {
for {
select {
Expand All @@ -295,29 +377,18 @@ func (c *Client) routePublishPackets() {
if !open {
return
}
c.Router.Route(pb)
switch pb.QoS {
case 1:
pa := packets.Puback{
Properties: &packets.Properties{},
PacketID: pb.PacketID,
}
c.debug.Println("sending PUBACK")
_, err := pa.WriteTo(c.Conn)
if err != nil {
c.errors.Printf("failed to send PUBACK for %d: %s", pb.PacketID, err)
}
case 2:
pr := packets.Pubrec{
Properties: &packets.Properties{},
PacketID: pb.PacketID,
}
c.debug.Printf("sending PUBREC")
_, err := pr.WriteTo(c.Conn)
if err != nil {
c.errors.Printf("failed to send PUBREC for %d: %s", pb.PacketID, err)
}

if !c.ClientConfig.EnableManualAcknowledgment {
c.Router.Route(pb)
c.ack(pb)
continue
}

if pb.QoS != 0 {
c.acksTracker.add(pb)
}

c.Router.Route(pb)
}
}
}
Expand Down Expand Up @@ -466,6 +537,8 @@ func (c *Client) close() {
c.debug.Println("ping stopped")
_ = c.Conn.Close()
c.debug.Println("conn closed")
c.acksTracker.reset()
c.debug.Println("acks tracker reset")
}

// error is called to signify that an error situation has occurred, this
Expand Down
Loading

0 comments on commit 1aa48ad

Please sign in to comment.