This is an automated email from the ASF dual-hosted git repository.

twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new 80ca6b09 Extract a rand sample helper and support negative sample 
count in "Set" (#2113)
80ca6b09 is described below

commit 80ca6b091bdc8f649fd3e625cfe420e146461983
Author: mwish <[email protected]>
AuthorDate: Mon Feb 26 20:53:12 2024 +0800

    Extract a rand sample helper and support negative sample count in "Set" 
(#2113)
---
 src/types/redis_hash.cc         | 54 ++++++++++++-------------------
 src/types/redis_set.cc          | 72 +++++++++++++++++------------------------
 src/types/redis_zset.cc         | 37 ++++-----------------
 src/types/sample_helper.h       | 67 ++++++++++++++++++++++++++++++++++++++
 tests/cppunit/types/set_test.cc | 15 +++------
 5 files changed, 129 insertions(+), 116 deletions(-)

diff --git a/src/types/redis_hash.cc b/src/types/redis_hash.cc
index d7a1ad8b..dcb1978e 100644
--- a/src/types/redis_hash.cc
+++ b/src/types/redis_hash.cc
@@ -30,6 +30,7 @@
 
 #include "db_util.h"
 #include "parse_util.h"
+#include "sample_helper.h"
 
 namespace redis {
 
@@ -389,43 +390,30 @@ rocksdb::Status Hash::RandField(const Slice &user_key, 
int64_t command_count, st
   rocksdb::Status s = GetMetadata(ns_key, &metadata);
   if (!s.ok()) return s;
 
-  uint64_t size = metadata.size;
   std::vector<FieldValue> samples;
   // TODO: Getting all values in Hash might be heavy, consider lazy-loading 
these values later
   if (count == 0) return rocksdb::Status::OK();
-  s = GetAll(user_key, &samples, type);
-  if (!s.ok()) return s;
-  auto append_field_with_index = [field_values, &samples, type](uint64_t 
index) {
-    if (type == HashFetchType::kAll) {
-      field_values->emplace_back(samples[index].field, samples[index].value);
-    } else {
-      field_values->emplace_back(samples[index].field, "");
-    }
-  };
-  field_values->reserve(std::min(size, count));
-  if (!unique || count == 1) {
-    // Case 1: Negative count, randomly select elements or without parameter
-    std::mt19937 gen(std::random_device{}());
-    std::uniform_int_distribution<uint64_t> dis(0, size - 1);
-    for (uint64_t i = 0; i < count; i++) {
-      uint64_t index = dis(gen);
-      append_field_with_index(index);
-    }
-  } else if (size <= count) {
-    // Case 2: Requested count is greater than or equal to the number of 
elements inside the hash
-    for (uint64_t i = 0; i < size; i++) {
-      append_field_with_index(i);
-    }
-  } else {
-    // Case 3: Requested count is less than the number of elements inside the 
hash
-    std::vector<uint64_t> indices(size);
-    std::iota(indices.begin(), indices.end(), 0);
-    std::mt19937 gen(std::random_device{}());
-    std::shuffle(indices.begin(), indices.end(), gen);  // use Fisher-Yates 
shuffle algorithm to randomize the order
-    for (uint64_t i = 0; i < count; i++) {
-      uint64_t index = indices[i];
-      append_field_with_index(index);
+  s = ExtractRandMemberFromSet<FieldValue>(
+      unique, count,
+      [this, user_key, type](std::vector<FieldValue> *elements) { return 
this->GetAll(user_key, elements, type); },
+      field_values);
+  if (!s.ok()) {
+    return s;
+  }
+  switch (type) {
+    case HashFetchType::kAll:
+      break;
+    case HashFetchType::kOnlyKey: {
+      // GetAll should only fetching the key, checking all the values is empty
+      for (const FieldValue &value : *field_values) {
+        DCHECK(value.value.empty());
+      }
+      break;
     }
+    case HashFetchType::kOnlyValue:
+      // Unreachable.
+      DCHECK(false);
+      break;
   }
   return rocksdb::Status::OK();
 }
diff --git a/src/types/redis_set.cc b/src/types/redis_set.cc
index 98677e27..35403b88 100644
--- a/src/types/redis_set.cc
+++ b/src/types/redis_set.cc
@@ -23,9 +23,9 @@
 #include <map>
 #include <memory>
 #include <optional>
-#include <random>
 
 #include "db_util.h"
+#include "sample_helper.h"
 
 namespace redis {
 
@@ -197,9 +197,14 @@ rocksdb::Status Set::MIsMember(const Slice &user_key, 
const std::vector<Slice> &
 }
 
 rocksdb::Status Set::Take(const Slice &user_key, std::vector<std::string> 
*members, int count, bool pop) {
-  int n = 0;
   members->clear();
-  if (count <= 0) return rocksdb::Status::OK();
+  bool unique = true;
+  if (count == 0) return rocksdb::Status::OK();
+  if (count < 0) {
+    DCHECK(!pop);
+    count = -count;
+    unique = false;
+  }
 
   std::string ns_key = AppendNamespacePrefix(user_key);
 
@@ -210,49 +215,30 @@ rocksdb::Status Set::Take(const Slice &user_key, 
std::vector<std::string> *membe
   rocksdb::Status s = GetMetadata(ns_key, &metadata);
   if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s;
 
-  auto batch = storage_->GetWriteBatchBase();
-  WriteBatchLogData log_data(kRedisSet);
-  batch->PutLogData(log_data.Encode());
-
-  std::string prefix = InternalKey(ns_key, "", metadata.version, 
storage_->IsSlotIdEncoded()).Encode();
-  std::string next_version_prefix = InternalKey(ns_key, "", metadata.version + 
1, storage_->IsSlotIdEncoded()).Encode();
-
-  rocksdb::ReadOptions read_options = storage_->DefaultScanOptions();
-  LatestSnapShot ss(storage_);
-  read_options.snapshot = ss.GetSnapShot();
-  rocksdb::Slice upper_bound(next_version_prefix);
-  read_options.iterate_upper_bound = &upper_bound;
-
-  std::vector<std::string> iter_keys;
-  iter_keys.reserve(count);
-  std::random_device rd;
-  std::mt19937 gen(rd());
-  auto iter = util::UniqueIterator(storage_, read_options);
-  for (iter->Seek(prefix); iter->Valid() && iter->key().starts_with(prefix); 
iter->Next()) {
-    ++n;
-    if (n <= count) {
-      iter_keys.push_back(iter->key().ToString());
-    } else {  // n > count
-      std::uniform_int_distribution<> distrib(0, n - 1);
-      int random = distrib(gen);  // [0,n-1]
-      if (random < count) {
-        iter_keys[random] = iter->key().ToString();
-      }
-    }
+  ObserverOrUniquePtr<rocksdb::WriteBatchBase> batch = 
storage_->GetWriteBatchBase();
+  if (pop) {
+    WriteBatchLogData log_data(kRedisSet);
+    batch->PutLogData(log_data.Encode());
   }
-  for (Slice key : iter_keys) {
-    InternalKey ikey(key, storage_->IsSlotIdEncoded());
-    members->emplace_back(ikey.GetSubKey().ToString());
-    if (pop) {
-      batch->Delete(key);
-    }
+  members->clear();
+  s = ExtractRandMemberFromSet<std::string>(
+      unique, count, [this, user_key](std::vector<std::string> *samples) { 
return this->Members(user_key, samples); },
+      members);
+  if (!s.ok()) {
+    return s;
   }
-  if (pop && !iter_keys.empty()) {
-    metadata.size -= iter_keys.size();
-    std::string bytes;
-    metadata.Encode(&bytes);
-    batch->Put(metadata_cf_handle_, ns_key, bytes);
+  // Avoid to write an empty op-log if just random select some members.
+  if (!pop) return rocksdb::Status::OK();
+  // Avoid to write an empty op-log if the set is empty.
+  if (members->empty()) return rocksdb::Status::OK();
+  for (std::string &user_sub_key : *members) {
+    std::string sub_key = InternalKey(ns_key, user_sub_key, metadata.version, 
storage_->IsSlotIdEncoded()).Encode();
+    batch->Delete(sub_key);
   }
+  metadata.size -= members->size();
+  std::string bytes;
+  metadata.Encode(&bytes);
+  batch->Put(metadata_cf_handle_, ns_key, bytes);
   return storage_->Write(storage_->DefaultWriteOptions(), 
batch->GetWriteBatch());
 }
 
diff --git a/src/types/redis_zset.cc b/src/types/redis_zset.cc
index 328182b5..1dd2feb5 100644
--- a/src/types/redis_zset.cc
+++ b/src/types/redis_zset.cc
@@ -25,10 +25,10 @@
 #include <map>
 #include <memory>
 #include <optional>
-#include <random>
 #include <set>
 
 #include "db_util.h"
+#include "sample_helper.h"
 
 namespace redis {
 
@@ -900,35 +900,12 @@ rocksdb::Status ZSet::RandMember(const Slice &user_key, 
int64_t command_count,
   if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s;
   if (metadata.size == 0) return rocksdb::Status::OK();
 
-  std::vector<MemberScore> samples;
-  s = GetAllMemberScores(user_key, &samples);
-  if (!s.ok() || samples.empty()) return s;
-
-  uint64_t size = samples.size();
-  member_scores->reserve(std::min(size, count));
-
-  if (!unique || count == 1) {
-    std::mt19937 gen(std::random_device{}());
-    std::uniform_int_distribution<uint64_t> dist(0, size - 1);
-    for (uint64_t i = 0; i < count; i++) {
-      uint64_t index = dist(gen);
-      member_scores->emplace_back(samples[index]);
-    }
-  } else if (size <= count) {
-    for (auto &sample : samples) {
-      member_scores->push_back(std::move(sample));
-    }
-  } else {
-    // first shuffle the samples
-    std::mt19937 gen(std::random_device{}());
-    std::shuffle(samples.begin(), samples.end(), gen);
-    // then pick the first `count` ones.
-    for (uint64_t i = 0; i < count; i++) {
-      member_scores->emplace_back(std::move(samples[i]));
-    }
-  }
-
-  return rocksdb::Status::OK();
+  return ExtractRandMemberFromSet<MemberScore>(
+      unique, count,
+      [this, user_key](std::vector<MemberScore> *scores) -> rocksdb::Status {
+        return this->GetAllMemberScores(user_key, scores);
+      },
+      member_scores);
 }
 
 rocksdb::Status ZSet::Diff(const std::vector<Slice> &keys, MemberScores 
*members) {
diff --git a/src/types/sample_helper.h b/src/types/sample_helper.h
new file mode 100644
index 00000000..1f161806
--- /dev/null
+++ b/src/types/sample_helper.h
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#pragma once
+
+#include <rocksdb/status.h>
+
+#include <random>
+#include <vector>
+
+/// ExtractRandMemberFromSet is a helper function to extract random elements 
from a kvrocks structure.
+///
+/// The complexity of the function is O(N) where N is the number of elements 
inside the structure.
+template <typename ElementType, typename GetAllMemberFnType>
+rocksdb::Status ExtractRandMemberFromSet(bool unique, size_t count, const 
GetAllMemberFnType &get_all_member_fn,
+                                         std::vector<ElementType> *elements) {
+  elements->clear();
+  std::vector<ElementType> samples;
+  rocksdb::Status s = get_all_member_fn(&samples);
+  if (!s.ok() || samples.empty()) return s;
+
+  size_t all_element_size = samples.size();
+  DCHECK_GE(all_element_size, 1U);
+  elements->reserve(std::min(all_element_size, count));
+
+  if (!unique || count == 1) {
+    // Case 1: Negative count, randomly select elements or without parameter
+    std::mt19937 gen(std::random_device{}());
+    std::uniform_int_distribution<uint64_t> dist(0, all_element_size - 1);
+    for (uint64_t i = 0; i < count; i++) {
+      uint64_t index = dist(gen);
+      elements->emplace_back(samples[index]);
+    }
+  } else if (all_element_size <= count) {
+    // Case 2: Requested count is greater than or equal to the number of 
elements inside the structure
+    for (auto &sample : samples) {
+      elements->push_back(std::move(sample));
+    }
+  } else {
+    // Case 3: Requested count is less than the number of elements inside the 
structure
+    std::mt19937 gen(std::random_device{}());
+    // use Fisher-Yates shuffle algorithm to randomize the order
+    std::shuffle(samples.begin(), samples.end(), gen);
+    // then pick the first `count` ones.
+    for (uint64_t i = 0; i < count; i++) {
+      elements->emplace_back(std::move(samples[i]));
+    }
+  }
+  return rocksdb::Status::OK();
+}
diff --git a/tests/cppunit/types/set_test.cc b/tests/cppunit/types/set_test.cc
index 5e5c774b..de94a611 100644
--- a/tests/cppunit/types/set_test.cc
+++ b/tests/cppunit/types/set_test.cc
@@ -35,6 +35,8 @@ class RedisSetTest : public TestBase {
     fields_ = {"set-key-1", "set-key-2", "set-key-3", "set-key-4"};
   }
 
+  void TearDown() override { [[maybe_unused]] auto s = set_->Del(key_); }
+
   std::unique_ptr<redis::Set> set_;
 };
 
@@ -48,7 +50,6 @@ TEST_F(RedisSetTest, AddAndRemove) {
   EXPECT_TRUE(s.ok() && fields_.size() == ret);
   s = set_->Card(key_, &ret);
   EXPECT_TRUE(s.ok() && ret == 0);
-  s = set_->Del(key_);
 }
 
 TEST_F(RedisSetTest, AddAndRemoveRepeated) {
@@ -65,8 +66,6 @@ TEST_F(RedisSetTest, AddAndRemoveRepeated) {
   EXPECT_TRUE(s.ok() && (remembers.size() - 1) == ret);
   set_->Card(key_, &card);
   EXPECT_EQ(card, allmembers.size() - 1 - ret);
-
-  s = set_->Del(key_);
 }
 
 TEST_F(RedisSetTest, Members) {
@@ -82,7 +81,6 @@ TEST_F(RedisSetTest, Members) {
   }
   s = set_->Remove(key_, fields_, &ret);
   EXPECT_TRUE(s.ok() && fields_.size() == ret);
-  s = set_->Del(key_);
 }
 
 TEST_F(RedisSetTest, IsMember) {
@@ -98,7 +96,6 @@ TEST_F(RedisSetTest, IsMember) {
   EXPECT_TRUE(s.ok() && !flag);
   s = set_->Remove(key_, fields_, &ret);
   EXPECT_TRUE(s.ok() && fields_.size() == ret);
-  s = set_->Del(key_);
 }
 
 TEST_F(RedisSetTest, MIsMember) {
@@ -118,7 +115,6 @@ TEST_F(RedisSetTest, MIsMember) {
   for (size_t i = 1; i < fields_.size(); i++) {
     EXPECT_TRUE(exists[i] == 1);
   }
-  s = set_->Del(key_);
 }
 
 TEST_F(RedisSetTest, Move) {
@@ -139,7 +135,6 @@ TEST_F(RedisSetTest, Move) {
   EXPECT_TRUE(s.ok() && fields_.size() == ret);
   s = set_->Remove(dst, fields_, &ret);
   EXPECT_TRUE(s.ok() && fields_.size() == ret);
-  s = set_->Del(key_);
   s = set_->Del(dst);
 }
 
@@ -157,7 +152,6 @@ TEST_F(RedisSetTest, TakeWithPop) {
   s = set_->Take(key_, &members, 1, true);
   EXPECT_TRUE(s.ok());
   EXPECT_TRUE(s.ok() && members.size() == 0);
-  s = set_->Del(key_);
 }
 
 TEST_F(RedisSetTest, Diff) {
@@ -261,7 +255,6 @@ TEST_F(RedisSetTest, Overwrite) {
   set_->Overwrite(key_, {"a"});
   set_->Card(key_, &ret);
   EXPECT_EQ(ret, 1);
-  s = set_->Del(key_);
 }
 
 TEST_F(RedisSetTest, TakeWithoutPop) {
@@ -275,7 +268,9 @@ TEST_F(RedisSetTest, TakeWithoutPop) {
   s = set_->Take(key_, &members, int(fields_.size() - 1), false);
   EXPECT_TRUE(s.ok());
   EXPECT_EQ(members.size(), fields_.size() - 1);
+  s = set_->Take(key_, &members, -int(fields_.size() - 1), false);
+  EXPECT_TRUE(s.ok());
+  EXPECT_EQ(members.size(), fields_.size() - 1);
   s = set_->Remove(key_, fields_, &ret);
   EXPECT_TRUE(s.ok() && fields_.size() == ret);
-  s = set_->Del(key_);
 }

Reply via email to