This is an automated email from the ASF dual-hosted git repository.

hulk pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new 505aeb6d Add support of new command: ssubscribe and sunsubscribe 
(#2003)
505aeb6d is described below

commit 505aeb6d7105090814e6fa914e8969c029968b23
Author: raffertyyu <[email protected]>
AuthorDate: Fri Jan 12 13:27:31 2024 +0800

    Add support of new command: ssubscribe and sunsubscribe (#2003)
---
 src/commands/cmd_pubsub.cc                   |  60 +++++++++-
 src/commands/cmd_server.cc                   |   5 +-
 src/server/redis_connection.cc               |  39 +++++++
 src/server/redis_connection.h                |   5 +
 src/server/server.cc                         |  61 ++++++++++
 src/server/server.h                          |   7 ++
 tests/gocase/unit/pubsub/pubsubshard_test.go | 164 +++++++++++++++++++++++++++
 7 files changed, 333 insertions(+), 8 deletions(-)

diff --git a/src/commands/cmd_pubsub.cc b/src/commands/cmd_pubsub.cc
index 45272eef..6ec61eea 100644
--- a/src/commands/cmd_pubsub.cc
+++ b/src/commands/cmd_pubsub.cc
@@ -138,6 +138,44 @@ class CommandPUnSubscribe : public Commander {
   }
 };
 
+class CommandSSubscribe : public Commander {
+ public:
+  Status Execute(Server *srv, Connection *conn, std::string *output) override {
+    uint16_t slot = 0;
+    if (srv->GetConfig()->cluster_enabled) {
+      slot = GetSlotIdFromKey(args_[1]);
+      for (unsigned int i = 2; i < args_.size(); i++) {
+        if (GetSlotIdFromKey(args_[i]) != slot) {
+          return {Status::RedisExecErr, "CROSSSLOT Keys in request don't hash 
to the same slot"};
+        }
+      }
+    }
+
+    for (unsigned int i = 1; i < args_.size(); i++) {
+      conn->SSubscribeChannel(args_[i], slot);
+      SubscribeCommandReply(output, "ssubscribe", args_[i], 
conn->SSubscriptionsCount());
+    }
+    return Status::OK();
+  }
+};
+
+class CommandSUnSubscribe : public Commander {
+ public:
+  Status Execute(Server *srv, Connection *conn, std::string *output) override {
+    if (args_.size() == 1) {
+      conn->SUnsubscribeAll([output](const std::string &sub_name, int num) {
+        SubscribeCommandReply(output, "sunsubscribe", sub_name, num);
+      });
+    } else {
+      for (size_t i = 1; i < args_.size(); i++) {
+        conn->SUnsubscribeChannel(args_[i], srv->GetConfig()->cluster_enabled 
? GetSlotIdFromKey(args_[i]) : 0);
+        SubscribeCommandReply(output, "sunsubscribe", args_[i], 
conn->SSubscriptionsCount());
+      }
+    }
+    return Status::OK();
+  }
+};
+
 class CommandPubSub : public Commander {
  public:
   Status Parse(const std::vector<std::string> &args) override {
@@ -146,14 +184,14 @@ class CommandPubSub : public Commander {
       return Status::OK();
     }
 
-    if ((subcommand_ == "numsub") && args.size() >= 2) {
+    if ((subcommand_ == "numsub" || subcommand_ == "shardnumsub") && 
args.size() >= 2) {
       if (args.size() > 2) {
         channels_ = std::vector<std::string>(args.begin() + 2, args.end());
       }
       return Status::OK();
     }
 
-    if ((subcommand_ == "channels") && args.size() <= 3) {
+    if ((subcommand_ == "channels" || subcommand_ == "shardchannels") && 
args.size() <= 3) {
       if (args.size() == 3) {
         pattern_ = args[2];
       }
@@ -169,9 +207,13 @@ class CommandPubSub : public Commander {
       return Status::OK();
     }
 
-    if (subcommand_ == "numsub") {
+    if (subcommand_ == "numsub" || subcommand_ == "shardnumsub") {
       std::vector<ChannelSubscribeNum> channel_subscribe_nums;
-      srv->ListChannelSubscribeNum(channels_, &channel_subscribe_nums);
+      if (subcommand_ == "numsub") {
+        srv->ListChannelSubscribeNum(channels_, &channel_subscribe_nums);
+      } else {
+        srv->ListSChannelSubscribeNum(channels_, &channel_subscribe_nums);
+      }
 
       output->append(redis::MultiLen(channel_subscribe_nums.size() * 2));
       for (const auto &chan_subscribe_num : channel_subscribe_nums) {
@@ -182,9 +224,13 @@ class CommandPubSub : public Commander {
       return Status::OK();
     }
 
-    if (subcommand_ == "channels") {
+    if (subcommand_ == "channels" || subcommand_ == "shardchannels") {
       std::vector<std::string> channels;
-      srv->GetChannelsByPattern(pattern_, &channels);
+      if (subcommand_ == "channels") {
+        srv->GetChannelsByPattern(pattern_, &channels);
+      } else {
+        srv->GetSChannelsByPattern(pattern_, &channels);
+      }
       *output = redis::MultiBulkString(channels);
       return Status::OK();
     }
@@ -205,6 +251,8 @@ REDIS_REGISTER_COMMANDS(
     MakeCmdAttr<CommandUnSubscribe>("unsubscribe", -1, "read-only pub-sub 
no-multi no-script", 0, 0, 0),
     MakeCmdAttr<CommandPSubscribe>("psubscribe", -2, "read-only pub-sub 
no-multi no-script", 0, 0, 0),
     MakeCmdAttr<CommandPUnSubscribe>("punsubscribe", -1, "read-only pub-sub 
no-multi no-script", 0, 0, 0),
+    MakeCmdAttr<CommandSSubscribe>("ssubscribe", -2, "read-only pub-sub 
no-multi no-script", 0, 0, 0),
+    MakeCmdAttr<CommandSUnSubscribe>("sunsubscribe", -1, "read-only pub-sub 
no-multi no-script", 0, 0, 0),
     MakeCmdAttr<CommandPubSub>("pubsub", -2, "read-only pub-sub no-script", 0, 
0, 0), )
 
 }  // namespace redis
diff --git a/src/commands/cmd_server.cc b/src/commands/cmd_server.cc
index a41094d1..d94d81e7 100644
--- a/src/commands/cmd_server.cc
+++ b/src/commands/cmd_server.cc
@@ -1160,7 +1160,7 @@ class CommandAnalyze : public Commander {
  public:
   Status Parse(const std::vector<std::string> &args) override {
     if (args.size() <= 1) return {Status::RedisExecErr, errInvalidSyntax};
-    for (int i = 1; i < args.size(); ++i) {
+    for (unsigned int i = 1; i < args.size(); ++i) {
       command_args_.push_back(args[i]);
     }
     return Status::OK();
@@ -1178,7 +1178,8 @@ class CommandAnalyze : public Commander {
     cmd->SetArgs(command_args_);
 
     int arity = cmd->GetAttributes()->arity;
-    if ((arity > 0 && command_args_.size() != arity) || (arity < 0 && 
command_args_.size() < -arity)) {
+    if ((arity > 0 && static_cast<int>(command_args_.size()) != arity) ||
+        (arity < 0 && static_cast<int>(command_args_.size()) < -arity)) {
       *output = redis::Error("ERR wrong number of arguments");
       return {Status::RedisExecErr, errWrongNumOfArguments};
     }
diff --git a/src/server/redis_connection.cc b/src/server/redis_connection.cc
index ae80e950..d6e0b5f6 100644
--- a/src/server/redis_connection.cc
+++ b/src/server/redis_connection.cc
@@ -261,6 +261,45 @@ void Connection::PUnsubscribeAll(const UnsubscribeCallback 
&reply) {
 
 int Connection::PSubscriptionsCount() { return 
static_cast<int>(subscribe_patterns_.size()); }
 
+void Connection::SSubscribeChannel(const std::string &channel, uint16_t slot) {
+  for (const auto &chan : subscribe_shard_channels_) {
+    if (channel == chan) return;
+  }
+
+  subscribe_shard_channels_.emplace_back(channel);
+  owner_->srv->SSubscribeChannel(channel, this, slot);
+}
+
+void Connection::SUnsubscribeChannel(const std::string &channel, uint16_t 
slot) {
+  for (auto iter = subscribe_shard_channels_.begin(); iter != 
subscribe_shard_channels_.end(); iter++) {
+    if (*iter == channel) {
+      subscribe_shard_channels_.erase(iter);
+      owner_->srv->SUnsubscribeChannel(channel, this, slot);
+      return;
+    }
+  }
+}
+
+void Connection::SUnsubscribeAll(const UnsubscribeCallback &reply) {
+  if (subscribe_shard_channels_.empty()) {
+    if (reply) reply("", 0);
+    return;
+  }
+
+  int removed = 0;
+  for (const auto &chan : subscribe_shard_channels_) {
+    owner_->srv->SUnsubscribeChannel(chan, this,
+                                     owner_->srv->GetConfig()->cluster_enabled 
? GetSlotIdFromKey(chan) : 0);
+    removed++;
+    if (reply) {
+      reply(chan, static_cast<int>(subscribe_shard_channels_.size() - 
removed));
+    }
+  }
+  subscribe_shard_channels_.clear();
+}
+
+int Connection::SSubscriptionsCount() { return 
static_cast<int>(subscribe_shard_channels_.size()); }
+
 bool Connection::IsProfilingEnabled(const std::string &cmd) {
   auto config = srv_->GetConfig();
   if (config->profiling_sample_ratio == 0) return false;
diff --git a/src/server/redis_connection.h b/src/server/redis_connection.h
index 25b522d8..34fbcbae 100644
--- a/src/server/redis_connection.h
+++ b/src/server/redis_connection.h
@@ -74,6 +74,10 @@ class Connection : public EvbufCallbackBase<Connection> {
   void PUnsubscribeChannel(const std::string &pattern);
   void PUnsubscribeAll(const UnsubscribeCallback &reply = nullptr);
   int PSubscriptionsCount();
+  void SSubscribeChannel(const std::string &channel, uint16_t slot);
+  void SUnsubscribeChannel(const std::string &channel, uint16_t slot);
+  void SUnsubscribeAll(const UnsubscribeCallback &reply = nullptr);
+  int SSubscriptionsCount();
 
   uint64_t GetAge() const;
   uint64_t GetIdleTime() const;
@@ -159,6 +163,7 @@ class Connection : public EvbufCallbackBase<Connection> {
 
   std::vector<std::string> subscribe_channels_;
   std::vector<std::string> subscribe_patterns_;
+  std::vector<std::string> subscribe_shard_channels_;
 
   Server *srv_;
   bool in_exec_ = false;
diff --git a/src/server/server.cc b/src/server/server.cc
index f8f2fb94..efe721b2 100644
--- a/src/server/server.cc
+++ b/src/server/server.cc
@@ -78,6 +78,9 @@ Server::Server(engine::Storage *storage, Config *config)
   // Init cluster
   cluster = std::make_unique<Cluster>(this, config_->binds, config_->port);
 
+  // init shard pub/sub channels
+  pubsub_shard_channels_.resize(config->cluster_enabled ? HASH_SLOTS_SIZE : 1);
+
   for (int i = 0; i < config->workers; i++) {
     auto worker = std::make_unique<Worker>(this, config);
     // multiple workers can't listen to the same unix socket, so
@@ -497,6 +500,64 @@ void Server::PUnsubscribeChannel(const std::string 
&pattern, redis::Connection *
   }
 }
 
+void Server::SSubscribeChannel(const std::string &channel, redis::Connection 
*conn, uint16_t slot) {
+  assert((config_->cluster_enabled && slot < HASH_SLOTS_SIZE) || slot == 0);
+  std::lock_guard<std::mutex> guard(pubsub_shard_channels_mu_);
+
+  auto conn_ctx = ConnContext(conn->Owner(), conn->GetFD());
+  if (auto iter = pubsub_shard_channels_[slot].find(channel); iter == 
pubsub_shard_channels_[slot].end()) {
+    pubsub_shard_channels_[slot].emplace(channel, 
std::list<ConnContext>{conn_ctx});
+  } else {
+    iter->second.emplace_back(conn_ctx);
+  }
+}
+
+void Server::SUnsubscribeChannel(const std::string &channel, redis::Connection 
*conn, uint16_t slot) {
+  assert((config_->cluster_enabled && slot < HASH_SLOTS_SIZE) || slot == 0);
+  std::lock_guard<std::mutex> guard(pubsub_shard_channels_mu_);
+
+  auto iter = pubsub_shard_channels_[slot].find(channel);
+  if (iter == pubsub_shard_channels_[slot].end()) {
+    return;
+  }
+
+  for (const auto &conn_ctx : iter->second) {
+    if (conn->GetFD() == conn_ctx.fd && conn->Owner() == conn_ctx.owner) {
+      iter->second.remove(conn_ctx);
+      if (iter->second.empty()) {
+        pubsub_shard_channels_[slot].erase(iter);
+      }
+      break;
+    }
+  }
+}
+
+void Server::GetSChannelsByPattern(const std::string &pattern, 
std::vector<std::string> *channels) {
+  std::lock_guard<std::mutex> guard(pubsub_shard_channels_mu_);
+
+  for (const auto &shard_channels : pubsub_shard_channels_) {
+    for (const auto &iter : shard_channels) {
+      if (pattern.empty() || util::StringMatch(pattern, iter.first, 0)) {
+        channels->emplace_back(iter.first);
+      }
+    }
+  }
+}
+
+void Server::ListSChannelSubscribeNum(const std::vector<std::string> &channels,
+                                      std::vector<ChannelSubscribeNum> 
*channel_subscribe_nums) {
+  std::lock_guard<std::mutex> guard(pubsub_shard_channels_mu_);
+
+  for (const auto &chan : channels) {
+    uint16_t slot = config_->cluster_enabled ? GetSlotIdFromKey(chan) : 0;
+    if (auto iter = pubsub_shard_channels_[slot].find(chan); iter != 
pubsub_shard_channels_[slot].end()) {
+      channel_subscribe_nums->emplace_back(ChannelSubscribeNum{iter->first, 
iter->second.size()});
+    } else {
+      channel_subscribe_nums->emplace_back(ChannelSubscribeNum{chan, 0});
+    }
+  }
+}
+
 void Server::BlockOnKey(const std::string &key, redis::Connection *conn) {
   std::lock_guard<std::mutex> guard(blocking_keys_mu_);
 
diff --git a/src/server/server.h b/src/server/server.h
index 2acd0f5d..a86eedf1 100644
--- a/src/server/server.h
+++ b/src/server/server.h
@@ -201,6 +201,11 @@ class Server {
   void PSubscribeChannel(const std::string &pattern, redis::Connection *conn);
   void PUnsubscribeChannel(const std::string &pattern, redis::Connection 
*conn);
   size_t GetPubSubPatternSize() const { return pubsub_patterns_.size(); }
+  void SSubscribeChannel(const std::string &channel, redis::Connection *conn, 
uint16_t slot);
+  void SUnsubscribeChannel(const std::string &channel, redis::Connection 
*conn, uint16_t slot);
+  void GetSChannelsByPattern(const std::string &pattern, 
std::vector<std::string> *channels);
+  void ListSChannelSubscribeNum(const std::vector<std::string> &channels,
+                                std::vector<ChannelSubscribeNum> 
*channel_subscribe_nums);
 
   void BlockOnKey(const std::string &key, redis::Connection *conn);
   void UnblockOnKey(const std::string &key, redis::Connection *conn);
@@ -351,6 +356,8 @@ class Server {
   std::map<std::string, std::list<ConnContext>> pubsub_channels_;
   std::map<std::string, std::list<ConnContext>> pubsub_patterns_;
   std::mutex pubsub_channels_mu_;
+  std::vector<std::map<std::string, std::list<ConnContext>>> 
pubsub_shard_channels_;
+  std::mutex pubsub_shard_channels_mu_;
   std::map<std::string, std::list<ConnContext>> blocking_keys_;
   std::mutex blocking_keys_mu_;
 
diff --git a/tests/gocase/unit/pubsub/pubsubshard_test.go 
b/tests/gocase/unit/pubsub/pubsubshard_test.go
new file mode 100644
index 00000000..9e8b04cf
--- /dev/null
+++ b/tests/gocase/unit/pubsub/pubsubshard_test.go
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package pubsub
+
+import (
+       "context"
+       "fmt"
+       "testing"
+
+       "github.com/apache/kvrocks/tests/gocase/util"
+       "github.com/redis/go-redis/v9"
+       "github.com/stretchr/testify/require"
+)
+
+func TestPubSubShard(t *testing.T) {
+       ctx := context.Background()
+
+       srv := util.StartServer(t, map[string]string{})
+       defer srv.Close()
+       rdb := srv.NewClient()
+       defer func() { require.NoError(t, rdb.Close()) }()
+
+       csrv := util.StartServer(t, map[string]string{"cluster-enabled": "yes"})
+       defer csrv.Close()
+       crdb := csrv.NewClient()
+       defer func() { require.NoError(t, crdb.Close()) }()
+
+       nodeID := "YYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYY"
+       require.NoError(t, crdb.Do(ctx, "clusterx", "SETNODEID", nodeID).Err())
+       clusterNodes := fmt.Sprintf("%s %s %d master - 0-16383", nodeID, 
csrv.Host(), csrv.Port())
+       require.NoError(t, crdb.Do(ctx, "clusterx", "SETNODES", clusterNodes, 
"1").Err())
+
+       rdbs := []*redis.Client{rdb, crdb}
+
+       t.Run("SSUBSCRIBE PING", func(t *testing.T) {
+               pubsub := rdb.SSubscribe(ctx, "somechannel")
+               receiveType(t, pubsub, &redis.Subscription{})
+               require.NoError(t, pubsub.Ping(ctx))
+               require.NoError(t, pubsub.Ping(ctx))
+               require.NoError(t, pubsub.SUnsubscribe(ctx, "somechannel"))
+               require.Equal(t, "PONG", rdb.Ping(ctx).Val())
+               receiveType(t, pubsub, &redis.Pong{})
+               receiveType(t, pubsub, &redis.Pong{})
+       })
+
+       t.Run("SSUBSCRIBE/SUNSUBSCRIBE basic", func(t *testing.T) {
+               for _, c := range rdbs {
+                       pubsub := c.SSubscribe(ctx, "singlechannel")
+                       defer pubsub.Close()
+
+                       msg := receiveType(t, pubsub, &redis.Subscription{})
+                       require.EqualValues(t, 1, msg.Count)
+                       require.EqualValues(t, "singlechannel", msg.Channel)
+                       require.EqualValues(t, "ssubscribe", msg.Kind)
+
+                       err := pubsub.SSubscribe(ctx, "multichannel1{tag1}", 
"multichannel2{tag1}", "multichannel1{tag1}")
+                       require.Nil(t, err)
+                       require.EqualValues(t, 2, receiveType(t, pubsub, 
&redis.Subscription{}).Count)
+                       require.EqualValues(t, 3, receiveType(t, pubsub, 
&redis.Subscription{}).Count)
+                       require.EqualValues(t, 3, receiveType(t, pubsub, 
&redis.Subscription{}).Count)
+
+                       err = pubsub.SSubscribe(ctx, "multichannel3{tag1}", 
"multichannel4{tag2}")
+                       require.Nil(t, err)
+                       if c == rdb {
+                               require.EqualValues(t, 4, receiveType(t, 
pubsub, &redis.Subscription{}).Count)
+                               require.EqualValues(t, 5, receiveType(t, 
pubsub, &redis.Subscription{}).Count)
+                       } else {
+                               // note: when cluster enabled, shard channels 
in single command must belong to the same slot
+                               // reference: 
https://redis.io/commands/ssubscribe
+                               _, err = pubsub.Receive(ctx)
+                               require.EqualError(t, err, "ERR CROSSSLOT Keys 
in request don't hash to the same slot")
+                       }
+
+                       err = pubsub.SUnsubscribe(ctx, "multichannel3{tag1}", 
"multichannel4{tag2}", "multichannel5{tag2}")
+                       require.Nil(t, err)
+                       if c == rdb {
+                               require.EqualValues(t, 4, receiveType(t, 
pubsub, &redis.Subscription{}).Count)
+                               require.EqualValues(t, 3, receiveType(t, 
pubsub, &redis.Subscription{}).Count)
+                               require.EqualValues(t, 3, receiveType(t, 
pubsub, &redis.Subscription{}).Count)
+                       } else {
+                               require.EqualValues(t, 3, receiveType(t, 
pubsub, &redis.Subscription{}).Count)
+                               require.EqualValues(t, 3, receiveType(t, 
pubsub, &redis.Subscription{}).Count)
+                               require.EqualValues(t, 3, receiveType(t, 
pubsub, &redis.Subscription{}).Count)
+                       }
+
+                       err = pubsub.SUnsubscribe(ctx)
+                       require.Nil(t, err)
+                       msg = receiveType(t, pubsub, &redis.Subscription{})
+                       require.EqualValues(t, 2, msg.Count)
+                       require.EqualValues(t, "sunsubscribe", msg.Kind)
+                       require.EqualValues(t, 1, receiveType(t, pubsub, 
&redis.Subscription{}).Count)
+                       require.EqualValues(t, 0, receiveType(t, pubsub, 
&redis.Subscription{}).Count)
+               }
+       })
+
+       t.Run("SSUBSCRIBE/SUNSUBSCRIBE with empty channel", func(t *testing.T) {
+               for _, c := range rdbs {
+                       pubsub := c.SSubscribe(ctx)
+                       defer pubsub.Close()
+
+                       err := pubsub.SUnsubscribe(ctx, "foo", "bar")
+                       require.Nil(t, err)
+                       require.EqualValues(t, 0, receiveType(t, pubsub, 
&redis.Subscription{}).Count)
+                       require.EqualValues(t, 0, receiveType(t, pubsub, 
&redis.Subscription{}).Count)
+               }
+       })
+
+       t.Run("SHARDNUMSUB returns numbers, not strings", func(t *testing.T) {
+               require.EqualValues(t, map[string]int64{
+                       "abc": 0,
+                       "def": 0,
+               }, rdb.PubSubShardNumSub(ctx, "abc", "def").Val())
+       })
+
+       t.Run("PUBSUB SHARDNUMSUB/SHARDCHANNELS", func(t *testing.T) {
+               for _, c := range rdbs {
+                       pubsub := c.SSubscribe(ctx, "singlechannel")
+                       defer pubsub.Close()
+                       receiveType(t, pubsub, &redis.Subscription{})
+
+                       err := pubsub.SSubscribe(ctx, "multichannel1{tag1}", 
"multichannel2{tag1}", "multichannel3{tag1}")
+                       require.Nil(t, err)
+                       receiveType(t, pubsub, &redis.Subscription{})
+                       receiveType(t, pubsub, &redis.Subscription{})
+                       receiveType(t, pubsub, &redis.Subscription{})
+
+                       pubsub1 := c.SSubscribe(ctx, "multichannel1{tag1}")
+                       defer pubsub1.Close()
+
+                       sc := c.PubSubShardChannels(ctx, "")
+                       require.EqualValues(t, len(sc.Val()), 4)
+                       sc = c.PubSubShardChannels(ctx, "multi*")
+                       require.EqualValues(t, len(sc.Val()), 3)
+
+                       sn := c.PubSubShardNumSub(ctx)
+                       require.EqualValues(t, len(sn.Val()), 0)
+                       sn = c.PubSubShardNumSub(ctx, "singlechannel", 
"multichannel1{tag1}", "multichannel2{tag1}", "multichannel3{tag1}")
+                       for i, k := range sn.Val() {
+                               if i == "multichannel1{tag1}" {
+                                       require.EqualValues(t, k, 2)
+                               } else {
+                                       require.EqualValues(t, k, 1)
+                               }
+                       }
+               }
+       })
+}

Reply via email to