This is an automated email from the ASF dual-hosted git repository.
twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new a3863f9c feat(hyperloglog): Add support of PFMERGE command (#2457)
a3863f9c is described below
commit a3863f9cb8f84c15bec97a5382cd4e1bb76aa04e
Author: mwish <[email protected]>
AuthorDate: Thu Aug 1 20:58:11 2024 +0800
feat(hyperloglog): Add support of PFMERGE command (#2457)
---
src/commands/cmd_hll.cc | 38 +++++++--
src/types/hyperloglog.cc | 25 ++++++
src/types/hyperloglog.h | 6 ++
src/types/redis_hyperloglog.cc | 136 +++++++++++++++++++++++++++++---
src/types/redis_hyperloglog.h | 20 +++--
tests/cppunit/types/hyperloglog_test.cc | 99 +++++++++++++++++++++++
6 files changed, 298 insertions(+), 26 deletions(-)
diff --git a/src/commands/cmd_hll.cc b/src/commands/cmd_hll.cc
index 343aa322..88545427 100644
--- a/src/commands/cmd_hll.cc
+++ b/src/commands/cmd_hll.cc
@@ -24,12 +24,8 @@
#include "commander.h"
#include "commands/command_parser.h"
-#include "commands/error_constants.h"
-#include "error_constants.h"
-#include "parse_util.h"
#include "server/redis_reply.h"
#include "server/server.h"
-#include "storage/redis_metadata.h"
namespace redis {
@@ -57,13 +53,17 @@ class CommandPfAdd final : public Commander {
/// Complexity: O(1) with a very small average constant time when called with
a single key.
/// O(N) with N being the number of keys, and much bigger
constant times,
/// when called with multiple keys.
-///
-/// TODO(mwish): Currently we don't supports merge, so only one key is
supported.
class CommandPfCount final : public Commander {
Status Execute(Server *srv, Connection *conn, std::string *output) override {
redis::HyperLogLog hll(srv->storage, conn->GetNamespace());
uint64_t ret{};
- auto s = hll.Count(args_[0], &ret);
+ rocksdb::Status s;
+ if (args_.size() > 1) {
+ std::vector<Slice> keys(args_.begin(), args_.end());
+ s = hll.CountMultiple(keys, &ret);
+ } else {
+ s = hll.Count(args_[0], &ret);
+ }
if (!s.ok() && !s.IsNotFound()) {
return {Status::RedisExecErr, s.ToString()};
}
@@ -75,7 +75,29 @@ class CommandPfCount final : public Commander {
}
};
+/// PFMERGE destkey [sourcekey [sourcekey ...]]
+///
+/// complexity: O(N) to merge N HyperLogLogs, but with high constant times.
+class CommandPfMerge final : public Commander {
+ Status Execute(Server *srv, Connection *conn, std::string *output) override {
+ redis::HyperLogLog hll(srv->storage, conn->GetNamespace());
+ std::vector<std::string> keys(args_.begin() + 1, args_.end());
+ std::vector<Slice> src_user_keys;
+ src_user_keys.reserve(args_.size() - 1);
+ for (size_t i = 1; i < args_.size(); i++) {
+ src_user_keys.emplace_back(args_[i]);
+ }
+ auto s = hll.Merge(/*dest_user_key=*/args_[0], src_user_keys);
+ if (!s.ok() && !s.IsNotFound()) {
+ return {Status::RedisExecErr, s.ToString()};
+ }
+ *output = redis::SimpleString("OK");
+ return Status::OK();
+ }
+};
+
REDIS_REGISTER_COMMANDS(MakeCmdAttr<CommandPfAdd>("pfadd", -2, "write", 1, 1,
1),
- MakeCmdAttr<CommandPfCount>("pfcount", 2, "read-only",
1, 1, 1), );
+ MakeCmdAttr<CommandPfCount>("pfcount", -2,
"read-only", 1, -1, 1),
+ MakeCmdAttr<CommandPfMerge>("pfmerge", -2, "write", 1,
-1, 1), );
} // namespace redis
diff --git a/src/types/hyperloglog.cc b/src/types/hyperloglog.cc
index 80923181..831988de 100644
--- a/src/types/hyperloglog.cc
+++ b/src/types/hyperloglog.cc
@@ -157,6 +157,31 @@ void HllDenseRegHisto(nonstd::span<const uint8_t>
registers, int *reghisto) {
}
}
+void HllMerge(std::vector<std::string> *dest_registers, const
std::vector<nonstd::span<const uint8_t>> ®isters) {
+ for (size_t segment_id = 0; segment_id < kHyperLogLogSegmentCount;
segment_id++) {
+ std::string *dest_segment = &dest_registers->at(segment_id);
+ nonstd::span<const uint8_t> src_segment = registers[segment_id];
+ if (src_segment.empty()) {
+ continue;
+ }
+ if (dest_segment->empty()) {
+ dest_segment->resize(src_segment.size());
+ memcpy(dest_segment->data(), src_segment.data(), src_segment.size());
+ continue;
+ }
+ // Do physical merge
+ // NOLINTNEXTLINE
+ uint8_t *dest_segment_data = reinterpret_cast<uint8_t
*>(dest_segment->data());
+ for (size_t register_idx = 0; register_idx < kHyperLogLogSegmentRegisters;
register_idx++) {
+ uint8_t val = HllDenseGetRegister(src_segment.data(), register_idx);
+ uint8_t previous_val = HllDenseGetRegister(dest_segment_data,
register_idx);
+ if (val > previous_val) {
+ HllDenseSetRegister(dest_segment_data, register_idx, val);
+ }
+ }
+ }
+}
+
/* ========================= HyperLogLog Count ==============================
* This is the core of the algorithm where the approximated count is computed.
* The function uses the lower level HllDenseRegHisto()
diff --git a/src/types/hyperloglog.h b/src/types/hyperloglog.h
index c99efe5c..2071f900 100644
--- a/src/types/hyperloglog.h
+++ b/src/types/hyperloglog.h
@@ -70,3 +70,9 @@ void HllDenseSetRegister(uint8_t *registers, uint32_t
register_index, uint8_t va
* or a kHyperLogLogSegmentBytes sized array.
*/
uint64_t HllDenseEstimate(const std::vector<nonstd::span<const uint8_t>>
®isters);
+
+/**
+ * Merge by computing MAX(registers_max[i],registers[i]) the HyperLogLog
'registers'
+ * with an array of uint8_t kHyperLogLogRegisterCount registers pointed by
'dest_registers'.
+ */
+void HllMerge(std::vector<std::string> *dest_registers, const
std::vector<nonstd::span<const uint8_t>> ®isters);
diff --git a/src/types/redis_hyperloglog.cc b/src/types/redis_hyperloglog.cc
index aef0cc66..f83c7e93 100644
--- a/src/types/redis_hyperloglog.cc
+++ b/src/types/redis_hyperloglog.cc
@@ -28,6 +28,25 @@
namespace redis {
+namespace {
+template <typename ElementType>
+std::vector<nonstd::span<const uint8_t>> TransformToSpan(const
std::vector<ElementType> ®isters) {
+ std::vector<nonstd::span<const uint8_t>> register_segments;
+ register_segments.reserve(kHyperLogLogSegmentCount);
+ for (const auto ®ister_segment : registers) {
+ if (register_segment.empty()) {
+ // Empty segment
+ register_segments.emplace_back();
+ continue;
+ }
+ // NOLINTNEXTLINE
+ const uint8_t *segment_data_ptr = reinterpret_cast<const uint8_t
*>(register_segment.data());
+ register_segments.emplace_back(segment_data_ptr, register_segment.size());
+ }
+ return register_segments;
+}
+} // namespace
+
/// Cache for writing to a HyperLogLog.
///
/// This is a bit like Bitmap::SegmentCacheStore, but simpler because
@@ -163,22 +182,101 @@ rocksdb::Status HyperLogLog::Count(const Slice
&user_key, uint64_t *ret) {
if (!s.ok()) return s;
}
DCHECK_EQ(kHyperLogLogSegmentCount, registers.size());
- std::vector<nonstd::span<const uint8_t>> register_segments;
- register_segments.reserve(kHyperLogLogSegmentCount);
- for (const auto ®ister_segment : registers) {
- if (register_segment.empty()) {
- // Empty segment
- register_segments.emplace_back();
+ std::vector<nonstd::span<const uint8_t>> register_segments =
TransformToSpan(registers);
+ *ret = HllDenseEstimate(register_segments);
+ return rocksdb::Status::OK();
+}
+
+rocksdb::Status HyperLogLog::mergeUserKeys(Database::GetOptions get_options,
const std::vector<Slice> &user_keys,
+ std::vector<std::string>
*register_segments) {
+ DCHECK_GE(user_keys.size(), static_cast<size_t>(1));
+
+ std::string first_ns_key = AppendNamespacePrefix(user_keys[0]);
+ rocksdb::Status s = getRegisters(get_options, first_ns_key,
register_segments);
+ if (!s.ok()) return s;
+ // The set of keys that have been seen so far
+ std::unordered_set<std::string_view> seend_user_keys;
+ seend_user_keys.emplace(user_keys[0].ToStringView());
+
+ for (size_t idx = 1; idx < user_keys.size(); idx++) {
+ rocksdb::Slice source_user_key = user_keys[idx];
+ if (!seend_user_keys.emplace(source_user_key.ToStringView()).second) {
+ // Skip duplicate keys
continue;
}
- // NOLINTNEXTLINE
- const uint8_t *segment_data_ptr = reinterpret_cast<const uint8_t
*>(register_segment.data());
- register_segments.emplace_back(segment_data_ptr, register_segment.size());
+ std::string source_key = AppendNamespacePrefix(source_user_key);
+ std::vector<rocksdb::PinnableSlice> source_registers;
+ s = getRegisters(get_options, source_key, &source_registers);
+ if (!s.ok()) return s;
+ DCHECK_EQ(kHyperLogLogSegmentCount, source_registers.size());
+ DCHECK_EQ(kHyperLogLogSegmentCount, register_segments->size());
+ std::vector<nonstd::span<const uint8_t>> source_register_span =
TransformToSpan(source_registers);
+ HllMerge(register_segments, source_register_span);
}
- *ret = HllDenseEstimate(register_segments);
return rocksdb::Status::OK();
}
+rocksdb::Status HyperLogLog::CountMultiple(const std::vector<Slice> &user_key,
uint64_t *ret) {
+ DCHECK_GT(user_key.size(), static_cast<size_t>(1));
+ std::vector<std::string> register_segments;
+ // Using same snapshot for all get operations
+ LatestSnapShot ss(storage_);
+ Database::GetOptions get_options(ss.GetSnapShot());
+ auto s = mergeUserKeys(get_options, user_key, ®ister_segments);
+ if (!s.ok()) return s;
+ std::vector<nonstd::span<const uint8_t>> register_segment_span =
TransformToSpan(register_segments);
+ *ret = HllDenseEstimate(register_segment_span);
+ return rocksdb::Status::OK();
+}
+
+rocksdb::Status HyperLogLog::Merge(const Slice &dest_user_key, const
std::vector<Slice> &source_user_keys) {
+ if (source_user_keys.empty()) {
+ return rocksdb::Status::OK();
+ }
+
+ std::string dest_key = AppendNamespacePrefix(dest_user_key);
+ LockGuard guard(storage_->GetLockManager(), dest_key);
+ // Using same snapshot for all get operations
+ LatestSnapShot ss(storage_);
+ Database::GetOptions get_options(ss.GetSnapShot());
+ HyperLogLogMetadata metadata;
+ rocksdb::Status s = GetMetadata(get_options, dest_user_key, &metadata);
+ if (!s.ok() && !s.IsNotFound()) return s;
+ std::vector<std::string> registers;
+ {
+ std::vector<Slice> all_user_keys;
+ all_user_keys.reserve(source_user_keys.size() + 1);
+ all_user_keys.push_back(dest_user_key);
+ for (const auto &source_user_key : source_user_keys) {
+ all_user_keys.push_back(source_user_key);
+ }
+ s = mergeUserKeys(get_options, all_user_keys, ®isters);
+ }
+
+ auto batch = storage_->GetWriteBatchBase();
+ WriteBatchLogData log_data(kRedisHyperLogLog);
+ batch->PutLogData(log_data.Encode());
+ for (uint32_t i = 0; i < kHyperLogLogSegmentCount; i++) {
+ if (registers[i].empty()) {
+ continue;
+ }
+ std::string sub_key =
+ InternalKey(dest_key, std::to_string(i), metadata.version,
storage_->IsSlotIdEncoded()).Encode();
+ batch->Put(sub_key, registers[i]);
+ // Release memory after batch is written
+ registers[i].clear();
+ }
+ // Metadata
+ {
+ metadata.encode_type = HyperLogLogMetadata::EncodeType::DENSE;
+ std::string bytes;
+ metadata.Encode(&bytes);
+ batch->Put(metadata_cf_handle_, dest_key, bytes);
+ }
+
+ return storage_->Write(storage_->DefaultWriteOptions(),
batch->GetWriteBatch());
+}
+
rocksdb::Status HyperLogLog::getRegisters(Database::GetOptions get_options,
const Slice &ns_key,
std::vector<rocksdb::PinnableSlice>
*register_segments) {
HyperLogLogMetadata metadata;
@@ -207,15 +305,27 @@ rocksdb::Status
HyperLogLog::getRegisters(Database::GetOptions get_options, cons
for (const auto &sub_key : sub_segment_keys) {
sub_segment_slices.emplace_back(sub_key);
}
- std::vector<rocksdb::PinnableSlice> values(kHyperLogLogSegmentCount);
+ register_segments->resize(kHyperLogLogSegmentCount);
std::vector<rocksdb::Status> statuses(kHyperLogLogSegmentCount);
storage_->MultiGet(read_options, storage_->GetDB()->DefaultColumnFamily(),
kHyperLogLogSegmentCount,
- sub_segment_slices.data(), values.data(),
statuses.data());
+ sub_segment_slices.data(), register_segments->data(),
statuses.data());
for (size_t i = 0; i < kHyperLogLogSegmentCount; i++) {
if (!statuses[i].ok() && !statuses[i].IsNotFound()) {
+ register_segments->at(i).clear();
return statuses[i];
}
- register_segments->push_back(std::move(values[i]));
+ }
+ return rocksdb::Status::OK();
+}
+
+rocksdb::Status HyperLogLog::getRegisters(Database::GetOptions get_options,
const Slice &ns_key,
+ std::vector<std::string>
*register_segments) {
+ std::vector<rocksdb::PinnableSlice> pinnable_slices;
+ rocksdb::Status s = getRegisters(get_options, ns_key, &pinnable_slices);
+ if (!s.ok()) return s;
+ register_segments->reserve(kHyperLogLogSegmentCount);
+ for (auto &pinnable_slice : pinnable_slices) {
+ register_segments->push_back(pinnable_slice.ToString());
}
return rocksdb::Status::OK();
}
diff --git a/src/types/redis_hyperloglog.h b/src/types/redis_hyperloglog.h
index d18e0335..6b2e441b 100644
--- a/src/types/redis_hyperloglog.h
+++ b/src/types/redis_hyperloglog.h
@@ -30,16 +30,26 @@ class HyperLogLog : public Database {
explicit HyperLogLog(engine::Storage *storage, const std::string &ns) :
Database(storage, ns) {}
rocksdb::Status Add(const Slice &user_key, const std::vector<uint64_t>
&element_hashes, uint64_t *ret);
rocksdb::Status Count(const Slice &user_key, uint64_t *ret);
- // TODO(mwish): Supports merge operation and related commands
- // rocksdb::Status Merge(const std::vector<Slice> &user_keys);
+ /// The count when user_keys.size() is greater than 1.
+ rocksdb::Status CountMultiple(const std::vector<Slice> &user_key, uint64_t
*ret);
+ rocksdb::Status Merge(const Slice &dest_user_key, const std::vector<Slice>
&source_user_keys);
static uint64_t HllHash(std::string_view);
private:
- rocksdb::Status GetMetadata(Database::GetOptions get_options, const Slice
&ns_key, HyperLogLogMetadata *metadata);
+ [[nodiscard]] rocksdb::Status GetMetadata(Database::GetOptions get_options,
const Slice &ns_key,
+ HyperLogLogMetadata *metadata);
+
+ [[nodiscard]] rocksdb::Status mergeUserKeys(Database::GetOptions
get_options, const std::vector<Slice> &user_keys,
+ std::vector<std::string>
*register_segments);
/// Using multi-get to acquire the register_segments
- rocksdb::Status getRegisters(Database::GetOptions get_options, const Slice
&ns_key,
- std::vector<rocksdb::PinnableSlice>
*register_segments);
+ ///
+ /// If the metadata is not found, register_segments will be initialized with
16 empty slices.
+ [[nodiscard]] rocksdb::Status getRegisters(Database::GetOptions get_options,
const Slice &ns_key,
+
std::vector<rocksdb::PinnableSlice> *register_segments);
+ /// Same with getRegisters, but the result is stored in a vector of strings.
+ [[nodiscard]] rocksdb::Status getRegisters(Database::GetOptions get_options,
const Slice &ns_key,
+ std::vector<std::string>
*register_segments);
};
} // namespace redis
diff --git a/tests/cppunit/types/hyperloglog_test.cc
b/tests/cppunit/types/hyperloglog_test.cc
index bf7c4914..234b688e 100644
--- a/tests/cppunit/types/hyperloglog_test.cc
+++ b/tests/cppunit/types/hyperloglog_test.cc
@@ -32,6 +32,22 @@ class RedisHyperLogLogTest : public TestBase {
}
~RedisHyperLogLogTest() override = default;
+ void SetUp() override {
+ TestBase::SetUp();
+ [[maybe_unused]] auto s = hll_->Del("hll");
+ for (int x = 1; x <= 3; x++) {
+ s = hll_->Del("hll" + std::to_string(x));
+ }
+ }
+
+ void TearDown() override {
+ TestBase::SetUp();
+ [[maybe_unused]] auto s = hll_->Del("hll");
+ for (int x = 1; x <= 3; x++) {
+ s = hll_->Del("hll" + std::to_string(x));
+ }
+ }
+
std::unique_ptr<redis::HyperLogLog> hll_;
static std::vector<uint64_t> computeHashes(const
std::vector<std::string_view> &elements) {
@@ -75,3 +91,86 @@ TEST_F(RedisHyperLogLogTest,
PFCOUNT_returns_approximated_cardinality_of_set) {
// pf count is 10
ASSERT_TRUE(hll_->Count("hll", &ret).ok() && ret == 10);
}
+
+TEST_F(RedisHyperLogLogTest,
PFMERGE_results_on_the_cardinality_of_union_of_sets) {
+ uint64_t ret = 0;
+ // pf add hll1 a b c
+ ASSERT_TRUE(hll_->Add("hll1", computeHashes({"a", "b", "c"}), &ret).ok() &&
ret == 1);
+ // pf add hll2 b c d
+ ASSERT_TRUE(hll_->Add("hll2", computeHashes({"b", "c", "d"}), &ret).ok() &&
ret == 1);
+ // pf add hll3 c d e
+ ASSERT_TRUE(hll_->Add("hll3", computeHashes({"c", "d", "e"}), &ret).ok() &&
ret == 1);
+ // pf merge hll hll1 hll2 hll3
+ ASSERT_TRUE(hll_->Merge("hll", {"hll1", "hll2", "hll3"}).ok());
+ // pf count hll is 5
+ ASSERT_TRUE(hll_->Count("hll", &ret).ok());
+ ASSERT_EQ(5, ret);
+}
+
+TEST_F(RedisHyperLogLogTest, PFCOUNT_multiple) {
+ uint64_t ret = 0;
+ ASSERT_TRUE(hll_->CountMultiple({"hll1", "hll2", "hll3"}, &ret).ok());
+ ASSERT_EQ(0, ret);
+ // pf add hll1 a b c
+ ASSERT_TRUE(hll_->Add("hll1", computeHashes({"a", "b", "c"}), &ret).ok() &&
ret == 1);
+ ASSERT_TRUE(hll_->Count("hll1", &ret).ok());
+ ASSERT_EQ(3, ret);
+ ASSERT_TRUE(hll_->CountMultiple({"hll1", "hll2", "hll3"}, &ret).ok());
+ ASSERT_EQ(3, ret);
+ // pf add hll2 b c d
+ ASSERT_TRUE(hll_->Add("hll2", computeHashes({"b", "c", "d"}), &ret).ok() &&
ret == 1);
+ ASSERT_TRUE(hll_->CountMultiple({"hll1", "hll2", "hll3"}, &ret).ok());
+ ASSERT_EQ(4, ret);
+ // pf add hll3 c d e
+ ASSERT_TRUE(hll_->Add("hll3", computeHashes({"c", "d", "e"}), &ret).ok() &&
ret == 1);
+ ASSERT_TRUE(hll_->CountMultiple({"hll1", "hll2", "hll3"}, &ret).ok());
+ ASSERT_EQ(5, ret);
+ // pf merge hll hll1 hll2 hll3
+ ASSERT_TRUE(hll_->Merge("hll", {"hll1", "hll2", "hll3"}).ok());
+ // pf count hll is 5
+ ASSERT_TRUE(hll_->Count("hll", &ret).ok());
+ ASSERT_EQ(5, ret);
+ ASSERT_TRUE(hll_->CountMultiple({"hll1", "hll2", "hll3", "hll"}, &ret).ok());
+ ASSERT_EQ(5, ret);
+}
+
+TEST_F(RedisHyperLogLogTest,
PFCOUNT_multiple_keys_merge_returns_cardinality_of_union_1) {
+ for (int x = 1; x < 1000; x++) {
+ uint64_t ret = 0;
+ ASSERT_TRUE(hll_->Add("hll0", computeHashes({"foo-" + std::to_string(x)}),
&ret).ok());
+ ASSERT_TRUE(hll_->Add("hll1", computeHashes({"bar-" + std::to_string(x)}),
&ret).ok());
+ ASSERT_TRUE(hll_->Add("hll2", computeHashes({"zap-" + std::to_string(x)}),
&ret).ok());
+ std::vector<uint64_t> cards(3);
+ ASSERT_TRUE(hll_->Count("hll0", &cards[0]).ok());
+ ASSERT_TRUE(hll_->Count("hll1", &cards[1]).ok());
+ ASSERT_TRUE(hll_->Count("hll2", &cards[2]).ok());
+ auto card = static_cast<double>(cards[0] + cards[1] + cards[2]);
+ double realcard = x * 3;
+ // assert the ABS of 'card' and 'realcart' is within 5% of the cardinality
+ double left = std::abs(card - realcard);
+ double right = card / 100 * 5;
+ ASSERT_LT(left, right) << "left : " << left << ", right: " << right;
+ }
+}
+
+TEST_F(RedisHyperLogLogTest,
PFCOUNT_multiple_keys_merge_returns_cardinality_of_union_2) {
+ std::srand(time(nullptr));
+ std::vector<int> realcard_vec;
+ for (auto i = 1; i < 1000; i++) {
+ for (auto j = 0; j < 3; j++) {
+ uint64_t ret = 0;
+ int rint = std::rand() % 20000;
+ ASSERT_TRUE(hll_->Add("hll" + std::to_string(j),
computeHashes({std::to_string(rint)}), &ret).ok());
+ realcard_vec.push_back(rint);
+ }
+ }
+ std::vector<uint64_t> cards(3);
+ ASSERT_TRUE(hll_->Count("hll0", &cards[0]).ok());
+ ASSERT_TRUE(hll_->Count("hll1", &cards[1]).ok());
+ ASSERT_TRUE(hll_->Count("hll2", &cards[2]).ok());
+ auto card = static_cast<double>(cards[0] + cards[1] + cards[2]);
+ auto realcard = static_cast<double>(realcard_vec.size());
+ double left = std::abs(card - realcard);
+ double right = card / 100 * 5;
+ ASSERT_LT(left, right) << "left : " << left << ", right: " << right;
+}