This is an automated email from the ASF dual-hosted git repository.

edwardxu pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new 34a36a1fa feat(tdigest): add the support of TDIGEST.RANK command 
(#3249)
34a36a1fa is described below

commit 34a36a1fa24928459f76e28559a6d970fbc55074
Author: Hao Dong <[email protected]>
AuthorDate: Sat Nov 15 20:49:16 2025 +0800

    feat(tdigest): add the support of TDIGEST.RANK command (#3249)
    
    Co-authored-by: tonidong <[email protected]>
---
 src/commands/cmd_tdigest.cc                    |  10 +-
 src/types/redis_tdigest.cc                     |  82 +++++++++++++---
 src/types/redis_tdigest.h                      |   4 +-
 src/types/tdigest.h                            |  93 ++++++++++++------
 tests/cppunit/types/tdigest_test.cc            | 111 +++++++++++++++++----
 tests/gocase/unit/type/tdigest/tdigest_test.go | 129 +++++++++++++++++++++++++
 6 files changed, 365 insertions(+), 64 deletions(-)

diff --git a/src/commands/cmd_tdigest.cc b/src/commands/cmd_tdigest.cc
index 36bf555a2..64dfafcd7 100644
--- a/src/commands/cmd_tdigest.cc
+++ b/src/commands/cmd_tdigest.cc
@@ -176,7 +176,8 @@ class CommandTDigestAdd : public Commander {
   std::vector<double> values_;
 };
 
-class CommandTDigestRevRank : public Commander {
+template <bool reverse>
+class TDigestRankCommand : public Commander {
  public:
   Status Parse(const std::vector<std::string> &args) override {
     key_name_ = args[1];
@@ -201,7 +202,7 @@ class CommandTDigestRevRank : public Commander {
     TDigest tdigest(srv->storage, conn->GetNamespace());
     std::vector<int> result;
     result.reserve(origin_inputs_.size());
-    if (const auto s = tdigest.RevRank(ctx, key_name_, unique_inputs_, 
result); !s.ok()) {
+    if (const auto s = tdigest.Rank(ctx, key_name_, unique_inputs_, reverse, 
result); !s.ok()) {
       if (s.IsNotFound()) {
         return {Status::RedisExecErr, errKeyNotFound};
       }
@@ -224,6 +225,10 @@ class CommandTDigestRevRank : public Commander {
   std::vector<std::string> origin_inputs_;
 };
 
+class CommandTDigestRevRank : public TDigestRankCommand<true> {};
+
+class CommandTDigestRank : public TDigestRankCommand<false> {};
+
 class CommandTDigestMinMax : public Commander {
  public:
   explicit CommandTDigestMinMax(bool is_min) : is_min_(is_min) {}
@@ -418,6 +423,7 @@ REDIS_REGISTER_COMMANDS(TDigest, 
MakeCmdAttr<CommandTDigestCreate>("tdigest.crea
                         MakeCmdAttr<CommandTDigestMax>("tdigest.max", 2, 
"read-only", 1, 1, 1),
                         MakeCmdAttr<CommandTDigestMin>("tdigest.min", 2, 
"read-only", 1, 1, 1),
                         MakeCmdAttr<CommandTDigestRevRank>("tdigest.revrank", 
-3, "read-only", 1, 1, 1),
+                        MakeCmdAttr<CommandTDigestRank>("tdigest.rank", -3, 
"read-only", 1, 1, 1),
                         
MakeCmdAttr<CommandTDigestQuantile>("tdigest.quantile", -3, "read-only", 1, 1, 
1),
                         MakeCmdAttr<CommandTDigestReset>("tdigest.reset", 2, 
"write", 1, 1, 1),
                         MakeCmdAttr<CommandTDigestMerge>("tdigest.merge", -4, 
"write", GetMergeKeyRange));
diff --git a/src/types/redis_tdigest.cc b/src/types/redis_tdigest.cc
index 2d45a58d8..fec7aef1c 100644
--- a/src/types/redis_tdigest.cc
+++ b/src/types/redis_tdigest.cc
@@ -50,38 +50,46 @@ namespace redis {
 // TODO: It should be replaced by a iteration of the rocksdb iterator
 class DummyCentroids {
  public:
+  class BaseIterator {
+   public:
+    virtual ~BaseIterator() = default;
+    virtual bool Next() = 0;
+    virtual bool Prev() = 0;
+    virtual bool Valid() const = 0;
+    virtual std::unique_ptr<BaseIterator> Clone() const = 0;
+    virtual StatusOr<Centroid> GetCentroid() const = 0;
+  };
+
   DummyCentroids(const TDigestMetadata& meta_data, const 
std::vector<Centroid>& centroids)
       : meta_data_(meta_data), centroids_(centroids) {}
-  class Iterator {
+  class Iterator : public BaseIterator {
    public:
     Iterator(std::vector<Centroid>::const_iterator&& iter, const 
std::vector<Centroid>& centroids)
         : iter_(iter), centroids_(centroids) {}
-    std::unique_ptr<Iterator> Clone() const {
+    std::unique_ptr<BaseIterator> Clone() const override {
       if (iter_ != centroids_.cend()) {
         return std::make_unique<Iterator>(std::next(centroids_.cbegin(), 
std::distance(centroids_.cbegin(), iter_)),
                                           centroids_);
       }
       return std::make_unique<Iterator>(centroids_.cend(), centroids_);
     }
-    bool Next() {
+    bool Next() override {
       if (Valid()) {
         std::advance(iter_, 1);
       }
       return iter_ != centroids_.cend();
     }
 
-    bool IsBegin() { return iter_ == centroids_.cbegin(); }
-
     // The Prev function can only be called for item is not cend,
     // because we must guarantee the iterator to be inside the valid range 
before iteration.
-    bool Prev() {
+    bool Prev() override {
       if (Valid() && iter_ != centroids_.cbegin()) {
         std::advance(iter_, -1);
       }
       return Valid();
     }
-    bool Valid() const { return iter_ != centroids_.cend(); }
-    StatusOr<Centroid> GetCentroid() const {
+    bool Valid() const override { return iter_ != centroids_.cend(); }
+    StatusOr<Centroid> GetCentroid() const override {
       if (iter_ == centroids_.cend()) {
         return {::Status::NotOK, "invalid iterator during decoding tdigest 
centroid"};
       }
@@ -93,11 +101,59 @@ class DummyCentroids {
     const std::vector<Centroid>& centroids_;
   };
 
-  std::unique_ptr<Iterator> Begin() { return 
std::make_unique<Iterator>(centroids_.cbegin(), centroids_); }
-  std::unique_ptr<Iterator> End() {
+  class ReverseIterator final : public BaseIterator {
+   public:
+    ReverseIterator(std::vector<Centroid>::const_reverse_iterator&& iter, 
const std::vector<Centroid>& centroids)
+        : iter_(iter), centroids_(centroids) {}
+    std::unique_ptr<BaseIterator> Clone() const override {
+      if (iter_ != centroids_.crend()) {
+        return std::make_unique<ReverseIterator>(
+            std::next(centroids_.crbegin(), 
std::distance(centroids_.crbegin(), iter_)), centroids_);
+      }
+      return std::make_unique<ReverseIterator>(centroids_.crend(), centroids_);
+    }
+    bool Next() override {
+      if (Valid()) {
+        std::advance(iter_, 1);
+      }
+      return iter_ != centroids_.crend();
+    }
+
+    bool Prev() override {
+      if (Valid() && iter_ != centroids_.crbegin()) {
+        std::advance(iter_, -1);
+      }
+      return Valid();
+    }
+    bool Valid() const override { return iter_ != centroids_.crend(); }
+    StatusOr<Centroid> GetCentroid() const override {
+      if (iter_ == centroids_.crend()) {
+        return {::Status::NotOK, "invalid iterator during decoding tdigest 
centroid"};
+      }
+      return *iter_;
+    }
+
+   private:
+    std::vector<Centroid>::const_reverse_iterator iter_;
+    const std::vector<Centroid>& centroids_;
+  };
+
+  std::unique_ptr<BaseIterator> Begin(const bool reverse = false) const {
+    if (reverse) {
+      return std::make_unique<ReverseIterator>(centroids_.crbegin(), 
centroids_);
+    }
+    return std::make_unique<Iterator>(centroids_.cbegin(), centroids_);
+  }
+  std::unique_ptr<BaseIterator> End(const bool reverse = false) const {
     if (centroids_.empty()) {
+      if (reverse) {
+        return std::make_unique<ReverseIterator>(centroids_.crend(), 
centroids_);
+      }
       return std::make_unique<Iterator>(centroids_.cend(), centroids_);
     }
+    if (reverse) {
+      return std::make_unique<ReverseIterator>(std::prev(centroids_.crend()), 
centroids_);
+    }
     return std::make_unique<Iterator>(std::prev(centroids_.cend()), 
centroids_);
   }
   double TotalWeight() const { return 
static_cast<double>(meta_data_.total_weight); }
@@ -217,8 +273,8 @@ rocksdb::Status TDigest::mergeNodes(engine::Context& ctx, 
const std::string& ns_
   return rocksdb::Status::OK();
 }
 
-rocksdb::Status TDigest::RevRank(engine::Context& ctx, const Slice& 
digest_name, const std::vector<double>& inputs,
-                                 std::vector<int>& result) {
+rocksdb::Status TDigest::Rank(engine::Context& ctx, const Slice& digest_name, 
const std::vector<double>& inputs,
+                              bool reverse, std::vector<int>& result) {
   auto ns_key = AppendNamespacePrefix(digest_name);
   TDigestMetadata metadata;
   {
@@ -244,7 +300,7 @@ rocksdb::Status TDigest::RevRank(engine::Context& ctx, 
const Slice& digest_name,
   }
 
   auto dump_centroids = DummyCentroids(metadata, centroids);
-  auto status = TDigestRevRank(dump_centroids, inputs, result);
+  auto status = TDigestRank(dump_centroids, inputs, reverse, result);
   if (!status) {
     return rocksdb::Status::InvalidArgument(status.Msg());
   }
diff --git a/src/types/redis_tdigest.h b/src/types/redis_tdigest.h
index 02ecc24c0..5daaed80c 100644
--- a/src/types/redis_tdigest.h
+++ b/src/types/redis_tdigest.h
@@ -77,8 +77,8 @@ class TDigest : public SubKeyScanner {
 
   rocksdb::Status Merge(engine::Context& ctx, const Slice& dest_digest, const 
std::vector<std::string>& source_digests,
                         const TDigestMergeOptions& options);
-  rocksdb::Status RevRank(engine::Context& ctx, const Slice& digest_name, 
const std::vector<double>& inputs,
-                          std::vector<int>& result);
+  rocksdb::Status Rank(engine::Context& ctx, const Slice& digest_name, const 
std::vector<double>& inputs, bool reverse,
+                       std::vector<int>& result);
   rocksdb::Status GetMetaData(engine::Context& context, const Slice& 
digest_name, TDigestMetadata* metadata);
 
  private:
diff --git a/src/types/tdigest.h b/src/types/tdigest.h
index 34843d6da..d77b673f7 100644
--- a/src/types/tdigest.h
+++ b/src/types/tdigest.h
@@ -24,6 +24,7 @@
 
 #include <map>
 #include <numeric>
+#include <variant>
 #include <vector>
 
 #include "common/status.h"
@@ -170,26 +171,49 @@ struct DoubleComparator {
   bool operator()(const double& a, const double& b) const { return 
DoubleCompare(a, b) == -1; }
 };
 
-template <typename TD>
-inline Status TDigestRevRank(TD&& td, const std::vector<double>& inputs, 
std::vector<int>& result) {
-  std::map<double, size_t, DoubleComparator> value_to_indices;
+template <typename TD, bool Reverse>
+inline Status TDigestRankImpl(TD&& td, const std::vector<double>& inputs, 
std::vector<int>& result) {
+  std::map<double, size_t, DoubleComparator> value_to_index;
   for (size_t i = 0; i < inputs.size(); ++i) {
-    value_to_indices[inputs[i]] = i;
+    value_to_index[inputs[i]] = i;
   }
 
   result.clear();
   result.resize(inputs.size(), -2);
-  auto it = value_to_indices.rbegin();
 
-  // handle inputs larger than maximum
-  while (it != value_to_indices.rend() && it->first > td.Max()) {
-    result[it->second] = -1;
-    ++it;
+  using MapType = decltype(value_to_index);
+  using IterType = std::conditional_t<Reverse, typename 
MapType::reverse_iterator, typename MapType::iterator>;
+  IterType it;
+  if constexpr (Reverse) {
+    it = value_to_index.rbegin();
+  } else {
+    it = value_to_index.begin();
   }
 
-  auto iter = td.End();
+  auto is_end = [&it, &value_to_index]() -> bool {
+    if constexpr (Reverse) {
+      return it == value_to_index.rend();
+    } else {
+      return it == value_to_index.end();
+    }
+  };
+
+  // handle inputs larger than maximum in reverse order or smaller than 
minimum in forward order
+  if constexpr (Reverse) {
+    while (!is_end() && it->first > td.Max()) {
+      result[it->second] = -1;
+      ++it;
+    }
+  } else {
+    while (!is_end() && it->first < td.Min()) {
+      result[it->second] = -1;
+      ++it;
+    }
+  }
+
+  auto iter = td.Begin(Reverse);
   double cumulative_weight = 0;
-  while (iter->Valid() && it != value_to_indices.rend()) {
+  while (iter->Valid() && !is_end()) {
     auto centroid = GET_OR_RET(iter->GetCentroid());
     auto input_value = it->first;
     if (DoubleEqual(centroid.mean, input_value)) {
@@ -197,47 +221,58 @@ inline Status TDigestRevRank(TD&& td, const 
std::vector<double>& inputs, std::ve
       auto current_mean_cumulative_weight = cumulative_weight + 
centroid.weight / 2;
       cumulative_weight += centroid.weight;
 
-      // handle all the previous centroids which has the same mean
-      while (!iter->IsBegin() && iter->Prev()) {
+      // handle all next centroids which has the same mean
+      while (iter->Next()) {
         auto next_centroid = GET_OR_RET(iter->GetCentroid());
         if (!DoubleEqual(current_mean, next_centroid.mean)) {
           // move back to the last equal centroid, because we will process it 
in the next loop
-          iter->Next();
+          iter->Prev();
           break;
         }
         current_mean_cumulative_weight += next_centroid.weight / 2;
         cumulative_weight += next_centroid.weight;
       }
 
-      // handle the prev inputs which have the same value
       result[it->second] = static_cast<int>(current_mean_cumulative_weight);
       ++it;
-      if (iter->IsBegin()) {
-        break;
-      }
-      iter->Prev();
-    } else if (DoubleCompare(centroid.mean, input_value) > 0) {
-      cumulative_weight += centroid.weight;
-      if (iter->IsBegin()) {
-        break;
+      iter->Next();
+    } else if constexpr (Reverse) {
+      if (DoubleCompare(centroid.mean, input_value) > 0) {
+        cumulative_weight += centroid.weight;
+        iter->Next();
+      } else {
+        result[it->second] = static_cast<int>(cumulative_weight);
+        ++it;
       }
-      iter->Prev();
     } else {
-      result[it->second] = static_cast<int>(cumulative_weight);
-      ++it;
+      if (DoubleCompare(centroid.mean, input_value) < 0) {
+        cumulative_weight += centroid.weight;
+        iter->Next();
+      } else {
+        result[it->second] = static_cast<int>(cumulative_weight);
+        ++it;
+      }
     }
   }
 
-  // handle inputs less than minimum
-  while (it != value_to_indices.rend()) {
+  while (!is_end()) {
     result[it->second] = static_cast<int>(td.TotalWeight());
     ++it;
   }
 
   for (auto r : result) {
     if (r <= -2) {
-      return Status{Status::InvalidArgument, "invalid result when computing 
revrank"};
+      return Status{Status::InvalidArgument, "invalid result when computing 
rank or revrank"};
     }
   }
   return Status::OK();
 }
+
+template <typename TD>
+inline Status TDigestRank(TD&& td, const std::vector<double>& inputs, bool 
reverse, std::vector<int>& result) {
+  if (reverse) {
+    return TDigestRankImpl<TD, true>(std::forward<TD>(td), inputs, result);
+  } else {
+    return TDigestRankImpl<TD, false>(std::forward<TD>(td), inputs, result);
+  }
+}
diff --git a/tests/cppunit/types/tdigest_test.cc 
b/tests/cppunit/types/tdigest_test.cc
index ae4bf29e9..9f5ad9736 100644
--- a/tests/cppunit/types/tdigest_test.cc
+++ b/tests/cppunit/types/tdigest_test.cc
@@ -299,7 +299,7 @@ TEST_F(RedisTDigestTest, 
Quantile_returns_nan_on_empty_tdigest) {
   ASSERT_FALSE(result.quantiles) << "should not have quantiles with empty 
tdigest";
 }
 
-TEST_F(RedisTDigestTest, RevRank_on_the_set_containing_different_elements) {
+TEST_F(RedisTDigestTest, 
RevRank_and_Rank_on_the_set_containing_different_elements) {
   std::string test_digest_name = "test_digest_revrank" + 
std::to_string(util::GetTimeStampMS());
   bool exists = false;
   auto status = tdigest_->Create(*ctx_, test_digest_name, {100}, &exists);
@@ -312,18 +312,28 @@ TEST_F(RedisTDigestTest, 
RevRank_on_the_set_containing_different_elements) {
   std::vector<int> result;
   result.reserve(input.size());
   const std::vector<double> value = {0, 10, 20, 30, 40, 50, 60, 70};
-  status = tdigest_->RevRank(*ctx_, test_digest_name, value, result);
-  const auto expect_result = std::vector<double>{6, 5, 4, 3, 2, 1, 0, -1};
+  status = tdigest_->Rank(*ctx_, test_digest_name, value, true, result);
+  const auto expect_result_revrank = std::vector<double>{6, 5, 4, 3, 2, 1, 0, 
-1};
 
   for (size_t i = 0; i < result.size(); i++) {
     auto got = result[i];
-    EXPECT_EQ(got, expect_result[i]);
+    EXPECT_EQ(got, expect_result_revrank[i]);
+  }
+  ASSERT_TRUE(status.ok()) << status.ToString();
+
+  result.clear();
+  result.reserve(input.size());
+  status = tdigest_->Rank(*ctx_, test_digest_name, value, false, result);
+  const auto expect_result_rank = std::vector<double>{-1, 0, 1, 2, 3, 4, 5, 6};
+  for (size_t i = 0; i < result.size(); i++) {
+    auto got = result[i];
+    EXPECT_EQ(got, expect_result_rank[i]);
   }
   ASSERT_TRUE(status.ok()) << status.ToString();
 }
 
-TEST_F(RedisTDigestTest, 
RevRank_on_the_set_containing_several_identical_elements) {
-  std::string test_digest_name = "test_digest_revrank" + 
std::to_string(util::GetTimeStampMS());
+TEST_F(RedisTDigestTest, 
RevRank_and_Rank_on_the_set_containing_several_identical_elements) {
+  std::string test_digest_name = "test_digest_revrank_and_rank" + 
std::to_string(util::GetTimeStampMS());
   bool exists = false;
   auto status = tdigest_->Create(*ctx_, test_digest_name, {100}, &exists);
   ASSERT_FALSE(exists);
@@ -333,13 +343,23 @@ TEST_F(RedisTDigestTest, 
RevRank_on_the_set_containing_several_identical_element
   ASSERT_TRUE(status.ok()) << status.ToString();
 
   std::vector<int> result;
-  result.reserve(input.size());
   const std::vector<double> value = {10, 20};
-  status = tdigest_->RevRank(*ctx_, test_digest_name, value, result);
-  const auto expect_result = std::vector<double>{3, 1};
+  result.reserve(value.size());
+  status = tdigest_->Rank(*ctx_, test_digest_name, value, true, result);
+  const auto expect_result_revrank = std::vector<double>{3, 1};
   for (size_t i = 0; i < result.size(); i++) {
     auto got = result[i];
-    EXPECT_EQ(got, expect_result[i]);
+    EXPECT_EQ(got, expect_result_revrank[i]);
+  }
+  ASSERT_TRUE(status.ok()) << status.ToString();
+
+  result.clear();
+  result.reserve(value.size());
+  status = tdigest_->Rank(*ctx_, test_digest_name, value, false, result);
+  const auto expect_result_rank = std::vector<double>{1, 4};
+  for (size_t i = 0; i < result.size(); i++) {
+    auto got = result[i];
+    EXPECT_EQ(got, expect_result_rank[i]);
   }
   ASSERT_TRUE(status.ok()) << status.ToString();
 
@@ -347,17 +367,28 @@ TEST_F(RedisTDigestTest, 
RevRank_on_the_set_containing_several_identical_element
   ASSERT_TRUE(status.ok()) << status.ToString();
 
   result.clear();
-  status = tdigest_->RevRank(*ctx_, test_digest_name, value, result);
-  const auto expect_result_new = std::vector<double>{4, 1};
+  result.reserve(value.size());
+  status = tdigest_->Rank(*ctx_, test_digest_name, value, true, result);
+  const auto expect_result_new_revrank = std::vector<double>{4, 1};
+  for (size_t i = 0; i < result.size(); i++) {
+    auto got = result[i];
+    EXPECT_EQ(got, expect_result_new_revrank[i]);
+  }
+  ASSERT_TRUE(status.ok()) << status.ToString();
+
+  result.clear();
+  result.reserve(value.size());
+  status = tdigest_->Rank(*ctx_, test_digest_name, value, false, result);
+  const auto expect_result_new_rank = std::vector<double>{2, 5};
   for (size_t i = 0; i < result.size(); i++) {
     auto got = result[i];
-    EXPECT_EQ(got, expect_result_new[i]);
+    EXPECT_EQ(got, expect_result_new_rank[i]);
   }
   ASSERT_TRUE(status.ok()) << status.ToString();
 }
 
-TEST_F(RedisTDigestTest, RevRank_on_empty_tdigest) {
-  std::string test_digest_name = "test_digest_revrank" + 
std::to_string(util::GetTimeStampMS());
+TEST_F(RedisTDigestTest, RevRank_and_Rank_on_empty_tdigest) {
+  std::string test_digest_name = "test_digest_revrank_and_rank" + 
std::to_string(util::GetTimeStampMS());
   bool exists = false;
   auto status = tdigest_->Create(*ctx_, test_digest_name, {100}, &exists);
   ASSERT_FALSE(exists);
@@ -366,11 +397,55 @@ TEST_F(RedisTDigestTest, RevRank_on_empty_tdigest) {
   std::vector<int> result;
   result.reserve(2);
   const std::vector<double> value = {10, 20};
-  status = tdigest_->RevRank(*ctx_, test_digest_name, value, result);
-  const auto expect_result = std::vector<double>{-2, -2};
+  status = tdigest_->Rank(*ctx_, test_digest_name, value, true, result);
+  const auto expect_result_revrank = std::vector<double>{-2, -2};
+  for (size_t i = 0; i < result.size(); i++) {
+    auto got = result[i];
+    EXPECT_EQ(got, expect_result_revrank[i]);
+  }
+  ASSERT_TRUE(status.ok()) << status.ToString();
+
+  result.clear();
+  result.reserve(2);
+  status = tdigest_->Rank(*ctx_, test_digest_name, value, false, result);
+  const auto expect_result_rank = std::vector<double>{-2, -2};
+  for (size_t i = 0; i < result.size(); i++) {
+    auto got = result[i];
+    EXPECT_EQ(got, expect_result_rank[i]);
+  }
+  ASSERT_TRUE(status.ok()) << status.ToString();
+}
+
+TEST_F(RedisTDigestTest, 
RevRank_and_Rank_on_different_or_same_and_unordered_inputs_tdigest) {
+  std::string test_digest_name = "test_digest_revrank_and_rank" + 
std::to_string(util::GetTimeStampMS());
+  bool exists = false;
+  auto status = tdigest_->Create(*ctx_, test_digest_name, {100}, &exists);
+  ASSERT_FALSE(exists);
+  ASSERT_TRUE(status.ok());
+
+  std::vector<double> input{12, 100, 50, 36, 75, 81, 35.5, 46, 36, 8.8, 15, 4, 
32.5, 12, 8.8, 7, 99, 0};
+  status = tdigest_->Add(*ctx_, test_digest_name, input);
+  ASSERT_TRUE(status.ok()) << status.ToString();
+
+  std::vector<int> result;
+  const std::vector<double> value = {50, 36, 4, 99, 8.8};
+  result.reserve(value.size());
+  status = tdigest_->Rank(*ctx_, test_digest_name, value, false, result);
+  const auto expect_result_rank = std::vector<double>{13, 11, 1, 16, 4};
+  for (size_t i = 0; i < result.size(); i++) {
+    auto got = result[i];
+    EXPECT_EQ(got, expect_result_rank[i]);
+  }
+  ASSERT_TRUE(status.ok()) << status.ToString();
+
+  const std::vector<double> value_new = {50, 36, 4, 99, 8.8, 12};
+  result.clear();
+  result.reserve(value_new.size());
+  status = tdigest_->Rank(*ctx_, test_digest_name, value_new, true, result);
+  const auto expect_result_revrank = std::vector<double>{4, 7, 16, 1, 14, 12};
   for (size_t i = 0; i < result.size(); i++) {
     auto got = result[i];
-    EXPECT_EQ(got, expect_result[i]);
+    EXPECT_EQ(got, expect_result_revrank[i]);
   }
   ASSERT_TRUE(status.ok()) << status.ToString();
 }
\ No newline at end of file
diff --git a/tests/gocase/unit/type/tdigest/tdigest_test.go 
b/tests/gocase/unit/type/tdigest/tdigest_test.go
index 129be0554..335ee0eff 100644
--- a/tests/gocase/unit/type/tdigest/tdigest_test.go
+++ b/tests/gocase/unit/type/tdigest/tdigest_test.go
@@ -587,4 +587,133 @@ func tdigestTests(t *testing.T, configs 
util.KvrocksServerConfigs) {
                        require.EqualValues(t, rank, expected[i])
                }
        })
+
+       t.Run("tdigest.rank with different arguments", func(t *testing.T) {
+               keyPrefix := "tdigest_rank_"
+
+               // Test invalid arguments
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.RANK").Err(), 
errMsgWrongNumberArg)
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.RANK", 
keyPrefix+"nonexistent").Err(), errMsgWrongNumberArg)
+
+               // Test Non-existent key
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.RANK", 
keyPrefix+"nonexistent", "10").Err(), errMsgKeyNotExist)
+
+               // Test with empty tdigest
+               key := keyPrefix + "test1"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               rsp := rdb.Do(ctx, "TDIGEST.RANK", key, "10", "20")
+               require.NoError(t, rsp.Err())
+               vals, err := rsp.Slice()
+               require.NoError(t, err)
+               require.Len(t, vals, 2)
+               expected := []int64{-2, -2}
+               for i, v := range vals {
+                       rank, ok := v.(int64)
+                       require.True(t, ok, "expected int64 but got %T at index 
%d", v, i)
+                       require.EqualValues(t, rank, expected[i])
+               }
+
+               // Test with set containing several identical elements
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "10", "10", 
"10", "20", "20").Err())
+               rsp = rdb.Do(ctx, "TDIGEST.RANK", key, "10", "20")
+               require.NoError(t, rsp.Err())
+               vals, err = rsp.Slice()
+               require.NoError(t, err)
+               require.Len(t, vals, 2)
+               expected = []int64{1, 4}
+               for i, v := range vals {
+                       rank, ok := v.(int64)
+                       require.True(t, ok, "expected int64 but got %T at index 
%d", v, i)
+                       require.EqualValues(t, rank, expected[i])
+               }
+
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "10").Err())
+               rsp = rdb.Do(ctx, "TDIGEST.RANK", key, "10", "20")
+               require.NoError(t, rsp.Err())
+               vals, err = rsp.Slice()
+               require.NoError(t, err)
+               require.Len(t, vals, 2)
+               expected = []int64{2, 5}
+               for i, v := range vals {
+                       rank, ok := v.(int64)
+                       require.True(t, ok, "expected int64 but got %T at index 
%d", v, i)
+                       require.EqualValues(t, rank, expected[i])
+               }
+
+               // Test with set containing different elements
+               key2 := keyPrefix + "test2"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key2, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key2, "10", "20", 
"30", "40", "50", "60").Err())
+               rsp = rdb.Do(ctx, "TDIGEST.RANK", key2, "0", "10", "20", "30", 
"40", "50", "60", "70")
+               require.NoError(t, rsp.Err())
+               vals, err = rsp.Slice()
+               require.NoError(t, err)
+               require.Len(t, vals, 8)
+               expected = []int64{-1, 0, 1, 2, 3, 4, 5, 6}
+               for i, v := range vals {
+                       rank, ok := v.(int64)
+                       require.True(t, ok, "expected int64 but got %T at index 
%d", v, i)
+                       require.EqualValues(t, rank, expected[i])
+               }
+
+               // Test with set containing unordered elements which are 
different or same
+               key3 := keyPrefix + "test3"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key3, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key3, "12", 
"100", "50", "36", "75", "81", "35.5", "46", "36", "8.8", "15", "4", "32.5", 
"12", "8.8", "7", "99", "1").Err())
+               rsp = rdb.Do(ctx, "TDIGEST.RANK", key3, "50", "36", "4", "99", 
"8.8", "0.1", "200")
+               require.NoError(t, rsp.Err())
+               vals, err = rsp.Slice()
+               require.NoError(t, err)
+               require.Len(t, vals, 7)
+               expected = []int64{13, 11, 1, 16, 4, -1, 18}
+               for i, v := range vals {
+                       rank, ok := v.(int64)
+                       require.True(t, ok, "expected int64 but got %T at index 
%d", v, i)
+                       require.EqualValues(t, rank, expected[i])
+               }
+
+               rsp = rdb.Do(ctx, "TDIGEST.REVRANK", key3, "50", "36", "4", 
"99", "8.8", "8.8", "12", "99", "200")
+               require.NoError(t, rsp.Err())
+               vals, err = rsp.Slice()
+               require.NoError(t, err)
+               require.Len(t, vals, 9)
+               expected = []int64{4, 7, 16, 1, 14, 14, 12, 1, -1}
+               for i, v := range vals {
+                       rank, ok := v.(int64)
+                       require.True(t, ok, "expected int64 but got %T at index 
%d", v, i)
+                       require.EqualValues(t, rank, expected[i])
+               }
+       })
+
+       t.Run("tdigest.rank and revrank with unordered elements containing 
duplicate values", func(t *testing.T) {
+               key := "tdigest_rank_unordered_dup_"
+
+               // Create digest and add unordered elements with duplicates
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "12", "100", 
"50", "36", "75", "81", "35.5", "46", "36", "8.8", "15", "4", "32.5", "12", 
"8.8", "7", "99", "1").Err())
+
+               rsp := rdb.Do(ctx, "TDIGEST.RANK", key, "50", "36", "4", "99", 
"8.8", "0.1", "200")
+               require.NoError(t, rsp.Err())
+               vals, err := rsp.Slice()
+               require.NoError(t, err)
+               require.Len(t, vals, 7)
+               expected := []int64{13, 11, 1, 16, 4, -1, 18}
+               for i, v := range vals {
+                       rank, ok := v.(int64)
+                       require.True(t, ok, "expected int64 but got %T at index 
%d", v, i)
+                       require.EqualValues(t, expected[i], rank, "RANK 
mismatch at index %d", i)
+               }
+
+               rsp = rdb.Do(ctx, "TDIGEST.REVRANK", key, "50", "36", "4", 
"99", "8.8", "8.8", "12", "99", "200")
+               require.NoError(t, rsp.Err())
+               vals, err = rsp.Slice()
+               require.NoError(t, err)
+               require.Len(t, vals, 9)
+               expected = []int64{4, 7, 16, 1, 14, 14, 12, 1, -1}
+               for i, v := range vals {
+                       rank, ok := v.(int64)
+                       require.True(t, ok, "expected int64 but got %T at index 
%d", v, i)
+                       require.EqualValues(t, expected[i], rank, "REVRANK 
mismatch at index %d", i)
+               }
+       })
 }

Reply via email to