From 27f583911319a524ed495233559e4a9d4f2a62d1 Mon Sep 17 00:00:00 2001 From: Francesco Casula Date: Fri, 4 Jun 2021 10:56:05 +0200 Subject: [PATCH 1/5] Acks tracker --- paho/acks_tracker.go | 79 +++++++++++++++++++++++++++++++++++++++ paho/acks_tracker_test.go | 79 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 158 insertions(+) create mode 100644 paho/acks_tracker.go create mode 100644 paho/acks_tracker_test.go diff --git a/paho/acks_tracker.go b/paho/acks_tracker.go new file mode 100644 index 0000000..47f11cb --- /dev/null +++ b/paho/acks_tracker.go @@ -0,0 +1,79 @@ +package paho + +import ( + "errors" + "sync" + + "github.com/eclipse/paho.golang/packets" +) + +var ( + ErrPacketNotFound = errors.New("packet not found") +) + +type acksTracker struct { + mx sync.Mutex + order []packet +} + +func (t *acksTracker) add(pb *packets.Publish) { + t.mx.Lock() + defer t.mx.Unlock() + + for _, v := range t.order { + if v.pb.PacketID == pb.PacketID { + return // already added + } + } + + t.order = append(t.order, packet{pb: pb}) +} + +func (t *acksTracker) markAsAcked(pb *packets.Publish) error { + t.mx.Lock() + defer t.mx.Unlock() + + for k, v := range t.order { + if pb.PacketID == v.pb.PacketID { + t.order[k].acknowledged = true + return nil + } + } + + return ErrPacketNotFound +} + +func (t *acksTracker) flush(do func([]*packets.Publish)) { + t.mx.Lock() + defer t.mx.Unlock() + + var ( + buf []*packets.Publish + ) + for _, v := range t.order { + if v.acknowledged { + buf = append(buf, v.pb) + } else { + break + } + } + + if len(buf) == 0 { + return + } + + do(buf) + t.order = t.order[len(buf):] +} + +// reset should be used upon disconnections +func (t *acksTracker) reset() { + t.mx.Lock() + defer t.mx.Unlock() + t.order = nil +} + +type packet struct { + pb *packets.Publish + acknowledged bool +} diff --git a/paho/acks_tracker_test.go b/paho/acks_tracker_test.go new file mode 100644 index 0000000..6362117 --- /dev/null +++ b/paho/acks_tracker_test.go @@ -0,0 +1,79 @@ +package paho + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/eclipse/paho.golang/packets" +) + +func TestAcksTracker(t *testing.T) { + var ( + at acksTracker + p1 = &packets.Publish{PacketID: 1} + p2 = &packets.Publish{PacketID: 2} + p3 = &packets.Publish{PacketID: 3} + p4 = &packets.Publish{PacketID: 4} // to test not found + ) + + t.Run("flush-without-acking", func(t *testing.T) { + at.add(p1) + at.add(p2) + at.add(p3) + require.Equal(t, ErrPacketNotFound, at.markAsAcked(p4)) + at.flush(func(_ []*packets.Publish) { + t.Fatal("flush should not call 'do' since no packets have been acknowledged so far") + }) + }) + + t.Run("ack-in-the-middle", func(t *testing.T) { + require.NoError(t, at.markAsAcked(p3)) + at.flush(func(_ []*packets.Publish) { + t.Fatal("flush should not call 'do' since p1 and p2 have not been acknowledged yet") + }) + }) + + t.Run("idempotent-acking", func(t *testing.T) { + require.NoError(t, at.markAsAcked(p3)) + require.NoError(t, at.markAsAcked(p3)) + require.NoError(t, at.markAsAcked(p3)) + }) + + t.Run("ack-first", func(t *testing.T) { + var flushCalled bool + require.NoError(t, at.markAsAcked(p1)) + at.flush(func(pbs []*packets.Publish) { + require.Equal(t, []*packets.Publish{p1}, pbs, "Only p1 expected even though p3 was acked, p2 is still missing") + flushCalled = true + }) + require.True(t, flushCalled) + }) + + t.Run("ack-after-flush", func(t *testing.T) { + var flushCalled bool + require.NoError(t, at.markAsAcked(p2)) + at.add(p4) // this should just be appended and not flushed (yet) + at.flush(func(pbs []*packets.Publish) { + require.Equal(t, []*packets.Publish{p2, p3}, pbs, "Only p2 and p3 expected, p1 was flushed in the previous call") + flushCalled = true + }) + require.True(t, flushCalled) + }) + + t.Run("ack-last", func(t *testing.T) { + var flushCalled bool + require.NoError(t, at.markAsAcked(p4)) + at.flush(func(pbs []*packets.Publish) { + require.Equal(t, []*packets.Publish{p4}, pbs, "Only p4 expected, the rest was flushed in previous calls") + flushCalled = true + }) + require.True(t, flushCalled) + }) + + t.Run("flush-after-acking-everything", func(t *testing.T) { + at.flush(func(_ []*packets.Publish) { + t.Fatal("no call to 'do' expected, we flushed all packets") + }) + }) +} From 9374cf96d83bd93af2507ea8adaa22501c85b275 Mon Sep 17 00:00:00 2001 From: Francesco Casula Date: Fri, 4 Jun 2021 10:56:15 +0200 Subject: [PATCH 2/5] Linter err --- packets/publish.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packets/publish.go b/packets/publish.go index 933fb85..7f9aa9e 100644 --- a/packets/publish.go +++ b/packets/publish.go @@ -50,7 +50,7 @@ func (p *Publish) Buffers() net.Buffers { var b bytes.Buffer writeString(p.Topic, &b) if p.QoS > 0 { - writeUint16(p.PacketID, &b) + _ = writeUint16(p.PacketID, &b) } idvp := p.Properties.Pack(PUBLISH) encodeVBIdirect(len(idvp), &b) From a47276e020a1bc234d42eee3fc0cbcf79b140f4e Mon Sep 17 00:00:00 2001 From: Francesco Casula Date: Fri, 4 Jun 2021 10:56:21 +0200 Subject: [PATCH 3/5] Manual acks --- paho/client.go | 139 +++++++++++++++++++++++++++++++++----------- paho/client_test.go | 87 +++++++++++++++++++++++++++ 2 files changed, 193 insertions(+), 33 deletions(-) diff --git a/paho/client.go b/paho/client.go index ec35a26..d1f64cf 100644 --- a/paho/client.go +++ b/paho/client.go @@ -2,6 +2,7 @@ package paho import ( "context" + "errors" "fmt" "math" "net" @@ -20,30 +21,46 @@ const ( MQTTv5 MQTTVersion = 5 ) +const defaultSendAckInterval = 50 * time.Millisecond + +var ( + ErrManualAcknowledgmentDisabled = errors.New("manual acknowledgments disabled") +) + type ( // ClientConfig are the user configurable options for the client, an // instance of this struct is passed into NewClient(), not all options // are required to be set, defaults are provided for Persistence, MIDs, // PingHandler, PacketTimeout and Router. ClientConfig struct { - ClientID string - Conn net.Conn - MIDs MIDService - AuthHandler Auther - PingHandler Pinger - Router Router - Persistence Persistence - PacketTimeout time.Duration + ClientID string + Conn net.Conn + MIDs MIDService + AuthHandler Auther + PingHandler Pinger + Router Router + Persistence Persistence + PacketTimeout time.Duration + // OnServerDisconnect is called only when a packets.DISCONNECT is received from server OnServerDisconnect func(*Disconnect) - // Only called when receiving packets.DISCONNECT from server + // OnClientError is for example called on net.Error OnClientError func(error) - // Client error call, For example: net.Error - PublishHook func(*Publish) // PublishHook allows a user provided function to be called before // a Publish packet is sent allowing it to inspect or modify the // Publish, an example of the utility of this is provided in the // Topic Alias Handler extension which will automatically assign // and use topic alias values rather than topic strings. + PublishHook func(*Publish) + // EnableManualAcknowledgment is used to control the acknowledgment of packets manually. + // BEWARE that the MQTT specs require clients to send acknowledgments in the order in which the corresponding + // PUBLISH packets were received. + // Consider the following scenario: the client receives packets 1,2,3,4 + // If you acknowledge 3 first, no ack is actually sent to the server but it's buffered until also 1 and 2 + // are acknowledged. + EnableManualAcknowledgment bool + // SendAcksInterval is used only when EnableManualAcknowledgment is true + // it determines how often the client tries to send a batch of acknowledgments in the right order + SendAcksInterval time.Duration } // Client is the struct representing an MQTT client Client struct { @@ -53,6 +70,7 @@ type ( raCtx *CPContext stop chan struct{} publishPackets chan *packets.Publish + acksTracker acksTracker workers sync.WaitGroup serverProps CommsProperties clientProps CommsProperties @@ -283,9 +301,73 @@ func (c *Client) Connect(ctx context.Context, cp *Connect) (*Connack, error) { c.incoming() }() + if c.EnableManualAcknowledgment { + c.debug.Println("starting acking routine") + + c.acksTracker.reset() + sendAcksInterval := defaultSendAckInterval + if c.SendAcksInterval != 0 { + sendAcksInterval = c.SendAcksInterval + } + + c.workers.Add(1) + go func() { + defer c.workers.Done() + defer c.debug.Println("returning from ack tracker routine") + t := time.NewTicker(sendAcksInterval) + for { + select { + case <-c.stop: + return + case <-t.C: + c.acksTracker.flush(func(pbs []*packets.Publish) { + for _, pb := range pbs { + c.ack(pb) + } + }) + } + } + }() + } + return ca, nil } +func (c *Client) Ack(pb *Publish) error { + if !c.EnableManualAcknowledgment { + return ErrManualAcknowledgmentDisabled + } + if pb.QoS == 0 { + return nil + } + return c.acksTracker.markAsAcked(pb.Packet()) +} + +func (c *Client) ack(pb *packets.Publish) { + switch pb.QoS { + case 1: + pa := packets.Puback{ + Properties: &packets.Properties{}, + PacketID: pb.PacketID, + } + c.debug.Println("sending PUBACK") + _, err := pa.WriteTo(c.Conn) + if err != nil { + c.errors.Printf("failed to send PUBACK for %d: %s", pb.PacketID, err) + } + case 2: + pr := packets.Pubrec{ + Properties: &packets.Properties{}, + PacketID: pb.PacketID, + } + c.debug.Printf("sending PUBREC") + _, err := pr.WriteTo(c.Conn) + if err != nil { + c.errors.Printf("failed to send PUBREC for %d: %s", pb.PacketID, err) + } + } +} + func (c *Client) routePublishPackets() { for { select { @@ -295,29 +377,18 @@ func (c *Client) routePublishPackets() { if !open { return } - c.Router.Route(pb) - switch pb.QoS { - case 1: - pa := packets.Puback{ - Properties: &packets.Properties{}, - PacketID: pb.PacketID, - } - c.debug.Println("sending PUBACK") - _, err := pa.WriteTo(c.Conn) - if err != nil { - c.errors.Printf("failed to send PUBACK for %d: %s", pb.PacketID, err) - } - case 2: - pr := packets.Pubrec{ - Properties: &packets.Properties{}, - PacketID: pb.PacketID, - } - c.debug.Printf("sending PUBREC") - _, err := pr.WriteTo(c.Conn) - if err != nil { - c.errors.Printf("failed to send PUBREC for %d: %s", pb.PacketID, err) - } + + if !c.ClientConfig.EnableManualAcknowledgment { + c.Router.Route(pb) + c.ack(pb) + continue } + + if pb.QoS != 0 { + c.acksTracker.add(pb) + } + + c.Router.Route(pb) } } } @@ -458,6 +529,8 @@ func (c *Client) close() { c.debug.Println("ping stopped") _ = c.Conn.Close() c.debug.Println("conn closed") + c.acksTracker.reset() + c.debug.Println("acks tracker reset") } // error is called to signify that an error situation has occurred, this diff --git a/paho/client_test.go b/paho/client_test.go index f3c1afa..fd50c9c 100644 --- a/paho/client_test.go +++ b/paho/client_test.go @@ -447,6 +447,93 @@ func TestClientReceiveAndAckInOrder(t *testing.T) { ) } +func TestManualAcksInOrder(t *testing.T) { + ts := newTestServer() + ts.SetResponse(packets.CONNACK, &packets.Connack{ + ReasonCode: 0, + SessionPresent: false, + Properties: &packets.Properties{ + MaximumPacketSize: Uint32(12345), + MaximumQOS: Byte(1), + ReceiveMaximum: Uint16(12345), + TopicAliasMaximum: Uint16(200), + }, + }) + go ts.Run() + defer ts.Stop() + + var ( + wg sync.WaitGroup + actualPublishPackets []packets.Publish + expectedPacketsCount = 3 + ) + + wg.Add(expectedPacketsCount) + c := NewClient(ClientConfig{ + Conn: ts.ClientConn(), + EnableManualAcknowledgment: true, + }) + c.Router = NewSingleHandlerRouter(func(p *Publish) { + defer wg.Done() + actualPublishPackets = append(actualPublishPackets, *p.Packet()) + require.NoError(t, c.Ack(p)) + }) + require.NotNil(t, c) + c.SetDebugLogger(log.New(os.Stderr, "RECEIVEORDER: ", log.LstdFlags)) + t.Cleanup(c.close) + + ctx := context.Background() + ca, err := c.Connect(ctx, &Connect{ + KeepAlive: 30, + ClientID: "testClient", + CleanStart: true, + Properties: &ConnectProperties{ + ReceiveMaximum: Uint16(200), + }, + }) + require.Nil(t, err) + assert.Equal(t, uint8(0), ca.ReasonCode) + + var expectedPublishPackets []packets.Publish + for i := 1; i <= expectedPacketsCount; i++ { + p := packets.Publish{ + PacketID: uint16(i), + Topic: fmt.Sprintf("test/%d", i), + Payload: []byte(fmt.Sprintf("test payload %d", i)), + QoS: 1, + Properties: &packets.Properties{ + User: make([]packets.User, 0), + }, + } + expectedPublishPackets = append(expectedPublishPackets, p) + require.NoError(t, ts.SendPacket(&p)) + } + + wg.Wait() + + require.Equal(t, expectedPublishPackets, actualPublishPackets) + expectedAcks := []packets.Puback{ + {PacketID: 1, ReasonCode: 0, Properties: &packets.Properties{}}, + {PacketID: 2, ReasonCode: 0, Properties: &packets.Properties{}}, + {PacketID: 3, ReasonCode: 0, Properties: &packets.Properties{}}, + } + require.Eventually(t, + func() bool { + return cmp.Equal(expectedAcks, ts.ReceivedPubacks()) + }, + time.Second, + 10*time.Millisecond, + cmp.Diff(expectedAcks, ts.ReceivedPubacks()), + ) + + // Test QoS 0 packets are ignored + require.NoError(t, c.Ack(&Publish{QoS: 0, PacketID: 11233})) + + // Test packets not found + require.True(t, errors.Is(c.Ack(&Publish{QoS: 1, PacketID: 123}), ErrPacketNotFound)) + require.True(t, errors.Is(c.Ack(&Publish{QoS: 2, PacketID: 65535}), ErrPacketNotFound)) +} + func TestReceiveServerDisconnect(t *testing.T) { rChan := make(chan struct{}) ts := newTestServer() From 641ac7a03c9fca5d2ba20b0574e05c89c4e6b1d4 Mon Sep 17 00:00:00 2001 From: Francesco Casula Date: Mon, 7 Jun 2021 09:11:31 +0200 Subject: [PATCH 4/5] Minor make up --- paho/acks_tracker_test.go | 8 +++++++- paho/client.go | 4 ++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/paho/acks_tracker_test.go b/paho/acks_tracker_test.go index 6362117..15f1839 100644 --- a/paho/acks_tracker_test.go +++ b/paho/acks_tracker_test.go @@ -17,6 +17,12 @@ func TestAcksTracker(t *testing.T) { p4 = &packets.Publish{PacketID: 4} // to test not found ) + t.Run("flush-empty", func(t *testing.T) { + at.flush(func(_ []*packets.Publish) { + t.Fatal("flush should not call 'do' since no packets have been added nor acknowledged") + }) + }) + t.Run("flush-without-acking", func(t *testing.T) { at.add(p1) at.add(p2) @@ -73,7 +79,7 @@ func TestAcksTracker(t *testing.T) { t.Run("flush-after-acking-everything", func(t *testing.T) { at.flush(func(_ []*packets.Publish) { - t.Fatal("no call to 'do' expected, we flushed all packets") + t.Fatal("no call to 'do' expected, we flushed all packets already") }) }) } diff --git a/paho/client.go b/paho/client.go index d1f64cf..ad54e9c 100644 --- a/paho/client.go +++ b/paho/client.go @@ -59,7 +59,7 @@ type ( // are acknowledged. EnableManualAcknowledgment bool // SendAcksInterval is used only when EnableManualAcknowledgment is true - // it determines how often the client tries to send a batch of acknowledgments in the right order + // it determines how often the client tries to send a batch of acknowledgments in the right order to the server. SendAcksInterval time.Duration } // Client is the struct representing an MQTT client @@ -306,7 +306,7 @@ func (c *Client) Connect(ctx context.Context, cp *Connect) (*Connack, error) { c.acksTracker.reset() sendAcksInterval := defaultSendAckInterval - if c.SendAcksInterval != 0 { + if c.SendAcksInterval > 0 { sendAcksInterval = c.SendAcksInterval } From 8e1297f73f5810ed5c0fb6ac6987db165d1c86e9 Mon Sep 17 00:00:00 2001 From: Francesco Casula Date: Wed, 16 Jun 2021 11:25:22 +0200 Subject: [PATCH 5/5] panic: send on closed channel --- paho/client.go | 22 +++++++++++++---- paho/client_test.go | 57 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 4 deletions(-) diff --git a/paho/client.go b/paho/client.go index ad54e9c..5dcbeb2 100644 --- a/paho/client.go +++ b/paho/client.go @@ -399,10 +399,10 @@ func (c *Client) routePublishPackets() { // Disconnect, the Stop channel is closed or there is an error reading // a packet from the network connection func (c *Client) incoming() { + defer c.debug.Println("client stopping, incoming stopping") for { select { case <-c.stop: - c.debug.Println("client stopping, incoming stopping") return default: recv, err := packets.ReadPacket(c.Conn) @@ -436,7 +436,15 @@ func (c *Client) incoming() { case packets.PUBLISH: pb := recv.Content.(*packets.Publish) c.debug.Printf("received QoS%d PUBLISH", pb.QoS) - c.publishPackets <- pb + c.mu.Lock() + select { + case <-c.stop: + c.mu.Unlock() + return + default: + c.publishPackets <- pb + c.mu.Unlock() + } case packets.PUBACK, packets.PUBCOMP, packets.SUBACK, packets.UNSUBACK: c.debug.Printf("received %s packet with id %d", recv.PacketType(), recv.PacketID()) if cpCtx := c.MIDs.Get(recv.PacketID()); cpCtx != nil { @@ -550,12 +558,18 @@ func (c *Client) error(e error) { // is received. func (c *Client) Authenticate(ctx context.Context, a *Auth) (*AuthResponse, error) { c.debug.Println("client initiated reauthentication") - c.mu.Lock() - defer c.mu.Unlock() + c.mu.Lock() + if c.raCtx != nil { + c.mu.Unlock() + return nil, fmt.Errorf("previous authentication is still in progress") + } c.raCtx = &CPContext{ctx, make(chan packets.ControlPacket, 1)} + c.mu.Unlock() defer func() { + c.mu.Lock() c.raCtx = nil + c.mu.Unlock() }() c.debug.Println("sending AUTH") diff --git a/paho/client_test.go b/paho/client_test.go index fd50c9c..3796ca2 100644 --- a/paho/client_test.go +++ b/paho/client_test.go @@ -750,6 +750,63 @@ func TestCloseDeadlock(t *testing.T) { wg.Wait() } +func TestSendOnClosedChannel(t *testing.T) { + ts := newTestServer() + ts.SetResponse(packets.CONNACK, &packets.Connack{ + ReasonCode: 0, + SessionPresent: false, + Properties: &packets.Properties{ + MaximumPacketSize: Uint32(12345), + MaximumQOS: Byte(1), + ReceiveMaximum: Uint16(12345), + TopicAliasMaximum: Uint16(200), + }, + }) + go ts.Run() + defer ts.Stop() + + c := NewClient(ClientConfig{ + Conn: ts.ClientConn(), + }) + require.NotNil(t, c) + + if testing.Verbose() { + l := log.New(os.Stdout, t.Name(), log.LstdFlags) + c.SetDebugLogger(l) + c.SetErrorLogger(l) + } + + ctx := context.Background() + ca, err := c.Connect(ctx, &Connect{ + KeepAlive: 30, + ClientID: "testClient", + CleanStart: true, + Properties: &ConnectProperties{ + ReceiveMaximum: Uint16(200), + }, + }) + require.Nil(t, err) + assert.Equal(t, uint8(0), ca.ReasonCode) + + go func() { + for i := uint16(0); true; i++ { + err := ts.SendPacket(&packets.Publish{ + Payload: []byte("ciao"), + Topic: "test", + PacketID: i, + QoS: 1, + }) + if err != nil { + t.Logf("Send packet error: %v", err) + return + } + } + }() + + time.Sleep(10 * time.Millisecond) + c.close() +} + func isChannelClosed(ch chan struct{}) (closed bool) { defer func() { err, ok := recover().(error)