This is an automated email from the ASF dual-hosted git repository.

edwardxu pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new 65db00a86 refactor(tdigest): refactor the `TDigest::Rank` `reverse` 
flag into a compile-time argument (#3268)
65db00a86 is described below

commit 65db00a86bc9031f4d974eb0b6f1893bd126c524
Author: Hao Dong <[email protected]>
AuthorDate: Sat Dec 13 21:03:13 2025 +0800

    refactor(tdigest): refactor the `TDigest::Rank` `reverse` flag into a 
compile-time argument (#3268)
    
    Co-authored-by: Twice <[email protected]>
    Co-authored-by: Edward Xu <[email protected]>
---
 src/commands/cmd_tdigest.cc         |  13 ++-
 src/types/redis_tdigest.cc          | 166 +++++++++++++++++-------------------
 src/types/redis_tdigest.h           |   7 +-
 src/types/tdigest.h                 |  15 +---
 tests/cppunit/types/tdigest_test.cc |  20 ++---
 5 files changed, 109 insertions(+), 112 deletions(-)

diff --git a/src/commands/cmd_tdigest.cc b/src/commands/cmd_tdigest.cc
index 64dfafcd7..a5479cb12 100644
--- a/src/commands/cmd_tdigest.cc
+++ b/src/commands/cmd_tdigest.cc
@@ -176,7 +176,7 @@ class CommandTDigestAdd : public Commander {
   std::vector<double> values_;
 };
 
-template <bool reverse>
+template <bool Reverse>
 class TDigestRankCommand : public Commander {
  public:
   Status Parse(const std::vector<std::string> &args) override {
@@ -202,7 +202,16 @@ class TDigestRankCommand : public Commander {
     TDigest tdigest(srv->storage, conn->GetNamespace());
     std::vector<int> result;
     result.reserve(origin_inputs_.size());
-    if (const auto s = tdigest.Rank(ctx, key_name_, unique_inputs_, reverse, 
result); !s.ok()) {
+
+    if (const auto s =
+            [&]() {
+              if constexpr (Reverse) {
+                return tdigest.RevRank(ctx, key_name_, unique_inputs_, result);
+              } else {
+                return tdigest.Rank(ctx, key_name_, unique_inputs_, result);
+              }
+            }();
+        !s.ok()) {
       if (s.IsNotFound()) {
         return {Status::RedisExecErr, errKeyNotFound};
       }
diff --git a/src/types/redis_tdigest.cc b/src/types/redis_tdigest.cc
index fec7aef1c..d9a7d425a 100644
--- a/src/types/redis_tdigest.cc
+++ b/src/types/redis_tdigest.cc
@@ -47,114 +47,81 @@
 
 namespace redis {
 
+namespace {
+template <bool Reverse, typename Container>
+inline decltype(auto) GetCbeginIter(const Container& centroids) {
+  if constexpr (Reverse) {
+    return centroids.crbegin();
+  } else {
+    return centroids.cbegin();
+  }
+}
+
+template <bool Reverse, typename Container>
+inline decltype(auto) GetCendIter(const Container& centroids) {
+  if constexpr (Reverse) {
+    return centroids.crend();
+  } else {
+    return centroids.cend();
+  }
+}
+}  // namespace
+
 // TODO: It should be replaced by a iteration of the rocksdb iterator
+template <bool Reverse>
 class DummyCentroids {
  public:
-  class BaseIterator {
-   public:
-    virtual ~BaseIterator() = default;
-    virtual bool Next() = 0;
-    virtual bool Prev() = 0;
-    virtual bool Valid() const = 0;
-    virtual std::unique_ptr<BaseIterator> Clone() const = 0;
-    virtual StatusOr<Centroid> GetCentroid() const = 0;
-  };
-
   DummyCentroids(const TDigestMetadata& meta_data, const 
std::vector<Centroid>& centroids)
       : meta_data_(meta_data), centroids_(centroids) {}
-  class Iterator : public BaseIterator {
+  class Iterator {
    public:
-    Iterator(std::vector<Centroid>::const_iterator&& iter, const 
std::vector<Centroid>& centroids)
-        : iter_(iter), centroids_(centroids) {}
-    std::unique_ptr<BaseIterator> Clone() const override {
-      if (iter_ != centroids_.cend()) {
-        return std::make_unique<Iterator>(std::next(centroids_.cbegin(), 
std::distance(centroids_.cbegin(), iter_)),
-                                          centroids_);
+    using IterType = std::conditional_t<Reverse, 
std::vector<Centroid>::const_reverse_iterator,
+                                        std::vector<Centroid>::const_iterator>;
+    Iterator(IterType iter, const std::vector<Centroid>& centroids) : 
iter_(iter), centroids_(centroids) {}
+    std::unique_ptr<Iterator> Clone() const {
+      if (iter_ != GetCendIter<Reverse>(centroids_)) {
+        return std::make_unique<Iterator>(
+            std::next(GetCbeginIter<Reverse>(centroids_), 
std::distance(GetCbeginIter<Reverse>(centroids_), iter_)),
+            centroids_);
       }
-      return std::make_unique<Iterator>(centroids_.cend(), centroids_);
+      return std::make_unique<Iterator>(GetCendIter<Reverse>(centroids_), 
centroids_);
     }
-    bool Next() override {
+    bool Next() {
       if (Valid()) {
         std::advance(iter_, 1);
       }
-      return iter_ != centroids_.cend();
+      return iter_ != GetCendIter<Reverse>(centroids_);
     }
 
     // The Prev function can only be called for item is not cend,
     // because we must guarantee the iterator to be inside the valid range 
before iteration.
-    bool Prev() override {
-      if (Valid() && iter_ != centroids_.cbegin()) {
+    bool Prev() {
+      if (Valid() && iter_ != GetCendIter<Reverse>(centroids_)) {
         std::advance(iter_, -1);
       }
       return Valid();
     }
-    bool Valid() const override { return iter_ != centroids_.cend(); }
-    StatusOr<Centroid> GetCentroid() const override {
-      if (iter_ == centroids_.cend()) {
+    bool Valid() const { return iter_ != GetCendIter<Reverse>(centroids_); }
+    StatusOr<Centroid> GetCentroid() const {
+      if (iter_ == GetCendIter<Reverse>(centroids_)) {
         return {::Status::NotOK, "invalid iterator during decoding tdigest 
centroid"};
       }
       return *iter_;
     }
 
    private:
-    std::vector<Centroid>::const_iterator iter_;
+    IterType iter_;
     const std::vector<Centroid>& centroids_;
   };
 
-  class ReverseIterator final : public BaseIterator {
-   public:
-    ReverseIterator(std::vector<Centroid>::const_reverse_iterator&& iter, 
const std::vector<Centroid>& centroids)
-        : iter_(iter), centroids_(centroids) {}
-    std::unique_ptr<BaseIterator> Clone() const override {
-      if (iter_ != centroids_.crend()) {
-        return std::make_unique<ReverseIterator>(
-            std::next(centroids_.crbegin(), 
std::distance(centroids_.crbegin(), iter_)), centroids_);
-      }
-      return std::make_unique<ReverseIterator>(centroids_.crend(), centroids_);
-    }
-    bool Next() override {
-      if (Valid()) {
-        std::advance(iter_, 1);
-      }
-      return iter_ != centroids_.crend();
-    }
-
-    bool Prev() override {
-      if (Valid() && iter_ != centroids_.crbegin()) {
-        std::advance(iter_, -1);
-      }
-      return Valid();
-    }
-    bool Valid() const override { return iter_ != centroids_.crend(); }
-    StatusOr<Centroid> GetCentroid() const override {
-      if (iter_ == centroids_.crend()) {
-        return {::Status::NotOK, "invalid iterator during decoding tdigest 
centroid"};
-      }
-      return *iter_;
-    }
-
-   private:
-    std::vector<Centroid>::const_reverse_iterator iter_;
-    const std::vector<Centroid>& centroids_;
-  };
-
-  std::unique_ptr<BaseIterator> Begin(const bool reverse = false) const {
-    if (reverse) {
-      return std::make_unique<ReverseIterator>(centroids_.crbegin(), 
centroids_);
-    }
-    return std::make_unique<Iterator>(centroids_.cbegin(), centroids_);
+  std::unique_ptr<Iterator> Begin() const {
+    return std::make_unique<Iterator>(GetCbeginIter<Reverse>(centroids_), 
centroids_);
   }
-  std::unique_ptr<BaseIterator> End(const bool reverse = false) const {
+  std::unique_ptr<Iterator> End() const {
     if (centroids_.empty()) {
-      if (reverse) {
-        return std::make_unique<ReverseIterator>(centroids_.crend(), 
centroids_);
-      }
-      return std::make_unique<Iterator>(centroids_.cend(), centroids_);
-    }
-    if (reverse) {
-      return std::make_unique<ReverseIterator>(std::prev(centroids_.crend()), 
centroids_);
+      return std::make_unique<Iterator>(GetCendIter<Reverse>(centroids_), 
centroids_);
     }
-    return std::make_unique<Iterator>(std::prev(centroids_.cend()), 
centroids_);
+    return 
std::make_unique<Iterator>(std::prev(GetCendIter<Reverse>(centroids_)), 
centroids_);
   }
   double TotalWeight() const { return 
static_cast<double>(meta_data_.total_weight); }
   double Min() const { return meta_data_.minimum; }
@@ -273,10 +240,9 @@ rocksdb::Status TDigest::mergeNodes(engine::Context& ctx, 
const std::string& ns_
   return rocksdb::Status::OK();
 }
 
-rocksdb::Status TDigest::Rank(engine::Context& ctx, const Slice& digest_name, 
const std::vector<double>& inputs,
-                              bool reverse, std::vector<int>& result) {
+rocksdb::Status TDigest::prepareRankData(engine::Context& ctx, const Slice& 
digest_name, TDigestMetadata& metadata,
+                                         std::vector<Centroid>& centroids) {
   auto ns_key = AppendNamespacePrefix(digest_name);
-  TDigestMetadata metadata;
   {
     LockGuard guard(storage_->GetLockManager(), ns_key);
 
@@ -285,7 +251,6 @@ rocksdb::Status TDigest::Rank(engine::Context& ctx, const 
Slice& digest_name, co
     }
 
     if (metadata.total_observations == 0) {
-      result.resize(inputs.size(), -2);
       return rocksdb::Status::OK();
     }
 
@@ -293,15 +258,44 @@ rocksdb::Status TDigest::Rank(engine::Context& ctx, const 
Slice& digest_name, co
       return status;
     }
   }
+  return dumpCentroids(ctx, ns_key, metadata, &centroids);
+}
 
+rocksdb::Status TDigest::Rank(engine::Context& ctx, const Slice& digest_name, 
const std::vector<double>& inputs,
+                              std::vector<int>& result) {
+  TDigestMetadata metadata;
   std::vector<Centroid> centroids;
-  if (auto status = dumpCentroids(ctx, ns_key, metadata, &centroids); 
!status.ok()) {
+  if (auto status = prepareRankData(ctx, digest_name, metadata, centroids); 
!status.ok()) {
+    return status;
+  }
+
+  if (metadata.total_observations == 0) {
+    result.resize(inputs.size(), -2);
+    return rocksdb::Status::OK();
+  }
+
+  auto dump_centroids = DummyCentroids<false>(metadata, centroids);
+  if (auto status = TDigestRank<false>(dump_centroids, inputs, result); 
!status) {
+    return rocksdb::Status::InvalidArgument(status.Msg());
+  }
+  return rocksdb::Status::OK();
+}
+
+rocksdb::Status TDigest::RevRank(engine::Context& ctx, const Slice& 
digest_name, const std::vector<double>& inputs,
+                                 std::vector<int>& result) {
+  TDigestMetadata metadata;
+  std::vector<Centroid> centroids;
+  if (auto status = prepareRankData(ctx, digest_name, metadata, centroids); 
!status.ok()) {
     return status;
   }
 
-  auto dump_centroids = DummyCentroids(metadata, centroids);
-  auto status = TDigestRank(dump_centroids, inputs, reverse, result);
-  if (!status) {
+  if (metadata.total_observations == 0) {
+    result.resize(inputs.size(), -2);
+    return rocksdb::Status::OK();
+  }
+
+  auto dump_centroids = DummyCentroids<true>(metadata, centroids);
+  if (auto status = TDigestRank<true>(dump_centroids, inputs, result); 
!status) {
     return rocksdb::Status::InvalidArgument(status.Msg());
   }
   return rocksdb::Status::OK();
@@ -332,7 +326,7 @@ rocksdb::Status TDigest::Quantile(engine::Context& ctx, 
const Slice& digest_name
     return status;
   }
 
-  auto dump_centroids = DummyCentroids(metadata, centroids);
+  auto dump_centroids = DummyCentroids<false>(metadata, centroids);
 
   auto quantile_results = std::vector<double>();
   quantile_results.reserve(qs.size());
diff --git a/src/types/redis_tdigest.h b/src/types/redis_tdigest.h
index 5daaed80c..4844f009f 100644
--- a/src/types/redis_tdigest.h
+++ b/src/types/redis_tdigest.h
@@ -77,8 +77,10 @@ class TDigest : public SubKeyScanner {
 
   rocksdb::Status Merge(engine::Context& ctx, const Slice& dest_digest, const 
std::vector<std::string>& source_digests,
                         const TDigestMergeOptions& options);
-  rocksdb::Status Rank(engine::Context& ctx, const Slice& digest_name, const 
std::vector<double>& inputs, bool reverse,
+  rocksdb::Status Rank(engine::Context& ctx, const Slice& digest_name, const 
std::vector<double>& inputs,
                        std::vector<int>& result);
+  rocksdb::Status RevRank(engine::Context& ctx, const Slice& digest_name, 
const std::vector<double>& inputs,
+                          std::vector<int>& result);
   rocksdb::Status GetMetaData(engine::Context& context, const Slice& 
digest_name, TDigestMetadata* metadata);
 
  private:
@@ -130,6 +132,7 @@ class TDigest : public SubKeyScanner {
   static std::string internalValueFromCentroid(const Centroid& centroid);
   rocksdb::Status decodeCentroidFromKeyValue(const rocksdb::Slice& key, const 
rocksdb::Slice& value,
                                              Centroid* centroid) const;
+  rocksdb::Status prepareRankData(engine::Context& ctx, const Slice& 
digest_name, TDigestMetadata& metadata,
+                                  std::vector<Centroid>& centroids);
 };
-
 }  // namespace redis
diff --git a/src/types/tdigest.h b/src/types/tdigest.h
index d77b673f7..e4caf914f 100644
--- a/src/types/tdigest.h
+++ b/src/types/tdigest.h
@@ -171,8 +171,8 @@ struct DoubleComparator {
   bool operator()(const double& a, const double& b) const { return 
DoubleCompare(a, b) == -1; }
 };
 
-template <typename TD, bool Reverse>
-inline Status TDigestRankImpl(TD&& td, const std::vector<double>& inputs, 
std::vector<int>& result) {
+template <bool Reverse, typename TD>
+inline Status TDigestRank(TD&& td, const std::vector<double>& inputs, 
std::vector<int>& result) {
   std::map<double, size_t, DoubleComparator> value_to_index;
   for (size_t i = 0; i < inputs.size(); ++i) {
     value_to_index[inputs[i]] = i;
@@ -211,7 +211,7 @@ inline Status TDigestRankImpl(TD&& td, const 
std::vector<double>& inputs, std::v
     }
   }
 
-  auto iter = td.Begin(Reverse);
+  auto iter = td.Begin();
   double cumulative_weight = 0;
   while (iter->Valid() && !is_end()) {
     auto centroid = GET_OR_RET(iter->GetCentroid());
@@ -267,12 +267,3 @@ inline Status TDigestRankImpl(TD&& td, const 
std::vector<double>& inputs, std::v
   }
   return Status::OK();
 }
-
-template <typename TD>
-inline Status TDigestRank(TD&& td, const std::vector<double>& inputs, bool 
reverse, std::vector<int>& result) {
-  if (reverse) {
-    return TDigestRankImpl<TD, true>(std::forward<TD>(td), inputs, result);
-  } else {
-    return TDigestRankImpl<TD, false>(std::forward<TD>(td), inputs, result);
-  }
-}
diff --git a/tests/cppunit/types/tdigest_test.cc 
b/tests/cppunit/types/tdigest_test.cc
index 9f5ad9736..a6ae7b295 100644
--- a/tests/cppunit/types/tdigest_test.cc
+++ b/tests/cppunit/types/tdigest_test.cc
@@ -312,7 +312,7 @@ TEST_F(RedisTDigestTest, 
RevRank_and_Rank_on_the_set_containing_different_elemen
   std::vector<int> result;
   result.reserve(input.size());
   const std::vector<double> value = {0, 10, 20, 30, 40, 50, 60, 70};
-  status = tdigest_->Rank(*ctx_, test_digest_name, value, true, result);
+  status = tdigest_->RevRank(*ctx_, test_digest_name, value, result);
   const auto expect_result_revrank = std::vector<double>{6, 5, 4, 3, 2, 1, 0, 
-1};
 
   for (size_t i = 0; i < result.size(); i++) {
@@ -323,7 +323,7 @@ TEST_F(RedisTDigestTest, 
RevRank_and_Rank_on_the_set_containing_different_elemen
 
   result.clear();
   result.reserve(input.size());
-  status = tdigest_->Rank(*ctx_, test_digest_name, value, false, result);
+  status = tdigest_->Rank(*ctx_, test_digest_name, value, result);
   const auto expect_result_rank = std::vector<double>{-1, 0, 1, 2, 3, 4, 5, 6};
   for (size_t i = 0; i < result.size(); i++) {
     auto got = result[i];
@@ -345,7 +345,7 @@ TEST_F(RedisTDigestTest, 
RevRank_and_Rank_on_the_set_containing_several_identica
   std::vector<int> result;
   const std::vector<double> value = {10, 20};
   result.reserve(value.size());
-  status = tdigest_->Rank(*ctx_, test_digest_name, value, true, result);
+  status = tdigest_->RevRank(*ctx_, test_digest_name, value, result);
   const auto expect_result_revrank = std::vector<double>{3, 1};
   for (size_t i = 0; i < result.size(); i++) {
     auto got = result[i];
@@ -355,7 +355,7 @@ TEST_F(RedisTDigestTest, 
RevRank_and_Rank_on_the_set_containing_several_identica
 
   result.clear();
   result.reserve(value.size());
-  status = tdigest_->Rank(*ctx_, test_digest_name, value, false, result);
+  status = tdigest_->Rank(*ctx_, test_digest_name, value, result);
   const auto expect_result_rank = std::vector<double>{1, 4};
   for (size_t i = 0; i < result.size(); i++) {
     auto got = result[i];
@@ -368,7 +368,7 @@ TEST_F(RedisTDigestTest, 
RevRank_and_Rank_on_the_set_containing_several_identica
 
   result.clear();
   result.reserve(value.size());
-  status = tdigest_->Rank(*ctx_, test_digest_name, value, true, result);
+  status = tdigest_->RevRank(*ctx_, test_digest_name, value, result);
   const auto expect_result_new_revrank = std::vector<double>{4, 1};
   for (size_t i = 0; i < result.size(); i++) {
     auto got = result[i];
@@ -378,7 +378,7 @@ TEST_F(RedisTDigestTest, 
RevRank_and_Rank_on_the_set_containing_several_identica
 
   result.clear();
   result.reserve(value.size());
-  status = tdigest_->Rank(*ctx_, test_digest_name, value, false, result);
+  status = tdigest_->Rank(*ctx_, test_digest_name, value, result);
   const auto expect_result_new_rank = std::vector<double>{2, 5};
   for (size_t i = 0; i < result.size(); i++) {
     auto got = result[i];
@@ -397,7 +397,7 @@ TEST_F(RedisTDigestTest, RevRank_and_Rank_on_empty_tdigest) 
{
   std::vector<int> result;
   result.reserve(2);
   const std::vector<double> value = {10, 20};
-  status = tdigest_->Rank(*ctx_, test_digest_name, value, true, result);
+  status = tdigest_->RevRank(*ctx_, test_digest_name, value, result);
   const auto expect_result_revrank = std::vector<double>{-2, -2};
   for (size_t i = 0; i < result.size(); i++) {
     auto got = result[i];
@@ -407,7 +407,7 @@ TEST_F(RedisTDigestTest, RevRank_and_Rank_on_empty_tdigest) 
{
 
   result.clear();
   result.reserve(2);
-  status = tdigest_->Rank(*ctx_, test_digest_name, value, false, result);
+  status = tdigest_->Rank(*ctx_, test_digest_name, value, result);
   const auto expect_result_rank = std::vector<double>{-2, -2};
   for (size_t i = 0; i < result.size(); i++) {
     auto got = result[i];
@@ -430,7 +430,7 @@ TEST_F(RedisTDigestTest, 
RevRank_and_Rank_on_different_or_same_and_unordered_inp
   std::vector<int> result;
   const std::vector<double> value = {50, 36, 4, 99, 8.8};
   result.reserve(value.size());
-  status = tdigest_->Rank(*ctx_, test_digest_name, value, false, result);
+  status = tdigest_->Rank(*ctx_, test_digest_name, value, result);
   const auto expect_result_rank = std::vector<double>{13, 11, 1, 16, 4};
   for (size_t i = 0; i < result.size(); i++) {
     auto got = result[i];
@@ -441,7 +441,7 @@ TEST_F(RedisTDigestTest, 
RevRank_and_Rank_on_different_or_same_and_unordered_inp
   const std::vector<double> value_new = {50, 36, 4, 99, 8.8, 12};
   result.clear();
   result.reserve(value_new.size());
-  status = tdigest_->Rank(*ctx_, test_digest_name, value_new, true, result);
+  status = tdigest_->RevRank(*ctx_, test_digest_name, value_new, result);
   const auto expect_result_revrank = std::vector<double>{4, 7, 16, 1, 14, 12};
   for (size_t i = 0; i < result.size(); i++) {
     auto got = result[i];

Reply via email to