This is an automated email from the ASF dual-hosted git repository.

hulk pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new 69b054ed Add support of the command ZDIFF and ZDIFFSTORE (#2021)
69b054ed is described below

commit 69b054edcfba00f82f8b00634383bdb02e75d69e
Author: HashTagInclude <[email protected]>
AuthorDate: Wed Jan 17 15:47:02 2024 +0530

    Add support of the command ZDIFF and ZDIFFSTORE (#2021)
---
 src/commands/cmd_zset.cc                 |  97 ++++++++++++++++++-
 src/types/redis_zset.cc                  |  38 ++++++++
 src/types/redis_zset.h                   |   2 +
 tests/cppunit/types/zset_test.cc         |  78 +++++++++++++++
 tests/gocase/unit/type/zset/zset_test.go | 161 +++++++++++++++++++++++++++++++
 5 files changed, 375 insertions(+), 1 deletion(-)

diff --git a/src/commands/cmd_zset.cc b/src/commands/cmd_zset.cc
index acddad82..05cdde2e 100644
--- a/src/commands/cmd_zset.cc
+++ b/src/commands/cmd_zset.cc
@@ -1419,6 +1419,99 @@ class CommandZRandMember : public Commander {
   bool no_parameters_ = true;
 };
 
+class CommandZDiff : public Commander {
+ public:
+  Status Parse(const std::vector<std::string> &args) override {
+    auto parse_result = ParseInt<int>(args[1], 10);
+    if (!parse_result) return {Status::RedisParseErr, errValueNotInteger};
+
+    numkeys_ = *parse_result;
+    if (numkeys_ > args.size() - 2) return {Status::RedisParseErr, 
errInvalidSyntax};
+
+    size_t j = 0;
+    while (j < numkeys_) {
+      keys_.emplace_back(args[j + 2]);
+      j++;
+    }
+
+    if (auto i = 2 + numkeys_; i < args.size()) {
+      if (util::ToLower(args[i]) == "withscores") {
+        with_scores_ = true;
+      }
+    }
+
+    return Commander::Parse(args);
+  }
+
+  Status Execute(Server *srv, Connection *conn, std::string *output) override {
+    redis::ZSet zset_db(srv->storage, conn->GetNamespace());
+
+    std::vector<MemberScore> members_with_scores;
+    auto s = zset_db.Diff(keys_, &members_with_scores);
+    if (!s.ok()) {
+      return {Status::RedisExecErr, s.ToString()};
+    }
+
+    output->append(redis::MultiLen(members_with_scores.size() * (with_scores_ 
? 2 : 1)));
+    for (const auto &ms : members_with_scores) {
+      output->append(redis::BulkString(ms.member));
+      if (with_scores_) 
output->append(redis::BulkString(util::Float2String(ms.score)));
+    }
+
+    return Status::OK();
+  }
+
+  static CommandKeyRange Range(const std::vector<std::string> &args) {
+    int num_key = *ParseInt<int>(args[1], 10);
+    return {2, 2 + num_key, 1};
+  }
+
+ protected:
+  size_t numkeys_ = 0;
+  std::vector<rocksdb::Slice> keys_;
+  bool with_scores_ = false;
+};
+
+class CommandZDiffStore : public Commander {
+ public:
+  Status Parse(const std::vector<std::string> &args) override {
+    auto parse_result = ParseInt<int>(args[2], 10);
+    if (!parse_result) return {Status::RedisParseErr, errValueNotInteger};
+
+    numkeys_ = *parse_result;
+    if (numkeys_ > args.size() - 3) return {Status::RedisParseErr, 
errInvalidSyntax};
+
+    size_t j = 0;
+    while (j < numkeys_) {
+      keys_.emplace_back(args[j + 3]);
+      j++;
+    }
+
+    return Commander::Parse(args);
+  }
+
+  Status Execute(Server *srv, Connection *conn, std::string *output) override {
+    redis::ZSet zset_db(srv->storage, conn->GetNamespace());
+
+    uint64_t stored_count = 0;
+    auto s = zset_db.DiffStore(args_[1], keys_, &stored_count);
+    if (!s.ok()) {
+      return {Status::RedisExecErr, s.ToString()};
+    }
+    *output = redis::Integer(stored_count);
+    return Status::OK();
+  }
+
+  static CommandKeyRange Range(const std::vector<std::string> &args) {
+    int num_key = *ParseInt<int>(args[1], 10);
+    return {3, 2 + num_key, 1};
+  }
+
+ protected:
+  size_t numkeys_ = 0;
+  std::vector<rocksdb::Slice> keys_;
+};
+
 REDIS_REGISTER_COMMANDS(MakeCmdAttr<CommandZAdd>("zadd", -4, "write", 1, 1, 1),
                         MakeCmdAttr<CommandZCard>("zcard", 2, "read-only", 1, 
1, 1),
                         MakeCmdAttr<CommandZCount>("zcount", 4, "read-only", 
1, 1, 1),
@@ -1451,6 +1544,8 @@ REDIS_REGISTER_COMMANDS(MakeCmdAttr<CommandZAdd>("zadd", 
-4, "write", 1, 1, 1),
                         MakeCmdAttr<CommandZScan>("zscan", -3, "read-only", 1, 
1, 1),
                         MakeCmdAttr<CommandZUnionStore>("zunionstore", -4, 
"write", CommandZUnionStore::Range),
                         MakeCmdAttr<CommandZUnion>("zunion", -3, "read-only", 
CommandZUnion::Range),
-                        MakeCmdAttr<CommandZRandMember>("zrandmember", -2, 
"read-only", 1, 1, 1))
+                        MakeCmdAttr<CommandZRandMember>("zrandmember", -2, 
"read-only", 1, 1, 1),
+                        MakeCmdAttr<CommandZDiff>("zdiff", -3, "read-only", 
CommandZDiff::Range),
+                        MakeCmdAttr<CommandZDiffStore>("zdiffstore", -3, 
"read-only", CommandZDiffStore::Range), )
 
 }  // namespace redis
diff --git a/src/types/redis_zset.cc b/src/types/redis_zset.cc
index 9215d621..7532a5f3 100644
--- a/src/types/redis_zset.cc
+++ b/src/types/redis_zset.cc
@@ -931,4 +931,42 @@ rocksdb::Status ZSet::RandMember(const Slice &user_key, 
int64_t command_count,
   return rocksdb::Status::OK();
 }
 
+rocksdb::Status ZSet::Diff(const std::vector<Slice> &keys, MemberScores 
*members) {
+  members->clear();
+  MemberScores source_member_scores;
+  RangeScoreSpec spec;
+  uint64_t size = 0;
+  auto s = RangeByScore(keys[0], spec, &source_member_scores, &size);
+  if (!s.ok()) return s;
+
+  if (size == 0) {
+    return rocksdb::Status::OK();
+  }
+
+  std::map<std::string, bool> exclude_members;
+  MemberScores target_member_scores;
+  for (size_t i = 1; i < keys.size(); i++) {
+    uint64_t size = 0;
+    s = RangeByScore(keys[i], spec, &target_member_scores, &size);
+    if (!s.ok()) return s;
+    for (const auto &member_score : target_member_scores) {
+      exclude_members[member_score.member] = true;
+    }
+  }
+  for (const auto &member_score : source_member_scores) {
+    if (exclude_members.find(member_score.member) == exclude_members.end()) {
+      members->push_back(member_score);
+    }
+  }
+  return rocksdb::Status::OK();
+}
+
+rocksdb::Status ZSet::DiffStore(const Slice &dst, const std::vector<Slice> 
&keys, uint64_t *stored_count) {
+  MemberScores mscores;
+  auto s = Diff(keys, &mscores);
+  if (!s.ok()) return s;
+  *stored_count = mscores.size();
+  return Overwrite(dst, mscores);
+}
+
 }  // namespace redis
