diff --git a/paho/client.go b/paho/client.go index ad54e9c..5dcbeb2 100644 --- a/paho/client.go +++ b/paho/client.go @@ -399,10 +399,10 @@ func (c *Client) routePublishPackets() { // Disconnect, the Stop channel is closed or there is an error reading // a packet from the network connection func (c *Client) incoming() { + defer c.debug.Println("client stopping, incoming stopping") for { select { case <-c.stop: - c.debug.Println("client stopping, incoming stopping") return default: recv, err := packets.ReadPacket(c.Conn) @@ -436,7 +436,15 @@ func (c *Client) incoming() { case packets.PUBLISH: pb := recv.Content.(*packets.Publish) c.debug.Printf("received QoS%d PUBLISH", pb.QoS) - c.publishPackets <- pb + c.mu.Lock() + select { + case <-c.stop: + c.mu.Unlock() + return + default: + c.publishPackets <- pb + c.mu.Unlock() + } case packets.PUBACK, packets.PUBCOMP, packets.SUBACK, packets.UNSUBACK: c.debug.Printf("received %s packet with id %d", recv.PacketType(), recv.PacketID()) if cpCtx := c.MIDs.Get(recv.PacketID()); cpCtx != nil { @@ -550,12 +558,18 @@ func (c *Client) error(e error) { // is received. func (c *Client) Authenticate(ctx context.Context, a *Auth) (*AuthResponse, error) { c.debug.Println("client initiated reauthentication") - c.mu.Lock() - defer c.mu.Unlock() + c.mu.Lock() + if c.raCtx != nil { + c.mu.Unlock() + return nil, fmt.Errorf("previous authentication is still in progress") + } c.raCtx = &CPContext{ctx, make(chan packets.ControlPacket, 1)} + c.mu.Unlock() defer func() { + c.mu.Lock() c.raCtx = nil + c.mu.Unlock() }() c.debug.Println("sending AUTH") diff --git a/paho/client_test.go b/paho/client_test.go index fd50c9c..3796ca2 100644 --- a/paho/client_test.go +++ b/paho/client_test.go @@ -750,6 +750,63 @@ func TestCloseDeadlock(t *testing.T) { wg.Wait() } +func TestSendOnClosedChannel(t *testing.T) { + ts := newTestServer() + ts.SetResponse(packets.CONNACK, &packets.Connack{ + ReasonCode: 0, + SessionPresent: false, + Properties: &packets.Properties{ + MaximumPacketSize: Uint32(12345), + MaximumQOS: Byte(1), + ReceiveMaximum: Uint16(12345), + TopicAliasMaximum: Uint16(200), + }, + }) + go ts.Run() + defer ts.Stop() + + c := NewClient(ClientConfig{ + Conn: ts.ClientConn(), + }) + require.NotNil(t, c) + + if testing.Verbose() { + l := log.New(os.Stdout, t.Name(), log.LstdFlags) + c.SetDebugLogger(l) + c.SetErrorLogger(l) + } + + ctx := context.Background() + ca, err := c.Connect(ctx, &Connect{ + KeepAlive: 30, + ClientID: "testClient", + CleanStart: true, + Properties: &ConnectProperties{ + ReceiveMaximum: Uint16(200), + }, + }) + require.Nil(t, err) + assert.Equal(t, uint8(0), ca.ReasonCode) + + go func() { + for i := uint16(0); true; i++ { + err := ts.SendPacket(&packets.Publish{ + Payload: []byte("ciao"), + Topic: "test", + PacketID: i, + QoS: 1, + }) + if err != nil { + t.Logf("Send packet error: %v", err) + return + } + } + }() + + time.Sleep(10 * time.Millisecond) + c.close() +} + func isChannelClosed(ch chan struct{}) (closed bool) { defer func() { err, ok := recover().(error)