From 698f963315e80d673bd812c7cc3eece1261dcee8 Mon Sep 17 00:00:00 2001 From: Taishi Kasuga Date: Mon, 2 Oct 2023 22:19:08 +0900 Subject: [PATCH] Add sharded Pub/Sub support for cluster --- cluster/test/commands_on_pub_sub_test.rb | 35 ++++++++++++++++++++++++ lib/redis/commands/pubsub.rb | 21 ++++++++++++++ lib/redis/subscribe.rb | 30 +++++++++++++++++++- 3 files changed, 85 insertions(+), 1 deletion(-) diff --git a/cluster/test/commands_on_pub_sub_test.rb b/cluster/test/commands_on_pub_sub_test.rb index 1db625006..107f35a7b 100644 --- a/cluster/test/commands_on_pub_sub_test.rb +++ b/cluster/test/commands_on_pub_sub_test.rb @@ -70,4 +70,39 @@ def test_publish_psubscribe_punsubscribe_pubsub assert_equal('two', messages['gucci2']) assert_equal('three', messages['hermes3']) end + + def test_spublish_ssubscribe_sunsubscribe_pubsub + omit_version('7.0.0') + + sub_cnt = 0 + messages = {} + + thread = Thread.new do + redis.ssubscribe('channel1', 'channel2') do |on| + on.ssubscribe { sub_cnt += 1 } + on.smessage do |c, msg| + messages[c] = msg + redis.sunsubscribe if messages.size == 2 + end + end + end + + Thread.pass until sub_cnt == 2 + + publisher = build_another_client + + assert_equal %w[channel1 channel2], publisher.pubsub(:shardchannels, 'channel*') + assert_equal({ 'channel1' => 1, 'channel2' => 1, 'channel3' => 0 }, + publisher.pubsub(:shardnumsub, 'channel1', 'channel2', 'channel3')) + + publisher.spublish('channel1', 'one') + publisher.spublish('channel2', 'two') + publisher.spublish('channel3', 'three') + + thread.join + + assert_equal(2, messages.size) + assert_equal('one', messages['channel1']) + assert_equal('two', messages['channel2']) + end end diff --git a/lib/redis/commands/pubsub.rb b/lib/redis/commands/pubsub.rb index 5d84e3fab..37702935c 100644 --- a/lib/redis/commands/pubsub.rb +++ b/lib/redis/commands/pubsub.rb @@ -49,6 +49,27 @@ def punsubscribe(*channels) def pubsub(subcommand, *args) send_command([:pubsub, subcommand] + args) end + + # Post a message to a channel in shard. + def spublish(channel, message) + send_command([:spublish, channel, message]) + end + + # Listen for messages published to the given channels in shard. + def ssubscribe(*channels, &block) + _subscription(:ssubscribe, 0, channels, block) + end + + # Listen for messages published to the given channels in shard. + # Throw a timeout error if there is no messages for a timeout period. + def ssubscribe_with_timeout(timeout, *channels, &block) + _subscription(:ssubscribe_with_timeout, timeout, channels, block) + end + + # Stop listening for messages posted to the given channels in shard. + def sunsubscribe(*channels) + _subscription(:sunsubscribe, 0, channels, nil) + end end end end diff --git a/lib/redis/subscribe.rb b/lib/redis/subscribe.rb index 94f0f0267..f95e29dfb 100644 --- a/lib/redis/subscribe.rb +++ b/lib/redis/subscribe.rb @@ -29,6 +29,14 @@ def psubscribe_with_timeout(timeout, *channels, &block) subscription("psubscribe", "punsubscribe", channels, block, timeout) end + def ssubscribe(*channels, &block) + subscription("ssubscribe", "sunsubscribe", channels, block) + end + + def ssubscribe_with_timeout(timeout, *channels, &block) + subscription("ssubscribe", "sunsubscribe", channels, block, timeout) + end + def unsubscribe(*channels) call_v([:unsubscribe, *channels]) end @@ -37,6 +45,10 @@ def punsubscribe(*channels) call_v([:punsubscribe, *channels]) end + def sunsubscribe(*channels) + call_v([:sunsubscribe, *channels]) + end + def close @client.close end @@ -46,7 +58,11 @@ def close def subscription(start, stop, channels, block, timeout = 0) sub = Subscription.new(&block) - call_v([start, *channels]) + case start + when "ssubscribe" then channels.each { |c| call_v([start, c]) } # avoid cross-slot keys + else call_v([start, *channels]) + end + while event = @client.next_event(timeout) if event.is_a?(::RedisClient::CommandError) raise Client::ERROR_MAPPING.fetch(event.class), event.message @@ -94,5 +110,17 @@ def punsubscribe(&block) def pmessage(&block) @callbacks["pmessage"] = block end + + def ssubscribe(&block) + @callbacks["ssubscribe"] = block + end + + def sunsubscribe(&block) + @callbacks["sunsubscribe"] = block + end + + def smessage(&block) + @callbacks["smessage"] = block + end end end