This is an automated email from the ASF dual-hosted git repository.

twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new a3863f9c feat(hyperloglog): Add support of PFMERGE command (#2457)
a3863f9c is described below

commit a3863f9cb8f84c15bec97a5382cd4e1bb76aa04e
Author: mwish <[email protected]>
AuthorDate: Thu Aug 1 20:58:11 2024 +0800

    feat(hyperloglog): Add support of PFMERGE command (#2457)
---
 src/commands/cmd_hll.cc                 |  38 +++++++--
 src/types/hyperloglog.cc                |  25 ++++++
 src/types/hyperloglog.h                 |   6 ++
 src/types/redis_hyperloglog.cc          | 136 +++++++++++++++++++++++++++++---
 src/types/redis_hyperloglog.h           |  20 +++--
 tests/cppunit/types/hyperloglog_test.cc |  99 +++++++++++++++++++++++
 6 files changed, 298 insertions(+), 26 deletions(-)

diff --git a/src/commands/cmd_hll.cc b/src/commands/cmd_hll.cc
index 343aa322..88545427 100644
--- a/src/commands/cmd_hll.cc
+++ b/src/commands/cmd_hll.cc
@@ -24,12 +24,8 @@
 
 #include "commander.h"
 #include "commands/command_parser.h"
-#include "commands/error_constants.h"
-#include "error_constants.h"
-#include "parse_util.h"
 #include "server/redis_reply.h"
 #include "server/server.h"
-#include "storage/redis_metadata.h"
 
 namespace redis {
 
@@ -57,13 +53,17 @@ class CommandPfAdd final : public Commander {
 /// Complexity: O(1) with a very small average constant time when called with 
a single key.
 ///              O(N) with N being the number of keys, and much bigger 
constant times,
 ///              when called with multiple keys.
-///
-/// TODO(mwish): Currently we don't supports merge, so only one key is 
supported.
 class CommandPfCount final : public Commander {
   Status Execute(Server *srv, Connection *conn, std::string *output) override {
     redis::HyperLogLog hll(srv->storage, conn->GetNamespace());
     uint64_t ret{};
-    auto s = hll.Count(args_[0], &ret);
+    rocksdb::Status s;
+    if (args_.size() > 1) {
+      std::vector<Slice> keys(args_.begin(), args_.end());
+      s = hll.CountMultiple(keys, &ret);
+    } else {
+      s = hll.Count(args_[0], &ret);
+    }
     if (!s.ok() && !s.IsNotFound()) {
       return {Status::RedisExecErr, s.ToString()};
     }
@@ -75,7 +75,29 @@ class CommandPfCount final : public Commander {
   }
 };
 
+/// PFMERGE destkey [sourcekey [sourcekey ...]]
+///
+/// complexity: O(N) to merge N HyperLogLogs, but with high constant times.
+class CommandPfMerge final : public Commander {
+  Status Execute(Server *srv, Connection *conn, std::string *output) override {
+    redis::HyperLogLog hll(srv->storage, conn->GetNamespace());
+    std::vector<std::string> keys(args_.begin() + 1, args_.end());
+    std::vector<Slice> src_user_keys;
+    src_user_keys.reserve(args_.size() - 1);
+    for (size_t i = 1; i < args_.size(); i++) {
+      src_user_keys.emplace_back(args_[i]);
+    }
+    auto s = hll.Merge(/*dest_user_key=*/args_[0], src_user_keys);
+    if (!s.ok() && !s.IsNotFound()) {
+      return {Status::RedisExecErr, s.ToString()};
+    }
+    *output = redis::SimpleString("OK");
+    return Status::OK();
+  }
+};
+
 REDIS_REGISTER_COMMANDS(MakeCmdAttr<CommandPfAdd>("pfadd", -2, "write", 1, 1, 
1),
-                        MakeCmdAttr<CommandPfCount>("pfcount", 2, "read-only", 
1, 1, 1), );
+                        MakeCmdAttr<CommandPfCount>("pfcount", -2, 
"read-only", 1, -1, 1),
+                        MakeCmdAttr<CommandPfMerge>("pfmerge", -2, "write", 1, 
-1, 1), );
 
 }  // namespace redis
diff --git a/src/types/hyperloglog.cc b/src/types/hyperloglog.cc
index 80923181..831988de 100644
--- a/src/types/hyperloglog.cc
+++ b/src/types/hyperloglog.cc
@@ -157,6 +157,31 @@ void HllDenseRegHisto(nonstd::span<const uint8_t> 
registers, int *reghisto) {
   }
 }
 
+void HllMerge(std::vector<std::string> *dest_registers, const 
std::vector<nonstd::span<const uint8_t>> &registers) {
+  for (size_t segment_id = 0; segment_id < kHyperLogLogSegmentCount; 
segment_id++) {
+    std::string *dest_segment = &dest_registers->at(segment_id);
+    nonstd::span<const uint8_t> src_segment = registers[segment_id];
+    if (src_segment.empty()) {
+      continue;
+    }
+    if (dest_segment->empty()) {
+      dest_segment->resize(src_segment.size());
+      memcpy(dest_segment->data(), src_segment.data(), src_segment.size());
+      continue;
+    }
+    // Do physical merge
+    // NOLINTNEXTLINE
+    uint8_t *dest_segment_data = reinterpret_cast<uint8_t 
*>(dest_segment->data());
+    for (size_t register_idx = 0; register_idx < kHyperLogLogSegmentRegisters; 
register_idx++) {
+      uint8_t val = HllDenseGetRegister(src_segment.data(), register_idx);
+      uint8_t previous_val = HllDenseGetRegister(dest_segment_data, 
register_idx);
+      if (val > previous_val) {
+        HllDenseSetRegister(dest_segment_data, register_idx, val);
+      }
+    }
+  }
+}
+
 /* ========================= HyperLogLog Count ==============================
  * This is the core of the algorithm where the approximated count is computed.
  * The function uses the lower level HllDenseRegHisto()
diff --git a/src/types/hyperloglog.h b/src/types/hyperloglog.h
index c99efe5c..2071f900 100644
--- a/src/types/hyperloglog.h
+++ b/src/types/hyperloglog.h
@@ -70,3 +70,9 @@ void HllDenseSetRegister(uint8_t *registers, uint32_t 
register_index, uint8_t va
  *                  or a kHyperLogLogSegmentBytes sized array.
  */
 uint64_t HllDenseEstimate(const std::vector<nonstd::span<const uint8_t>> 
&registers);
+
+/**
+ * Merge by computing MAX(registers_max[i],registers[i]) the HyperLogLog 
'registers'
+ * with an array of uint8_t kHyperLogLogRegisterCount registers pointed by 
'dest_registers'.
+ */
+void HllMerge(std::vector<std::string> *dest_registers, const 
std::vector<nonstd::span<const uint8_t>> &registers);
diff --git a/src/types/redis_hyperloglog.cc b/src/types/redis_hyperloglog.cc
index aef0cc66..f83c7e93 100644
--- a/src/types/redis_hyperloglog.cc
+++ b/src/types/redis_hyperloglog.cc
@@ -28,6 +28,25 @@
 
 namespace redis {
 
+namespace {
+template <typename ElementType>
+std::vector<nonstd::span<const uint8_t>> TransformToSpan(const 
std::vector<ElementType> &registers) {
+  std::vector<nonstd::span<const uint8_t>> register_segments;
+  register_segments.reserve(kHyperLogLogSegmentCount);
+  for (const auto &register_segment : registers) {
+    if (register_segment.empty()) {
+      // Empty segment
+      register_segments.emplace_back();
+      continue;
+    }
+    // NOLINTNEXTLINE
+    const uint8_t *segment_data_ptr = reinterpret_cast<const uint8_t 
*>(register_segment.data());
+    register_segments.emplace_back(segment_data_ptr, register_segment.size());
+  }
+  return register_segments;
+}
+}  // namespace
+
 /// Cache for writing to a HyperLogLog.
 ///
 /// This is a bit like Bitmap::SegmentCacheStore, but simpler because
@@ -163,22 +182,101 @@ rocksdb::Status HyperLogLog::Count(const Slice 
&user_key, uint64_t *ret) {
     if (!s.ok()) return s;
   }
   DCHECK_EQ(kHyperLogLogSegmentCount, registers.size());
-  std::vector<nonstd::span<const uint8_t>> register_segments;
-  register_segments.reserve(kHyperLogLogSegmentCount);
-  for (const auto &register_segment : registers) {
-    if (register_segment.empty()) {
-      // Empty segment
-      register_segments.emplace_back();
+  std::vector<nonstd::span<const uint8_t>> register_segments = 
TransformToSpan(registers);
+  *ret = HllDenseEstimate(register_segments);
+  return rocksdb::Status::OK();
+}
+
+rocksdb::Status HyperLogLog::mergeUserKeys(Database::GetOptions get_options, 
const std::vector<Slice> &user_keys,
+                                           std::vector<std::string> 
*register_segments) {
+  DCHECK_GE(user_keys.size(), static_cast<size_t>(1));
+
+  std::string first_ns_key = AppendNamespacePrefix(user_keys[0]);
+  rocksdb::Status s = getRegisters(get_options, first_ns_key, 
register_segments);
+  if (!s.ok()) return s;
+  // The set of keys that have been seen so far
+  std::unordered_set<std::string_view> seend_user_keys;
+  seend_user_keys.emplace(user_keys[0].ToStringView());
+
+  for (size_t idx = 1; idx < user_keys.size(); idx++) {
+    rocksdb::Slice source_user_key = user_keys[idx];
+    if (!seend_user_keys.emplace(source_user_key.ToStringView()).second) {
+      // Skip duplicate keys
       continue;
     }
-    // NOLINTNEXTLINE
-    const uint8_t *segment_data_ptr = reinterpret_cast<const uint8_t 
*>(register_segment.data());
-    register_segments.emplace_back(segment_data_ptr, register_segment.size());
+    std::string source_key = AppendNamespacePrefix(source_user_key);
+    std::vector<rocksdb::PinnableSlice> source_registers;
+    s = getRegisters(get_options, source_key, &source_registers);
+    if (!s.ok()) return s;
+    DCHECK_EQ(kHyperLogLogSegmentCount, source_registers.size());
+    DCHECK_EQ(kHyperLogLogSegmentCount, register_segments->size());
+    std::vector<nonstd::span<const uint8_t>> source_register_span = 
TransformToSpan(source_registers);
+    HllMerge(register_segments, source_register_span);
   }
-  *ret = HllDenseEstimate(register_segments);
   return rocksdb::Status::OK();
 }
 
+rocksdb::Status HyperLogLog::CountMultiple(const std::vector<Slice> &user_key, 
uint64_t *ret) {
+  DCHECK_GT(user_key.size(), static_cast<size_t>(1));
+  std::vector<std::string> register_segments;
+  // Using same snapshot for all get operations
+  LatestSnapShot ss(storage_);
+  Database::GetOptions get_options(ss.GetSnapShot());
+  auto s = mergeUserKeys(get_options, user_key, &register_segments);
+  if (!s.ok()) return s;
+  std::vector<nonstd::span<const uint8_t>> register_segment_span = 
TransformToSpan(register_segments);
+  *ret = HllDenseEstimate(register_segment_span);
+  return rocksdb::Status::OK();
+}
+
+rocksdb::Status HyperLogLog::Merge(const Slice &dest_user_key, const 
std::vector<Slice> &source_user_keys) {
+  if (source_user_keys.empty()) {
+    return rocksdb::Status::OK();
+  }
+
+  std::string dest_key = AppendNamespacePrefix(dest_user_key);
+  LockGuard guard(storage_->GetLockManager(), dest_key);
+  // Using same snapshot for all get operations
+  LatestSnapShot ss(storage_);
+  Database::GetOptions get_options(ss.GetSnapShot());
+  HyperLogLogMetadata metadata;
+  rocksdb::Status s = GetMetadata(get_options, dest_user_key, &metadata);
+  if (!s.ok() && !s.IsNotFound()) return s;
+  std::vector<std::string> registers;
+  {
+    std::vector<Slice> all_user_keys;
+    all_user_keys.reserve(source_user_keys.size() + 1);
+    all_user_keys.push_back(dest_user_key);
+    for (const auto &source_user_key : source_user_keys) {
+      all_user_keys.push_back(source_user_key);
+    }
+    s = mergeUserKeys(get_options, all_user_keys, &registers);
+  }
+
+  auto batch = storage_->GetWriteBatchBase();
+  WriteBatchLogData log_data(kRedisHyperLogLog);
+  batch->PutLogData(log_data.Encode());
+  for (uint32_t i = 0; i < kHyperLogLogSegmentCount; i++) {
+    if (registers[i].empty()) {
+      continue;
+    }
+    std::string sub_key =
+        InternalKey(dest_key, std::to_string(i), metadata.version, 
storage_->IsSlotIdEncoded()).Encode();
+    batch->Put(sub_key, registers[i]);
+    // Release memory after batch is written
+    registers[i].clear();
+  }
+  // Metadata
+  {
+    metadata.encode_type = HyperLogLogMetadata::EncodeType::DENSE;
+    std::string bytes;
+    metadata.Encode(&bytes);
+    batch->Put(metadata_cf_handle_, dest_key, bytes);
+  }
+
+  return storage_->Write(storage_->DefaultWriteOptions(), 
batch->GetWriteBatch());
+}
+
 rocksdb::Status HyperLogLog::getRegisters(Database::GetOptions get_options, 
const Slice &ns_key,
                                           std::vector<rocksdb::PinnableSlice> 
*register_segments) {
   HyperLogLogMetadata metadata;
@@ -207,15 +305,27 @@ rocksdb::Status 
HyperLogLog::getRegisters(Database::GetOptions get_options, cons
   for (const auto &sub_key : sub_segment_keys) {
     sub_segment_slices.emplace_back(sub_key);
   }
-  std::vector<rocksdb::PinnableSlice> values(kHyperLogLogSegmentCount);
+  register_segments->resize(kHyperLogLogSegmentCount);
   std::vector<rocksdb::Status> statuses(kHyperLogLogSegmentCount);
   storage_->MultiGet(read_options, storage_->GetDB()->DefaultColumnFamily(), 
kHyperLogLogSegmentCount,
-                     sub_segment_slices.data(), values.data(), 
statuses.data());
+                     sub_segment_slices.data(), register_segments->data(), 
statuses.data());
   for (size_t i = 0; i < kHyperLogLogSegmentCount; i++) {
     if (!statuses[i].ok() && !statuses[i].IsNotFound()) {
+      register_segments->at(i).clear();
       return statuses[i];
     }
-    register_segments->push_back(std::move(values[i]));
+  }
+  return rocksdb::Status::OK();
+}
+
+rocksdb::Status HyperLogLog::getRegisters(Database::GetOptions get_options, 
const Slice &ns_key,
+                                          std::vector<std::string> 
*register_segments) {
+  std::vector<rocksdb::PinnableSlice> pinnable_slices;
+  rocksdb::Status s = getRegisters(get_options, ns_key, &pinnable_slices);
+  if (!s.ok()) return s;
+  register_segments->reserve(kHyperLogLogSegmentCount);
+  for (auto &pinnable_slice : pinnable_slices) {
+    register_segments->push_back(pinnable_slice.ToString());
   }
   return rocksdb::Status::OK();
 }
diff --git a/src/types/redis_hyperloglog.h b/src/types/redis_hyperloglog.h
index d18e0335..6b2e441b 100644
--- a/src/types/redis_hyperloglog.h
+++ b/src/types/redis_hyperloglog.h
@@ -30,16 +30,26 @@ class HyperLogLog : public Database {
   explicit HyperLogLog(engine::Storage *storage, const std::string &ns) : 
Database(storage, ns) {}
   rocksdb::Status Add(const Slice &user_key, const std::vector<uint64_t> 
&element_hashes, uint64_t *ret);
   rocksdb::Status Count(const Slice &user_key, uint64_t *ret);
-  // TODO(mwish): Supports merge operation and related commands
-  // rocksdb::Status Merge(const std::vector<Slice> &user_keys);
+  /// The count when user_keys.size() is greater than 1.
+  rocksdb::Status CountMultiple(const std::vector<Slice> &user_key, uint64_t 
*ret);
+  rocksdb::Status Merge(const Slice &dest_user_key, const std::vector<Slice> 
&source_user_keys);
 
   static uint64_t HllHash(std::string_view);
 
  private:
-  rocksdb::Status GetMetadata(Database::GetOptions get_options, const Slice 
&ns_key, HyperLogLogMetadata *metadata);
+  [[nodiscard]] rocksdb::Status GetMetadata(Database::GetOptions get_options, 
const Slice &ns_key,
+                                            HyperLogLogMetadata *metadata);
+
+  [[nodiscard]] rocksdb::Status mergeUserKeys(Database::GetOptions 
get_options, const std::vector<Slice> &user_keys,
+                                              std::vector<std::string> 
*register_segments);
   /// Using multi-get to acquire the register_segments
-  rocksdb::Status getRegisters(Database::GetOptions get_options, const Slice 
&ns_key,
-                               std::vector<rocksdb::PinnableSlice> 
*register_segments);
+  ///
+  /// If the metadata is not found, register_segments will be initialized with 
16 empty slices.
+  [[nodiscard]] rocksdb::Status getRegisters(Database::GetOptions get_options, 
const Slice &ns_key,
+                                             
std::vector<rocksdb::PinnableSlice> *register_segments);
+  /// Same with getRegisters, but the result is stored in a vector of strings.
+  [[nodiscard]] rocksdb::Status getRegisters(Database::GetOptions get_options, 
const Slice &ns_key,
+                                             std::vector<std::string> 
*register_segments);
 };
 
 }  // namespace redis
diff --git a/tests/cppunit/types/hyperloglog_test.cc 
b/tests/cppunit/types/hyperloglog_test.cc
index bf7c4914..234b688e 100644
--- a/tests/cppunit/types/hyperloglog_test.cc
+++ b/tests/cppunit/types/hyperloglog_test.cc
@@ -32,6 +32,22 @@ class RedisHyperLogLogTest : public TestBase {
   }
   ~RedisHyperLogLogTest() override = default;
 
+  void SetUp() override {
+    TestBase::SetUp();
+    [[maybe_unused]] auto s = hll_->Del("hll");
+    for (int x = 1; x <= 3; x++) {
+      s = hll_->Del("hll" + std::to_string(x));
+    }
+  }
+
+  void TearDown() override {
+    TestBase::SetUp();
+    [[maybe_unused]] auto s = hll_->Del("hll");
+    for (int x = 1; x <= 3; x++) {
+      s = hll_->Del("hll" + std::to_string(x));
+    }
+  }
+
   std::unique_ptr<redis::HyperLogLog> hll_;
 
   static std::vector<uint64_t> computeHashes(const 
std::vector<std::string_view> &elements) {
@@ -75,3 +91,86 @@ TEST_F(RedisHyperLogLogTest, 
PFCOUNT_returns_approximated_cardinality_of_set) {
   // pf count is 10
   ASSERT_TRUE(hll_->Count("hll", &ret).ok() && ret == 10);
 }
+
+TEST_F(RedisHyperLogLogTest, 
PFMERGE_results_on_the_cardinality_of_union_of_sets) {
+  uint64_t ret = 0;
+  // pf add hll1 a b c
+  ASSERT_TRUE(hll_->Add("hll1", computeHashes({"a", "b", "c"}), &ret).ok() && 
ret == 1);
+  // pf add hll2 b c d
+  ASSERT_TRUE(hll_->Add("hll2", computeHashes({"b", "c", "d"}), &ret).ok() && 
ret == 1);
+  // pf add hll3 c d e
+  ASSERT_TRUE(hll_->Add("hll3", computeHashes({"c", "d", "e"}), &ret).ok() && 
ret == 1);
+  // pf merge hll hll1 hll2 hll3
+  ASSERT_TRUE(hll_->Merge("hll", {"hll1", "hll2", "hll3"}).ok());
+  // pf count hll is 5
+  ASSERT_TRUE(hll_->Count("hll", &ret).ok());
+  ASSERT_EQ(5, ret);
+}
+
+TEST_F(RedisHyperLogLogTest, PFCOUNT_multiple) {
+  uint64_t ret = 0;
+  ASSERT_TRUE(hll_->CountMultiple({"hll1", "hll2", "hll3"}, &ret).ok());
+  ASSERT_EQ(0, ret);
+  // pf add hll1 a b c
+  ASSERT_TRUE(hll_->Add("hll1", computeHashes({"a", "b", "c"}), &ret).ok() && 
ret == 1);
+  ASSERT_TRUE(hll_->Count("hll1", &ret).ok());
+  ASSERT_EQ(3, ret);
+  ASSERT_TRUE(hll_->CountMultiple({"hll1", "hll2", "hll3"}, &ret).ok());
+  ASSERT_EQ(3, ret);
+  // pf add hll2 b c d
+  ASSERT_TRUE(hll_->Add("hll2", computeHashes({"b", "c", "d"}), &ret).ok() && 
ret == 1);
+  ASSERT_TRUE(hll_->CountMultiple({"hll1", "hll2", "hll3"}, &ret).ok());
+  ASSERT_EQ(4, ret);
+  // pf add hll3 c d e
+  ASSERT_TRUE(hll_->Add("hll3", computeHashes({"c", "d", "e"}), &ret).ok() && 
ret == 1);
+  ASSERT_TRUE(hll_->CountMultiple({"hll1", "hll2", "hll3"}, &ret).ok());
+  ASSERT_EQ(5, ret);
+  // pf merge hll hll1 hll2 hll3
+  ASSERT_TRUE(hll_->Merge("hll", {"hll1", "hll2", "hll3"}).ok());
+  // pf count hll is 5
+  ASSERT_TRUE(hll_->Count("hll", &ret).ok());
+  ASSERT_EQ(5, ret);
+  ASSERT_TRUE(hll_->CountMultiple({"hll1", "hll2", "hll3", "hll"}, &ret).ok());
+  ASSERT_EQ(5, ret);
+}
+
+TEST_F(RedisHyperLogLogTest, 
PFCOUNT_multiple_keys_merge_returns_cardinality_of_union_1) {
+  for (int x = 1; x < 1000; x++) {
+    uint64_t ret = 0;
+    ASSERT_TRUE(hll_->Add("hll0", computeHashes({"foo-" + std::to_string(x)}), 
&ret).ok());
+    ASSERT_TRUE(hll_->Add("hll1", computeHashes({"bar-" + std::to_string(x)}), 
&ret).ok());
+    ASSERT_TRUE(hll_->Add("hll2", computeHashes({"zap-" + std::to_string(x)}), 
&ret).ok());
+    std::vector<uint64_t> cards(3);
+    ASSERT_TRUE(hll_->Count("hll0", &cards[0]).ok());
+    ASSERT_TRUE(hll_->Count("hll1", &cards[1]).ok());
+    ASSERT_TRUE(hll_->Count("hll2", &cards[2]).ok());
+    auto card = static_cast<double>(cards[0] + cards[1] + cards[2]);
+    double realcard = x * 3;
+    // assert the ABS of 'card' and 'realcart' is within 5% of the cardinality
+    double left = std::abs(card - realcard);
+    double right = card / 100 * 5;
+    ASSERT_LT(left, right) << "left : " << left << ", right: " << right;
+  }
+}
+
+TEST_F(RedisHyperLogLogTest, 
PFCOUNT_multiple_keys_merge_returns_cardinality_of_union_2) {
+  std::srand(time(nullptr));
+  std::vector<int> realcard_vec;
+  for (auto i = 1; i < 1000; i++) {
+    for (auto j = 0; j < 3; j++) {
+      uint64_t ret = 0;
+      int rint = std::rand() % 20000;
+      ASSERT_TRUE(hll_->Add("hll" + std::to_string(j), 
computeHashes({std::to_string(rint)}), &ret).ok());
+      realcard_vec.push_back(rint);
+    }
+  }
+  std::vector<uint64_t> cards(3);
+  ASSERT_TRUE(hll_->Count("hll0", &cards[0]).ok());
+  ASSERT_TRUE(hll_->Count("hll1", &cards[1]).ok());
+  ASSERT_TRUE(hll_->Count("hll2", &cards[2]).ok());
+  auto card = static_cast<double>(cards[0] + cards[1] + cards[2]);
+  auto realcard = static_cast<double>(realcard_vec.size());
+  double left = std::abs(card - realcard);
+  double right = card / 100 * 5;
+  ASSERT_LT(left, right) << "left : " << left << ", right: " << right;
+}

Reply via email to