diff --git a/pubsub/gochannel/pubsub.go b/pubsub/gochannel/pubsub.go index 0d4fc61c7..e36fa68cd 100644 --- a/pubsub/gochannel/pubsub.go +++ b/pubsub/gochannel/pubsub.go @@ -93,8 +93,12 @@ func (g *GoChannel) Publish(topic string, messages ...*message.Message) error { g.subscribersLock.RLock() defer g.subscribersLock.RUnlock() - subLock, _ := g.subscribersByTopicLock.LoadOrStore(topic, &sync.Mutex{}) + subLock, loaded := g.subscribersByTopicLock.LoadOrStore(topic, &sync.Mutex{}) subLock.(*sync.Mutex).Lock() + + if !loaded { + defer g.subscribersByTopicLock.Delete(topic) + } defer subLock.(*sync.Mutex).Unlock() if g.config.Persistent { @@ -205,7 +209,14 @@ func (g *GoChannel) Subscribe(ctx context.Context, topic string) (<-chan *messag s.Close() g.subscribersLock.Lock() - defer g.subscribersLock.Unlock() + defer func() { + // if there are no subscribers, clean up any resources related to the topic + if len(g.subscribers[topic]) == 0 { + delete(g.subscribers, topic) + g.subscribersByTopicLock.Delete(topic) + } + g.subscribersLock.Unlock() + }() subLock, _ := g.subscribersByTopicLock.Load(topic) subLock.(*sync.Mutex).Lock() diff --git a/pubsub/gochannel/pubsub_internal_test.go b/pubsub/gochannel/pubsub_internal_test.go new file mode 100644 index 000000000..fdf047f61 --- /dev/null +++ b/pubsub/gochannel/pubsub_internal_test.go @@ -0,0 +1,77 @@ +package gochannel + +import ( + "context" + "log" + "strconv" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ThreeDotsLabs/watermill" + "github.com/ThreeDotsLabs/watermill/message" +) + +func TestSubscribe_clean_subscriber_data(t *testing.T) { + subCount := 100 + pubSub := NewGoChannel( + Config{OutputChannelBuffer: int64(subCount)}, + watermill.NewStdLogger(false, false), + ) + topicName := "test_topic" + + allClosed := sync.WaitGroup{} + + for i := 0; i < subCount; i++ { + ctx, cancel := context.WithCancel(context.Background()) + _, err := pubSub.Subscribe(ctx, topicName+"_index_"+strconv.Itoa(i)) + require.NoError(t, err) + + allClosed.Add(1) + go func() { + cancel() + allClosed.Done() + }() + } + + log.Println("waiting for all closed") + allClosed.Wait() + + assert.Len(t, pubSub.subscribers, 0) + lockCount := 0 + pubSub.subscribersByTopicLock.Range(func(_, _ any) bool { + lockCount++ + return true + }) + assert.Equal(t, 0, lockCount) + + assert.NoError(t, pubSub.Close()) +} + +func TestPublish_clean_lock_data(t *testing.T) { + messageCount := 100 + pubSub := NewGoChannel( + Config{OutputChannelBuffer: int64(messageCount)}, + watermill.NewStdLogger(false, false), + ) + topicName := "test_topic" + + _, err := pubSub.Subscribe(context.Background(), topicName+"_index_"+strconv.Itoa(0)) + require.NoError(t, err) + + for i := 0; i < messageCount; i++ { + err := pubSub.Publish(topicName+"_index_"+strconv.Itoa(i), message.NewMessage(watermill.NewShortUUID(), nil)) + require.NoError(t, err) + } + + lockCount := 0 + pubSub.subscribersByTopicLock.Range(func(_, _ any) bool { + lockCount++ + return true + }) + assert.Equal(t, 1, lockCount) + + assert.NoError(t, pubSub.Close()) +}