This is an automated email from the ASF dual-hosted git repository.

twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new f9d72976 chore(hyperloglog): add go test cases and minor 
optimizations/bugfixes (#2463)
f9d72976 is described below

commit f9d7297630ecbf342bbed492061acb31fbfa3fd2
Author: mwish <[email protected]>
AuthorDate: Sun Aug 4 14:00:41 2024 +0800

    chore(hyperloglog): add go test cases and minor optimizations/bugfixes 
(#2463)
---
 src/commands/cmd_hll.cc                           |  27 ++-
 src/types/hyperloglog.cc                          |   6 +-
 src/types/redis_hyperloglog.cc                    |  39 ++--
 tests/gocase/unit/hyperloglog/hyperloglog_test.go | 216 ++++++++++++++++++++++
 4 files changed, 256 insertions(+), 32 deletions(-)

diff --git a/src/commands/cmd_hll.cc b/src/commands/cmd_hll.cc
index 88545427..6cde3e61 100644
--- a/src/commands/cmd_hll.cc
+++ b/src/commands/cmd_hll.cc
@@ -35,12 +35,13 @@ class CommandPfAdd final : public Commander {
  public:
   Status Execute(Server *srv, Connection *conn, std::string *output) override {
     redis::HyperLogLog hll(srv->storage, conn->GetNamespace());
-    std::vector<uint64_t> hashes(args_.size() - 1);
-    for (size_t i = 1; i < args_.size(); i++) {
-      hashes[i - 1] = redis::HyperLogLog::HllHash(args_[i]);
+    DCHECK_GE(args_.size(), 2);
+    std::vector<uint64_t> hashes(args_.size() - 2);
+    for (size_t i = 2; i < args_.size(); i++) {
+      hashes[i - 2] = redis::HyperLogLog::HllHash(args_[i]);
     }
     uint64_t ret{};
-    auto s = hll.Add(args_[0], hashes, &ret);
+    auto s = hll.Add(args_[1], hashes, &ret);
     if (!s.ok() && !s.IsNotFound()) {
       return {Status::RedisExecErr, s.ToString()};
     }
@@ -58,11 +59,13 @@ class CommandPfCount final : public Commander {
     redis::HyperLogLog hll(srv->storage, conn->GetNamespace());
     uint64_t ret{};
     rocksdb::Status s;
-    if (args_.size() > 1) {
-      std::vector<Slice> keys(args_.begin(), args_.end());
+    // The first argument is the command name, so we need to skip it.
+    DCHECK_GE(args_.size(), 2);
+    if (args_.size() > 2) {
+      std::vector<Slice> keys(args_.begin() + 1, args_.end());
       s = hll.CountMultiple(keys, &ret);
     } else {
-      s = hll.Count(args_[0], &ret);
+      s = hll.Count(args_[1], &ret);
     }
     if (!s.ok() && !s.IsNotFound()) {
       return {Status::RedisExecErr, s.ToString()};
@@ -81,13 +84,9 @@ class CommandPfCount final : public Commander {
 class CommandPfMerge final : public Commander {
   Status Execute(Server *srv, Connection *conn, std::string *output) override {
     redis::HyperLogLog hll(srv->storage, conn->GetNamespace());
-    std::vector<std::string> keys(args_.begin() + 1, args_.end());
-    std::vector<Slice> src_user_keys;
-    src_user_keys.reserve(args_.size() - 1);
-    for (size_t i = 1; i < args_.size(); i++) {
-      src_user_keys.emplace_back(args_[i]);
-    }
-    auto s = hll.Merge(/*dest_user_key=*/args_[0], src_user_keys);
+    DCHECK_GT(args_.size(), 1);
+    std::vector<Slice> src_user_keys(args_.begin() + 2, args_.end());
+    auto s = hll.Merge(/*dest_user_key=*/args_[1], src_user_keys);
     if (!s.ok() && !s.IsNotFound()) {
       return {Status::RedisExecErr, s.ToString()};
     }
diff --git a/src/types/hyperloglog.cc b/src/types/hyperloglog.cc
index 831988de..e2d0b46c 100644
--- a/src/types/hyperloglog.cc
+++ b/src/types/hyperloglog.cc
@@ -165,11 +165,11 @@ void HllMerge(std::vector<std::string> *dest_registers, 
const std::vector<nonstd
       continue;
     }
     if (dest_segment->empty()) {
-      dest_segment->resize(src_segment.size());
-      memcpy(dest_segment->data(), src_segment.data(), src_segment.size());
+      DCHECK_EQ(kHyperLogLogSegmentBytes, src_segment.size());
+      *dest_segment = std::string(src_segment.begin(), src_segment.end());
       continue;
     }
-    // Do physical merge
+    // Do physical merge for this segment.
     // NOLINTNEXTLINE
     uint8_t *dest_segment_data = reinterpret_cast<uint8_t 
*>(dest_segment->data());
     for (size_t register_idx = 0; register_idx < kHyperLogLogSegmentRegisters; 
register_idx++) {
diff --git a/src/types/redis_hyperloglog.cc b/src/types/redis_hyperloglog.cc
index f83c7e93..e6c0e420 100644
--- a/src/types/redis_hyperloglog.cc
+++ b/src/types/redis_hyperloglog.cc
@@ -115,7 +115,9 @@ rocksdb::Status HyperLogLog::Add(const Slice &user_key, 
const std::vector<uint64
   LockGuard guard(storage_->GetLockManager(), ns_key);
   HyperLogLogMetadata metadata{};
   rocksdb::Status s = GetMetadata(GetOptions(), ns_key, &metadata);
-  if (!s.ok() && !s.IsNotFound()) return s;
+  if (!s.ok() && !s.IsNotFound()) {
+    return s;
+  }
 
   auto batch = storage_->GetWriteBatchBase();
   WriteBatchLogData log_data(kRedisHyperLogLog);
@@ -148,7 +150,9 @@ rocksdb::Status HyperLogLog::Add(const Slice &user_key, 
const std::vector<uint64
     }
   }
   // Nothing changed, no need to flush the segments
-  if (*ret == 0) return rocksdb::Status::OK();
+  if (*ret == 0) {
+    return rocksdb::Status::OK();
+  }
 
   // Flush dirty segments
   // Release memory after batch is written
@@ -179,7 +183,9 @@ rocksdb::Status HyperLogLog::Count(const Slice &user_key, 
uint64_t *ret) {
     LatestSnapShot ss(storage_);
     Database::GetOptions get_options(ss.GetSnapShot());
     auto s = getRegisters(get_options, ns_key, &registers);
-    if (!s.ok()) return s;
+    if (!s.ok()) {
+      return s;
+    }
   }
   DCHECK_EQ(kHyperLogLogSegmentCount, registers.size());
   std::vector<nonstd::span<const uint8_t>> register_segments = 
TransformToSpan(registers);
@@ -236,21 +242,24 @@ rocksdb::Status HyperLogLog::Merge(const Slice 
&dest_user_key, const std::vector
 
   std::string dest_key = AppendNamespacePrefix(dest_user_key);
   LockGuard guard(storage_->GetLockManager(), dest_key);
-  // Using same snapshot for all get operations
-  LatestSnapShot ss(storage_);
-  Database::GetOptions get_options(ss.GetSnapShot());
-  HyperLogLogMetadata metadata;
-  rocksdb::Status s = GetMetadata(get_options, dest_user_key, &metadata);
-  if (!s.ok() && !s.IsNotFound()) return s;
   std::vector<std::string> registers;
+  HyperLogLogMetadata metadata;
   {
-    std::vector<Slice> all_user_keys;
-    all_user_keys.reserve(source_user_keys.size() + 1);
-    all_user_keys.push_back(dest_user_key);
-    for (const auto &source_user_key : source_user_keys) {
-      all_user_keys.push_back(source_user_key);
+    // Using same snapshot for all get operations and release it after
+    // finishing the merge operation
+    LatestSnapShot ss(storage_);
+    Database::GetOptions get_options(ss.GetSnapShot());
+    rocksdb::Status s = GetMetadata(get_options, dest_user_key, &metadata);
+    if (!s.ok() && !s.IsNotFound()) return s;
+    {
+      std::vector<Slice> all_user_keys;
+      all_user_keys.reserve(source_user_keys.size() + 1);
+      all_user_keys.push_back(dest_user_key);
+      for (const auto &source_user_key : source_user_keys) {
+        all_user_keys.push_back(source_user_key);
+      }
+      s = mergeUserKeys(get_options, all_user_keys, &registers);
     }
-    s = mergeUserKeys(get_options, all_user_keys, &registers);
   }
 
   auto batch = storage_->GetWriteBatchBase();
diff --git a/tests/gocase/unit/hyperloglog/hyperloglog_test.go 
b/tests/gocase/unit/hyperloglog/hyperloglog_test.go
new file mode 100644
index 00000000..62fdff40
--- /dev/null
+++ b/tests/gocase/unit/hyperloglog/hyperloglog_test.go
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package hyperloglog
+
+import (
+       "context"
+       "fmt"
+       "testing"
+
+       "github.com/apache/kvrocks/tests/gocase/util"
+       "github.com/stretchr/testify/require"
+)
+
+func TestHyperLogLog(t *testing.T) {
+       srv := util.StartServer(t, map[string]string{})
+       defer srv.Close()
+
+       ctx := context.Background()
+       rdb := srv.NewClient()
+       defer func() { require.NoError(t, rdb.Close()) }()
+
+       t.Run("basic add", func(t *testing.T) {
+               require.NoError(t, rdb.Do(ctx, "DEL", "hll").Err())
+
+               card, err := rdb.PFCount(ctx, "hll").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 0, card)
+               addCnt, err := rdb.PFAdd(ctx, "hll", "foo").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 1, addCnt)
+               card, err = rdb.PFCount(ctx, "hll").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 1, card)
+       })
+
+       t.Run("duplicate add", func(t *testing.T) {
+               require.NoError(t, rdb.Do(ctx, "DEL", "hll").Err())
+
+               card, err := rdb.PFCount(ctx, "hll").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 0, card)
+               addCnt, err := rdb.PFAdd(ctx, "hll", "foo").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 1, addCnt)
+               addCnt, err = rdb.PFAdd(ctx, "hll", "foo").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 0, addCnt)
+       })
+
+       t.Run("empty add", func(t *testing.T) {
+               require.NoError(t, rdb.Do(ctx, "DEL", "hll").Err())
+
+               addCnt, err := rdb.PFAdd(ctx, "hll").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 0, addCnt)
+
+               card, err := rdb.PFCount(ctx, "hll").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 0, card)
+       })
+
+       t.Run("multiple add", func(t *testing.T) {
+               require.NoError(t, rdb.Do(ctx, "DEL", "hll").Err())
+
+               addCnt, err := rdb.PFAdd(ctx, "hll", "a", "b", "c", 
"d").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 1, addCnt)
+
+               card, err := rdb.PFCount(ctx, "hll").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 4, card)
+
+               addCnt, err = rdb.PFAdd(ctx, "hll", "a", "b", "c").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 0, addCnt)
+
+               addCnt, err = rdb.PFAdd(ctx, "hll", "a", "f", "c").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 1, addCnt)
+
+               card, err = rdb.PFCount(ctx, "hll").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 5, card)
+       })
+
+       t.Run("multiple count", func(t *testing.T) {
+               require.NoError(t, rdb.Do(ctx, "DEL", "hll").Err())
+               // Delete hll1, hll2, hll3
+               for i := 0; i < 3; i++ {
+                       require.NoError(t, rdb.Do(ctx, "DEL", 
fmt.Sprintf("hll%d", i)).Err())
+               }
+
+               addCnt, err := rdb.PFAdd(ctx, "hll", "a", "b", "c", 
"d").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 1, addCnt)
+
+               addCnt, err = rdb.PFAdd(ctx, "hll1", "a", "b", "c", 
"d").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 1, addCnt)
+
+               card, err := rdb.PFCount(ctx, "hll", "hll1").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 4, card)
+
+               // Order doesn't matter
+               card, err = rdb.PFCount(ctx, "hll1", "hll").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 4, card)
+
+               // Count non-exist key
+               card, err = rdb.PFCount(ctx, "hll1", "hll2", "hll3").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 4, card)
+
+               addCnt, err = rdb.PFAdd(ctx, "hll2", "1", "2", "3").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 1, addCnt)
+
+               card, err = rdb.PFCount(ctx, "hll", "hll1", "hll2").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 7, card)
+
+               // add with overlap
+               addCnt, err = rdb.PFAdd(ctx, "hll3", "a", "3", "5").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 1, addCnt)
+
+               card, err = rdb.PFCount(ctx, "hll", "hll1", "hll2", 
"hll3").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 8, card)
+       })
+
+       t.Run("basic merge", func(t *testing.T) {
+               require.NoError(t, rdb.Do(ctx, "DEL", "hll").Err())
+               // Delete hll1, hll2, hll3, hll4
+               for i := 0; i < 4; i++ {
+                       require.NoError(t, rdb.Do(ctx, "DEL", 
fmt.Sprintf("hll%d", i)).Err())
+               }
+
+               addCnt, err := rdb.PFAdd(ctx, "hll", "a", "b", "c", 
"d").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 1, addCnt)
+
+               // Empty merge
+               mergeCmd, err := rdb.PFMerge(ctx, "hll").Result()
+               require.NoError(t, err)
+               // mergeCmd result is always "OK"
+               require.EqualValues(t, "OK", mergeCmd)
+
+               // Count the merged key
+               card, err := rdb.PFCount(ctx, "hll").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 4, card)
+
+               // Merge to hll1
+               mergeCmd, err = rdb.PFMerge(ctx, "hll1", "hll").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, "OK", mergeCmd)
+
+               // Count the merged key
+               card, err = rdb.PFCount(ctx, "hll1").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 4, card)
+
+               // Add more elements to hll2
+               addCnt, err = rdb.PFAdd(ctx, "hll2", "e", "f", "g").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 1, addCnt)
+
+               card, err = rdb.PFCount(ctx, "hll2").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 3, card)
+
+               // merge to hll3
+               mergeCmd, err = rdb.PFMerge(ctx, "hll3", "hll", "hll1", 
"hll2").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, "OK", mergeCmd)
+
+               // Count the merged key
+               card, err = rdb.PFCount(ctx, "hll3").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 7, card)
+
+               // Add more elements to hll4
+               addCnt, err = rdb.PFAdd(ctx, "hll4", "h", "i", "j").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 1, addCnt)
+
+               // Merge all to existing hll4
+               mergeCmd, err = rdb.PFMerge(ctx, "hll4", "hll", "hll1", "hll2", 
"hll3").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, "OK", mergeCmd)
+
+               // Count the merged key
+               card, err = rdb.PFCount(ctx, "hll4").Result()
+               require.NoError(t, err)
+               require.EqualValues(t, 10, card)
+       })
+}

Reply via email to