This is an automated email from the ASF dual-hosted git repository.
hulk pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new 505aeb6d Add support of new command: ssubscribe and sunsubscribe
(#2003)
505aeb6d is described below
commit 505aeb6d7105090814e6fa914e8969c029968b23
Author: raffertyyu <[email protected]>
AuthorDate: Fri Jan 12 13:27:31 2024 +0800
Add support of new command: ssubscribe and sunsubscribe (#2003)
---
src/commands/cmd_pubsub.cc | 60 +++++++++-
src/commands/cmd_server.cc | 5 +-
src/server/redis_connection.cc | 39 +++++++
src/server/redis_connection.h | 5 +
src/server/server.cc | 61 ++++++++++
src/server/server.h | 7 ++
tests/gocase/unit/pubsub/pubsubshard_test.go | 164 +++++++++++++++++++++++++++
7 files changed, 333 insertions(+), 8 deletions(-)
diff --git a/src/commands/cmd_pubsub.cc b/src/commands/cmd_pubsub.cc
index 45272eef..6ec61eea 100644
--- a/src/commands/cmd_pubsub.cc
+++ b/src/commands/cmd_pubsub.cc
@@ -138,6 +138,44 @@ class CommandPUnSubscribe : public Commander {
}
};
+class CommandSSubscribe : public Commander {
+ public:
+ Status Execute(Server *srv, Connection *conn, std::string *output) override {
+ uint16_t slot = 0;
+ if (srv->GetConfig()->cluster_enabled) {
+ slot = GetSlotIdFromKey(args_[1]);
+ for (unsigned int i = 2; i < args_.size(); i++) {
+ if (GetSlotIdFromKey(args_[i]) != slot) {
+ return {Status::RedisExecErr, "CROSSSLOT Keys in request don't hash
to the same slot"};
+ }
+ }
+ }
+
+ for (unsigned int i = 1; i < args_.size(); i++) {
+ conn->SSubscribeChannel(args_[i], slot);
+ SubscribeCommandReply(output, "ssubscribe", args_[i],
conn->SSubscriptionsCount());
+ }
+ return Status::OK();
+ }
+};
+
+class CommandSUnSubscribe : public Commander {
+ public:
+ Status Execute(Server *srv, Connection *conn, std::string *output) override {
+ if (args_.size() == 1) {
+ conn->SUnsubscribeAll([output](const std::string &sub_name, int num) {
+ SubscribeCommandReply(output, "sunsubscribe", sub_name, num);
+ });
+ } else {
+ for (size_t i = 1; i < args_.size(); i++) {
+ conn->SUnsubscribeChannel(args_[i], srv->GetConfig()->cluster_enabled
? GetSlotIdFromKey(args_[i]) : 0);
+ SubscribeCommandReply(output, "sunsubscribe", args_[i],
conn->SSubscriptionsCount());
+ }
+ }
+ return Status::OK();
+ }
+};
+
class CommandPubSub : public Commander {
public:
Status Parse(const std::vector<std::string> &args) override {
@@ -146,14 +184,14 @@ class CommandPubSub : public Commander {
return Status::OK();
}
- if ((subcommand_ == "numsub") && args.size() >= 2) {
+ if ((subcommand_ == "numsub" || subcommand_ == "shardnumsub") &&
args.size() >= 2) {
if (args.size() > 2) {
channels_ = std::vector<std::string>(args.begin() + 2, args.end());
}
return Status::OK();
}
- if ((subcommand_ == "channels") && args.size() <= 3) {
+ if ((subcommand_ == "channels" || subcommand_ == "shardchannels") &&
args.size() <= 3) {
if (args.size() == 3) {
pattern_ = args[2];
}
@@ -169,9 +207,13 @@ class CommandPubSub : public Commander {
return Status::OK();
}
- if (subcommand_ == "numsub") {
+ if (subcommand_ == "numsub" || subcommand_ == "shardnumsub") {
std::vector<ChannelSubscribeNum> channel_subscribe_nums;
- srv->ListChannelSubscribeNum(channels_, &channel_subscribe_nums);
+ if (subcommand_ == "numsub") {
+ srv->ListChannelSubscribeNum(channels_, &channel_subscribe_nums);
+ } else {
+ srv->ListSChannelSubscribeNum(channels_, &channel_subscribe_nums);
+ }
output->append(redis::MultiLen(channel_subscribe_nums.size() * 2));
for (const auto &chan_subscribe_num : channel_subscribe_nums) {
@@ -182,9 +224,13 @@ class CommandPubSub : public Commander {
return Status::OK();
}
- if (subcommand_ == "channels") {
+ if (subcommand_ == "channels" || subcommand_ == "shardchannels") {
std::vector<std::string> channels;
- srv->GetChannelsByPattern(pattern_, &channels);
+ if (subcommand_ == "channels") {
+ srv->GetChannelsByPattern(pattern_, &channels);
+ } else {
+ srv->GetSChannelsByPattern(pattern_, &channels);
+ }
*output = redis::MultiBulkString(channels);
return Status::OK();
}
@@ -205,6 +251,8 @@ REDIS_REGISTER_COMMANDS(
MakeCmdAttr<CommandUnSubscribe>("unsubscribe", -1, "read-only pub-sub
no-multi no-script", 0, 0, 0),
MakeCmdAttr<CommandPSubscribe>("psubscribe", -2, "read-only pub-sub
no-multi no-script", 0, 0, 0),
MakeCmdAttr<CommandPUnSubscribe>("punsubscribe", -1, "read-only pub-sub
no-multi no-script", 0, 0, 0),
+ MakeCmdAttr<CommandSSubscribe>("ssubscribe", -2, "read-only pub-sub
no-multi no-script", 0, 0, 0),
+ MakeCmdAttr<CommandSUnSubscribe>("sunsubscribe", -1, "read-only pub-sub
no-multi no-script", 0, 0, 0),
MakeCmdAttr<CommandPubSub>("pubsub", -2, "read-only pub-sub no-script", 0,
0, 0), )
} // namespace redis
diff --git a/src/commands/cmd_server.cc b/src/commands/cmd_server.cc
index a41094d1..d94d81e7 100644
--- a/src/commands/cmd_server.cc
+++ b/src/commands/cmd_server.cc
@@ -1160,7 +1160,7 @@ class CommandAnalyze : public Commander {
public:
Status Parse(const std::vector<std::string> &args) override {
if (args.size() <= 1) return {Status::RedisExecErr, errInvalidSyntax};
- for (int i = 1; i < args.size(); ++i) {
+ for (unsigned int i = 1; i < args.size(); ++i) {
command_args_.push_back(args[i]);
}
return Status::OK();
@@ -1178,7 +1178,8 @@ class CommandAnalyze : public Commander {
cmd->SetArgs(command_args_);
int arity = cmd->GetAttributes()->arity;
- if ((arity > 0 && command_args_.size() != arity) || (arity < 0 &&
command_args_.size() < -arity)) {
+ if ((arity > 0 && static_cast<int>(command_args_.size()) != arity) ||
+ (arity < 0 && static_cast<int>(command_args_.size()) < -arity)) {
*output = redis::Error("ERR wrong number of arguments");
return {Status::RedisExecErr, errWrongNumOfArguments};
}
diff --git a/src/server/redis_connection.cc b/src/server/redis_connection.cc
index ae80e950..d6e0b5f6 100644
--- a/src/server/redis_connection.cc
+++ b/src/server/redis_connection.cc
@@ -261,6 +261,45 @@ void Connection::PUnsubscribeAll(const UnsubscribeCallback
&reply) {
int Connection::PSubscriptionsCount() { return
static_cast<int>(subscribe_patterns_.size()); }
+void Connection::SSubscribeChannel(const std::string &channel, uint16_t slot) {
+ for (const auto &chan : subscribe_shard_channels_) {
+ if (channel == chan) return;
+ }
+
+ subscribe_shard_channels_.emplace_back(channel);
+ owner_->srv->SSubscribeChannel(channel, this, slot);
+}
+
+void Connection::SUnsubscribeChannel(const std::string &channel, uint16_t
slot) {
+ for (auto iter = subscribe_shard_channels_.begin(); iter !=
subscribe_shard_channels_.end(); iter++) {
+ if (*iter == channel) {
+ subscribe_shard_channels_.erase(iter);
+ owner_->srv->SUnsubscribeChannel(channel, this, slot);
+ return;
+ }
+ }
+}
+
+void Connection::SUnsubscribeAll(const UnsubscribeCallback &reply) {
+ if (subscribe_shard_channels_.empty()) {
+ if (reply) reply("", 0);
+ return;
+ }
+
+ int removed = 0;
+ for (const auto &chan : subscribe_shard_channels_) {
+ owner_->srv->SUnsubscribeChannel(chan, this,
+ owner_->srv->GetConfig()->cluster_enabled
? GetSlotIdFromKey(chan) : 0);
+ removed++;
+ if (reply) {
+ reply(chan, static_cast<int>(subscribe_shard_channels_.size() -
removed));
+ }
+ }
+ subscribe_shard_channels_.clear();
+}
+
+int Connection::SSubscriptionsCount() { return
static_cast<int>(subscribe_shard_channels_.size()); }
+
bool Connection::IsProfilingEnabled(const std::string &cmd) {
auto config = srv_->GetConfig();
if (config->profiling_sample_ratio == 0) return false;
diff --git a/src/server/redis_connection.h b/src/server/redis_connection.h
index 25b522d8..34fbcbae 100644
--- a/src/server/redis_connection.h
+++ b/src/server/redis_connection.h
@@ -74,6 +74,10 @@ class Connection : public EvbufCallbackBase<Connection> {
void PUnsubscribeChannel(const std::string &pattern);
void PUnsubscribeAll(const UnsubscribeCallback &reply = nullptr);
int PSubscriptionsCount();
+ void SSubscribeChannel(const std::string &channel, uint16_t slot);
+ void SUnsubscribeChannel(const std::string &channel, uint16_t slot);
+ void SUnsubscribeAll(const UnsubscribeCallback &reply = nullptr);
+ int SSubscriptionsCount();
uint64_t GetAge() const;
uint64_t GetIdleTime() const;
@@ -159,6 +163,7 @@ class Connection : public EvbufCallbackBase<Connection> {
std::vector<std::string> subscribe_channels_;
std::vector<std::string> subscribe_patterns_;
+ std::vector<std::string> subscribe_shard_channels_;
Server *srv_;
bool in_exec_ = false;
diff --git a/src/server/server.cc b/src/server/server.cc
index f8f2fb94..efe721b2 100644
--- a/src/server/server.cc
+++ b/src/server/server.cc
@@ -78,6 +78,9 @@ Server::Server(engine::Storage *storage, Config *config)
// Init cluster
cluster = std::make_unique<Cluster>(this, config_->binds, config_->port);
+ // init shard pub/sub channels
+ pubsub_shard_channels_.resize(config->cluster_enabled ? HASH_SLOTS_SIZE : 1);
+
for (int i = 0; i < config->workers; i++) {
auto worker = std::make_unique<Worker>(this, config);
// multiple workers can't listen to the same unix socket, so
@@ -497,6 +500,64 @@ void Server::PUnsubscribeChannel(const std::string
&pattern, redis::Connection *
}
}
+void Server::SSubscribeChannel(const std::string &channel, redis::Connection
*conn, uint16_t slot) {
+ assert((config_->cluster_enabled && slot < HASH_SLOTS_SIZE) || slot == 0);
+ std::lock_guard<std::mutex> guard(pubsub_shard_channels_mu_);
+
+ auto conn_ctx = ConnContext(conn->Owner(), conn->GetFD());
+ if (auto iter = pubsub_shard_channels_[slot].find(channel); iter ==
pubsub_shard_channels_[slot].end()) {
+ pubsub_shard_channels_[slot].emplace(channel,
std::list<ConnContext>{conn_ctx});
+ } else {
+ iter->second.emplace_back(conn_ctx);
+ }
+}
+
+void Server::SUnsubscribeChannel(const std::string &channel, redis::Connection
*conn, uint16_t slot) {
+ assert((config_->cluster_enabled && slot < HASH_SLOTS_SIZE) || slot == 0);
+ std::lock_guard<std::mutex> guard(pubsub_shard_channels_mu_);
+
+ auto iter = pubsub_shard_channels_[slot].find(channel);
+ if (iter == pubsub_shard_channels_[slot].end()) {
+ return;
+ }
+
+ for (const auto &conn_ctx : iter->second) {
+ if (conn->GetFD() == conn_ctx.fd && conn->Owner() == conn_ctx.owner) {
+ iter->second.remove(conn_ctx);
+ if (iter->second.empty()) {
+ pubsub_shard_channels_[slot].erase(iter);
+ }
+ break;
+ }
+ }
+}
+
+void Server::GetSChannelsByPattern(const std::string &pattern,
std::vector<std::string> *channels) {
+ std::lock_guard<std::mutex> guard(pubsub_shard_channels_mu_);
+
+ for (const auto &shard_channels : pubsub_shard_channels_) {
+ for (const auto &iter : shard_channels) {
+ if (pattern.empty() || util::StringMatch(pattern, iter.first, 0)) {
+ channels->emplace_back(iter.first);
+ }
+ }
+ }
+}
+
+void Server::ListSChannelSubscribeNum(const std::vector<std::string> &channels,
+ std::vector<ChannelSubscribeNum>
*channel_subscribe_nums) {
+ std::lock_guard<std::mutex> guard(pubsub_shard_channels_mu_);
+
+ for (const auto &chan : channels) {
+ uint16_t slot = config_->cluster_enabled ? GetSlotIdFromKey(chan) : 0;
+ if (auto iter = pubsub_shard_channels_[slot].find(chan); iter !=
pubsub_shard_channels_[slot].end()) {
+ channel_subscribe_nums->emplace_back(ChannelSubscribeNum{iter->first,
iter->second.size()});
+ } else {
+ channel_subscribe_nums->emplace_back(ChannelSubscribeNum{chan, 0});
+ }
+ }
+}
+
void Server::BlockOnKey(const std::string &key, redis::Connection *conn) {
std::lock_guard<std::mutex> guard(blocking_keys_mu_);
diff --git a/src/server/server.h b/src/server/server.h
index 2acd0f5d..a86eedf1 100644
--- a/src/server/server.h
+++ b/src/server/server.h
@@ -201,6 +201,11 @@ class Server {
void PSubscribeChannel(const std::string &pattern, redis::Connection *conn);
void PUnsubscribeChannel(const std::string &pattern, redis::Connection
*conn);
size_t GetPubSubPatternSize() const { return pubsub_patterns_.size(); }
+ void SSubscribeChannel(const std::string &channel, redis::Connection *conn,
uint16_t slot);
+ void SUnsubscribeChannel(const std::string &channel, redis::Connection
*conn, uint16_t slot);
+ void GetSChannelsByPattern(const std::string &pattern,
std::vector<std::string> *channels);
+ void ListSChannelSubscribeNum(const std::vector<std::string> &channels,
+ std::vector<ChannelSubscribeNum>
*channel_subscribe_nums);
void BlockOnKey(const std::string &key, redis::Connection *conn);
void UnblockOnKey(const std::string &key, redis::Connection *conn);
@@ -351,6 +356,8 @@ class Server {
std::map<std::string, std::list<ConnContext>> pubsub_channels_;
std::map<std::string, std::list<ConnContext>> pubsub_patterns_;
std::mutex pubsub_channels_mu_;
+ std::vector<std::map<std::string, std::list<ConnContext>>>
pubsub_shard_channels_;
+ std::mutex pubsub_shard_channels_mu_;
std::map<std::string, std::list<ConnContext>> blocking_keys_;
std::mutex blocking_keys_mu_;
diff --git a/tests/gocase/unit/pubsub/pubsubshard_test.go
b/tests/gocase/unit/pubsub/pubsubshard_test.go
new file mode 100644
index 00000000..9e8b04cf
--- /dev/null
+++ b/tests/gocase/unit/pubsub/pubsubshard_test.go
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package pubsub
+
+import (
+ "context"
+ "fmt"
+ "testing"
+
+ "github.com/apache/kvrocks/tests/gocase/util"
+ "github.com/redis/go-redis/v9"
+ "github.com/stretchr/testify/require"
+)
+
+func TestPubSubShard(t *testing.T) {
+ ctx := context.Background()
+
+ srv := util.StartServer(t, map[string]string{})
+ defer srv.Close()
+ rdb := srv.NewClient()
+ defer func() { require.NoError(t, rdb.Close()) }()
+
+ csrv := util.StartServer(t, map[string]string{"cluster-enabled": "yes"})
+ defer csrv.Close()
+ crdb := csrv.NewClient()
+ defer func() { require.NoError(t, crdb.Close()) }()
+
+ nodeID := "YYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYY"
+ require.NoError(t, crdb.Do(ctx, "clusterx", "SETNODEID", nodeID).Err())
+ clusterNodes := fmt.Sprintf("%s %s %d master - 0-16383", nodeID,
csrv.Host(), csrv.Port())
+ require.NoError(t, crdb.Do(ctx, "clusterx", "SETNODES", clusterNodes,
"1").Err())
+
+ rdbs := []*redis.Client{rdb, crdb}
+
+ t.Run("SSUBSCRIBE PING", func(t *testing.T) {
+ pubsub := rdb.SSubscribe(ctx, "somechannel")
+ receiveType(t, pubsub, &redis.Subscription{})
+ require.NoError(t, pubsub.Ping(ctx))
+ require.NoError(t, pubsub.Ping(ctx))
+ require.NoError(t, pubsub.SUnsubscribe(ctx, "somechannel"))
+ require.Equal(t, "PONG", rdb.Ping(ctx).Val())
+ receiveType(t, pubsub, &redis.Pong{})
+ receiveType(t, pubsub, &redis.Pong{})
+ })
+
+ t.Run("SSUBSCRIBE/SUNSUBSCRIBE basic", func(t *testing.T) {
+ for _, c := range rdbs {
+ pubsub := c.SSubscribe(ctx, "singlechannel")
+ defer pubsub.Close()
+
+ msg := receiveType(t, pubsub, &redis.Subscription{})
+ require.EqualValues(t, 1, msg.Count)
+ require.EqualValues(t, "singlechannel", msg.Channel)
+ require.EqualValues(t, "ssubscribe", msg.Kind)
+
+ err := pubsub.SSubscribe(ctx, "multichannel1{tag1}",
"multichannel2{tag1}", "multichannel1{tag1}")
+ require.Nil(t, err)
+ require.EqualValues(t, 2, receiveType(t, pubsub,
&redis.Subscription{}).Count)
+ require.EqualValues(t, 3, receiveType(t, pubsub,
&redis.Subscription{}).Count)
+ require.EqualValues(t, 3, receiveType(t, pubsub,
&redis.Subscription{}).Count)
+
+ err = pubsub.SSubscribe(ctx, "multichannel3{tag1}",
"multichannel4{tag2}")
+ require.Nil(t, err)
+ if c == rdb {
+ require.EqualValues(t, 4, receiveType(t,
pubsub, &redis.Subscription{}).Count)
+ require.EqualValues(t, 5, receiveType(t,
pubsub, &redis.Subscription{}).Count)
+ } else {
+ // note: when cluster enabled, shard channels
in single command must belong to the same slot
+ // reference:
https://redis.io/commands/ssubscribe
+ _, err = pubsub.Receive(ctx)
+ require.EqualError(t, err, "ERR CROSSSLOT Keys
in request don't hash to the same slot")
+ }
+
+ err = pubsub.SUnsubscribe(ctx, "multichannel3{tag1}",
"multichannel4{tag2}", "multichannel5{tag2}")
+ require.Nil(t, err)
+ if c == rdb {
+ require.EqualValues(t, 4, receiveType(t,
pubsub, &redis.Subscription{}).Count)
+ require.EqualValues(t, 3, receiveType(t,
pubsub, &redis.Subscription{}).Count)
+ require.EqualValues(t, 3, receiveType(t,
pubsub, &redis.Subscription{}).Count)
+ } else {
+ require.EqualValues(t, 3, receiveType(t,
pubsub, &redis.Subscription{}).Count)
+ require.EqualValues(t, 3, receiveType(t,
pubsub, &redis.Subscription{}).Count)
+ require.EqualValues(t, 3, receiveType(t,
pubsub, &redis.Subscription{}).Count)
+ }
+
+ err = pubsub.SUnsubscribe(ctx)
+ require.Nil(t, err)
+ msg = receiveType(t, pubsub, &redis.Subscription{})
+ require.EqualValues(t, 2, msg.Count)
+ require.EqualValues(t, "sunsubscribe", msg.Kind)
+ require.EqualValues(t, 1, receiveType(t, pubsub,
&redis.Subscription{}).Count)
+ require.EqualValues(t, 0, receiveType(t, pubsub,
&redis.Subscription{}).Count)
+ }
+ })
+
+ t.Run("SSUBSCRIBE/SUNSUBSCRIBE with empty channel", func(t *testing.T) {
+ for _, c := range rdbs {
+ pubsub := c.SSubscribe(ctx)
+ defer pubsub.Close()
+
+ err := pubsub.SUnsubscribe(ctx, "foo", "bar")
+ require.Nil(t, err)
+ require.EqualValues(t, 0, receiveType(t, pubsub,
&redis.Subscription{}).Count)
+ require.EqualValues(t, 0, receiveType(t, pubsub,
&redis.Subscription{}).Count)
+ }
+ })
+
+ t.Run("SHARDNUMSUB returns numbers, not strings", func(t *testing.T) {
+ require.EqualValues(t, map[string]int64{
+ "abc": 0,
+ "def": 0,
+ }, rdb.PubSubShardNumSub(ctx, "abc", "def").Val())
+ })
+
+ t.Run("PUBSUB SHARDNUMSUB/SHARDCHANNELS", func(t *testing.T) {
+ for _, c := range rdbs {
+ pubsub := c.SSubscribe(ctx, "singlechannel")
+ defer pubsub.Close()
+ receiveType(t, pubsub, &redis.Subscription{})
+
+ err := pubsub.SSubscribe(ctx, "multichannel1{tag1}",
"multichannel2{tag1}", "multichannel3{tag1}")
+ require.Nil(t, err)
+ receiveType(t, pubsub, &redis.Subscription{})
+ receiveType(t, pubsub, &redis.Subscription{})
+ receiveType(t, pubsub, &redis.Subscription{})
+
+ pubsub1 := c.SSubscribe(ctx, "multichannel1{tag1}")
+ defer pubsub1.Close()
+
+ sc := c.PubSubShardChannels(ctx, "")
+ require.EqualValues(t, len(sc.Val()), 4)
+ sc = c.PubSubShardChannels(ctx, "multi*")
+ require.EqualValues(t, len(sc.Val()), 3)
+
+ sn := c.PubSubShardNumSub(ctx)
+ require.EqualValues(t, len(sn.Val()), 0)
+ sn = c.PubSubShardNumSub(ctx, "singlechannel",
"multichannel1{tag1}", "multichannel2{tag1}", "multichannel3{tag1}")
+ for i, k := range sn.Val() {
+ if i == "multichannel1{tag1}" {
+ require.EqualValues(t, k, 2)
+ } else {
+ require.EqualValues(t, k, 1)
+ }
+ }
+ }
+ })
+}