This is an automated email from the ASF dual-hosted git repository.

hulk pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new dad7494f Add support of new command: ZRANDMEMBER  (#2016)
dad7494f is described below

commit dad7494f0028e51bd3265459196dc4049b9cc615
Author: jxlikar <[email protected]>
AuthorDate: Mon Jan 15 15:57:09 2024 +0800

    Add support of new command: ZRANDMEMBER  (#2016)
---
 src/commands/cmd_zset.cc                 |  65 ++++++++++++++-
 src/types/redis_zset.cc                  |  79 ++++++++++++++++++
 src/types/redis_zset.h                   |   2 +
 tests/cppunit/types/zset_test.cc         | 102 ++++++++++++++++++++++++
 tests/gocase/unit/type/zset/zset_test.go | 132 +++++++++++++++++++++++++++++++
 5 files changed, 379 insertions(+), 1 deletion(-)

diff --git a/src/commands/cmd_zset.cc b/src/commands/cmd_zset.cc
index 1fa51ab2..f1160274 100644
--- a/src/commands/cmd_zset.cc
+++ b/src/commands/cmd_zset.cc
@@ -25,8 +25,11 @@
 #include "commands/blocking_commander.h"
 #include "commands/scan_base.h"
 #include "error_constants.h"
+#include "parse_util.h"
+#include "rocksdb/env.h"
 #include "server/redis_reply.h"
 #include "server/server.h"
+#include "string_util.h"
 #include "types/redis_zset.h"
 
 namespace redis {
@@ -1357,6 +1360,65 @@ class CommandZScan : public CommandSubkeyScanBase {
   }
 };
 
+class CommandZRandMember : public Commander {
+ public:
+  CommandZRandMember() = default;
+
+  Status Parse(const std::vector<std::string> &args) override {
+    if (args.size() > 4) {
+      return {Status::RedisParseErr, errWrongNumOfArguments};
+    }
+
+    if (args.size() >= 3) {
+      no_parameters_ = false;
+      auto parse_result = ParseInt<int64_t>(args[2], 10);
+      if (!parse_result) {
+        return {Status::RedisParseErr, errValueNotInteger};
+      }
+      count_ = *parse_result;
+    }
+
+    if (args.size() == 4) {
+      if (util::ToLower(args[3]) == "withscores") {
+        with_scores_ = true;
+      } else {
+        return {Status::RedisParseErr, errInvalidSyntax};
+      }
+    }
+
+    return Commander::Parse(args);
+  }
+
+  Status Execute(Server *srv, Connection *conn, std::string *output) override {
+    redis::ZSet zset_db(srv->storage, conn->GetNamespace());
+    std::vector<MemberScore> member_scores;
+    auto s = zset_db.RandMember(args_[1], count_, &member_scores);
+
+    if (!s.ok() && !s.IsNotFound()) {
+      return {Status::RedisExecErr, s.ToString()};
+    }
+
+    std::vector<std::string> result_entries;
+    result_entries.reserve(member_scores.size());
+
+    for (const auto &[member, score] : member_scores) {
+      result_entries.emplace_back(member);
+      if (with_scores_) result_entries.emplace_back(util::Float2String(score));
+    }
+
+    if (no_parameters_)
+      *output = s.IsNotFound() ? redis::NilString() : 
redis::BulkString(result_entries[0]);
+    else
+      *output = redis::MultiBulkString(result_entries, false);
+    return Status::OK();
+  }
+
+ private:
+  int64_t count_ = 1;
+  bool with_scores_ = false;
+  bool no_parameters_ = true;
+};
+
 REDIS_REGISTER_COMMANDS(MakeCmdAttr<CommandZAdd>("zadd", -4, "write", 1, 1, 1),
                         MakeCmdAttr<CommandZCard>("zcard", 2, "read-only", 1, 
1, 1),
                         MakeCmdAttr<CommandZCount>("zcount", 4, "read-only", 
1, 1, 1),
@@ -1388,6 +1450,7 @@ REDIS_REGISTER_COMMANDS(MakeCmdAttr<CommandZAdd>("zadd", 
-4, "write", 1, 1, 1),
                         MakeCmdAttr<CommandZMScore>("zmscore", -3, 
"read-only", 1, 1, 1),
                         MakeCmdAttr<CommandZScan>("zscan", -3, "read-only", 1, 
1, 1),
                         MakeCmdAttr<CommandZUnionStore>("zunionstore", -4, 
"write", CommandZUnionStore::Range),
-                        MakeCmdAttr<CommandZUnion>("zunion", -3, "read-only", 
CommandZUnion::Range), )
+                        MakeCmdAttr<CommandZUnion>("zunion", -3, "read-only", 
CommandZUnion::Range),
+                        MakeCmdAttr<CommandZRandMember>("zrandmember", -2, 
"read-only", 1, 1, 1))
 
 }  // namespace redis
diff --git a/src/types/redis_zset.cc b/src/types/redis_zset.cc
index f2976544..57c154ac 100644
--- a/src/types/redis_zset.cc
+++ b/src/types/redis_zset.cc
@@ -25,6 +25,7 @@
 #include <map>
 #include <memory>
 #include <optional>
+#include <random>
 #include <set>
 
 #include "db_util.h"
@@ -851,4 +852,82 @@ rocksdb::Status ZSet::MGet(const Slice &user_key, const 
std::vector<Slice> &memb
   return rocksdb::Status::OK();
 }
 
+rocksdb::Status ZSet::GetAllMemberScores(const Slice &user_key, 
std::vector<MemberScore> *member_scores) {
+  member_scores->clear();
+  std::string ns_key = AppendNamespacePrefix(user_key);
+  ZSetMetadata metadata(false);
+  rocksdb::Status s = GetMetadata(ns_key, &metadata);
+  if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s;
+
+  std::string prefix_key = InternalKey(ns_key, "", metadata.version, 
storage_->IsSlotIdEncoded()).Encode();
+  std::string next_version_prefix_key =
+      InternalKey(ns_key, "", metadata.version + 1, 
storage_->IsSlotIdEncoded()).Encode();
+
+  rocksdb::ReadOptions read_options = storage_->DefaultScanOptions();
+  LatestSnapShot ss(storage_);
+  read_options.snapshot = ss.GetSnapShot();
+
+  rocksdb::Slice upper_bound(next_version_prefix_key);
+  rocksdb::Slice lower_bound(prefix_key);
+  read_options.iterate_upper_bound = &upper_bound;
+  read_options.iterate_lower_bound = &lower_bound;
+
+  auto iter = util::UniqueIterator(storage_, read_options, score_cf_handle_);
+
+  for (iter->Seek(prefix_key); iter->Valid() && 
iter->key().starts_with(prefix_key); iter->Next()) {
+    InternalKey ikey(iter->key(), storage_->IsSlotIdEncoded());
+    Slice score_key = ikey.GetSubKey();
+    double score = NAN;
+    GetDouble(&score_key, &score);
+    member_scores->emplace_back(MemberScore{score_key.ToString(), score});
+  }
+
+  return rocksdb::Status::OK();
+}
+
+rocksdb::Status ZSet::RandMember(const Slice &user_key, int64_t command_count,
+                                 std::vector<MemberScore> *member_scores) {
+  if (command_count == 0) {
+    return rocksdb::Status::OK();
+  }
+
+  uint64_t count = command_count > 0 ? static_cast<uint64_t>(command_count) : 
static_cast<uint64_t>(-command_count);
+  bool unique = (command_count >= 0);
+
+  std::string ns_key = AppendNamespacePrefix(user_key);
+  ZSetMetadata metadata(false);
+  rocksdb::Status s = GetMetadata(ns_key, &metadata);
+  if (!s.ok() || metadata.size == 0) return s;
+
+  std::vector<MemberScore> samples;
+  s = GetAllMemberScores(user_key, &samples);
+  if (!s.ok() || samples.empty()) return s;
+
+  auto size = static_cast<uint64_t>(samples.size());
+  member_scores->reserve(std::min(size, count));
+
+  if (!unique || count == 1) {
+    std::mt19937 gen(std::random_device{}());
+    std::uniform_int_distribution<uint64_t> dist(0, size - 1);
+    for (uint64_t i = 0; i < count; i++) {
+      uint64_t index = dist(gen);
+      member_scores->emplace_back(samples[index]);
+    }
+  } else if (size <= count) {
+    for (auto &sample : samples) {
+      member_scores->push_back(sample);
+    }
+  } else {
+    // first shuffle the samples
+    std::shuffle(samples.begin(), samples.end(), std::random_device{});
+
+    // then pick the first `count` ones.
+    for (uint64_t i = 0; i < count; i++) {
+      member_scores->emplace_back(samples[i]);
+    }
+  }
+
+  return rocksdb::Status::OK();
+}
+
 }  // namespace redis
