From 7c2fc0cd0fe39b89e4cd03712225d54ff6ef36df Mon Sep 17 00:00:00 2001 From: chyezh Date: Thu, 23 Jan 2025 15:17:55 +0800 Subject: [PATCH] fix: unsafe concurrent consuming api of rocksmq Signed-off-by: chyezh --- internal/rootcoord/dml_channels_test.go | 2 +- pkg/mq/mqimpl/rocksmq/client/client_impl.go | 4 +++- pkg/mq/mqimpl/rocksmq/server/rocksmq_impl.go | 9 ++++++--- pkg/mq/mqimpl/rocksmq/server/rocksmq_retention.go | 2 +- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/internal/rootcoord/dml_channels_test.go b/internal/rootcoord/dml_channels_test.go index 7d7f3284835f2..2a8e4b6f4deb0 100644 --- a/internal/rootcoord/dml_channels_test.go +++ b/internal/rootcoord/dml_channels_test.go @@ -165,7 +165,7 @@ func TestDmlChannels(t *testing.T) { defer paramtable.Get().Reset(Params.CommonCfg.PreCreatedTopicEnabled.Key) defer paramtable.Get().Reset(Params.CommonCfg.TopicNames.Key) - assert.Panics(t, func() { newDmlChannels(ctx, factory, dmlChanPrefix, totalDmlChannelNum) }) + newDmlChannels(ctx, factory, dmlChanPrefix, totalDmlChannelNum) } func TestDmChannelsFailure(t *testing.T) { diff --git a/pkg/mq/mqimpl/rocksmq/client/client_impl.go b/pkg/mq/mqimpl/rocksmq/client/client_impl.go index 95f9bcf897e93..a278d51f9fd0a 100644 --- a/pkg/mq/mqimpl/rocksmq/client/client_impl.go +++ b/pkg/mq/mqimpl/rocksmq/client/client_impl.go @@ -111,7 +111,9 @@ func (c *client) Subscribe(options ConsumerOptions) (Consumer, error) { GroupName: consumer.consumerName, MsgMutex: consumer.msgMutex, } - c.server.RegisterConsumer(cons) + if err := c.server.RegisterConsumer(cons); err != nil { + return nil, err + } if options.SubscriptionInitialPosition == common.SubscriptionPositionLatest { err = c.server.SeekToLatest(options.Topic, options.SubscriptionName) diff --git a/pkg/mq/mqimpl/rocksmq/server/rocksmq_impl.go b/pkg/mq/mqimpl/rocksmq/server/rocksmq_impl.go index a7c778403443f..ee15161571705 100644 --- a/pkg/mq/mqimpl/rocksmq/server/rocksmq_impl.go +++ b/pkg/mq/mqimpl/rocksmq/server/rocksmq_impl.go @@ -420,9 +420,7 @@ func (rmq *rocksmq) CreateTopic(topicName string) error { return nil } - if _, ok := topicMu.Load(topicName); !ok { - topicMu.Store(topicName, new(sync.Mutex)) - } + topicMu.LoadOrStore(topicName, new(sync.Mutex)) // msgSizeKey -> msgSize // topicIDKey -> topic creating time @@ -550,6 +548,11 @@ func (rmq *rocksmq) RegisterConsumer(consumer *Consumer) error { if rmq.isClosed() { return errors.New(RmqNotServingErrMsg) } + ll, _ := topicMu.LoadOrStore(consumer.Topic, new(sync.Mutex)) + mu, _ := ll.(*sync.Mutex) + mu.Lock() + defer mu.Unlock() + start := time.Now() if vals, ok := rmq.consumers.Load(consumer.Topic); ok { for _, v := range vals.([]*Consumer) { diff --git a/pkg/mq/mqimpl/rocksmq/server/rocksmq_retention.go b/pkg/mq/mqimpl/rocksmq/server/rocksmq_retention.go index e96e2659e5ed9..c221e9834d247 100644 --- a/pkg/mq/mqimpl/rocksmq/server/rocksmq_retention.go +++ b/pkg/mq/mqimpl/rocksmq/server/rocksmq_retention.go @@ -63,7 +63,7 @@ func initRetentionInfo(kv *rocksdbkv.RocksdbKV, db *gorocksdb.DB) (*retentionInf for _, key := range topicKeys { topic := key[len(TopicIDTitle):] ri.topicRetetionTime.Insert(topic, time.Now().Unix()) - topicMu.Store(topic, new(sync.Mutex)) + topicMu.LoadOrStore(topic, new(sync.Mutex)) } return ri, nil }