This is an automated email from the ASF dual-hosted git repository.
edwardxu pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new 65db00a86 refactor(tdigest): refactor the `TDigest::Rank` `reverse`
flag into a compile-time argument (#3268)
65db00a86 is described below
commit 65db00a86bc9031f4d974eb0b6f1893bd126c524
Author: Hao Dong <[email protected]>
AuthorDate: Sat Dec 13 21:03:13 2025 +0800
refactor(tdigest): refactor the `TDigest::Rank` `reverse` flag into a
compile-time argument (#3268)
Co-authored-by: Twice <[email protected]>
Co-authored-by: Edward Xu <[email protected]>
---
src/commands/cmd_tdigest.cc | 13 ++-
src/types/redis_tdigest.cc | 166 +++++++++++++++++-------------------
src/types/redis_tdigest.h | 7 +-
src/types/tdigest.h | 15 +---
tests/cppunit/types/tdigest_test.cc | 20 ++---
5 files changed, 109 insertions(+), 112 deletions(-)
diff --git a/src/commands/cmd_tdigest.cc b/src/commands/cmd_tdigest.cc
index 64dfafcd7..a5479cb12 100644
--- a/src/commands/cmd_tdigest.cc
+++ b/src/commands/cmd_tdigest.cc
@@ -176,7 +176,7 @@ class CommandTDigestAdd : public Commander {
std::vector<double> values_;
};
-template <bool reverse>
+template <bool Reverse>
class TDigestRankCommand : public Commander {
public:
Status Parse(const std::vector<std::string> &args) override {
@@ -202,7 +202,16 @@ class TDigestRankCommand : public Commander {
TDigest tdigest(srv->storage, conn->GetNamespace());
std::vector<int> result;
result.reserve(origin_inputs_.size());
- if (const auto s = tdigest.Rank(ctx, key_name_, unique_inputs_, reverse,
result); !s.ok()) {
+
+ if (const auto s =
+ [&]() {
+ if constexpr (Reverse) {
+ return tdigest.RevRank(ctx, key_name_, unique_inputs_, result);
+ } else {
+ return tdigest.Rank(ctx, key_name_, unique_inputs_, result);
+ }
+ }();
+ !s.ok()) {
if (s.IsNotFound()) {
return {Status::RedisExecErr, errKeyNotFound};
}
diff --git a/src/types/redis_tdigest.cc b/src/types/redis_tdigest.cc
index fec7aef1c..d9a7d425a 100644
--- a/src/types/redis_tdigest.cc
+++ b/src/types/redis_tdigest.cc
@@ -47,114 +47,81 @@
namespace redis {
+namespace {
+template <bool Reverse, typename Container>
+inline decltype(auto) GetCbeginIter(const Container& centroids) {
+ if constexpr (Reverse) {
+ return centroids.crbegin();
+ } else {
+ return centroids.cbegin();
+ }
+}
+
+template <bool Reverse, typename Container>
+inline decltype(auto) GetCendIter(const Container& centroids) {
+ if constexpr (Reverse) {
+ return centroids.crend();
+ } else {
+ return centroids.cend();
+ }
+}
+} // namespace
+
// TODO: It should be replaced by a iteration of the rocksdb iterator
+template <bool Reverse>
class DummyCentroids {
public:
- class BaseIterator {
- public:
- virtual ~BaseIterator() = default;
- virtual bool Next() = 0;
- virtual bool Prev() = 0;
- virtual bool Valid() const = 0;
- virtual std::unique_ptr<BaseIterator> Clone() const = 0;
- virtual StatusOr<Centroid> GetCentroid() const = 0;
- };
-
DummyCentroids(const TDigestMetadata& meta_data, const
std::vector<Centroid>& centroids)
: meta_data_(meta_data), centroids_(centroids) {}
- class Iterator : public BaseIterator {
+ class Iterator {
public:
- Iterator(std::vector<Centroid>::const_iterator&& iter, const
std::vector<Centroid>& centroids)
- : iter_(iter), centroids_(centroids) {}
- std::unique_ptr<BaseIterator> Clone() const override {
- if (iter_ != centroids_.cend()) {
- return std::make_unique<Iterator>(std::next(centroids_.cbegin(),
std::distance(centroids_.cbegin(), iter_)),
- centroids_);
+ using IterType = std::conditional_t<Reverse,
std::vector<Centroid>::const_reverse_iterator,
+ std::vector<Centroid>::const_iterator>;
+ Iterator(IterType iter, const std::vector<Centroid>& centroids) :
iter_(iter), centroids_(centroids) {}
+ std::unique_ptr<Iterator> Clone() const {
+ if (iter_ != GetCendIter<Reverse>(centroids_)) {
+ return std::make_unique<Iterator>(
+ std::next(GetCbeginIter<Reverse>(centroids_),
std::distance(GetCbeginIter<Reverse>(centroids_), iter_)),
+ centroids_);
}
- return std::make_unique<Iterator>(centroids_.cend(), centroids_);
+ return std::make_unique<Iterator>(GetCendIter<Reverse>(centroids_),
centroids_);
}
- bool Next() override {
+ bool Next() {
if (Valid()) {
std::advance(iter_, 1);
}
- return iter_ != centroids_.cend();
+ return iter_ != GetCendIter<Reverse>(centroids_);
}
// The Prev function can only be called for item is not cend,
// because we must guarantee the iterator to be inside the valid range
before iteration.
- bool Prev() override {
- if (Valid() && iter_ != centroids_.cbegin()) {
+ bool Prev() {
+ if (Valid() && iter_ != GetCendIter<Reverse>(centroids_)) {
std::advance(iter_, -1);
}
return Valid();
}
- bool Valid() const override { return iter_ != centroids_.cend(); }
- StatusOr<Centroid> GetCentroid() const override {
- if (iter_ == centroids_.cend()) {
+ bool Valid() const { return iter_ != GetCendIter<Reverse>(centroids_); }
+ StatusOr<Centroid> GetCentroid() const {
+ if (iter_ == GetCendIter<Reverse>(centroids_)) {
return {::Status::NotOK, "invalid iterator during decoding tdigest
centroid"};
}
return *iter_;
}
private:
- std::vector<Centroid>::const_iterator iter_;
+ IterType iter_;
const std::vector<Centroid>& centroids_;
};
- class ReverseIterator final : public BaseIterator {
- public:
- ReverseIterator(std::vector<Centroid>::const_reverse_iterator&& iter,
const std::vector<Centroid>& centroids)
- : iter_(iter), centroids_(centroids) {}
- std::unique_ptr<BaseIterator> Clone() const override {
- if (iter_ != centroids_.crend()) {
- return std::make_unique<ReverseIterator>(
- std::next(centroids_.crbegin(),
std::distance(centroids_.crbegin(), iter_)), centroids_);
- }
- return std::make_unique<ReverseIterator>(centroids_.crend(), centroids_);
- }
- bool Next() override {
- if (Valid()) {
- std::advance(iter_, 1);
- }
- return iter_ != centroids_.crend();
- }
-
- bool Prev() override {
- if (Valid() && iter_ != centroids_.crbegin()) {
- std::advance(iter_, -1);
- }
- return Valid();
- }
- bool Valid() const override { return iter_ != centroids_.crend(); }
- StatusOr<Centroid> GetCentroid() const override {
- if (iter_ == centroids_.crend()) {
- return {::Status::NotOK, "invalid iterator during decoding tdigest
centroid"};
- }
- return *iter_;
- }
-
- private:
- std::vector<Centroid>::const_reverse_iterator iter_;
- const std::vector<Centroid>& centroids_;
- };
-
- std::unique_ptr<BaseIterator> Begin(const bool reverse = false) const {
- if (reverse) {
- return std::make_unique<ReverseIterator>(centroids_.crbegin(),
centroids_);
- }
- return std::make_unique<Iterator>(centroids_.cbegin(), centroids_);
+ std::unique_ptr<Iterator> Begin() const {
+ return std::make_unique<Iterator>(GetCbeginIter<Reverse>(centroids_),
centroids_);
}
- std::unique_ptr<BaseIterator> End(const bool reverse = false) const {
+ std::unique_ptr<Iterator> End() const {
if (centroids_.empty()) {
- if (reverse) {
- return std::make_unique<ReverseIterator>(centroids_.crend(),
centroids_);
- }
- return std::make_unique<Iterator>(centroids_.cend(), centroids_);
- }
- if (reverse) {
- return std::make_unique<ReverseIterator>(std::prev(centroids_.crend()),
centroids_);
+ return std::make_unique<Iterator>(GetCendIter<Reverse>(centroids_),
centroids_);
}
- return std::make_unique<Iterator>(std::prev(centroids_.cend()),
centroids_);
+ return
std::make_unique<Iterator>(std::prev(GetCendIter<Reverse>(centroids_)),
centroids_);
}
double TotalWeight() const { return
static_cast<double>(meta_data_.total_weight); }
double Min() const { return meta_data_.minimum; }
@@ -273,10 +240,9 @@ rocksdb::Status TDigest::mergeNodes(engine::Context& ctx,
const std::string& ns_
return rocksdb::Status::OK();
}
-rocksdb::Status TDigest::Rank(engine::Context& ctx, const Slice& digest_name,
const std::vector<double>& inputs,
- bool reverse, std::vector<int>& result) {
+rocksdb::Status TDigest::prepareRankData(engine::Context& ctx, const Slice&
digest_name, TDigestMetadata& metadata,
+ std::vector<Centroid>& centroids) {
auto ns_key = AppendNamespacePrefix(digest_name);
- TDigestMetadata metadata;
{
LockGuard guard(storage_->GetLockManager(), ns_key);
@@ -285,7 +251,6 @@ rocksdb::Status TDigest::Rank(engine::Context& ctx, const
Slice& digest_name, co
}
if (metadata.total_observations == 0) {
- result.resize(inputs.size(), -2);
return rocksdb::Status::OK();
}
@@ -293,15 +258,44 @@ rocksdb::Status TDigest::Rank(engine::Context& ctx, const
Slice& digest_name, co
return status;
}
}
+ return dumpCentroids(ctx, ns_key, metadata, ¢roids);
+}
+rocksdb::Status TDigest::Rank(engine::Context& ctx, const Slice& digest_name,
const std::vector<double>& inputs,
+ std::vector<int>& result) {
+ TDigestMetadata metadata;
std::vector<Centroid> centroids;
- if (auto status = dumpCentroids(ctx, ns_key, metadata, ¢roids);
!status.ok()) {
+ if (auto status = prepareRankData(ctx, digest_name, metadata, centroids);
!status.ok()) {
+ return status;
+ }
+
+ if (metadata.total_observations == 0) {
+ result.resize(inputs.size(), -2);
+ return rocksdb::Status::OK();
+ }
+
+ auto dump_centroids = DummyCentroids<false>(metadata, centroids);
+ if (auto status = TDigestRank<false>(dump_centroids, inputs, result);
!status) {
+ return rocksdb::Status::InvalidArgument(status.Msg());
+ }
+ return rocksdb::Status::OK();
+}
+
+rocksdb::Status TDigest::RevRank(engine::Context& ctx, const Slice&
digest_name, const std::vector<double>& inputs,
+ std::vector<int>& result) {
+ TDigestMetadata metadata;
+ std::vector<Centroid> centroids;
+ if (auto status = prepareRankData(ctx, digest_name, metadata, centroids);
!status.ok()) {
return status;
}
- auto dump_centroids = DummyCentroids(metadata, centroids);
- auto status = TDigestRank(dump_centroids, inputs, reverse, result);
- if (!status) {
+ if (metadata.total_observations == 0) {
+ result.resize(inputs.size(), -2);
+ return rocksdb::Status::OK();
+ }
+
+ auto dump_centroids = DummyCentroids<true>(metadata, centroids);
+ if (auto status = TDigestRank<true>(dump_centroids, inputs, result);
!status) {
return rocksdb::Status::InvalidArgument(status.Msg());
}
return rocksdb::Status::OK();
@@ -332,7 +326,7 @@ rocksdb::Status TDigest::Quantile(engine::Context& ctx,
const Slice& digest_name
return status;
}
- auto dump_centroids = DummyCentroids(metadata, centroids);
+ auto dump_centroids = DummyCentroids<false>(metadata, centroids);
auto quantile_results = std::vector<double>();
quantile_results.reserve(qs.size());
diff --git a/src/types/redis_tdigest.h b/src/types/redis_tdigest.h
index 5daaed80c..4844f009f 100644
--- a/src/types/redis_tdigest.h
+++ b/src/types/redis_tdigest.h
@@ -77,8 +77,10 @@ class TDigest : public SubKeyScanner {
rocksdb::Status Merge(engine::Context& ctx, const Slice& dest_digest, const
std::vector<std::string>& source_digests,
const TDigestMergeOptions& options);
- rocksdb::Status Rank(engine::Context& ctx, const Slice& digest_name, const
std::vector<double>& inputs, bool reverse,
+ rocksdb::Status Rank(engine::Context& ctx, const Slice& digest_name, const
std::vector<double>& inputs,
std::vector<int>& result);
+ rocksdb::Status RevRank(engine::Context& ctx, const Slice& digest_name,
const std::vector<double>& inputs,
+ std::vector<int>& result);
rocksdb::Status GetMetaData(engine::Context& context, const Slice&
digest_name, TDigestMetadata* metadata);
private:
@@ -130,6 +132,7 @@ class TDigest : public SubKeyScanner {
static std::string internalValueFromCentroid(const Centroid& centroid);
rocksdb::Status decodeCentroidFromKeyValue(const rocksdb::Slice& key, const
rocksdb::Slice& value,
Centroid* centroid) const;
+ rocksdb::Status prepareRankData(engine::Context& ctx, const Slice&
digest_name, TDigestMetadata& metadata,
+ std::vector<Centroid>& centroids);
};
-
} // namespace redis
diff --git a/src/types/tdigest.h b/src/types/tdigest.h
index d77b673f7..e4caf914f 100644
--- a/src/types/tdigest.h
+++ b/src/types/tdigest.h
@@ -171,8 +171,8 @@ struct DoubleComparator {
bool operator()(const double& a, const double& b) const { return
DoubleCompare(a, b) == -1; }
};
-template <typename TD, bool Reverse>
-inline Status TDigestRankImpl(TD&& td, const std::vector<double>& inputs,
std::vector<int>& result) {
+template <bool Reverse, typename TD>
+inline Status TDigestRank(TD&& td, const std::vector<double>& inputs,
std::vector<int>& result) {
std::map<double, size_t, DoubleComparator> value_to_index;
for (size_t i = 0; i < inputs.size(); ++i) {
value_to_index[inputs[i]] = i;
@@ -211,7 +211,7 @@ inline Status TDigestRankImpl(TD&& td, const
std::vector<double>& inputs, std::v
}
}
- auto iter = td.Begin(Reverse);
+ auto iter = td.Begin();
double cumulative_weight = 0;
while (iter->Valid() && !is_end()) {
auto centroid = GET_OR_RET(iter->GetCentroid());
@@ -267,12 +267,3 @@ inline Status TDigestRankImpl(TD&& td, const
std::vector<double>& inputs, std::v
}
return Status::OK();
}
-
-template <typename TD>
-inline Status TDigestRank(TD&& td, const std::vector<double>& inputs, bool
reverse, std::vector<int>& result) {
- if (reverse) {
- return TDigestRankImpl<TD, true>(std::forward<TD>(td), inputs, result);
- } else {
- return TDigestRankImpl<TD, false>(std::forward<TD>(td), inputs, result);
- }
-}
diff --git a/tests/cppunit/types/tdigest_test.cc
b/tests/cppunit/types/tdigest_test.cc
index 9f5ad9736..a6ae7b295 100644
--- a/tests/cppunit/types/tdigest_test.cc
+++ b/tests/cppunit/types/tdigest_test.cc
@@ -312,7 +312,7 @@ TEST_F(RedisTDigestTest,
RevRank_and_Rank_on_the_set_containing_different_elemen
std::vector<int> result;
result.reserve(input.size());
const std::vector<double> value = {0, 10, 20, 30, 40, 50, 60, 70};
- status = tdigest_->Rank(*ctx_, test_digest_name, value, true, result);
+ status = tdigest_->RevRank(*ctx_, test_digest_name, value, result);
const auto expect_result_revrank = std::vector<double>{6, 5, 4, 3, 2, 1, 0,
-1};
for (size_t i = 0; i < result.size(); i++) {
@@ -323,7 +323,7 @@ TEST_F(RedisTDigestTest,
RevRank_and_Rank_on_the_set_containing_different_elemen
result.clear();
result.reserve(input.size());
- status = tdigest_->Rank(*ctx_, test_digest_name, value, false, result);
+ status = tdigest_->Rank(*ctx_, test_digest_name, value, result);
const auto expect_result_rank = std::vector<double>{-1, 0, 1, 2, 3, 4, 5, 6};
for (size_t i = 0; i < result.size(); i++) {
auto got = result[i];
@@ -345,7 +345,7 @@ TEST_F(RedisTDigestTest,
RevRank_and_Rank_on_the_set_containing_several_identica
std::vector<int> result;
const std::vector<double> value = {10, 20};
result.reserve(value.size());
- status = tdigest_->Rank(*ctx_, test_digest_name, value, true, result);
+ status = tdigest_->RevRank(*ctx_, test_digest_name, value, result);
const auto expect_result_revrank = std::vector<double>{3, 1};
for (size_t i = 0; i < result.size(); i++) {
auto got = result[i];
@@ -355,7 +355,7 @@ TEST_F(RedisTDigestTest,
RevRank_and_Rank_on_the_set_containing_several_identica
result.clear();
result.reserve(value.size());
- status = tdigest_->Rank(*ctx_, test_digest_name, value, false, result);
+ status = tdigest_->Rank(*ctx_, test_digest_name, value, result);
const auto expect_result_rank = std::vector<double>{1, 4};
for (size_t i = 0; i < result.size(); i++) {
auto got = result[i];
@@ -368,7 +368,7 @@ TEST_F(RedisTDigestTest,
RevRank_and_Rank_on_the_set_containing_several_identica
result.clear();
result.reserve(value.size());
- status = tdigest_->Rank(*ctx_, test_digest_name, value, true, result);
+ status = tdigest_->RevRank(*ctx_, test_digest_name, value, result);
const auto expect_result_new_revrank = std::vector<double>{4, 1};
for (size_t i = 0; i < result.size(); i++) {
auto got = result[i];
@@ -378,7 +378,7 @@ TEST_F(RedisTDigestTest,
RevRank_and_Rank_on_the_set_containing_several_identica
result.clear();
result.reserve(value.size());
- status = tdigest_->Rank(*ctx_, test_digest_name, value, false, result);
+ status = tdigest_->Rank(*ctx_, test_digest_name, value, result);
const auto expect_result_new_rank = std::vector<double>{2, 5};
for (size_t i = 0; i < result.size(); i++) {
auto got = result[i];
@@ -397,7 +397,7 @@ TEST_F(RedisTDigestTest, RevRank_and_Rank_on_empty_tdigest)
{
std::vector<int> result;
result.reserve(2);
const std::vector<double> value = {10, 20};
- status = tdigest_->Rank(*ctx_, test_digest_name, value, true, result);
+ status = tdigest_->RevRank(*ctx_, test_digest_name, value, result);
const auto expect_result_revrank = std::vector<double>{-2, -2};
for (size_t i = 0; i < result.size(); i++) {
auto got = result[i];
@@ -407,7 +407,7 @@ TEST_F(RedisTDigestTest, RevRank_and_Rank_on_empty_tdigest)
{
result.clear();
result.reserve(2);
- status = tdigest_->Rank(*ctx_, test_digest_name, value, false, result);
+ status = tdigest_->Rank(*ctx_, test_digest_name, value, result);
const auto expect_result_rank = std::vector<double>{-2, -2};
for (size_t i = 0; i < result.size(); i++) {
auto got = result[i];
@@ -430,7 +430,7 @@ TEST_F(RedisTDigestTest,
RevRank_and_Rank_on_different_or_same_and_unordered_inp
std::vector<int> result;
const std::vector<double> value = {50, 36, 4, 99, 8.8};
result.reserve(value.size());
- status = tdigest_->Rank(*ctx_, test_digest_name, value, false, result);
+ status = tdigest_->Rank(*ctx_, test_digest_name, value, result);
const auto expect_result_rank = std::vector<double>{13, 11, 1, 16, 4};
for (size_t i = 0; i < result.size(); i++) {
auto got = result[i];
@@ -441,7 +441,7 @@ TEST_F(RedisTDigestTest,
RevRank_and_Rank_on_different_or_same_and_unordered_inp
const std::vector<double> value_new = {50, 36, 4, 99, 8.8, 12};
result.clear();
result.reserve(value_new.size());
- status = tdigest_->Rank(*ctx_, test_digest_name, value_new, true, result);
+ status = tdigest_->RevRank(*ctx_, test_digest_name, value_new, result);
const auto expect_result_revrank = std::vector<double>{4, 7, 16, 1, 14, 12};
for (size_t i = 0; i < result.size(); i++) {
auto got = result[i];