This is an automated email from the ASF dual-hosted git repository.

edwardxu pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new 35e72fc7a feat(tdigest): add the support of TDIGEST.REVRANK command 
(#3130)
35e72fc7a is described below

commit 35e72fc7a3386cfd2413ac25ce0e4d5bbe9fe1da
Author: Hao Dong <[email protected]>
AuthorDate: Mon Nov 3 18:46:09 2025 +0800

    feat(tdigest): add the support of TDIGEST.REVRANK command (#3130)
    
    # ISSUE
    It closes #3063.
    
    ## Proposed Changes
    Add TDIGEST.REVRANK command implementation
    Add cpp unit tests
    
    ---------
    
    Co-authored-by: Twice <[email protected]>
    Co-authored-by: donghao526 <[email protected]>
    Co-authored-by: Zhixin Wen <[email protected]>
    Co-authored-by: hulk <[email protected]>
    Co-authored-by: RX Xiao <[email protected]>
    Co-authored-by: Jonah Gao <[email protected]>
    Co-authored-by: Roman Donchenko <[email protected]>
    Co-authored-by: Aleks Lozovyuk <[email protected]>
    Co-authored-by: sryan yuan <[email protected]>
    Co-authored-by: Edward Xu <[email protected]>
    Co-authored-by: Copilot <[email protected]>
    Co-authored-by: Edward Xu <[email protected]>
---
 src/commands/cmd_tdigest.cc                    | 49 ++++++++++++++
 src/types/redis_tdigest.cc                     | 85 ++++++++++++++++++------
 src/types/redis_tdigest.h                      |  5 +-
 src/types/tdigest.h                            | 91 ++++++++++++++++++++++++++
 tests/cppunit/types/tdigest_test.cc            | 76 +++++++++++++++++++++
 tests/gocase/unit/type/tdigest/tdigest_test.go | 69 +++++++++++++++++++
 6 files changed, 354 insertions(+), 21 deletions(-)

diff --git a/src/commands/cmd_tdigest.cc b/src/commands/cmd_tdigest.cc
index e7815256d..36bf555a2 100644
--- a/src/commands/cmd_tdigest.cc
+++ b/src/commands/cmd_tdigest.cc
@@ -176,6 +176,54 @@ class CommandTDigestAdd : public Commander {
   std::vector<double> values_;
 };
 
+class CommandTDigestRevRank : public Commander {
+ public:
+  Status Parse(const std::vector<std::string> &args) override {
+    key_name_ = args[1];
+
+    std::set<std::string> unique_inputs_set(args.begin() + 2, args.end());
+    origin_inputs_.assign(args.begin() + 2, args.end());
+
+    unique_inputs_.reserve(unique_inputs_set.size());
+    size_t i = 0;
+    for (const auto &input : unique_inputs_set) {
+      auto value = ParseFloat(input);
+      if (!value) {
+        return {Status::RedisParseErr, errValueIsNotFloat};
+      }
+      unique_inputs_.push_back(*value);
+      unique_inputs_order_[input] = i;
+      ++i;
+    }
+    return Status::OK();
+  }
+  Status Execute(engine::Context &ctx, Server *srv, Connection *conn, 
std::string *output) override {
+    TDigest tdigest(srv->storage, conn->GetNamespace());
+    std::vector<int> result;
+    result.reserve(origin_inputs_.size());
+    if (const auto s = tdigest.RevRank(ctx, key_name_, unique_inputs_, 
result); !s.ok()) {
+      if (s.IsNotFound()) {
+        return {Status::RedisExecErr, errKeyNotFound};
+      }
+      return {Status::RedisExecErr, s.ToString()};
+    }
+
+    std::vector<std::string> rev_ranks;
+    rev_ranks.reserve(origin_inputs_.size());
+    for (const auto &v : origin_inputs_) {
+      rev_ranks.push_back(redis::Integer(result[unique_inputs_order_[v]]));
+    }
+    *output = redis::Array(rev_ranks);
+    return Status::OK();
+  }
+
+ private:
+  std::string key_name_;
+  std::vector<double> unique_inputs_;
+  std::map<std::string, size_t> unique_inputs_order_;
+  std::vector<std::string> origin_inputs_;
+};
+
 class CommandTDigestMinMax : public Commander {
  public:
   explicit CommandTDigestMinMax(bool is_min) : is_min_(is_min) {}
@@ -369,6 +417,7 @@ REDIS_REGISTER_COMMANDS(TDigest, 
MakeCmdAttr<CommandTDigestCreate>("tdigest.crea
                         MakeCmdAttr<CommandTDigestAdd>("tdigest.add", -3, 
"write", 1, 1, 1),
                         MakeCmdAttr<CommandTDigestMax>("tdigest.max", 2, 
"read-only", 1, 1, 1),
                         MakeCmdAttr<CommandTDigestMin>("tdigest.min", 2, 
"read-only", 1, 1, 1),
+                        MakeCmdAttr<CommandTDigestRevRank>("tdigest.revrank", 
-3, "read-only", 1, 1, 1),
                         
MakeCmdAttr<CommandTDigestQuantile>("tdigest.quantile", -3, "read-only", 1, 1, 
1),
                         MakeCmdAttr<CommandTDigestReset>("tdigest.reset", 2, 
"write", 1, 1, 1),
                         MakeCmdAttr<CommandTDigestMerge>("tdigest.merge", -4, 
"write", GetMergeKeyRange));
diff --git a/src/types/redis_tdigest.cc b/src/types/redis_tdigest.cc
index f506ad92e..2d45a58d8 100644
--- a/src/types/redis_tdigest.cc
+++ b/src/types/redis_tdigest.cc
@@ -70,6 +70,8 @@ class DummyCentroids {
       return iter_ != centroids_.cend();
     }
 
+    bool IsBegin() { return iter_ == centroids_.cbegin(); }
+
     // The Prev function can only be called for item is not cend,
     // because we must guarantee the iterator to be inside the valid range 
before iteration.
     bool Prev() {
@@ -186,8 +188,37 @@ rocksdb::Status TDigest::Add(engine::Context& ctx, const 
Slice& digest_name, con
   return storage_->Write(ctx, storage_->DefaultWriteOptions(), 
batch->GetWriteBatch());
 }
 
-rocksdb::Status TDigest::Quantile(engine::Context& ctx, const Slice& 
digest_name, const std::vector<double>& qs,
-                                  TDigestQuantitleResult* result) {
+rocksdb::Status TDigest::mergeNodes(engine::Context& ctx, const std::string& 
ns_key, TDigestMetadata* metadata) {
+  if (metadata->unmerged_nodes == 0) {
+    return rocksdb::Status::OK();
+  }
+
+  auto batch = storage_->GetWriteBatchBase();
+  WriteBatchLogData log_data(kRedisTDigest);
+  if (auto status = batch->PutLogData(log_data.Encode()); !status.ok()) {
+    return status;
+  }
+
+  if (auto status = mergeCurrentBuffer(ctx, ns_key, batch, metadata); 
!status.ok()) {
+    return status;
+  }
+
+  std::string metadata_bytes;
+  metadata->Encode(&metadata_bytes);
+  if (auto status = batch->Put(metadata_cf_handle_, ns_key, metadata_bytes); 
!status.ok()) {
+    return status;
+  }
+
+  if (auto status = storage_->Write(ctx, storage_->DefaultWriteOptions(), 
batch->GetWriteBatch()); !status.ok()) {
+    return status;
+  }
+
+  ctx.RefreshLatestSnapshot();
+  return rocksdb::Status::OK();
+}
+
+rocksdb::Status TDigest::RevRank(engine::Context& ctx, const Slice& 
digest_name, const std::vector<double>& inputs,
+                                 std::vector<int>& result) {
   auto ns_key = AppendNamespacePrefix(digest_name);
   TDigestMetadata metadata;
   {
@@ -198,31 +229,45 @@ rocksdb::Status TDigest::Quantile(engine::Context& ctx, 
const Slice& digest_name
     }
 
     if (metadata.total_observations == 0) {
+      result.resize(inputs.size(), -2);
       return rocksdb::Status::OK();
     }
 
-    if (metadata.unmerged_nodes > 0) {
-      auto batch = storage_->GetWriteBatchBase();
-      WriteBatchLogData log_data(kRedisTDigest);
-      if (auto status = batch->PutLogData(log_data.Encode()); !status.ok()) {
-        return status;
-      }
+    if (auto status = mergeNodes(ctx, ns_key, &metadata); !status.ok()) {
+      return status;
+    }
+  }
 
-      if (auto status = mergeCurrentBuffer(ctx, ns_key, batch, &metadata); 
!status.ok()) {
-        return status;
-      }
+  std::vector<Centroid> centroids;
+  if (auto status = dumpCentroids(ctx, ns_key, metadata, &centroids); 
!status.ok()) {
+    return status;
+  }
 
-      std::string metadata_bytes;
-      metadata.Encode(&metadata_bytes);
-      if (auto status = batch->Put(metadata_cf_handle_, ns_key, 
metadata_bytes); !status.ok()) {
-        return status;
-      }
+  auto dump_centroids = DummyCentroids(metadata, centroids);
+  auto status = TDigestRevRank(dump_centroids, inputs, result);
+  if (!status) {
+    return rocksdb::Status::InvalidArgument(status.Msg());
+  }
+  return rocksdb::Status::OK();
+}
 
-      if (auto status = storage_->Write(ctx, storage_->DefaultWriteOptions(), 
batch->GetWriteBatch()); !status.ok()) {
-        return status;
-      }
+rocksdb::Status TDigest::Quantile(engine::Context& ctx, const Slice& 
digest_name, const std::vector<double>& qs,
+                                  TDigestQuantitleResult* result) {
+  auto ns_key = AppendNamespacePrefix(digest_name);
+  TDigestMetadata metadata;
+  {
+    LockGuard guard(storage_->GetLockManager(), ns_key);
 
-      ctx.RefreshLatestSnapshot();
+    if (auto status = getMetaDataByNsKey(ctx, ns_key, &metadata); 
!status.ok()) {
+      return status;
+    }
+
+    if (metadata.total_observations == 0) {
+      return rocksdb::Status::OK();
+    }
+
+    if (auto status = mergeNodes(ctx, ns_key, &metadata); !status.ok()) {
+      return status;
     }
   }
 
diff --git a/src/types/redis_tdigest.h b/src/types/redis_tdigest.h
index 2026cf94d..02ecc24c0 100644
--- a/src/types/redis_tdigest.h
+++ b/src/types/redis_tdigest.h
@@ -77,7 +77,8 @@ class TDigest : public SubKeyScanner {
 
   rocksdb::Status Merge(engine::Context& ctx, const Slice& dest_digest, const 
std::vector<std::string>& source_digests,
                         const TDigestMergeOptions& options);
-
+  rocksdb::Status RevRank(engine::Context& ctx, const Slice& digest_name, 
const std::vector<double>& inputs,
+                          std::vector<int>& result);
   rocksdb::Status GetMetaData(engine::Context& context, const Slice& 
digest_name, TDigestMetadata* metadata);
 
  private:
@@ -117,6 +118,8 @@ class TDigest : public SubKeyScanner {
   std::string internalSegmentGuardPrefixKey(const TDigestMetadata& metadata, 
const std::string& ns_key,
                                             SegmentType seg) const;
 
+  rocksdb::Status mergeNodes(engine::Context& ctx, const std::string& ns_key, 
TDigestMetadata* metadata);
+
   rocksdb::Status mergeCurrentBuffer(engine::Context& ctx, const std::string& 
ns_key,
                                      
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch, TDigestMetadata* metadata,
                                      const std::vector<double>* 
additional_buffer = nullptr,
diff --git a/src/types/tdigest.h b/src/types/tdigest.h
index 1f416b48e..34843d6da 100644
--- a/src/types/tdigest.h
+++ b/src/types/tdigest.h
@@ -22,6 +22,8 @@
 
 #include <fmt/format.h>
 
+#include <map>
+#include <numeric>
 #include <vector>
 
 #include "common/status.h"
@@ -150,3 +152,92 @@ inline StatusOr<double> TDigestQuantile(TD&& td, double q) 
{
   diff /= (lc.weight / 2 + rc.weight / 2);
   return Lerp(lc.mean, rc.mean, diff);
 }
+
+inline int DoubleCompare(double a, double b, double rel_eps = 1e-12, double 
abs_eps = 1e-9) {
+  double diff = a - b;
+  double adiff = std::abs(diff);
+  if (adiff <= abs_eps) return 0;
+  double maxab = std::max(std::abs(a), std::abs(b));
+  if (adiff <= maxab * rel_eps) return 0;
+  return (diff < 0) ? -1 : 1;
+}
+
+inline bool DoubleEqual(double a, double b, double rel_eps = 1e-12, double 
abs_eps = 1e-9) {
+  return DoubleCompare(a, b, rel_eps, abs_eps) == 0;
+}
+
+struct DoubleComparator {
+  bool operator()(const double& a, const double& b) const { return 
DoubleCompare(a, b) == -1; }
+};
+
+template <typename TD>
+inline Status TDigestRevRank(TD&& td, const std::vector<double>& inputs, 
std::vector<int>& result) {
+  std::map<double, size_t, DoubleComparator> value_to_indices;
+  for (size_t i = 0; i < inputs.size(); ++i) {
+    value_to_indices[inputs[i]] = i;
+  }
+
+  result.clear();
+  result.resize(inputs.size(), -2);
+  auto it = value_to_indices.rbegin();
+
+  // handle inputs larger than maximum
+  while (it != value_to_indices.rend() && it->first > td.Max()) {
+    result[it->second] = -1;
+    ++it;
+  }
+
+  auto iter = td.End();
+  double cumulative_weight = 0;
+  while (iter->Valid() && it != value_to_indices.rend()) {
+    auto centroid = GET_OR_RET(iter->GetCentroid());
+    auto input_value = it->first;
+    if (DoubleEqual(centroid.mean, input_value)) {
+      auto current_mean = centroid.mean;
+      auto current_mean_cumulative_weight = cumulative_weight + 
centroid.weight / 2;
+      cumulative_weight += centroid.weight;
+
+      // handle all the previous centroids which has the same mean
+      while (!iter->IsBegin() && iter->Prev()) {
+        auto next_centroid = GET_OR_RET(iter->GetCentroid());
+        if (!DoubleEqual(current_mean, next_centroid.mean)) {
+          // move back to the last equal centroid, because we will process it 
in the next loop
+          iter->Next();
+          break;
+        }
+        current_mean_cumulative_weight += next_centroid.weight / 2;
+        cumulative_weight += next_centroid.weight;
+      }
+
+      // handle the prev inputs which have the same value
+      result[it->second] = static_cast<int>(current_mean_cumulative_weight);
+      ++it;
+      if (iter->IsBegin()) {
+        break;
+      }
+      iter->Prev();
+    } else if (DoubleCompare(centroid.mean, input_value) > 0) {
+      cumulative_weight += centroid.weight;
+      if (iter->IsBegin()) {
+        break;
+      }
+      iter->Prev();
+    } else {
+      result[it->second] = static_cast<int>(cumulative_weight);
+      ++it;
+    }
+  }
+
+  // handle inputs less than minimum
+  while (it != value_to_indices.rend()) {
+    result[it->second] = static_cast<int>(td.TotalWeight());
+    ++it;
+  }
+
+  for (auto r : result) {
+    if (r <= -2) {
+      return Status{Status::InvalidArgument, "invalid result when computing 
revrank"};
+    }
+  }
+  return Status::OK();
+}
diff --git a/tests/cppunit/types/tdigest_test.cc 
b/tests/cppunit/types/tdigest_test.cc
index 91d1f311f..ae4bf29e9 100644
--- a/tests/cppunit/types/tdigest_test.cc
+++ b/tests/cppunit/types/tdigest_test.cc
@@ -298,3 +298,79 @@ TEST_F(RedisTDigestTest, 
Quantile_returns_nan_on_empty_tdigest) {
   ASSERT_TRUE(status.ok()) << status.ToString();
   ASSERT_FALSE(result.quantiles) << "should not have quantiles with empty 
tdigest";
 }
+
+TEST_F(RedisTDigestTest, RevRank_on_the_set_containing_different_elements) {
+  std::string test_digest_name = "test_digest_revrank" + 
std::to_string(util::GetTimeStampMS());
+  bool exists = false;
+  auto status = tdigest_->Create(*ctx_, test_digest_name, {100}, &exists);
+  ASSERT_FALSE(exists);
+  ASSERT_TRUE(status.ok());
+  std::vector<double> input{10, 20, 30, 40, 50, 60};
+  status = tdigest_->Add(*ctx_, test_digest_name, input);
+  ASSERT_TRUE(status.ok()) << status.ToString();
+
+  std::vector<int> result;
+  result.reserve(input.size());
+  const std::vector<double> value = {0, 10, 20, 30, 40, 50, 60, 70};
+  status = tdigest_->RevRank(*ctx_, test_digest_name, value, result);
+  const auto expect_result = std::vector<double>{6, 5, 4, 3, 2, 1, 0, -1};
+
+  for (size_t i = 0; i < result.size(); i++) {
+    auto got = result[i];
+    EXPECT_EQ(got, expect_result[i]);
+  }
+  ASSERT_TRUE(status.ok()) << status.ToString();
+}
+
+TEST_F(RedisTDigestTest, 
RevRank_on_the_set_containing_several_identical_elements) {
+  std::string test_digest_name = "test_digest_revrank" + 
std::to_string(util::GetTimeStampMS());
+  bool exists = false;
+  auto status = tdigest_->Create(*ctx_, test_digest_name, {100}, &exists);
+  ASSERT_FALSE(exists);
+  ASSERT_TRUE(status.ok());
+  std::vector<double> input{10, 10, 10, 20, 20};
+  status = tdigest_->Add(*ctx_, test_digest_name, input);
+  ASSERT_TRUE(status.ok()) << status.ToString();
+
+  std::vector<int> result;
+  result.reserve(input.size());
+  const std::vector<double> value = {10, 20};
+  status = tdigest_->RevRank(*ctx_, test_digest_name, value, result);
+  const auto expect_result = std::vector<double>{3, 1};
+  for (size_t i = 0; i < result.size(); i++) {
+    auto got = result[i];
+    EXPECT_EQ(got, expect_result[i]);
+  }
+  ASSERT_TRUE(status.ok()) << status.ToString();
+
+  status = tdigest_->Add(*ctx_, test_digest_name, std::vector<double>{10});
+  ASSERT_TRUE(status.ok()) << status.ToString();
+
+  result.clear();
+  status = tdigest_->RevRank(*ctx_, test_digest_name, value, result);
+  const auto expect_result_new = std::vector<double>{4, 1};
+  for (size_t i = 0; i < result.size(); i++) {
+    auto got = result[i];
+    EXPECT_EQ(got, expect_result_new[i]);
+  }
+  ASSERT_TRUE(status.ok()) << status.ToString();
+}
+
+TEST_F(RedisTDigestTest, RevRank_on_empty_tdigest) {
+  std::string test_digest_name = "test_digest_revrank" + 
std::to_string(util::GetTimeStampMS());
+  bool exists = false;
+  auto status = tdigest_->Create(*ctx_, test_digest_name, {100}, &exists);
+  ASSERT_FALSE(exists);
+  ASSERT_TRUE(status.ok());
+
+  std::vector<int> result;
+  result.reserve(2);
+  const std::vector<double> value = {10, 20};
+  status = tdigest_->RevRank(*ctx_, test_digest_name, value, result);
+  const auto expect_result = std::vector<double>{-2, -2};
+  for (size_t i = 0; i < result.size(); i++) {
+    auto got = result[i];
+    EXPECT_EQ(got, expect_result[i]);
+  }
+  ASSERT_TRUE(status.ok()) << status.ToString();
+}
\ No newline at end of file
diff --git a/tests/gocase/unit/type/tdigest/tdigest_test.go 
b/tests/gocase/unit/type/tdigest/tdigest_test.go
index 3c2565ac0..129be0554 100644
--- a/tests/gocase/unit/type/tdigest/tdigest_test.go
+++ b/tests/gocase/unit/type/tdigest/tdigest_test.go
@@ -518,4 +518,73 @@ func tdigestTests(t *testing.T, configs 
util.KvrocksServerConfigs) {
                validation(newDestKey1)
                validation(newDestKey2)
        })
+
+       t.Run("tdigest.revrank with different arguments", func(t *testing.T) {
+               keyPrefix := "tdigest_revrank_"
+
+               // Test invalid arguments
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.REVRANK").Err(), 
errMsgWrongNumberArg)
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.REVRANK", 
keyPrefix+"nonexistent").Err(), errMsgWrongNumberArg)
+
+               // Test Non-existent key
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.REVRANK", 
keyPrefix+"nonexistent", "10").Err(), errMsgKeyNotExist)
+
+               // Test with empty tdigest
+               key := keyPrefix + "test1"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               rsp := rdb.Do(ctx, "TDIGEST.REVRANK", key, "10")
+               require.NoError(t, rsp.Err())
+               vals, err := rsp.Slice()
+               require.NoError(t, err)
+               require.Len(t, vals, 1)
+               expected := []int64{-2}
+               for i, v := range vals {
+                       rank, ok := v.(int64)
+                       require.True(t, ok, "expected int64 but got %T at index 
%d", v, i)
+                       require.EqualValues(t, rank, expected[i])
+               }
+
+               // Test with set containing several identical elements
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "10", "10", 
"10", "20", "20").Err())
+               rsp = rdb.Do(ctx, "TDIGEST.REVRANK", key, "10", "20")
+               require.NoError(t, rsp.Err())
+               vals, err = rsp.Slice()
+               require.NoError(t, err)
+               require.Len(t, vals, 2)
+               expected = []int64{3, 1}
+               for i, v := range vals {
+                       rank, ok := v.(int64)
+                       require.True(t, ok, "expected int64 but got %T at index 
%d", v, i)
+                       require.EqualValues(t, rank, expected[i])
+               }
+
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "10").Err())
+               rsp = rdb.Do(ctx, "TDIGEST.REVRANK", key, "10", "20")
+               require.NoError(t, rsp.Err())
+               vals, err = rsp.Slice()
+               require.NoError(t, err)
+               require.Len(t, vals, 2)
+               expected = []int64{4, 1}
+               for i, v := range vals {
+                       rank, ok := v.(int64)
+                       require.True(t, ok, "expected int64 but got %T at index 
%d", v, i)
+                       require.EqualValues(t, rank, expected[i])
+               }
+
+               // Test with set containing different elements
+               key2 := keyPrefix + "test2"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key2, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key2, "10", "20", 
"30", "40", "50", "60").Err())
+               rsp = rdb.Do(ctx, "TDIGEST.REVRANK", key2, "0", "10", "20", 
"30", "40", "50", "60", "70")
+               require.NoError(t, rsp.Err())
+               vals, err = rsp.Slice()
+               require.NoError(t, err)
+               require.Len(t, vals, 8)
+               expected = []int64{6, 5, 4, 3, 2, 1, 0, -1}
+               for i, v := range vals {
+                       rank, ok := v.(int64)
+                       require.True(t, ok, "expected int64 but got %T at index 
%d", v, i)
+                       require.EqualValues(t, rank, expected[i])
+               }
+       })
 }

Reply via email to