diff --git a/src/types/redis_zset.h b/src/types/redis_zset.h
index 397ca10b..d806d57e 100644
--- a/src/types/redis_zset.h
+++ b/src/types/redis_zset.h
@@ -116,6 +116,8 @@ class ZSet : public SubKeyScanner {
                              AggregateMethod aggregate_method, uint64_t 
*saved_cnt);
   rocksdb::Status Union(const std::vector<KeyWeight> &keys_weights, 
AggregateMethod aggregate_method,
                         std::vector<MemberScore> *members);
+  rocksdb::Status Diff(const std::vector<Slice> &keys, MemberScores *members);
+  rocksdb::Status DiffStore(const Slice &dst, const std::vector<Slice> &keys, 
uint64_t *stored_count);
   rocksdb::Status MGet(const Slice &user_key, const std::vector<Slice> 
&members, std::map<std::string, double> *scores);
   rocksdb::Status GetMetadata(const Slice &ns_key, ZSetMetadata *metadata);
 
diff --git a/tests/cppunit/types/zset_test.cc b/tests/cppunit/types/zset_test.cc
index 34c71d78..da2ce714 100644
--- a/tests/cppunit/types/zset_test.cc
+++ b/tests/cppunit/types/zset_test.cc
@@ -535,3 +535,81 @@ TEST_F(RedisZSetTest, RandMember) {
   auto s = zset_->Del(key_);
   EXPECT_TRUE(s.ok());
 }
