This is an automated email from the ASF dual-hosted git repository.
edwardxu pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new 35e72fc7a feat(tdigest): add the support of TDIGEST.REVRANK command
(#3130)
35e72fc7a is described below
commit 35e72fc7a3386cfd2413ac25ce0e4d5bbe9fe1da
Author: Hao Dong <[email protected]>
AuthorDate: Mon Nov 3 18:46:09 2025 +0800
feat(tdigest): add the support of TDIGEST.REVRANK command (#3130)
# ISSUE
It closes #3063.
## Proposed Changes
Add TDIGEST.REVRANK command implementation
Add cpp unit tests
---------
Co-authored-by: Twice <[email protected]>
Co-authored-by: donghao526 <[email protected]>
Co-authored-by: Zhixin Wen <[email protected]>
Co-authored-by: hulk <[email protected]>
Co-authored-by: RX Xiao <[email protected]>
Co-authored-by: Jonah Gao <[email protected]>
Co-authored-by: Roman Donchenko <[email protected]>
Co-authored-by: Aleks Lozovyuk <[email protected]>
Co-authored-by: sryan yuan <[email protected]>
Co-authored-by: Edward Xu <[email protected]>
Co-authored-by: Copilot <[email protected]>
Co-authored-by: Edward Xu <[email protected]>
---
src/commands/cmd_tdigest.cc | 49 ++++++++++++++
src/types/redis_tdigest.cc | 85 ++++++++++++++++++------
src/types/redis_tdigest.h | 5 +-
src/types/tdigest.h | 91 ++++++++++++++++++++++++++
tests/cppunit/types/tdigest_test.cc | 76 +++++++++++++++++++++
tests/gocase/unit/type/tdigest/tdigest_test.go | 69 +++++++++++++++++++
6 files changed, 354 insertions(+), 21 deletions(-)
diff --git a/src/commands/cmd_tdigest.cc b/src/commands/cmd_tdigest.cc
index e7815256d..36bf555a2 100644
--- a/src/commands/cmd_tdigest.cc
+++ b/src/commands/cmd_tdigest.cc
@@ -176,6 +176,54 @@ class CommandTDigestAdd : public Commander {
std::vector<double> values_;
};
+class CommandTDigestRevRank : public Commander {
+ public:
+ Status Parse(const std::vector<std::string> &args) override {
+ key_name_ = args[1];
+
+ std::set<std::string> unique_inputs_set(args.begin() + 2, args.end());
+ origin_inputs_.assign(args.begin() + 2, args.end());
+
+ unique_inputs_.reserve(unique_inputs_set.size());
+ size_t i = 0;
+ for (const auto &input : unique_inputs_set) {
+ auto value = ParseFloat(input);
+ if (!value) {
+ return {Status::RedisParseErr, errValueIsNotFloat};
+ }
+ unique_inputs_.push_back(*value);
+ unique_inputs_order_[input] = i;
+ ++i;
+ }
+ return Status::OK();
+ }
+ Status Execute(engine::Context &ctx, Server *srv, Connection *conn,
std::string *output) override {
+ TDigest tdigest(srv->storage, conn->GetNamespace());
+ std::vector<int> result;
+ result.reserve(origin_inputs_.size());
+ if (const auto s = tdigest.RevRank(ctx, key_name_, unique_inputs_,
result); !s.ok()) {
+ if (s.IsNotFound()) {
+ return {Status::RedisExecErr, errKeyNotFound};
+ }
+ return {Status::RedisExecErr, s.ToString()};
+ }
+
+ std::vector<std::string> rev_ranks;
+ rev_ranks.reserve(origin_inputs_.size());
+ for (const auto &v : origin_inputs_) {
+ rev_ranks.push_back(redis::Integer(result[unique_inputs_order_[v]]));
+ }
+ *output = redis::Array(rev_ranks);
+ return Status::OK();
+ }
+
+ private:
+ std::string key_name_;
+ std::vector<double> unique_inputs_;
+ std::map<std::string, size_t> unique_inputs_order_;
+ std::vector<std::string> origin_inputs_;
+};
+
class CommandTDigestMinMax : public Commander {
public:
explicit CommandTDigestMinMax(bool is_min) : is_min_(is_min) {}
@@ -369,6 +417,7 @@ REDIS_REGISTER_COMMANDS(TDigest,
MakeCmdAttr<CommandTDigestCreate>("tdigest.crea
MakeCmdAttr<CommandTDigestAdd>("tdigest.add", -3,
"write", 1, 1, 1),
MakeCmdAttr<CommandTDigestMax>("tdigest.max", 2,
"read-only", 1, 1, 1),
MakeCmdAttr<CommandTDigestMin>("tdigest.min", 2,
"read-only", 1, 1, 1),
+ MakeCmdAttr<CommandTDigestRevRank>("tdigest.revrank",
-3, "read-only", 1, 1, 1),
MakeCmdAttr<CommandTDigestQuantile>("tdigest.quantile", -3, "read-only", 1, 1,
1),
MakeCmdAttr<CommandTDigestReset>("tdigest.reset", 2,
"write", 1, 1, 1),
MakeCmdAttr<CommandTDigestMerge>("tdigest.merge", -4,
"write", GetMergeKeyRange));
diff --git a/src/types/redis_tdigest.cc b/src/types/redis_tdigest.cc
index f506ad92e..2d45a58d8 100644
--- a/src/types/redis_tdigest.cc
+++ b/src/types/redis_tdigest.cc
@@ -70,6 +70,8 @@ class DummyCentroids {
return iter_ != centroids_.cend();
}
+ bool IsBegin() { return iter_ == centroids_.cbegin(); }
+
// The Prev function can only be called for item is not cend,
// because we must guarantee the iterator to be inside the valid range
before iteration.
bool Prev() {
@@ -186,8 +188,37 @@ rocksdb::Status TDigest::Add(engine::Context& ctx, const
Slice& digest_name, con
return storage_->Write(ctx, storage_->DefaultWriteOptions(),
batch->GetWriteBatch());
}
-rocksdb::Status TDigest::Quantile(engine::Context& ctx, const Slice&
digest_name, const std::vector<double>& qs,
- TDigestQuantitleResult* result) {
+rocksdb::Status TDigest::mergeNodes(engine::Context& ctx, const std::string&
ns_key, TDigestMetadata* metadata) {
+ if (metadata->unmerged_nodes == 0) {
+ return rocksdb::Status::OK();
+ }
+
+ auto batch = storage_->GetWriteBatchBase();
+ WriteBatchLogData log_data(kRedisTDigest);
+ if (auto status = batch->PutLogData(log_data.Encode()); !status.ok()) {
+ return status;
+ }
+
+ if (auto status = mergeCurrentBuffer(ctx, ns_key, batch, metadata);
!status.ok()) {
+ return status;
+ }
+
+ std::string metadata_bytes;
+ metadata->Encode(&metadata_bytes);
+ if (auto status = batch->Put(metadata_cf_handle_, ns_key, metadata_bytes);
!status.ok()) {
+ return status;
+ }
+
+ if (auto status = storage_->Write(ctx, storage_->DefaultWriteOptions(),
batch->GetWriteBatch()); !status.ok()) {
+ return status;
+ }
+
+ ctx.RefreshLatestSnapshot();
+ return rocksdb::Status::OK();
+}
+
+rocksdb::Status TDigest::RevRank(engine::Context& ctx, const Slice&
digest_name, const std::vector<double>& inputs,
+ std::vector<int>& result) {
auto ns_key = AppendNamespacePrefix(digest_name);
TDigestMetadata metadata;
{
@@ -198,31 +229,45 @@ rocksdb::Status TDigest::Quantile(engine::Context& ctx,
const Slice& digest_name
}
if (metadata.total_observations == 0) {
+ result.resize(inputs.size(), -2);
return rocksdb::Status::OK();
}
- if (metadata.unmerged_nodes > 0) {
- auto batch = storage_->GetWriteBatchBase();
- WriteBatchLogData log_data(kRedisTDigest);
- if (auto status = batch->PutLogData(log_data.Encode()); !status.ok()) {
- return status;
- }
+ if (auto status = mergeNodes(ctx, ns_key, &metadata); !status.ok()) {
+ return status;
+ }
+ }
- if (auto status = mergeCurrentBuffer(ctx, ns_key, batch, &metadata);
!status.ok()) {
- return status;
- }
+ std::vector<Centroid> centroids;
+ if (auto status = dumpCentroids(ctx, ns_key, metadata, ¢roids);
!status.ok()) {
+ return status;
+ }
- std::string metadata_bytes;
- metadata.Encode(&metadata_bytes);
- if (auto status = batch->Put(metadata_cf_handle_, ns_key,
metadata_bytes); !status.ok()) {
- return status;
- }
+ auto dump_centroids = DummyCentroids(metadata, centroids);
+ auto status = TDigestRevRank(dump_centroids, inputs, result);
+ if (!status) {
+ return rocksdb::Status::InvalidArgument(status.Msg());
+ }
+ return rocksdb::Status::OK();
+}
- if (auto status = storage_->Write(ctx, storage_->DefaultWriteOptions(),
batch->GetWriteBatch()); !status.ok()) {
- return status;
- }
+rocksdb::Status TDigest::Quantile(engine::Context& ctx, const Slice&
digest_name, const std::vector<double>& qs,
+ TDigestQuantitleResult* result) {
+ auto ns_key = AppendNamespacePrefix(digest_name);
+ TDigestMetadata metadata;
+ {
+ LockGuard guard(storage_->GetLockManager(), ns_key);
- ctx.RefreshLatestSnapshot();
+ if (auto status = getMetaDataByNsKey(ctx, ns_key, &metadata);
!status.ok()) {
+ return status;
+ }
+
+ if (metadata.total_observations == 0) {
+ return rocksdb::Status::OK();
+ }
+
+ if (auto status = mergeNodes(ctx, ns_key, &metadata); !status.ok()) {
+ return status;
}
}
diff --git a/src/types/redis_tdigest.h b/src/types/redis_tdigest.h
index 2026cf94d..02ecc24c0 100644
--- a/src/types/redis_tdigest.h
+++ b/src/types/redis_tdigest.h
@@ -77,7 +77,8 @@ class TDigest : public SubKeyScanner {
rocksdb::Status Merge(engine::Context& ctx, const Slice& dest_digest, const
std::vector<std::string>& source_digests,
const TDigestMergeOptions& options);
-
+ rocksdb::Status RevRank(engine::Context& ctx, const Slice& digest_name,
const std::vector<double>& inputs,
+ std::vector<int>& result);
rocksdb::Status GetMetaData(engine::Context& context, const Slice&
digest_name, TDigestMetadata* metadata);
private:
@@ -117,6 +118,8 @@ class TDigest : public SubKeyScanner {
std::string internalSegmentGuardPrefixKey(const TDigestMetadata& metadata,
const std::string& ns_key,
SegmentType seg) const;
+ rocksdb::Status mergeNodes(engine::Context& ctx, const std::string& ns_key,
TDigestMetadata* metadata);
+
rocksdb::Status mergeCurrentBuffer(engine::Context& ctx, const std::string&
ns_key,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch, TDigestMetadata* metadata,
const std::vector<double>*
additional_buffer = nullptr,
diff --git a/src/types/tdigest.h b/src/types/tdigest.h
index 1f416b48e..34843d6da 100644
--- a/src/types/tdigest.h
+++ b/src/types/tdigest.h
@@ -22,6 +22,8 @@
#include <fmt/format.h>
+#include <map>
+#include <numeric>
#include <vector>
#include "common/status.h"
@@ -150,3 +152,92 @@ inline StatusOr<double> TDigestQuantile(TD&& td, double q)
{
diff /= (lc.weight / 2 + rc.weight / 2);
return Lerp(lc.mean, rc.mean, diff);
}
+
+inline int DoubleCompare(double a, double b, double rel_eps = 1e-12, double
abs_eps = 1e-9) {
+ double diff = a - b;
+ double adiff = std::abs(diff);
+ if (adiff <= abs_eps) return 0;
+ double maxab = std::max(std::abs(a), std::abs(b));
+ if (adiff <= maxab * rel_eps) return 0;
+ return (diff < 0) ? -1 : 1;
+}
+
+inline bool DoubleEqual(double a, double b, double rel_eps = 1e-12, double
abs_eps = 1e-9) {
+ return DoubleCompare(a, b, rel_eps, abs_eps) == 0;
+}
+
+struct DoubleComparator {
+ bool operator()(const double& a, const double& b) const { return
DoubleCompare(a, b) == -1; }
+};
+
+template <typename TD>
+inline Status TDigestRevRank(TD&& td, const std::vector<double>& inputs,
std::vector<int>& result) {
+ std::map<double, size_t, DoubleComparator> value_to_indices;
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ value_to_indices[inputs[i]] = i;
+ }
+
+ result.clear();
+ result.resize(inputs.size(), -2);
+ auto it = value_to_indices.rbegin();
+
+ // handle inputs larger than maximum
+ while (it != value_to_indices.rend() && it->first > td.Max()) {
+ result[it->second] = -1;
+ ++it;
+ }
+
+ auto iter = td.End();
+ double cumulative_weight = 0;
+ while (iter->Valid() && it != value_to_indices.rend()) {
+ auto centroid = GET_OR_RET(iter->GetCentroid());
+ auto input_value = it->first;
+ if (DoubleEqual(centroid.mean, input_value)) {
+ auto current_mean = centroid.mean;
+ auto current_mean_cumulative_weight = cumulative_weight +
centroid.weight / 2;
+ cumulative_weight += centroid.weight;
+
+ // handle all the previous centroids which has the same mean
+ while (!iter->IsBegin() && iter->Prev()) {
+ auto next_centroid = GET_OR_RET(iter->GetCentroid());
+ if (!DoubleEqual(current_mean, next_centroid.mean)) {
+ // move back to the last equal centroid, because we will process it
in the next loop
+ iter->Next();
+ break;
+ }
+ current_mean_cumulative_weight += next_centroid.weight / 2;
+ cumulative_weight += next_centroid.weight;
+ }
+
+ // handle the prev inputs which have the same value
+ result[it->second] = static_cast<int>(current_mean_cumulative_weight);
+ ++it;
+ if (iter->IsBegin()) {
+ break;
+ }
+ iter->Prev();
+ } else if (DoubleCompare(centroid.mean, input_value) > 0) {
+ cumulative_weight += centroid.weight;
+ if (iter->IsBegin()) {
+ break;
+ }
+ iter->Prev();
+ } else {
+ result[it->second] = static_cast<int>(cumulative_weight);
+ ++it;
+ }
+ }
+
+ // handle inputs less than minimum
+ while (it != value_to_indices.rend()) {
+ result[it->second] = static_cast<int>(td.TotalWeight());
+ ++it;
+ }
+
+ for (auto r : result) {
+ if (r <= -2) {
+ return Status{Status::InvalidArgument, "invalid result when computing
revrank"};
+ }
+ }
+ return Status::OK();
+}
diff --git a/tests/cppunit/types/tdigest_test.cc
b/tests/cppunit/types/tdigest_test.cc
index 91d1f311f..ae4bf29e9 100644
--- a/tests/cppunit/types/tdigest_test.cc
+++ b/tests/cppunit/types/tdigest_test.cc
@@ -298,3 +298,79 @@ TEST_F(RedisTDigestTest,
Quantile_returns_nan_on_empty_tdigest) {
ASSERT_TRUE(status.ok()) << status.ToString();
ASSERT_FALSE(result.quantiles) << "should not have quantiles with empty
tdigest";
}
+
+TEST_F(RedisTDigestTest, RevRank_on_the_set_containing_different_elements) {
+ std::string test_digest_name = "test_digest_revrank" +
std::to_string(util::GetTimeStampMS());
+ bool exists = false;
+ auto status = tdigest_->Create(*ctx_, test_digest_name, {100}, &exists);
+ ASSERT_FALSE(exists);
+ ASSERT_TRUE(status.ok());
+ std::vector<double> input{10, 20, 30, 40, 50, 60};
+ status = tdigest_->Add(*ctx_, test_digest_name, input);
+ ASSERT_TRUE(status.ok()) << status.ToString();
+
+ std::vector<int> result;
+ result.reserve(input.size());
+ const std::vector<double> value = {0, 10, 20, 30, 40, 50, 60, 70};
+ status = tdigest_->RevRank(*ctx_, test_digest_name, value, result);
+ const auto expect_result = std::vector<double>{6, 5, 4, 3, 2, 1, 0, -1};
+
+ for (size_t i = 0; i < result.size(); i++) {
+ auto got = result[i];
+ EXPECT_EQ(got, expect_result[i]);
+ }
+ ASSERT_TRUE(status.ok()) << status.ToString();
+}
+
+TEST_F(RedisTDigestTest,
RevRank_on_the_set_containing_several_identical_elements) {
+ std::string test_digest_name = "test_digest_revrank" +
std::to_string(util::GetTimeStampMS());
+ bool exists = false;
+ auto status = tdigest_->Create(*ctx_, test_digest_name, {100}, &exists);
+ ASSERT_FALSE(exists);
+ ASSERT_TRUE(status.ok());
+ std::vector<double> input{10, 10, 10, 20, 20};
+ status = tdigest_->Add(*ctx_, test_digest_name, input);
+ ASSERT_TRUE(status.ok()) << status.ToString();
+
+ std::vector<int> result;
+ result.reserve(input.size());
+ const std::vector<double> value = {10, 20};
+ status = tdigest_->RevRank(*ctx_, test_digest_name, value, result);
+ const auto expect_result = std::vector<double>{3, 1};
+ for (size_t i = 0; i < result.size(); i++) {
+ auto got = result[i];
+ EXPECT_EQ(got, expect_result[i]);
+ }
+ ASSERT_TRUE(status.ok()) << status.ToString();
+
+ status = tdigest_->Add(*ctx_, test_digest_name, std::vector<double>{10});
+ ASSERT_TRUE(status.ok()) << status.ToString();
+
+ result.clear();
+ status = tdigest_->RevRank(*ctx_, test_digest_name, value, result);
+ const auto expect_result_new = std::vector<double>{4, 1};
+ for (size_t i = 0; i < result.size(); i++) {
+ auto got = result[i];
+ EXPECT_EQ(got, expect_result_new[i]);
+ }
+ ASSERT_TRUE(status.ok()) << status.ToString();
+}
+
+TEST_F(RedisTDigestTest, RevRank_on_empty_tdigest) {
+ std::string test_digest_name = "test_digest_revrank" +
std::to_string(util::GetTimeStampMS());
+ bool exists = false;
+ auto status = tdigest_->Create(*ctx_, test_digest_name, {100}, &exists);
+ ASSERT_FALSE(exists);
+ ASSERT_TRUE(status.ok());
+
+ std::vector<int> result;
+ result.reserve(2);
+ const std::vector<double> value = {10, 20};
+ status = tdigest_->RevRank(*ctx_, test_digest_name, value, result);
+ const auto expect_result = std::vector<double>{-2, -2};
+ for (size_t i = 0; i < result.size(); i++) {
+ auto got = result[i];
+ EXPECT_EQ(got, expect_result[i]);
+ }
+ ASSERT_TRUE(status.ok()) << status.ToString();
+}
\ No newline at end of file
diff --git a/tests/gocase/unit/type/tdigest/tdigest_test.go
b/tests/gocase/unit/type/tdigest/tdigest_test.go
index 3c2565ac0..129be0554 100644
--- a/tests/gocase/unit/type/tdigest/tdigest_test.go
+++ b/tests/gocase/unit/type/tdigest/tdigest_test.go
@@ -518,4 +518,73 @@ func tdigestTests(t *testing.T, configs
util.KvrocksServerConfigs) {
validation(newDestKey1)
validation(newDestKey2)
})
+
+ t.Run("tdigest.revrank with different arguments", func(t *testing.T) {
+ keyPrefix := "tdigest_revrank_"
+
+ // Test invalid arguments
+ require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.REVRANK").Err(),
errMsgWrongNumberArg)
+ require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.REVRANK",
keyPrefix+"nonexistent").Err(), errMsgWrongNumberArg)
+
+ // Test Non-existent key
+ require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.REVRANK",
keyPrefix+"nonexistent", "10").Err(), errMsgKeyNotExist)
+
+ // Test with empty tdigest
+ key := keyPrefix + "test1"
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key,
"compression", "100").Err())
+ rsp := rdb.Do(ctx, "TDIGEST.REVRANK", key, "10")
+ require.NoError(t, rsp.Err())
+ vals, err := rsp.Slice()
+ require.NoError(t, err)
+ require.Len(t, vals, 1)
+ expected := []int64{-2}
+ for i, v := range vals {
+ rank, ok := v.(int64)
+ require.True(t, ok, "expected int64 but got %T at index
%d", v, i)
+ require.EqualValues(t, rank, expected[i])
+ }
+
+ // Test with set containing several identical elements
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "10", "10",
"10", "20", "20").Err())
+ rsp = rdb.Do(ctx, "TDIGEST.REVRANK", key, "10", "20")
+ require.NoError(t, rsp.Err())
+ vals, err = rsp.Slice()
+ require.NoError(t, err)
+ require.Len(t, vals, 2)
+ expected = []int64{3, 1}
+ for i, v := range vals {
+ rank, ok := v.(int64)
+ require.True(t, ok, "expected int64 but got %T at index
%d", v, i)
+ require.EqualValues(t, rank, expected[i])
+ }
+
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "10").Err())
+ rsp = rdb.Do(ctx, "TDIGEST.REVRANK", key, "10", "20")
+ require.NoError(t, rsp.Err())
+ vals, err = rsp.Slice()
+ require.NoError(t, err)
+ require.Len(t, vals, 2)
+ expected = []int64{4, 1}
+ for i, v := range vals {
+ rank, ok := v.(int64)
+ require.True(t, ok, "expected int64 but got %T at index
%d", v, i)
+ require.EqualValues(t, rank, expected[i])
+ }
+
+ // Test with set containing different elements
+ key2 := keyPrefix + "test2"
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key2,
"compression", "100").Err())
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key2, "10", "20",
"30", "40", "50", "60").Err())
+ rsp = rdb.Do(ctx, "TDIGEST.REVRANK", key2, "0", "10", "20",
"30", "40", "50", "60", "70")
+ require.NoError(t, rsp.Err())
+ vals, err = rsp.Slice()
+ require.NoError(t, err)
+ require.Len(t, vals, 8)
+ expected = []int64{6, 5, 4, 3, 2, 1, 0, -1}
+ for i, v := range vals {
+ rank, ok := v.(int64)
+ require.True(t, ok, "expected int64 but got %T at index
%d", v, i)
+ require.EqualValues(t, rank, expected[i])
+ }
+ })
}