diff --git a/src/types/redis_zset.h b/src/types/redis_zset.h
index 3cd81622..397ca10b 100644
--- a/src/types/redis_zset.h
+++ b/src/types/redis_zset.h
@@ -126,6 +126,8 @@ class ZSet : public SubKeyScanner {
                                uint64_t *removed_cnt);
   rocksdb::Status RangeByLex(const Slice &user_key, const RangeLexSpec &spec, 
MemberScores *mscores,
                              uint64_t *removed_cnt);
+  rocksdb::Status GetAllMemberScores(const Slice &user_key, 
std::vector<MemberScore> *member_scores);
+  rocksdb::Status RandMember(const Slice &user_key, int64_t command_count, 
std::vector<MemberScore> *member_scores);
 
  private:
   rocksdb::ColumnFamilyHandle *score_cf_handle_;
diff --git a/tests/cppunit/types/zset_test.cc b/tests/cppunit/types/zset_test.cc
index 230aa400..34c71d78 100644
--- a/tests/cppunit/types/zset_test.cc
+++ b/tests/cppunit/types/zset_test.cc
@@ -433,3 +433,105 @@ TEST_F(RedisZSetTest, Rank) {
   }
   auto s = zset_->Del(key_);
 }
+
+TEST_F(RedisZSetTest, RandMember) {
+  uint64_t ret = 0;
+  {
+    std::vector<MemberScore> in_mscores;
+    in_mscores.reserve(fields_.size());
+    for (size_t i = 0; i < fields_.size(); i++) {
+      in_mscores.emplace_back(MemberScore{fields_[i].ToString(), scores_[i]});
+    }
+    zset_->Add(key_, ZAddFlags::Default(), &in_mscores, &ret);
+    EXPECT_EQ(static_cast<int>(fields_.size()), ret);
+  }
+
+  std::unordered_map<std::string, double> member_map;
+  for (size_t i = 0; i < fields_.size(); i++) {
+    member_map[fields_[i].ToString()] = scores_[i];
+  }
+
+  // count = 0
+  {
+    std::vector<MemberScore> mscores;
+    rocksdb::Status s = zset_->RandMember(key_, 0, &mscores);
+    EXPECT_EQ(0, mscores.size());
+    EXPECT_TRUE(s.ok());
+  }
+
+  // count = 1/-1
+  for (int64_t count : {1, -1}) {
+    std::vector<MemberScore> mscores;
+    rocksdb::Status s = zset_->RandMember(key_, count, &mscores);
+    EXPECT_EQ(1, mscores.size());
+    EXPECT_TRUE(s.ok());
+    EXPECT_NE(member_map.find(mscores[0].member), member_map.end());
+  }
+
+  auto no_duplicate_members = [](const std::vector<MemberScore> &mscores) {
+    std::unordered_set<std::string> member_set;
+    for (const auto &mscore : mscores) {
+      if (member_set.find(mscore.member) != member_set.end()) {
+        return false;
+      }
+      member_set.insert(mscore.member);
+    }
+    return true;
+  };
+
+  auto no_non_exist_members = [&member_map](const std::vector<MemberScore> 
&mscores) {
+    for (const auto &mscore : mscores) {
+      const auto find_res = member_map.find(mscore.member);
+      if (find_res == member_map.end() || find_res->second != mscore.score) {
+        return false;
+      }
+    }
+    return true;
+  };
+
+  // count > 1, but count <= fields_.size()
+  for (int64_t count : {
+           static_cast<int64_t>(fields_.size()),
+           static_cast<int64_t>(fields_.size() / 2),
+       }) {
+    std::vector<MemberScore> mscores;
+    rocksdb::Status s = zset_->RandMember(key_, count, &mscores);
+    EXPECT_EQ(static_cast<size_t>(count), mscores.size());
+    EXPECT_TRUE(s.ok());
+    ASSERT_TRUE(no_non_exist_members(mscores));
+    ASSERT_TRUE(no_duplicate_members(mscores));
+  }
+
+  // count < -1, but count >= -fields_.size()
+  for (int64_t count : {
+           -static_cast<int64_t>(fields_.size()),
+           -static_cast<int64_t>(fields_.size() / 2),
+       }) {
+    std::vector<MemberScore> mscores;
+    rocksdb::Status s = zset_->RandMember(key_, count, &mscores);
+    EXPECT_EQ(static_cast<size_t>(-count), mscores.size());
+    EXPECT_TRUE(s.ok());
+    ASSERT_TRUE(no_non_exist_members(mscores));
+  }
+
+  // cout < -fields_.size() or count > fields_.size()
+
+  for (int64_t count : {
+           static_cast<int64_t>(fields_.size() + 10),
+           -static_cast<int64_t>(fields_.size() + 10),
+       }) {
+    std::vector<MemberScore> mscores;
+    rocksdb::Status s = zset_->RandMember(key_, count, &mscores);
+    EXPECT_TRUE(s.ok());
+    ASSERT_TRUE(no_non_exist_members(mscores));
+    if (count > 0) {
+      EXPECT_EQ(fields_.size(), mscores.size());
+      ASSERT_TRUE(no_duplicate_members(mscores));
+    } else {
+      EXPECT_EQ(static_cast<size_t>(-count), mscores.size());
+    }
+  }
+
+  auto s = zset_->Del(key_);
+  EXPECT_TRUE(s.ok());
+}
diff --git a/tests/gocase/unit/type/zset/zset_test.go 
b/tests/gocase/unit/type/zset/zset_test.go
index 860316b2..86adceda 100644
--- a/tests/gocase/unit/type/zset/zset_test.go
+++ b/tests/gocase/unit/type/zset/zset_test.go
@@ -1288,6 +1288,138 @@ func basicTests(t *testing.T, rdb *redis.Client, ctx 
context.Context, encoding s
 
        })
 