+
+TEST_F(RedisZSetTest, Diff) {
+  uint64_t ret = 0;
+
+  std::string k1 = "key1";
+  std::vector<MemberScore> k1_mscores = {{"a", -100.1}, {"b", -100.1}, {"c", 
0}, {"d", 1.234}};
+
+  std::string k2 = "key2";
+  std::vector<MemberScore> k2_mscores = {{"c", -150.1}};
+
+  std::string k3 = "key3";
+  std::vector<MemberScore> k3_mscores = {{"a", -1000.1}, {"c", -100.1}, {"e", 
8000.9}};
+
+  auto s = zset_->Add(k1, ZAddFlags::Default(), &k1_mscores, &ret);
+  EXPECT_EQ(ret, 4);
+  zset_->Add(k2, ZAddFlags::Default(), &k2_mscores, &ret);
+  EXPECT_EQ(ret, 1);
+  zset_->Add(k3, ZAddFlags::Default(), &k3_mscores, &ret);
+  EXPECT_EQ(ret, 3);
+
+  std::vector<MemberScore> mscores;
+  zset_->Diff({k1, k2, k3}, &mscores);
+
+  EXPECT_EQ(2, mscores.size());
+  std::vector<MemberScore> expected_mscores = {{"b", -100.1}, {"d", 1.234}};
+  int index = 0;
+  for (const auto &mscore : expected_mscores) {
+    EXPECT_EQ(mscore.member, mscores[index].member);
+    EXPECT_EQ(mscore.score, mscores[index].score);
+    index++;
+  }
+
+  s = zset_->Del(k1);
+  EXPECT_TRUE(s.ok());
+  s = zset_->Del(k2);
+  EXPECT_TRUE(s.ok());
+  s = zset_->Del(k3);
+  EXPECT_TRUE(s.ok());
+}
+
+TEST_F(RedisZSetTest, DiffStore) {
+  uint64_t ret = 0;
+
+  std::string k1 = "key1";
+  std::vector<MemberScore> k1_mscores = {{"a", -100.1}, {"b", -100.1}, {"c", 
0}, {"d", 1.234}};
+
+  std::string k2 = "key2";
+  std::vector<MemberScore> k2_mscores = {{"c", -150.1}};
+
+  auto s = zset_->Add(k1, ZAddFlags::Default(), &k1_mscores, &ret);
+  EXPECT_EQ(ret, 4);
+  zset_->Add(k2, ZAddFlags::Default(), &k2_mscores, &ret);
+  EXPECT_EQ(ret, 1);
+
+  uint64_t stored_count = 0;
+  zset_->DiffStore("zsetdiff", {k1, k2}, &stored_count);
+  EXPECT_EQ(stored_count, 3);
+
+  RangeScoreSpec spec;
+  std::vector<MemberScore> mscores;
+  zset_->RangeByScore("zsetdiff", spec, &mscores, nullptr);
+  EXPECT_EQ(mscores.size(), 3);
+
+  std::vector<MemberScore> expected_mscores = {{"a", -100.1}, {"b", -100.1}, 
{"d", 1.234}};
+  int index = 0;
+  for (const auto &mscore : expected_mscores) {
+    EXPECT_EQ(mscore.member, mscores[index].member);
+    EXPECT_EQ(mscore.score, mscores[index].score);
+    index++;
+  }
+
+  s = zset_->Del(k1);
+  EXPECT_TRUE(s.ok());
+  s = zset_->Del(k2);
+  EXPECT_TRUE(s.ok());
+  s = zset_->Del("zsetdiff");
+  EXPECT_TRUE(s.ok());
+}
diff --git a/tests/gocase/unit/type/zset/zset_test.go 
b/tests/gocase/unit/type/zset/zset_test.go
index 86adceda..d7bc434e 100644
--- a/tests/gocase/unit/type/zset/zset_test.go
+++ b/tests/gocase/unit/type/zset/zset_test.go
@@ -1463,6 +1463,167 @@ func basicTests(t *testing.T, rdb *redis.Client, ctx 
context.Context, encoding s
                        ).Err(), ".*weight.*not.*double.*")
                })
        }
