This is an automated email from the ASF dual-hosted git repository.
twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new 80ca6b09 Extract a rand sample helper and support negative sample
count in "Set" (#2113)
80ca6b09 is described below
commit 80ca6b091bdc8f649fd3e625cfe420e146461983
Author: mwish <[email protected]>
AuthorDate: Mon Feb 26 20:53:12 2024 +0800
Extract a rand sample helper and support negative sample count in "Set"
(#2113)
---
src/types/redis_hash.cc | 54 ++++++++++++-------------------
src/types/redis_set.cc | 72 +++++++++++++++++------------------------
src/types/redis_zset.cc | 37 ++++-----------------
src/types/sample_helper.h | 67 ++++++++++++++++++++++++++++++++++++++
tests/cppunit/types/set_test.cc | 15 +++------
5 files changed, 129 insertions(+), 116 deletions(-)
diff --git a/src/types/redis_hash.cc b/src/types/redis_hash.cc
index d7a1ad8b..dcb1978e 100644
--- a/src/types/redis_hash.cc
+++ b/src/types/redis_hash.cc
@@ -30,6 +30,7 @@
#include "db_util.h"
#include "parse_util.h"
+#include "sample_helper.h"
namespace redis {
@@ -389,43 +390,30 @@ rocksdb::Status Hash::RandField(const Slice &user_key,
int64_t command_count, st
rocksdb::Status s = GetMetadata(ns_key, &metadata);
if (!s.ok()) return s;
- uint64_t size = metadata.size;
std::vector<FieldValue> samples;
// TODO: Getting all values in Hash might be heavy, consider lazy-loading
these values later
if (count == 0) return rocksdb::Status::OK();
- s = GetAll(user_key, &samples, type);
- if (!s.ok()) return s;
- auto append_field_with_index = [field_values, &samples, type](uint64_t
index) {
- if (type == HashFetchType::kAll) {
- field_values->emplace_back(samples[index].field, samples[index].value);
- } else {
- field_values->emplace_back(samples[index].field, "");
- }
- };
- field_values->reserve(std::min(size, count));
- if (!unique || count == 1) {
- // Case 1: Negative count, randomly select elements or without parameter
- std::mt19937 gen(std::random_device{}());
- std::uniform_int_distribution<uint64_t> dis(0, size - 1);
- for (uint64_t i = 0; i < count; i++) {
- uint64_t index = dis(gen);
- append_field_with_index(index);
- }
- } else if (size <= count) {
- // Case 2: Requested count is greater than or equal to the number of
elements inside the hash
- for (uint64_t i = 0; i < size; i++) {
- append_field_with_index(i);
- }
- } else {
- // Case 3: Requested count is less than the number of elements inside the
hash
- std::vector<uint64_t> indices(size);
- std::iota(indices.begin(), indices.end(), 0);
- std::mt19937 gen(std::random_device{}());
- std::shuffle(indices.begin(), indices.end(), gen); // use Fisher-Yates
shuffle algorithm to randomize the order
- for (uint64_t i = 0; i < count; i++) {
- uint64_t index = indices[i];
- append_field_with_index(index);
+ s = ExtractRandMemberFromSet<FieldValue>(
+ unique, count,
+ [this, user_key, type](std::vector<FieldValue> *elements) { return
this->GetAll(user_key, elements, type); },
+ field_values);
+ if (!s.ok()) {
+ return s;
+ }
+ switch (type) {
+ case HashFetchType::kAll:
+ break;
+ case HashFetchType::kOnlyKey: {
+ // GetAll should only fetching the key, checking all the values is empty
+ for (const FieldValue &value : *field_values) {
+ DCHECK(value.value.empty());
+ }
+ break;
}
+ case HashFetchType::kOnlyValue:
+ // Unreachable.
+ DCHECK(false);
+ break;
}
return rocksdb::Status::OK();
}
diff --git a/src/types/redis_set.cc b/src/types/redis_set.cc
index 98677e27..35403b88 100644
--- a/src/types/redis_set.cc
+++ b/src/types/redis_set.cc
@@ -23,9 +23,9 @@
#include <map>
#include <memory>
#include <optional>
-#include <random>
#include "db_util.h"
+#include "sample_helper.h"
namespace redis {
@@ -197,9 +197,14 @@ rocksdb::Status Set::MIsMember(const Slice &user_key,
const std::vector<Slice> &
}
rocksdb::Status Set::Take(const Slice &user_key, std::vector<std::string>
*members, int count, bool pop) {
- int n = 0;
members->clear();
- if (count <= 0) return rocksdb::Status::OK();
+ bool unique = true;
+ if (count == 0) return rocksdb::Status::OK();
+ if (count < 0) {
+ DCHECK(!pop);
+ count = -count;
+ unique = false;
+ }
std::string ns_key = AppendNamespacePrefix(user_key);
@@ -210,49 +215,30 @@ rocksdb::Status Set::Take(const Slice &user_key,
std::vector<std::string> *membe
rocksdb::Status s = GetMetadata(ns_key, &metadata);
if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s;
- auto batch = storage_->GetWriteBatchBase();
- WriteBatchLogData log_data(kRedisSet);
- batch->PutLogData(log_data.Encode());
-
- std::string prefix = InternalKey(ns_key, "", metadata.version,
storage_->IsSlotIdEncoded()).Encode();
- std::string next_version_prefix = InternalKey(ns_key, "", metadata.version +
1, storage_->IsSlotIdEncoded()).Encode();
-
- rocksdb::ReadOptions read_options = storage_->DefaultScanOptions();
- LatestSnapShot ss(storage_);
- read_options.snapshot = ss.GetSnapShot();
- rocksdb::Slice upper_bound(next_version_prefix);
- read_options.iterate_upper_bound = &upper_bound;
-
- std::vector<std::string> iter_keys;
- iter_keys.reserve(count);
- std::random_device rd;
- std::mt19937 gen(rd());
- auto iter = util::UniqueIterator(storage_, read_options);
- for (iter->Seek(prefix); iter->Valid() && iter->key().starts_with(prefix);
iter->Next()) {
- ++n;
- if (n <= count) {
- iter_keys.push_back(iter->key().ToString());
- } else { // n > count
- std::uniform_int_distribution<> distrib(0, n - 1);
- int random = distrib(gen); // [0,n-1]
- if (random < count) {
- iter_keys[random] = iter->key().ToString();
- }
- }
+ ObserverOrUniquePtr<rocksdb::WriteBatchBase> batch =
storage_->GetWriteBatchBase();
+ if (pop) {
+ WriteBatchLogData log_data(kRedisSet);
+ batch->PutLogData(log_data.Encode());
}
- for (Slice key : iter_keys) {
- InternalKey ikey(key, storage_->IsSlotIdEncoded());
- members->emplace_back(ikey.GetSubKey().ToString());
- if (pop) {
- batch->Delete(key);
- }
+ members->clear();
+ s = ExtractRandMemberFromSet<std::string>(
+ unique, count, [this, user_key](std::vector<std::string> *samples) {
return this->Members(user_key, samples); },
+ members);
+ if (!s.ok()) {
+ return s;
}
- if (pop && !iter_keys.empty()) {
- metadata.size -= iter_keys.size();
- std::string bytes;
- metadata.Encode(&bytes);
- batch->Put(metadata_cf_handle_, ns_key, bytes);
+ // Avoid to write an empty op-log if just random select some members.
+ if (!pop) return rocksdb::Status::OK();
+ // Avoid to write an empty op-log if the set is empty.
+ if (members->empty()) return rocksdb::Status::OK();
+ for (std::string &user_sub_key : *members) {
+ std::string sub_key = InternalKey(ns_key, user_sub_key, metadata.version,
storage_->IsSlotIdEncoded()).Encode();
+ batch->Delete(sub_key);
}
+ metadata.size -= members->size();
+ std::string bytes;
+ metadata.Encode(&bytes);
+ batch->Put(metadata_cf_handle_, ns_key, bytes);
return storage_->Write(storage_->DefaultWriteOptions(),
batch->GetWriteBatch());
}
diff --git a/src/types/redis_zset.cc b/src/types/redis_zset.cc
index 328182b5..1dd2feb5 100644
--- a/src/types/redis_zset.cc
+++ b/src/types/redis_zset.cc
@@ -25,10 +25,10 @@
#include <map>
#include <memory>
#include <optional>
-#include <random>
#include <set>
#include "db_util.h"
+#include "sample_helper.h"
namespace redis {
@@ -900,35 +900,12 @@ rocksdb::Status ZSet::RandMember(const Slice &user_key,
int64_t command_count,
if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s;
if (metadata.size == 0) return rocksdb::Status::OK();
- std::vector<MemberScore> samples;
- s = GetAllMemberScores(user_key, &samples);
- if (!s.ok() || samples.empty()) return s;
-
- uint64_t size = samples.size();
- member_scores->reserve(std::min(size, count));
-
- if (!unique || count == 1) {
- std::mt19937 gen(std::random_device{}());
- std::uniform_int_distribution<uint64_t> dist(0, size - 1);
- for (uint64_t i = 0; i < count; i++) {
- uint64_t index = dist(gen);
- member_scores->emplace_back(samples[index]);
- }
- } else if (size <= count) {
- for (auto &sample : samples) {
- member_scores->push_back(std::move(sample));
- }
- } else {
- // first shuffle the samples
- std::mt19937 gen(std::random_device{}());
- std::shuffle(samples.begin(), samples.end(), gen);
- // then pick the first `count` ones.
- for (uint64_t i = 0; i < count; i++) {
- member_scores->emplace_back(std::move(samples[i]));
- }
- }
-
- return rocksdb::Status::OK();
+ return ExtractRandMemberFromSet<MemberScore>(
+ unique, count,
+ [this, user_key](std::vector<MemberScore> *scores) -> rocksdb::Status {
+ return this->GetAllMemberScores(user_key, scores);
+ },
+ member_scores);
}
rocksdb::Status ZSet::Diff(const std::vector<Slice> &keys, MemberScores
*members) {
diff --git a/src/types/sample_helper.h b/src/types/sample_helper.h
new file mode 100644
index 00000000..1f161806
--- /dev/null
+++ b/src/types/sample_helper.h
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#pragma once
+
+#include <rocksdb/status.h>
+
+#include <random>
+#include <vector>
+
+/// ExtractRandMemberFromSet is a helper function to extract random elements
from a kvrocks structure.
+///
+/// The complexity of the function is O(N) where N is the number of elements
inside the structure.
+template <typename ElementType, typename GetAllMemberFnType>
+rocksdb::Status ExtractRandMemberFromSet(bool unique, size_t count, const
GetAllMemberFnType &get_all_member_fn,
+ std::vector<ElementType> *elements) {
+ elements->clear();
+ std::vector<ElementType> samples;
+ rocksdb::Status s = get_all_member_fn(&samples);
+ if (!s.ok() || samples.empty()) return s;
+
+ size_t all_element_size = samples.size();
+ DCHECK_GE(all_element_size, 1U);
+ elements->reserve(std::min(all_element_size, count));
+
+ if (!unique || count == 1) {
+ // Case 1: Negative count, randomly select elements or without parameter
+ std::mt19937 gen(std::random_device{}());
+ std::uniform_int_distribution<uint64_t> dist(0, all_element_size - 1);
+ for (uint64_t i = 0; i < count; i++) {
+ uint64_t index = dist(gen);
+ elements->emplace_back(samples[index]);
+ }
+ } else if (all_element_size <= count) {
+ // Case 2: Requested count is greater than or equal to the number of
elements inside the structure
+ for (auto &sample : samples) {
+ elements->push_back(std::move(sample));
+ }
+ } else {
+ // Case 3: Requested count is less than the number of elements inside the
structure
+ std::mt19937 gen(std::random_device{}());
+ // use Fisher-Yates shuffle algorithm to randomize the order
+ std::shuffle(samples.begin(), samples.end(), gen);
+ // then pick the first `count` ones.
+ for (uint64_t i = 0; i < count; i++) {
+ elements->emplace_back(std::move(samples[i]));
+ }
+ }
+ return rocksdb::Status::OK();
+}
diff --git a/tests/cppunit/types/set_test.cc b/tests/cppunit/types/set_test.cc
index 5e5c774b..de94a611 100644
--- a/tests/cppunit/types/set_test.cc
+++ b/tests/cppunit/types/set_test.cc
@@ -35,6 +35,8 @@ class RedisSetTest : public TestBase {
fields_ = {"set-key-1", "set-key-2", "set-key-3", "set-key-4"};
}
+ void TearDown() override { [[maybe_unused]] auto s = set_->Del(key_); }
+
std::unique_ptr<redis::Set> set_;
};
@@ -48,7 +50,6 @@ TEST_F(RedisSetTest, AddAndRemove) {
EXPECT_TRUE(s.ok() && fields_.size() == ret);
s = set_->Card(key_, &ret);
EXPECT_TRUE(s.ok() && ret == 0);
- s = set_->Del(key_);
}
TEST_F(RedisSetTest, AddAndRemoveRepeated) {
@@ -65,8 +66,6 @@ TEST_F(RedisSetTest, AddAndRemoveRepeated) {
EXPECT_TRUE(s.ok() && (remembers.size() - 1) == ret);
set_->Card(key_, &card);
EXPECT_EQ(card, allmembers.size() - 1 - ret);
-
- s = set_->Del(key_);
}
TEST_F(RedisSetTest, Members) {
@@ -82,7 +81,6 @@ TEST_F(RedisSetTest, Members) {
}
s = set_->Remove(key_, fields_, &ret);
EXPECT_TRUE(s.ok() && fields_.size() == ret);
- s = set_->Del(key_);
}
TEST_F(RedisSetTest, IsMember) {
@@ -98,7 +96,6 @@ TEST_F(RedisSetTest, IsMember) {
EXPECT_TRUE(s.ok() && !flag);
s = set_->Remove(key_, fields_, &ret);
EXPECT_TRUE(s.ok() && fields_.size() == ret);
- s = set_->Del(key_);
}
TEST_F(RedisSetTest, MIsMember) {
@@ -118,7 +115,6 @@ TEST_F(RedisSetTest, MIsMember) {
for (size_t i = 1; i < fields_.size(); i++) {
EXPECT_TRUE(exists[i] == 1);
}
- s = set_->Del(key_);
}
TEST_F(RedisSetTest, Move) {
@@ -139,7 +135,6 @@ TEST_F(RedisSetTest, Move) {
EXPECT_TRUE(s.ok() && fields_.size() == ret);
s = set_->Remove(dst, fields_, &ret);
EXPECT_TRUE(s.ok() && fields_.size() == ret);
- s = set_->Del(key_);
s = set_->Del(dst);
}
@@ -157,7 +152,6 @@ TEST_F(RedisSetTest, TakeWithPop) {
s = set_->Take(key_, &members, 1, true);
EXPECT_TRUE(s.ok());
EXPECT_TRUE(s.ok() && members.size() == 0);
- s = set_->Del(key_);
}
TEST_F(RedisSetTest, Diff) {
@@ -261,7 +255,6 @@ TEST_F(RedisSetTest, Overwrite) {
set_->Overwrite(key_, {"a"});
set_->Card(key_, &ret);
EXPECT_EQ(ret, 1);
- s = set_->Del(key_);
}
TEST_F(RedisSetTest, TakeWithoutPop) {
@@ -275,7 +268,9 @@ TEST_F(RedisSetTest, TakeWithoutPop) {
s = set_->Take(key_, &members, int(fields_.size() - 1), false);
EXPECT_TRUE(s.ok());
EXPECT_EQ(members.size(), fields_.size() - 1);
+ s = set_->Take(key_, &members, -int(fields_.size() - 1), false);
+ EXPECT_TRUE(s.ok());
+ EXPECT_EQ(members.size(), fields_.size() - 1);
s = set_->Remove(key_, fields_, &ret);
EXPECT_TRUE(s.ok() && fields_.size() == ret);
- s = set_->Del(key_);
}