diff --git a/channel.go b/channel.go index 09ce37e..96ebd0b 100644 --- a/channel.go +++ b/channel.go @@ -1826,8 +1826,8 @@ func (ch *Channel) Reject(tag uint64, requeue bool) error { // GetNextPublishSeqNo returns the sequence number of the next message to be // published, when in confirm mode. func (ch *Channel) GetNextPublishSeqNo() uint64 { - ch.confirms.m.Lock() - defer ch.confirms.m.Unlock() + ch.confirms.publishedMut.Lock() + defer ch.confirms.publishedMut.Unlock() return ch.confirms.published + 1 } diff --git a/integration_test.go b/integration_test.go index f92d788..50c6507 100644 --- a/integration_test.go +++ b/integration_test.go @@ -2025,6 +2025,59 @@ func TestIntegrationGetNextPublishSeqNo(t *testing.T) { } } +func TestIntegrationGetNextPublishSeqNoRace(t *testing.T) { + if c := integrationConnection(t, "GetNextPublishSeqNoRace"); c != nil { + defer c.Close() + + ch, err := c.Channel() + if err != nil { + t.Fatalf("channel: %v", err) + } + + if err = ch.Confirm(false); err != nil { + t.Fatalf("could not confirm") + } + + ex := "test-get-next-pub" + if err = ch.ExchangeDeclare(ex, "direct", false, false, false, false, nil); err != nil { + t.Fatalf("cannot declare %v: got: %v", ex, err) + } + + n := ch.GetNextPublishSeqNo() + if n != 1 { + t.Fatalf("wrong next publish seqence number before any publish, expected: %d, got: %d", 1, n) + } + + wg := sync.WaitGroup{} + fail := false + + wg.Add(2) + + go func() { + defer wg.Done() + _ = ch.GetNextPublishSeqNo() + }() + + go func() { + defer wg.Done() + if err := ch.PublishWithContext(context.TODO(), "test-get-next-pub-seq", "", false, false, Publishing{}); err != nil { + t.Logf("publish error: %v", err) + fail = true + } + }() + + wg.Wait() + if fail { + t.FailNow() + } + + n = ch.GetNextPublishSeqNo() + if n != 2 { + t.Fatalf("wrong next publish seqence number after 15 publishing, expected: %d, got: %d", 2, n) + } + } +} + // https://github.com/rabbitmq/amqp091-go/pull/44 func TestShouldNotWaitAfterConnectionClosedIssue44(t *testing.T) { conn := integrationConnection(t, "TestShouldNotWaitAfterConnectionClosedIssue44")