+
+       t.Run(fmt.Sprintf("ZDIFF with two sets - %s", encoding), func(t 
*testing.T) {
+               createZset(rdb, ctx, "zseta", []redis.Z{
+                       {Score: 1, Member: "a"},
+                       {Score: 2, Member: "b"},
+                       {Score: 3, Member: "c"},
+                       {Score: 3, Member: "d"},
+                       {Score: 4, Member: "e"},
+               })
+               createZset(rdb, ctx, "zsetb", []redis.Z{
+                       {Score: 1, Member: "b"},
+                       {Score: 2, Member: "c"},
+                       {Score: 4, Member: "f"},
+               })
+               cmd := rdb.ZDiff(ctx, "zseta", "zsetb")
+               require.NoError(t, cmd.Err())
+               sort.Strings(cmd.Val())
+               require.EqualValues(t, []string{"a", "d", "e"}, cmd.Val())
+       })
+
+       t.Run(fmt.Sprintf("ZDIFF with three sets - %s", encoding), func(t 
*testing.T) {
+               createZset(rdb, ctx, "zseta", []redis.Z{
+                       {Score: 1, Member: "a"},
+                       {Score: 2, Member: "b"},
+                       {Score: 3, Member: "c"},
+                       {Score: 3, Member: "d"},
+                       {Score: 4, Member: "e"},
+               })
+               createZset(rdb, ctx, "zsetb", []redis.Z{
+                       {Score: 1, Member: "b"},
+                       {Score: 2, Member: "c"},
+                       {Score: 4, Member: "f"},
+               })
+               createZset(rdb, ctx, "zsetc", []redis.Z{
+                       {Score: 3, Member: "c"},
+                       {Score: 3, Member: "d"},
+                       {Score: 5, Member: "e"},
+               })
+               cmd := rdb.ZDiff(ctx, "zseta", "zsetb", "zsetc")
+               require.NoError(t, cmd.Err())
+               sort.Strings(cmd.Val())
+               require.EqualValues(t, []string{"a"}, cmd.Val())
+       })
+
+       t.Run(fmt.Sprintf("ZDIFF with three sets with scores - %s", encoding), 
func(t *testing.T) {
+               createZset(rdb, ctx, "zseta", []redis.Z{
+                       {Score: 1, Member: "a"},
+                       {Score: 2, Member: "b"},
+                       {Score: 3, Member: "c"},
+                       {Score: 3, Member: "d"},
+                       {Score: 4, Member: "e"},
+               })
+               createZset(rdb, ctx, "zsetb", []redis.Z{
+                       {Score: 1, Member: "b"},
+                       {Score: 2, Member: "c"},
+                       {Score: 4, Member: "f"},
+               })
+               createZset(rdb, ctx, "zsetc", []redis.Z{
+                       {Score: 4, Member: "c"},
+                       {Score: 5, Member: "e"},
+               })
+               cmd := rdb.ZDiffWithScores(ctx, "zseta", "zsetb", "zsetc")
+               require.NoError(t, cmd.Err())
+               require.EqualValues(t, []redis.Z([]redis.Z{{Score: 1, Member: 
"a"}, {Score: 3, Member: "d"}}), cmd.Val())
+       })
+
+       t.Run(fmt.Sprintf("ZDIFF with empty sets - %s", encoding), func(t 
*testing.T) {
+               createZset(rdb, ctx, "zseta", []redis.Z{})
+               createZset(rdb, ctx, "zsetb", []redis.Z{})
+               cmd := rdb.ZDiff(ctx, "zseta", "zsetb")
+               require.NoError(t, cmd.Err())
+               require.EqualValues(t, []string([]string{}), cmd.Val())
+       })
+
+       t.Run(fmt.Sprintf("ZDIFF with non existing sets - %s", encoding), 
func(t *testing.T) {
+               rdb.Del(ctx, "zseta")
+               rdb.Del(ctx, "zsetb")
+               cmd := rdb.ZDiff(ctx, "zseta", "zsetb")
+               require.NoError(t, cmd.Err())
+               require.EqualValues(t, []string([]string{}), cmd.Val())
+       })
+
+       t.Run(fmt.Sprintf("ZDIFF with missing set with scores - %s", encoding), 
func(t *testing.T) {
+               createZset(rdb, ctx, "zseta", []redis.Z{
+                       {Score: 1, Member: "a"},
+                       {Score: 2, Member: "b"},
+                       {Score: 3, Member: "c"},
+                       {Score: 3, Member: "d"},
+               })
+               createZset(rdb, ctx, "zsetb", []redis.Z{
+                       {Score: 1, Member: "b"},
+                       {Score: 2, Member: "c"},
+                       {Score: 4, Member: "f"},
+               })
+               rdb.Del(ctx, "zsetc")
+               cmd := rdb.ZDiffWithScores(ctx, "zseta", "zsetb", "zsetc")
+               require.NoError(t, cmd.Err())
+               require.EqualValues(t, []redis.Z([]redis.Z{{Score: 1, Member: 
"a"}, {Score: 3, Member: "d"}}), cmd.Val())
+       })
+
+       t.Run(fmt.Sprintf("ZDIFF with empty sets with scores - %s", encoding), 
func(t *testing.T) {
+               createZset(rdb, ctx, "zseta", []redis.Z{})
+               createZset(rdb, ctx, "zsetb", []redis.Z{})
+               cmd := rdb.ZDiffWithScores(ctx, "zseta", "zsetb")
+               require.NoError(t, cmd.Err())
+               require.EqualValues(t, []redis.Z([]redis.Z{}), cmd.Val())
+       })
+
+       t.Run("ZDIFFSTORE with three sets - ", func(t *testing.T) {
+               createZset(rdb, ctx, "zseta", []redis.Z{
+                       {Score: 1, Member: "a"},
+                       {Score: 2, Member: "b"},
+                       {Score: 3, Member: "c"},
+                       {Score: 3, Member: "d"},
+                       {Score: 4, Member: "e"},
+               })
+               createZset(rdb, ctx, "zsetb", []redis.Z{
+                       {Score: 1, Member: "b"},
+                       {Score: 2, Member: "c"},
+                       {Score: 4, Member: "f"},
+               })
+               createZset(rdb, ctx, "zsetc", []redis.Z{
+                       {Score: 4, Member: "c"},
+                       {Score: 5, Member: "e"},
+               })
+               cmd := rdb.ZDiffStore(ctx, "setres", "zseta", "zsetb", "zsetc")
+               require.NoError(t, cmd.Err())
+               require.EqualValues(t, int64(2), cmd.Val())
+               require.Equal(t, []redis.Z([]redis.Z{{Score: 1, Member: "a"}, 
{Score: 3, Member: "d"}}), rdb.ZRangeWithScores(ctx, "setres", 0, -1).Val())
+       })
+
+       t.Run("ZDIFFSTORE with missing sets - ", func(t *testing.T) {
+               createZset(rdb, ctx, "zseta", []redis.Z{
+                       {Score: 1, Member: "a"},
+                       {Score: 2, Member: "b"},
+                       {Score: 3, Member: "c"},
+                       {Score: 3, Member: "d"},
+                       {Score: 4, Member: "e"},
+               })
+               createZset(rdb, ctx, "zsetb", []redis.Z{
+                       {Score: 1, Member: "b"},
+                       {Score: 2, Member: "c"},
+                       {Score: 4, Member: "f"},
+                       {Score: 4, Member: "e"},
+               })
+               rdb.Del(ctx, "zsetc")
+               cmd := rdb.ZDiffStore(ctx, "setres", "zseta", "zsetb", "zsetc")
+               require.NoError(t, cmd.Err())
+               require.EqualValues(t, int64(2), cmd.Val())
+               require.Equal(t, []redis.Z([]redis.Z{{Score: 1, Member: "a"}, 
{Score: 3, Member: "d"}}), rdb.ZRangeWithScores(ctx, "setres", 0, -1).Val())
+       })
+
+       t.Run("ZDIFFSTORE with missing sets - ", func(t *testing.T) {
+               rdb.Del(ctx, "zseta")
+               rdb.Del(ctx, "zsetb")
+               rdb.Del(ctx, "zsetc")
+               cmd := rdb.ZDiffStore(ctx, "setres", "zseta", "zsetb", "zsetc")
+               require.NoError(t, cmd.Err())
+               require.EqualValues(t, int64(0), cmd.Val())
+               require.Equal(t, []redis.Z([]redis.Z{}), 
rdb.ZRangeWithScores(ctx, "setres", 0, -1).Val())
+       })
 }
 
 func stressTests(t *testing.T, rdb *redis.Client, ctx context.Context, 
encoding string) {

Reply via email to