+       t.Run(fmt.Sprintf("ZRANDMEMBER without scores - %s", encoding), func(t 
*testing.T) {
+               // create a zset with 6 elements
+               members := []string{"a", "b", "c", "d", "e", "f"}
+               scores := []float64{1, 2, 3, 4, 5, 6}
+               sort.Strings(members)
+               sort.Float64s(scores)
+
+               z := make([]redis.Z, len(members))
+               for i := range members {
+                       z[i] = redis.Z{Score: scores[i], Member: members[i]}
+               }
+               createZset(rdb, ctx, "zset", z)
+
+               // ZRANDMEMBER zset len(members)
+               res := rdb.ZRandMember(ctx, "zset", len(members)).Val()
+               sort.Strings(res)
+               require.Equal(t, members, res)
+
+               // ZRANDMEMBER zset len(members)+10
+               res = rdb.ZRandMember(ctx, "zset", len(members)+10).Val()
+               sort.Strings(res)
+               require.Equal(t, members, res)
+
+               // ZRANDMEMBER zset -len(members)
+               res = rdb.ZRandMember(ctx, "zset", -len(members)).Val()
+               sort.Strings(res)
+               for _, v := range res {
+                       require.Contains(t, members, v)
+               }
+
+               // ZRANDMEMBER zset -len(members) - 10
+               res = rdb.ZRandMember(ctx, "zset", -len(members)-10).Val()
+               sort.Strings(res)
+               require.Equal(t, len(res), len(members)+10)
+               for _, v := range res {
+                       require.Contains(t, members, v)
+               }
+
+               // ZRANDMEMBER zset 0
+               require.Equal(t, []string{}, rdb.ZRandMember(ctx, "zset", 
0).Val())
+               // ZRANDMEMBER zset 1
+               res = rdb.ZRandMember(ctx, "zset", 1).Val()
+               require.Len(t, res, 1)
+               require.Contains(t, members, res[0])
+
+               // ZRANDMEMBER zset 3
+               res = rdb.ZRandMember(ctx, "zset", 3).Val()
+               require.Len(t, res, 3)
+               memberMap := make(map[string]struct{})
+               for _, v := range res {
+                       require.Contains(t, members, v)
+                       memberMap[v] = struct{}{}
+               }
+               require.Equal(t, len(res), len(memberMap))
+
+               // ZRANDMEMBER zset -3
+               res = rdb.ZRandMember(ctx, "zset", -3).Val()
+               require.Len(t, res, 3)
+               for _, v := range res {
+                       require.Contains(t, members, v)
+               }
+       })
+
+       t.Run(fmt.Sprintf("ZRANDMEMBER with scores - %s", encoding), func(t 
*testing.T) {
+               // create a zset with 6 elements
+               members := []string{"a", "b", "c", "d", "e", "f"}
+               scores := []float64{1, 2, 3, 4, 5, 6}
+               sort.Strings(members)
+               sort.Float64s(scores)
+
+               z := make([]redis.Z, len(members))
+               for i := range members {
+                       z[i] = redis.Z{Score: scores[i], Member: members[i]}
+               }
+               createZset(rdb, ctx, "zset", z)
+
+               // ZRANDMEMBER zset len(members) WITHSCORES
+               res := rdb.ZRandMemberWithScores(ctx, "zset", 
len(members)).Val()
+               sort.Slice(res, func(i, j int) bool {
+                       return res[i].Member < res[j].Member
+               })
+               require.Equal(t, z, res)
+
+               // ZRANDMEMBER zset len(members)+10 WITHSCORES
+               res = rdb.ZRandMemberWithScores(ctx, "zset", 
len(members)+10).Val()
+               sort.Slice(res, func(i, j int) bool {
+                       return res[i].Member < res[j].Member
+               })
+               require.Equal(t, z, res)
+
+               // ZRANDMEMBER zset -len(members) WITHSCORES
+               res = rdb.ZRandMemberWithScores(ctx, "zset", 
-len(members)).Val()
+               sort.Slice(res, func(i, j int) bool {
+                       return res[i].Member < res[j].Member
+               })
+               for _, v := range res {
+                       require.Contains(t, z, v)
+               }
+
+               // ZRANDMEMBER zset -len(members)-10 WITHSCORES
+               res = rdb.ZRandMemberWithScores(ctx, "zset", 
-len(members)-10).Val()
+               require.Equal(t, len(members)+10, len(res))
+               for _, v := range res {
+                       require.Contains(t, z, v)
+               }
+
+               // ZRANDMEMBER zset 0 WITHSCORES
+               require.Equal(t, []redis.Z{}, rdb.ZRandMemberWithScores(ctx, 
"zset", 0).Val())
+
+               // ZRANDMEMBER zset 1 WITHSCORES
+               res = rdb.ZRandMemberWithScores(ctx, "zset", 1).Val()
+               require.Len(t, res, 1)
+               require.Contains(t, z, res[0])
+
+               // ZRANDMEMBER zset 3 WITHSCORES
+               res = rdb.ZRandMemberWithScores(ctx, "zset", 3).Val()
+               require.Len(t, res, 3)
+               memberMap := make(map[string]struct{})
+               for _, v := range res {
+                       require.Contains(t, z, v)
+                       memberMap[v.Member] = struct{}{}
+               }
+               require.Equal(t, len(res), len(memberMap))
+
+               // ZRANDMEMBER zset -3 WITHSCORES
+               res = rdb.ZRandMemberWithScores(ctx, "zset", -3).Val()
+               require.Len(t, res, 3)
+               for _, v := range res {
+                       require.Contains(t, z, v)
+               }
+       })
+
        for i, cmd := range []func(ctx context.Context, dest string, store 
*redis.ZStore) *redis.IntCmd{rdb.ZInterStore, rdb.ZUnionStore} {
                var funcName string
                switch i {

Reply via email to