This is an automated email from the ASF dual-hosted git repository.
twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new 54c5862f feat(tdigest): add tdigest algorithm and storage encoding
implementations (#2741)
54c5862f is described below
commit 54c5862fee740da59ca28648d5bfdc0be094aa52
Author: Edward Xu <[email protected]>
AuthorDate: Wed Feb 12 14:17:15 2025 +0800
feat(tdigest): add tdigest algorithm and storage encoding implementations
(#2741)
Co-authored-by: Twice <[email protected]>
Co-authored-by: Aleks Lozovyuk <[email protected]>
Co-authored-by: Twice <[email protected]>
---
src/storage/redis_metadata.cc | 40 +++-
src/storage/redis_metadata.h | 32 ++-
src/types/redis_tdigest.cc | 442 ++++++++++++++++++++++++++++++++++++
src/types/redis_tdigest.h | 116 ++++++++++
src/types/tdigest.cc | 435 +++++++++++++++++++++++++++++++++++
src/types/tdigest.h | 150 ++++++++++++
tests/cppunit/types/tdigest_test.cc | 247 ++++++++++++++++++++
7 files changed, 1458 insertions(+), 4 deletions(-)
diff --git a/src/storage/redis_metadata.cc b/src/storage/redis_metadata.cc
index 9df10532..c957901a 100644
--- a/src/storage/redis_metadata.cc
+++ b/src/storage/redis_metadata.cc
@@ -331,7 +331,8 @@ bool Metadata::ExpireAt(uint64_t expired_ts) const {
bool Metadata::IsSingleKVType() const { return Type() == kRedisString ||
Type() == kRedisJson; }
bool Metadata::IsEmptyableType() const {
- return IsSingleKVType() || Type() == kRedisStream || Type() ==
kRedisBloomFilter || Type() == kRedisHyperLogLog;
+ return IsSingleKVType() || Type() == kRedisStream || Type() ==
kRedisBloomFilter || Type() == kRedisHyperLogLog ||
+ Type() == kRedisTDigest;
}
bool Metadata::Expired() const { return ExpireAt(util::GetTimeStampMS()); }
@@ -497,3 +498,40 @@ rocksdb::Status HyperLogLogMetadata::Decode(Slice *input) {
return rocksdb::Status::OK();
}
+
+void TDigestMetadata::Encode(std::string *dst) const {
+ Metadata::Encode(dst);
+ PutFixed32(dst, compression);
+ PutFixed32(dst, capacity);
+ PutFixed64(dst, unmerged_nodes);
+ PutFixed64(dst, merged_nodes);
+ PutFixed64(dst, total_weight);
+ PutFixed64(dst, merged_weight);
+ PutDouble(dst, minimum);
+ PutDouble(dst, maximum);
+ PutFixed64(dst, total_observations);
+ PutFixed64(dst, merge_times);
+}
+
+rocksdb::Status TDigestMetadata::Decode(Slice *input) {
+ if (auto s = Metadata::Decode(input); !s.ok()) {
+ return s;
+ }
+
+ if (input->size() < (sizeof(uint32_t) * 2 + sizeof(uint64_t) * 6 +
sizeof(double) * 2)) {
+ return rocksdb::Status::InvalidArgument(kErrMetadataTooShort);
+ }
+
+ GetFixed32(input, &compression);
+ GetFixed32(input, &capacity);
+ GetFixed64(input, &unmerged_nodes);
+ GetFixed64(input, &merged_nodes);
+ GetFixed64(input, &total_weight);
+ GetFixed64(input, &merged_weight);
+ GetDouble(input, &minimum);
+ GetDouble(input, &maximum);
+ GetFixed64(input, &total_observations);
+ GetFixed64(input, &merge_times);
+
+ return rocksdb::Status::OK();
+}
diff --git a/src/storage/redis_metadata.h b/src/storage/redis_metadata.h
index 69a0db1c..dd956e0e 100644
--- a/src/storage/redis_metadata.h
+++ b/src/storage/redis_metadata.h
@@ -26,6 +26,7 @@
#include <atomic>
#include <bitset>
#include <initializer_list>
+#include <limits>
#include <string>
#include <vector>
@@ -51,6 +52,7 @@ enum RedisType : uint8_t {
kRedisBloomFilter = 9,
kRedisJson = 10,
kRedisHyperLogLog = 11,
+ kRedisTDigest = 12,
};
struct RedisTypes {
@@ -92,9 +94,9 @@ enum RedisCommand {
kRedisCmdLMove,
};
-const std::vector<std::string> RedisTypeNames = {"none", "string",
"hash", "list",
- "set", "zset",
"bitmap", "sortedint",
- "stream", "MBbloom--",
"ReJSON-RL", "hyperloglog"};
+const std::vector<std::string> RedisTypeNames = {"none", "string",
"hash", "list", "set",
+ "zset", "bitmap",
"sortedint", "stream", "MBbloom--",
+ "ReJSON-RL", "hyperloglog",
"TDIS-TYPE"};
constexpr const char *kErrMsgWrongType = "WRONGTYPE Operation against a key
holding the wrong kind of value";
constexpr const char *kErrMsgKeyExpired = "the key was expired";
@@ -337,3 +339,27 @@ class HyperLogLogMetadata : public Metadata {
EncodeType encode_type = EncodeType::DENSE;
};
+
+class TDigestMetadata : public Metadata {
+ public:
+ uint32_t compression;
+ uint32_t capacity;
+ uint64_t unmerged_nodes = 0;
+ uint64_t merged_nodes = 0;
+ uint64_t total_weight = 0;
+ uint64_t merged_weight = 0;
+ double minimum = std::numeric_limits<double>::max();
+ double maximum = std::numeric_limits<double>::lowest();
+ uint64_t total_observations = 0; // reserved for TDIGEST.INFO command
+ uint64_t merge_times = 0; // reserved for TDIGEST.INFO command
+
+ explicit TDigestMetadata(uint32_t compression, uint32_t capacity, bool
generate_version = true)
+ : Metadata(kRedisTDigest, generate_version), compression(compression),
capacity(capacity) {}
+ explicit TDigestMetadata(bool generate_version = true) : TDigestMetadata(0,
0, generate_version) {}
+ void Encode(std::string *dst) const override;
+ rocksdb::Status Decode(Slice *input) override;
+
+ uint64_t TotalNodes() const { return merged_nodes + unmerged_nodes; }
+
+ double Delta() const { return 1. / static_cast<double>(compression); }
+};
diff --git a/src/types/redis_tdigest.cc b/src/types/redis_tdigest.cc
new file mode 100644
index 00000000..e29b35f0
--- /dev/null
+++ b/src/types/redis_tdigest.cc
@@ -0,0 +1,442 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#include "redis_tdigest.h"
+
+#include <fmt/format.h>
+#include <rocksdb/db.h>
+#include <rocksdb/iterator.h>
+#include <rocksdb/options.h>
+#include <rocksdb/slice.h>
+#include <rocksdb/status.h>
+
+#include <algorithm>
+#include <iterator>
+#include <limits>
+#include <memory>
+#include <range/v3/algorithm/minmax.hpp>
+#include <range/v3/range/conversion.hpp>
+#include <range/v3/view/join.hpp>
+#include <range/v3/view/transform.hpp>
+#include <vector>
+
+#include "db_util.h"
+#include "encoding.h"
+#include "status.h"
+#include "storage/redis_db.h"
+#include "storage/redis_metadata.h"
+#include "types/tdigest.h"
+
+namespace redis {
+
+// TODO: It should be replaced by a iteration of the rocksdb iterator
+class DummyCentroids {
+ public:
+ DummyCentroids(const TDigestMetadata& meta_data, const
std::vector<Centroid>& centroids)
+ : meta_data_(meta_data), centroids_(centroids) {}
+ class Iterator {
+ public:
+ Iterator(std::vector<Centroid>::const_iterator&& iter, const
std::vector<Centroid>& centroids)
+ : iter_(iter), centroids_(centroids) {}
+ std::unique_ptr<Iterator> Clone() const {
+ if (iter_ != centroids_.cend()) {
+ return std::make_unique<Iterator>(std::next(centroids_.cbegin(),
std::distance(centroids_.cbegin(), iter_)),
+ centroids_);
+ }
+ return std::make_unique<Iterator>(centroids_.cend(), centroids_);
+ }
+ bool Next() {
+ if (Valid()) {
+ std::advance(iter_, 1);
+ }
+ return iter_ != centroids_.cend();
+ }
+
+ // The Prev function can only be called for item is not cend,
+ // because we must guarantee the iterator to be inside the valid range
before iteration.
+ bool Prev() {
+ if (Valid() && iter_ != centroids_.cbegin()) {
+ std::advance(iter_, -1);
+ }
+ return Valid();
+ }
+ bool Valid() const { return iter_ != centroids_.cend(); }
+ StatusOr<Centroid> GetCentroid() const {
+ if (iter_ == centroids_.cend()) {
+ return {::Status::NotOK, "invalid iterator during decoding tdigest
centroid"};
+ }
+ return *iter_;
+ }
+
+ private:
+ std::vector<Centroid>::const_iterator iter_;
+ const std::vector<Centroid>& centroids_;
+ };
+
+ std::unique_ptr<Iterator> Begin() { return
std::make_unique<Iterator>(centroids_.cbegin(), centroids_); }
+ std::unique_ptr<Iterator> End() {
+ if (centroids_.empty()) {
+ return std::make_unique<Iterator>(centroids_.cend(), centroids_);
+ }
+ return std::make_unique<Iterator>(std::prev(centroids_.cend()),
centroids_);
+ }
+ double TotalWeight() const { return
static_cast<double>(meta_data_.total_weight); }
+ double Min() const { return meta_data_.minimum; }
+ double Max() const { return meta_data_.maximum; }
+ uint64_t Size() const { return meta_data_.merged_nodes; }
+
+ private:
+ const TDigestMetadata& meta_data_;
+ const std::vector<Centroid>& centroids_;
+};
+
+uint32_t constexpr kMaxElements = 1 * 1024; // 1k doubles
+uint32_t constexpr kMaxCompression = 1000; // limit the compression to 1k
+
+rocksdb::Status TDigest::Create(engine::Context& ctx, const Slice&
digest_name, const TDigestCreateOptions& options,
+ bool* exists) {
+ if (options.compression > kMaxCompression) {
+ return rocksdb::Status::InvalidArgument(fmt::format("compression should be
less than {}", kMaxCompression));
+ }
+
+ auto ns_key = AppendNamespacePrefix(digest_name);
+ auto capacity = options.compression * 6 + 10;
+ capacity = ((capacity < kMaxElements) ? capacity : kMaxElements);
+ TDigestMetadata metadata(options.compression, capacity);
+
+ LockGuard guard(storage_->GetLockManager(), ns_key);
+ auto status = GetMetaData(ctx, ns_key, &metadata);
+ *exists = status.ok();
+ if (*exists) {
+ return rocksdb::Status::InvalidArgument("tdigest already exists");
+ }
+
+ if (!status.IsNotFound()) {
+ return status;
+ }
+
+ auto batch = storage_->GetWriteBatchBase();
+ WriteBatchLogData log_data(kRedisTDigest);
+ if (status = batch->PutLogData(log_data.Encode()); !status.ok()) {
+ return status;
+ }
+
+ std::string metadata_bytes;
+ metadata.Encode(&metadata_bytes);
+ if (status = batch->Put(metadata_cf_handle_, ns_key, metadata_bytes);
!status.ok()) {
+ return status;
+ }
+
+ return storage_->Write(ctx, storage_->DefaultWriteOptions(),
batch->GetWriteBatch());
+}
+
+rocksdb::Status TDigest::Add(engine::Context& ctx, const Slice& digest_name,
const std::vector<double>& inputs) {
+ auto ns_key = AppendNamespacePrefix(digest_name);
+ LockGuard guard(storage_->GetLockManager(), ns_key);
+
+ TDigestMetadata metadata;
+ if (auto status = GetMetaData(ctx, ns_key, &metadata); !status.ok()) {
+ return status;
+ }
+
+ auto batch = storage_->GetWriteBatchBase();
+ WriteBatchLogData log_data(kRedisTDigest);
+ if (auto status = batch->PutLogData(log_data.Encode()); !status.ok()) {
+ return status;
+ }
+
+ metadata.total_observations += inputs.size();
+ metadata.total_weight += inputs.size();
+
+ if (metadata.unmerged_nodes + inputs.size() <= metadata.capacity) {
+ if (auto status = appendBuffer(ctx, batch, ns_key, inputs, &metadata);
!status.ok()) {
+ return status;
+ }
+ metadata.unmerged_nodes += inputs.size();
+ } else {
+ if (auto status = mergeCurrentBuffer(ctx, ns_key, batch, &metadata,
&inputs); !status.ok()) {
+ return status;
+ }
+ }
+
+ std::string metadata_bytes;
+ metadata.Encode(&metadata_bytes);
+ if (auto status = batch->Put(metadata_cf_handle_, ns_key, metadata_bytes);
!status.ok()) {
+ return status;
+ }
+
+ return storage_->Write(ctx, storage_->DefaultWriteOptions(),
batch->GetWriteBatch());
+}
+
+rocksdb::Status TDigest::Quantile(engine::Context& ctx, const Slice&
digest_name, const std::vector<double>& qs,
+ TDigestQuantitleResult* result) {
+ auto ns_key = AppendNamespacePrefix(digest_name);
+ TDigestMetadata metadata;
+ {
+ LockGuard guard(storage_->GetLockManager(), ns_key);
+
+ if (auto status = GetMetaData(ctx, ns_key, &metadata); !status.ok()) {
+ return status;
+ }
+
+ if (metadata.unmerged_nodes > 0) {
+ auto batch = storage_->GetWriteBatchBase();
+ WriteBatchLogData log_data(kRedisTDigest);
+ if (auto status = batch->PutLogData(log_data.Encode()); !status.ok()) {
+ return status;
+ }
+
+ if (auto status = mergeCurrentBuffer(ctx, ns_key, batch, &metadata);
!status.ok()) {
+ return status;
+ }
+
+ std::string metadata_bytes;
+ metadata.Encode(&metadata_bytes);
+ if (auto status = batch->Put(metadata_cf_handle_, ns_key,
metadata_bytes); !status.ok()) {
+ return status;
+ }
+
+ if (auto status = storage_->Write(ctx, storage_->DefaultWriteOptions(),
batch->GetWriteBatch()); !status.ok()) {
+ return status;
+ }
+
+ ctx.RefreshLatestSnapshot();
+ }
+ }
+
+ std::vector<Centroid> centroids;
+ if (auto status = dumpCentroids(ctx, ns_key, metadata, ¢roids);
!status.ok()) {
+ return status;
+ }
+
+ auto dump_centroids = DummyCentroids(metadata, centroids);
+
+ for (auto q : qs) {
+ auto status_or_value = TDigestQuantile(dump_centroids, q);
+ if (!status_or_value) {
+ return rocksdb::Status::InvalidArgument(status_or_value.Msg());
+ }
+ result->quantiles.push_back(*status_or_value);
+ }
+
+ return rocksdb::Status::OK();
+}
+
+rocksdb::Status TDigest::GetMetaData(engine::Context& context, const Slice&
ns_key, TDigestMetadata* metadata) {
+ return Database::GetMetadata(context, {kRedisTDigest}, ns_key, metadata);
+}
+
+rocksdb::Status TDigest::mergeCurrentBuffer(engine::Context& ctx, const
std::string& ns_key,
+
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch,
+ TDigestMetadata* metadata, const
std::vector<double>* additional_buffer) {
+ std::vector<Centroid> centroids;
+ std::vector<double> buffer;
+ centroids.reserve(metadata->merged_nodes);
+ buffer.reserve(metadata->unmerged_nodes + (additional_buffer == nullptr ? 0
: additional_buffer->size()));
+ if (auto status = dumpCentroidsAndBuffer(ctx, ns_key, *metadata, ¢roids,
&buffer, &batch); !status.ok()) {
+ return status;
+ }
+
+ if (additional_buffer != nullptr) {
+ std::copy(additional_buffer->cbegin(), additional_buffer->cend(),
std::back_inserter(buffer));
+ }
+
+ auto merged_centroids = TDigestMerge(buffer, {
+ .centroids = centroids,
+ .delta =
metadata->compression,
+ .min = metadata->minimum,
+ .max = metadata->maximum,
+ .total_weight =
static_cast<double>(metadata->merged_weight),
+ });
+
+ if (!merged_centroids.IsOK()) {
+ return rocksdb::Status::InvalidArgument(merged_centroids.Msg());
+ }
+
+ if (auto status = applyNewCentroids(batch, ns_key, *metadata,
merged_centroids->centroids); !status.ok()) {
+ return status;
+ }
+
+ metadata->merge_times++;
+ metadata->merged_nodes = merged_centroids->centroids.size();
+ metadata->unmerged_nodes = 0;
+ metadata->minimum = merged_centroids->min;
+ metadata->maximum = merged_centroids->max;
+ metadata->merged_weight =
static_cast<uint64_t>(merged_centroids->total_weight);
+
+ return rocksdb::Status::OK();
+}
+
+std::string TDigest::internalBufferKey(const std::string& ns_key, const
TDigestMetadata& metadata) const {
+ std::string sub_key;
+ PutFixed8(&sub_key, static_cast<uint8_t>(SegmentType::kBuffer));
+ return InternalKey(ns_key, sub_key, metadata.version,
storage_->IsSlotIdEncoded()).Encode();
+}
+
+std::string TDigest::internalKeyFromCentroid(const std::string& ns_key, const
TDigestMetadata& metadata,
+ const Centroid& centroid) const {
+ std::string sub_key;
+ PutFixed8(&sub_key, static_cast<uint8_t>(SegmentType::kCentroids));
+ PutDouble(&sub_key, centroid.mean); // It uses EncodeDoubleToUInt64 and
keeps original order of double
+ return InternalKey(ns_key, sub_key, metadata.version,
storage_->IsSlotIdEncoded()).Encode();
+}
+
+std::string TDigest::internalValueFromCentroid(const Centroid& centroid) {
+ std::string value;
+ PutDouble(&value, centroid.weight);
+ return value;
+}
+
+rocksdb::Status TDigest::decodeCentroidFromKeyValue(const rocksdb::Slice& key,
const rocksdb::Slice& value,
+ Centroid* centroid) const {
+ InternalKey ikey(key, storage_->IsSlotIdEncoded());
+ auto subkey = ikey.GetSubKey();
+ auto type_flg = static_cast<uint8_t>(SegmentType::kGuardFlag);
+ if (!GetFixed8(&subkey, &type_flg)) {
+ LOG(ERROR) << "corrupted tdigest centroid key, extract type failed";
+ return rocksdb::Status::Corruption("corrupted tdigest centroid key");
+ }
+ if (static_cast<SegmentType>(type_flg) != SegmentType::kCentroids) {
+ LOG(ERROR) << "corrupted tdigest centroid key type: " << type_flg << ",
expect to be "
+ << static_cast<uint8_t>(SegmentType::kCentroids);
+ return rocksdb::Status::Corruption("corrupted tdigest centroid key type");
+ }
+ if (!GetDouble(&subkey, ¢roid->mean)) {
+ LOG(ERROR) << "corrupted tdigest centroid key, extract mean failed";
+ return rocksdb::Status::Corruption("corrupted tdigest centroid key");
+ }
+
+ if (rocksdb::Slice value_slice = value; // GetDouble needs a mutable
pointer of slice
+ !GetDouble(&value_slice, ¢roid->weight)) {
+ LOG(ERROR) << "corrupted tdigest centroid value, extract weight failed";
+ return rocksdb::Status::Corruption("corrupted tdigest centroid value");
+ }
+ return rocksdb::Status::OK();
+}
+
+rocksdb::Status TDigest::appendBuffer(engine::Context& ctx,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch,
+ const std::string& ns_key, const
std::vector<double>& inputs,
+ TDigestMetadata* metadata) {
+ // must guard by lock
+ auto buffer_key = internalBufferKey(ns_key, *metadata);
+ std::string buffer_value;
+ if (auto status = storage_->Get(ctx, ctx.GetReadOptions(), cf_handle_,
buffer_key, &buffer_value);
+ !status.ok() && !status.IsNotFound()) {
+ return status;
+ }
+
+ for (auto item : inputs) {
+ PutDouble(&buffer_value, item);
+ }
+
+ if (auto status = batch->Put(cf_handle_, buffer_key, buffer_value);
!status.ok()) {
+ return status;
+ }
+
+ return rocksdb::Status::OK();
+}
+
+rocksdb::Status TDigest::dumpCentroidsAndBuffer(engine::Context& ctx, const
std::string& ns_key,
+ const TDigestMetadata&
metadata, std::vector<Centroid>* centroids,
+ std::vector<double>* buffer,
+
ObserverOrUniquePtr<rocksdb::WriteBatchBase>* clean_after_dump_batch) {
+ if (buffer != nullptr) {
+ buffer->clear();
+ buffer->reserve(metadata.unmerged_nodes);
+ auto buffer_key = internalBufferKey(ns_key, metadata);
+ std::string buffer_value;
+ auto status = storage_->Get(ctx, ctx.GetReadOptions(), cf_handle_,
buffer_key, &buffer_value);
+ if (!status.ok() && !status.IsNotFound()) {
+ return status;
+ }
+
+ if (status.ok()) {
+ rocksdb::Slice buffer_slice = buffer_value;
+ for (uint64_t i = 0; i < metadata.unmerged_nodes; ++i) {
+ double tmp_value = std::numeric_limits<double>::quiet_NaN();
+ if (!GetDouble(&buffer_slice, &tmp_value)) {
+ LOG(ERROR) << "metadata has " << metadata.unmerged_nodes << "
records, but get " << i << " failed";
+ return rocksdb::Status::Corruption("corrupted tdigest buffer value");
+ }
+ buffer->emplace_back(tmp_value);
+ }
+ }
+
+ if (clean_after_dump_batch != nullptr) {
+ if (status = (*clean_after_dump_batch)->Delete(cf_handle_, buffer_key);
!status.ok()) {
+ return status;
+ }
+ }
+ }
+
+ centroids->clear();
+ centroids->reserve(metadata.merged_nodes);
+
+ auto start_key = internalSegmentGuardPrefixKey(metadata, ns_key,
SegmentType::kCentroids);
+ auto guard_key = internalSegmentGuardPrefixKey(metadata, ns_key,
SegmentType::kGuardFlag);
+
+ rocksdb::ReadOptions read_options = ctx.DefaultScanOptions();
+ rocksdb::Slice upper_bound(guard_key);
+ read_options.iterate_upper_bound = &upper_bound;
+ rocksdb::Slice lower_bound(start_key);
+ read_options.iterate_lower_bound = &lower_bound;
+
+ auto iter = util::UniqueIterator(ctx, read_options, cf_handle_);
+ for (iter->SeekToFirst(); iter->Valid(); iter->Next()) {
+ Centroid centroid;
+ if (auto status = decodeCentroidFromKeyValue(iter->key(), iter->value(),
¢roid); !status.ok()) {
+ return status;
+ }
+ centroids->emplace_back(centroid);
+ if (clean_after_dump_batch != nullptr) {
+ if (auto status = (*clean_after_dump_batch)->Delete(cf_handle_,
iter->key()); !status.ok()) {
+ return status;
+ }
+ }
+ }
+
+ if (centroids->size() != metadata.merged_nodes) {
+ LOG(ERROR) << "metadata has " << metadata.merged_nodes << " merged nodes,
but got " << centroids->size();
+ return rocksdb::Status::Corruption("centroids count mismatch with
metadata");
+ }
+ return rocksdb::Status::OK();
+}
+
+rocksdb::Status
TDigest::applyNewCentroids(ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch,
+ const std::string& ns_key, const
TDigestMetadata& metadata,
+ const std::vector<Centroid>&
centroids) {
+ for (const auto& c : centroids) {
+ auto centroid_key = internalKeyFromCentroid(ns_key, metadata, c);
+ auto centroid_payload = internalValueFromCentroid(c);
+ if (auto status = batch->Put(cf_handle_, centroid_key, centroid_payload);
!status.ok()) {
+ return status;
+ }
+ }
+
+ return rocksdb::Status::OK();
+}
+
+std::string TDigest::internalSegmentGuardPrefixKey(const TDigestMetadata&
metadata, const std::string& ns_key,
+ SegmentType seg) const {
+ std::string prefix_key;
+ PutFixed8(&prefix_key, static_cast<uint8_t>(seg));
+ return InternalKey(ns_key, prefix_key, metadata.version,
storage_->IsSlotIdEncoded()).Encode();
+}
+} // namespace redis
diff --git a/src/types/redis_tdigest.h b/src/types/redis_tdigest.h
new file mode 100644
index 00000000..c1621338
--- /dev/null
+++ b/src/types/redis_tdigest.h
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#pragma once
+
+#include <rocksdb/db.h>
+#include <rocksdb/slice.h>
+#include <rocksdb/status.h>
+
+#include <vector>
+
+#include "storage/redis_db.h"
+#include "storage/redis_metadata.h"
+#include "storage/storage.h"
+#include "tdigest.h"
+
+namespace redis {
+struct CentroidWithKey {
+ Centroid centroid;
+ rocksdb::Slice key;
+};
+
+struct TDigestCreateOptions {
+ uint32_t compression;
+};
+
+struct TDigestQuantitleResult {
+ std::vector<double> quantiles;
+};
+
+class TDigest : public SubKeyScanner {
+ public:
+ using Slice = rocksdb::Slice;
+ explicit TDigest(engine::Storage* storage, const std::string& ns)
+ : SubKeyScanner(storage, ns),
cf_handle_(storage->GetCFHandle(ColumnFamilyID::PrimarySubkey)) {}
+ /**
+ * @brief Create a t-digest structure.
+ *
+ * @param ctx The context of the operation.
+ * @param digest_name The name of the t-digest.
+ * @param options The options of the t-digest.
+ * @param exists The output parameter to indicate whether the t-digest
already exists.
+ * @return rocksdb::Status
+ */
+ rocksdb::Status Create(engine::Context& ctx, const Slice& digest_name, const
TDigestCreateOptions& options,
+ bool* exists);
+ rocksdb::Status Add(engine::Context& ctx, const Slice& digest_name, const
std::vector<double>& inputs);
+ rocksdb::Status Quantile(engine::Context& ctx, const Slice& digest_name,
const std::vector<double>& qs,
+ TDigestQuantitleResult* result);
+
+ rocksdb::Status GetMetaData(engine::Context& context, const Slice&
digest_name, TDigestMetadata* metadata);
+
+ private:
+ enum class SegmentType : uint8_t { kBuffer = 0, kCentroids = 1, kGuardFlag =
0xFF };
+
+ rocksdb::ColumnFamilyHandle* cf_handle_;
+
+ rocksdb::Status appendBuffer(engine::Context& ctx,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch,
+ const std::string& ns_key, const
std::vector<double>& inputs, TDigestMetadata* metadata);
+
+ rocksdb::Status dumpCentroids(engine::Context& ctx, const std::string&
ns_key, const TDigestMetadata& metadata,
+ std::vector<Centroid>* centroids) {
+ return dumpCentroidsAndBuffer(ctx, ns_key, metadata, centroids, nullptr,
nullptr);
+ }
+
+ /**
+ * @brief Dumps the centroids and buffer of the t-digest.
+ *
+ * This function reads the centroids and buffer from persistent storage and
removes them from the storage.
+ * @param ctx The context of the operation.
+ * @param ns_key The namespace key of the t-digest.
+ * @param metadata The metadata of the t-digest.
+ * @param centroids The output vector to store the centroids.
+ * @param buffer The output vector to store the buffer. If it is nullptr,
the buffer will not be read.
+ * @param clean_after_dump_batch The write batch to store the clean
operations. If it is nullptr, the clean operations
+ * @return rocksdb::Status
+ */
+ rocksdb::Status dumpCentroidsAndBuffer(engine::Context& ctx, const
std::string& ns_key,
+ const TDigestMetadata& metadata,
std::vector<Centroid>* centroids,
+ std::vector<double>* buffer,
+
ObserverOrUniquePtr<rocksdb::WriteBatchBase>* clean_after_dump_batch);
+ rocksdb::Status
applyNewCentroids(ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch, const
std::string& ns_key,
+ const TDigestMetadata& metadata, const
std::vector<Centroid>& centroids);
+
+ std::string internalSegmentGuardPrefixKey(const TDigestMetadata& metadata,
const std::string& ns_key,
+ SegmentType seg) const;
+
+ rocksdb::Status mergeCurrentBuffer(engine::Context& ctx, const std::string&
ns_key,
+
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch, TDigestMetadata* metadata,
+ const std::vector<double>*
additional_buffer = nullptr);
+ std::string internalBufferKey(const std::string& ns_key, const
TDigestMetadata& metadata) const;
+ std::string internalKeyFromCentroid(const std::string& ns_key, const
TDigestMetadata& metadata,
+ const Centroid& centroid) const;
+ static std::string internalValueFromCentroid(const Centroid& centroid);
+ rocksdb::Status decodeCentroidFromKeyValue(const rocksdb::Slice& key, const
rocksdb::Slice& value,
+ Centroid* centroid) const;
+};
+
+} // namespace redis
diff --git a/src/types/tdigest.cc b/src/types/tdigest.cc
new file mode 100644
index 00000000..43a5a798
--- /dev/null
+++ b/src/types/tdigest.cc
@@ -0,0 +1,435 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+/*
+This implementation follows apache arrow.
+refer to
https://github.com/apache/arrow/blob/27bbd593625122a4a25d9471c8aaf5df54a6dcf9/cpp/src/arrow/util/tdigest.cc
+*/
+
+#include "tdigest.h"
+
+#include <fmt/format.h>
+#include <glog/logging.h>
+
+#include <algorithm>
+#include <iterator>
+#include <queue>
+
+#include "common/status.h"
+
+namespace {
+// scale function K1
+struct ScalerK1 {
+ explicit ScalerK1(uint32_t delta) : delta_norm(delta / (2.0 * M_PI)) {}
+
+ double K(double q) const { return delta_norm * std::asin(2 * q - 1); }
+ double Q(double k) const { return (std::sin(k / delta_norm) + 1) / 2; }
+
+ const double delta_norm;
+};
+} // namespace
+
+template <typename T = ScalerK1>
+class TDigestMerger : private T {
+ public:
+ using Status = rocksdb::Status;
+ explicit TDigestMerger(uint32_t delta) : T(delta) { Reset(0, nullptr); }
+
+ void Reset(double total_weight, std::vector<Centroid>* tdigest) {
+ total_weight_ = total_weight;
+ tdigest_ = tdigest;
+ if (tdigest_) {
+ tdigest_->resize(0);
+ }
+ weight_so_far_ = 0;
+ weight_limit_ = -1; // trigger first centroid merge
+ }
+
+ // merge one centroid from a sorted centroid stream
+ void Add(const Centroid& centroid) {
+ auto& td = *tdigest_;
+ const double weight = weight_so_far_ + centroid.weight;
+ if (weight <= weight_limit_) {
+ td.back().Merge(centroid);
+ } else {
+ const double quantile = weight_so_far_ / total_weight_;
+ // weight limit should be strictly increasing, until the last centroid
+ if (const double next_weight_limit = total_weight_ *
this->Q(this->K(quantile) + 1);
+ next_weight_limit <= weight_limit_) {
+ weight_limit_ = total_weight_;
+ } else {
+ weight_limit_ = next_weight_limit;
+ }
+ td.push_back(centroid); // should never exceed capacity and trigger
reallocation
+ }
+ weight_so_far_ = weight;
+ }
+
+ // validate k-size of a tdigest
+ Status Validate(const std::vector<Centroid>& tdigest, double total_weight)
const {
+ double q_prev = 0;
+ double k_prev = this->K(0);
+ for (const auto& i : tdigest) {
+ const double q = q_prev + i.weight / total_weight;
+ const double k = this->K(q);
+ if (i.weight != 1 && (k - k_prev) > 1.001) {
+ return Status::Corruption(fmt::format("oversized centroid: {}", k -
k_prev));
+ }
+ k_prev = k;
+ q_prev = q;
+ }
+ return Status::OK();
+ }
+
+ private:
+ double total_weight_; // total weight of this tdigest
+ double weight_so_far_; // accumulated weight till current bin
+ double weight_limit_; // max accumulated weight to move to next bin
+ std::vector<Centroid>* tdigest_;
+};
+
+class TDigestImpl {
+ public:
+ using Status = rocksdb::Status;
+ explicit TDigestImpl(uint32_t delta) : delta_(delta > 10 ? delta : 10),
merger_(delta_) {
+ tdigests_[0].reserve(delta_);
+ tdigests_[1].reserve(delta_);
+ Reset();
+ }
+
+ void Reset() {
+ tdigests_[0].resize(0);
+ tdigests_[1].resize(0);
+ current_ = 0;
+ total_weight_ = 0;
+ min_ = std::numeric_limits<double>::infinity();
+ max_ = -std::numeric_limits<double>::infinity();
+ merger_.Reset(0, nullptr);
+ }
+
+ void Reset(const std::vector<Centroid>& centroids, double min, double max,
double total_weight) {
+ tdigests_[0] = centroids;
+ tdigests_[1].resize(0);
+ current_ = 0;
+ total_weight_ = total_weight;
+ min_ = min;
+ max_ = max;
+ merger_.Reset(0, nullptr);
+ }
+
+ Status Validate() const {
+ // check weight, centroid order
+ double total_weight = 0;
+ double prev_mean = std::numeric_limits<double>::lowest();
+ for (const auto& centroid : tdigests_[current_]) {
+ if (std::isnan(centroid.mean) || std::isnan(centroid.weight)) {
+ return Status::Corruption("NAN found in tdigest");
+ }
+ if (centroid.mean < prev_mean) {
+ return Status::Corruption("centroid mean decreases");
+ }
+ if (centroid.weight < 1) {
+ return Status::Corruption("invalid centroid weight");
+ }
+ prev_mean = centroid.mean;
+ total_weight += centroid.weight;
+ }
+ if (total_weight != total_weight_) {
+ return Status::Corruption("tdigest total weight mismatch");
+ }
+ // check if buffer expanded
+ if (tdigests_[0].capacity() > delta_ || tdigests_[1].capacity() > delta_) {
+ return Status::Corruption("oversized tdigest buffer");
+ }
+ // check k-size
+ return merger_.Validate(tdigests_[current_], total_weight_);
+ }
+
+ std::vector<Centroid> Centroids() const { return tdigests_[current_]; }
+
+ double Min() const { return min_; }
+
+ double Max() const { return max_; }
+
+ uint32_t Delta() const { return delta_; }
+
+ // merge with other tdigests
+ void Merge(const std::vector<const TDigestImpl*>& tdigest_impls) {
+ // current and end iterator
+ using CentroidIter = std::vector<Centroid>::const_iterator;
+ using CentroidIterPair = std::pair<CentroidIter, CentroidIter>;
+ // use a min-heap to find next minimal centroid from all tdigests
+ auto centroid_gt = [](const CentroidIterPair& lhs, const CentroidIterPair&
rhs) {
+ return lhs.first->mean > rhs.first->mean;
+ };
+ using CentroidQueue = std::priority_queue<CentroidIterPair,
std::vector<CentroidIterPair>, decltype(centroid_gt)>;
+
+ // trivial dynamic memory allocated at runtime
+ std::vector<CentroidIterPair> queue_buffer;
+ queue_buffer.reserve(tdigest_impls.size() + 1);
+ CentroidQueue queue(std::move(centroid_gt), std::move(queue_buffer));
+
+ if (const auto& this_tdigest = tdigests_[current_]; !this_tdigest.empty())
{
+ queue.emplace(this_tdigest.cbegin(), this_tdigest.cend());
+ }
+ for (const TDigestImpl* td : tdigest_impls) {
+ const auto& other_tdigest = td->tdigests_[td->current_];
+ if (other_tdigest.size() > 0) {
+ queue.emplace(other_tdigest.cbegin(), other_tdigest.cend());
+ total_weight_ += td->total_weight_;
+ min_ = std::min(min_, td->min_);
+ max_ = std::max(max_, td->max_);
+ }
+ }
+
+ merger_.Reset(total_weight_, &tdigests_[1 - current_]);
+ CentroidIter current_iter;
+ CentroidIter end_iter;
+ // do k-way merge till one buffer left
+ while (queue.size() > 1) {
+ std::tie(current_iter, end_iter) = queue.top();
+ merger_.Add(*current_iter);
+ queue.pop();
+ if (++current_iter != end_iter) {
+ queue.emplace(current_iter, end_iter);
+ }
+ }
+ // merge last buffer
+ if (!queue.empty()) {
+ std::tie(current_iter, end_iter) = queue.top();
+ while (current_iter != end_iter) {
+ merger_.Add(*current_iter++);
+ }
+ }
+ merger_.Reset(0, nullptr);
+
+ current_ = 1 - current_;
+ }
+
+ // merge input data with current tdigest
+ void MergeInput(std::vector<double> input) {
+ if (tdigests_[current_].empty() && !input.empty()) {
+ min_ = input.front();
+ max_ = input.front();
+ }
+ total_weight_ += static_cast<double>(input.size());
+
+ std::sort(input.begin(), input.end());
+ if (input.empty()) {
+ return;
+ }
+ min_ = std::min(min_, input.front());
+ max_ = std::max(max_, input.back());
+
+ // pick next minimal centroid from input and tdigest, feed to merger
+ merger_.Reset(total_weight_, &tdigests_[1 - current_]);
+ const auto& td = tdigests_[current_];
+ uint32_t tdigest_index = 0;
+ uint32_t input_index = 0;
+ while (tdigest_index < td.size() && input_index < input.size()) {
+ if (td[tdigest_index].mean < input[input_index]) {
+ merger_.Add(td[tdigest_index]);
+ ++tdigest_index;
+ } else {
+ merger_.Add(Centroid{input[input_index], 1});
+ ++input_index;
+ }
+ }
+ while (tdigest_index < td.size()) {
+ merger_.Add(td[tdigest_index]);
+ ++tdigest_index;
+ }
+ while (input_index < input.size()) {
+ merger_.Add(Centroid{input[input_index], 1});
+ ++input_index;
+ }
+ merger_.Reset(0, nullptr);
+ current_ = 1 - current_;
+ }
+
+ double Quantile(double q) const {
+ const auto& td = tdigests_[current_];
+
+ if (q < 0 || q > 1 || td.empty()) {
+ return NAN;
+ }
+
+ const double index = q * total_weight_;
+ if (index <= 1) {
+ return min_;
+ } else if (index >= total_weight_ - 1) {
+ return max_;
+ }
+
+ // find centroid contains the index
+ uint32_t ci = 0;
+ double weight_sum = 0;
+ for (; ci < td.size(); ++ci) {
+ weight_sum += td[ci].weight;
+ if (index <= weight_sum) {
+ break;
+ }
+ }
+ DCHECK_LT(ci, td.size());
+
+ // deviation of index from the centroid center
+ double diff = index + td[ci].weight / 2 - weight_sum;
+
+ // index happen to be in a unit weight centroid
+ if (td[ci].weight == 1 && std::abs(diff) < 0.5) {
+ return td[ci].mean;
+ }
+
+ // find adjacent centroids for interpolation
+ uint32_t ci_left = ci;
+ uint32_t ci_right = ci;
+ if (diff > 0) {
+ if (ci_right == td.size() - 1) {
+ // index larger than center of last bin
+ DCHECK_EQ(weight_sum, total_weight_);
+ const Centroid* c = &td[ci_right];
+ DCHECK_GE(c->weight, 2);
+ return Lerp(c->mean, max_, diff / (c->weight / 2));
+ }
+ ++ci_right;
+ } else {
+ if (ci_left == 0) {
+ // index smaller than center of first bin
+ const Centroid* c = &td[0];
+ DCHECK_GE(c->weight, 2);
+ return Lerp(min_, c->mean, index / (c->weight / 2));
+ }
+ --ci_left;
+ diff += td[ci_left].weight / 2 + td[ci_right].weight / 2;
+ }
+
+ // interpolate from adjacent centroids
+ diff /= (td[ci_left].weight / 2 + td[ci_right].weight / 2);
+ return Lerp(td[ci_left].mean, td[ci_right].mean, diff);
+ }
+
+ double Mean() const {
+ double sum = 0;
+ for (const auto& centroid : tdigests_[current_]) {
+ sum += centroid.mean * centroid.weight;
+ }
+ return total_weight_ == 0 ? NAN : sum / total_weight_;
+ }
+
+ double TotalWeight() const { return total_weight_; }
+
+ private:
+ // must be declared before merger_, see constructor initialization list
+ const uint32_t delta_;
+
+ TDigestMerger<> merger_;
+ double total_weight_;
+ double min_;
+ double max_;
+
+ // ping-pong buffer holds two tdigests, size = 2 * delta * sizeof(Centroid)
+ std::vector<Centroid> tdigests_[2];
+ // index of active tdigest buffer, 0 or 1
+ int current_;
+};
+
+class TDigest {
+ public:
+ explicit TDigest(uint64_t delta);
+
+ TDigest(const TDigest&) = delete;
+ TDigest& operator=(const TDigest&) = delete;
+ TDigest(TDigest&& rhs) = default;
+ ~TDigest() = default;
+
+ void Merge(const std::vector<TDigest>& others);
+ void Add(const std::vector<double>& items);
+ void Reset(const CentroidsWithDelta& centroid_list);
+ void Reset();
+ CentroidsWithDelta DumpCentroids() const;
+
+ private:
+ TDigestImpl impl_;
+};
+
+TDigest::TDigest(uint64_t delta) : impl_(TDigestImpl(delta)) { Reset({}); }
+
+void TDigest::Merge(const std::vector<TDigest>& others) {
+ if (others.empty()) {
+ return;
+ }
+
+ std::vector<const TDigestImpl*> impls;
+ impls.reserve(others.size());
+
+ std::transform(others.cbegin(), others.cend(), std::back_inserter(impls),
[](const TDigest& i) { return &i.impl_; });
+
+ impl_.Merge(impls);
+}
+
+void TDigest::Reset(const CentroidsWithDelta& centroids_list) {
+ impl_.Reset(centroids_list.centroids, centroids_list.min,
centroids_list.max, centroids_list.total_weight);
+}
+
+void TDigest::Reset() { impl_.Reset(); }
+
+CentroidsWithDelta TDigest::DumpCentroids() const {
+ auto centroids = impl_.Centroids();
+ return {
+ .centroids = std::move(centroids),
+ .delta = impl_.Delta(),
+ .min = impl_.Min(),
+ .max = impl_.Max(),
+ .total_weight = impl_.TotalWeight(),
+ };
+}
+
+void TDigest::Add(const std::vector<double>& items) { impl_.MergeInput(items);
}
+
+StatusOr<CentroidsWithDelta> TDigestMerge(const
std::vector<CentroidsWithDelta>& centroids_list) {
+ if (centroids_list.empty()) {
+ return Status{Status::InvalidArgument, "centroids_list is empty"};
+ }
+ if (centroids_list.size() == 1) {
+ return centroids_list.front();
+ }
+
+ TDigest digest{centroids_list.front().delta};
+ digest.Reset(centroids_list.front());
+
+ std::vector<TDigest> others;
+ others.reserve(centroids_list.size() - 1);
+
+ for (size_t i = 1; i < centroids_list.size(); ++i) {
+ TDigest d{centroids_list[i].delta};
+ digest.Reset(centroids_list[i]);
+ others.emplace_back(std::move(d));
+ }
+
+ digest.Merge(others);
+
+ return digest.DumpCentroids();
+}
+StatusOr<CentroidsWithDelta> TDigestMerge(const std::vector<double>& buffer,
const CentroidsWithDelta& centroid_list) {
+ TDigest digest{centroid_list.delta};
+ digest.Reset(centroid_list);
+ digest.Add(buffer);
+ return digest.DumpCentroids();
+}
diff --git a/src/types/tdigest.h b/src/types/tdigest.h
new file mode 100644
index 00000000..6e8b0382
--- /dev/null
+++ b/src/types/tdigest.h
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#pragma once
+
+#include <fmt/format.h>
+
+#include <vector>
+
+#include "common/status.h"
+
+struct Centroid {
+ double mean;
+ double weight = 1.0;
+
+ // merge with another centroid
+ void Merge(const Centroid& centroid) {
+ weight += centroid.weight;
+ mean += (centroid.mean - mean) * centroid.weight / weight;
+ }
+
+ std::string ToString() const { return fmt::format("centroid<mean: {},
weight: {}>", mean, weight); }
+
+ explicit Centroid() = default;
+ explicit Centroid(double mean, double weight) : mean(mean), weight(weight) {}
+};
+
+struct CentroidsWithDelta {
+ std::vector<Centroid> centroids;
+ uint64_t delta;
+ double min;
+ double max;
+ double total_weight;
+};
+
+StatusOr<CentroidsWithDelta> TDigestMerge(const
std::vector<CentroidsWithDelta>& centroids_list);
+StatusOr<CentroidsWithDelta> TDigestMerge(const std::vector<double>& buffer,
const CentroidsWithDelta& centroid_list);
+
+/**
+
+TD should looks like below:
+class TDSample {
+ public:
+ struct Iterator {
+ Iterator* Clone() const;
+ bool Next();
+ bool Valid() const;
+ StatusOr<Centroid> GetCentroid() const;
+ };
+ Iterator* Begin();
+ Iterator* End();
+ double TotalWeight();
+ double Min() const;
+ double Max() const;
+};
+
+**/
+
+// a numerically stable lerp is unbelievably complex
+// but we are *approximating* the quantile, so let's keep it simple
+// reference:
+//
https://github.com/apache/arrow/blob/27bbd593625122a4a25d9471c8aaf5df54a6dcf9/cpp/src/arrow/util/tdigest.cc#L38
+static inline double Lerp(double a, double b, double t) { return a + t * (b -
a); }
+
+template <typename TD>
+inline StatusOr<double> TDigestQuantile(TD&& td, double q) {
+ if (q < 0 || q > 1 || td.Size() == 0) {
+ return Status{Status::InvalidArgument, "invalid quantile or empty
tdigest"};
+ }
+
+ const double index = q * td.TotalWeight();
+ if (index <= 1) {
+ return td.Min();
+ } else if (index >= td.TotalWeight() - 1) {
+ return td.Max();
+ }
+
+ // find centroid contains the index
+ double weight_sum = 0;
+ auto iter = td.Begin();
+ for (; iter->Valid(); iter->Next()) {
+ weight_sum += GET_OR_RET(iter->GetCentroid()).weight;
+ if (index <= weight_sum) {
+ break;
+ }
+ }
+
+ // since index is in (1, total_weight - 1), iter should be valid
+ if (!iter->Valid()) {
+ return Status{Status::InvalidArgument, "invalid iterator during decoding
tdigest centroid"};
+ }
+
+ auto centroid = GET_OR_RET(iter->GetCentroid());
+
+ // deviation of index from the centroid center
+ double diff = index + centroid.weight / 2 - weight_sum;
+
+ // index happen to be in a unit weight centroid
+ if (centroid.weight == 1 && std::abs(diff) < 0.5) {
+ return centroid.mean;
+ }
+
+ // find adjacent centroids for interpolation
+ auto ci_left = iter->Clone();
+ auto ci_right = iter->Clone();
+ if (diff > 0) {
+ if (ci_right == td.End()) {
+ // index larger than center of last bin
+ auto c = GET_OR_RET(ci_left->GetCentroid());
+ DCHECK_GE(c.weight, 2);
+ return Lerp(c.mean, td.Max(), diff / (c.weight / 2));
+ }
+ ci_right->Next();
+ } else {
+ if (ci_left == td.Begin()) {
+ // index smaller than center of first bin
+ auto c = GET_OR_RET(ci_left->GetCentroid());
+ DCHECK_GE(c.weight, 2);
+ return Lerp(td.Min(), c.mean, index / (c.weight / 2));
+ }
+ ci_left->Prev();
+ auto lc = GET_OR_RET(ci_left->GetCentroid());
+ auto rc = GET_OR_RET(ci_right->GetCentroid());
+ diff += lc.weight / 2 + rc.weight / 2;
+ }
+
+ auto lc = GET_OR_RET(ci_left->GetCentroid());
+ auto rc = GET_OR_RET(ci_right->GetCentroid());
+
+ // interpolate from adjacent centroids
+ diff /= (lc.weight / 2 + rc.weight / 2);
+ return Lerp(lc.mean, rc.mean, diff);
+}
diff --git a/tests/cppunit/types/tdigest_test.cc
b/tests/cppunit/types/tdigest_test.cc
new file mode 100644
index 00000000..849c27f6
--- /dev/null
+++ b/tests/cppunit/types/tdigest_test.cc
@@ -0,0 +1,247 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#include "types/tdigest.h"
+
+#include <fmt/format.h>
+#include <glog/logging.h>
+#include <gtest/gtest.h>
+
+#include <algorithm>
+#include <cmath>
+#include <memory>
+#include <random>
+#include <range/v3/algorithm/shuffle.hpp>
+#include <range/v3/range.hpp>
+#include <range/v3/view/chunk.hpp>
+#include <range/v3/view/iota.hpp>
+#include <range/v3/view/join.hpp>
+#include <range/v3/view/transform.hpp>
+#include <string>
+#include <vector>
+
+#include "storage/redis_metadata.h"
+#include "test_base.h"
+#include "time_util.h"
+#include "types/redis_tdigest.h"
+
+namespace {
+constexpr std::random_device::result_type kSeed = 14863; // fixed seed for
reproducibility
+
+std::vector<double> QuantileOf(const std::vector<double> &samples, const
std::vector<double> &qs) {
+ std::vector<double> result;
+ result.reserve(qs.size());
+ std::vector<double> sorted_samples = samples;
+ std::sort(sorted_samples.begin(), sorted_samples.end());
+ for (auto q : qs) {
+ auto index = q * static_cast<double>(sorted_samples.size());
+ if (index <= 1) {
+ result.push_back(sorted_samples.front());
+ } else if (index >= static_cast<double>(sorted_samples.size() - 1)) {
+ result.push_back(sorted_samples.back());
+ } else {
+ auto left = sorted_samples[static_cast<int>(index)];
+ auto right = sorted_samples[static_cast<int>(index) + 1];
+ auto diff = index - static_cast<int>(index);
+ result.push_back(left + (right - left) * diff);
+ }
+ }
+ return result;
+}
+
+std::vector<std::pair<double, double>> QuantileIntervalOf(const
std::vector<double> &samples,
+ const
std::vector<double> &qs) {
+ std::vector<std::pair<double, double>> result;
+ result.reserve(qs.size());
+ std::vector<double> sorted_samples = samples;
+ std::sort(sorted_samples.begin(), sorted_samples.end());
+ for (auto q : qs) {
+ auto index = q * static_cast<double>(sorted_samples.size());
+ if (index <= 1) {
+ result.emplace_back(sorted_samples.front(), sorted_samples.front());
+ } else if (index >= static_cast<double>(sorted_samples.size() - 1)) {
+ result.emplace_back(sorted_samples.back(), sorted_samples.back());
+ } else {
+ auto left = sorted_samples[static_cast<int>(index)];
+ auto right = sorted_samples[static_cast<int>(index) + 1];
+ result.emplace_back(left, right);
+ }
+ }
+ return result;
+}
+
+std::vector<double> GenerateSamples(int count, double from, double to) {
+ std::vector<double> samples;
+ samples.reserve(count);
+ for (int i = 0; i < count; i++) {
+ samples.push_back(from + static_cast<double>(i) * (to - from) /
static_cast<double>(count));
+ }
+ return samples;
+}
+
+std::vector<double> GenerateQuantiles(int count, bool with_head = false, bool
with_tail = false) {
+ std::vector<double> qs;
+ qs.reserve(count);
+ for (int i = 1; i <= count; i++) {
+ qs.push_back(static_cast<double>(i) / static_cast<double>(count));
+ }
+ if (with_head) {
+ qs.insert(qs.begin(), 0);
+ }
+ if (with_tail) {
+ qs.push_back(1);
+ }
+ return qs;
+}
+
+} // namespace
+
+class RedisTDigestTest : public TestBase {
+ protected:
+ RedisTDigestTest() : name_("tdigest_test") {
+ tdigest_ = std::make_unique<redis::TDigest>(storage_.get(), "tdigest_ns");
+ }
+
+ std::string name_;
+ std::unique_ptr<redis::TDigest> tdigest_;
+};
+
+TEST_F(RedisTDigestTest, CentroidTest) {
+ Centroid c1{
+ 2.,
+ 3.,
+ };
+ Centroid c2{
+ 3.,
+ 4.,
+ };
+
+ c1.Merge(c2);
+
+ EXPECT_NEAR(c1.weight, 7., 0.01);
+ EXPECT_NEAR(c1.mean, 2.57, 0.01);
+}
+
+TEST_F(RedisTDigestTest, Create) {
+ std::string test_digest_name = "test_digest_create" +
std::to_string(util::GetTimeStampMS());
+ bool exists = false;
+ auto status = tdigest_->Create(*ctx_, test_digest_name, {100}, &exists);
+ ASSERT_FALSE(exists);
+ ASSERT_TRUE(status.ok());
+
+ status = tdigest_->Create(*ctx_, test_digest_name, {100}, &exists);
+ ASSERT_TRUE(exists);
+ ASSERT_TRUE(status.IsInvalidArgument());
+
+ auto ns_key = tdigest_->AppendNamespacePrefix(test_digest_name);
+ TDigestMetadata metadata;
+ auto get_status = tdigest_->GetMetaData(*ctx_, ns_key, &metadata);
+ ASSERT_TRUE(get_status.ok()) << get_status.ToString();
+ ASSERT_EQ(metadata.compression, 100) << metadata.compression;
+}
+
+TEST_F(RedisTDigestTest, Quantile) {
+ std::string test_digest_name = "test_digest_quantile" +
std::to_string(util::GetTimeStampMS());
+
+ bool exists = false;
+ auto status = tdigest_->Create(*ctx_, test_digest_name, {100}, &exists);
+ ASSERT_FALSE(exists);
+ ASSERT_TRUE(status.ok());
+ std::vector<double> samples = ranges::views::iota(1, 101) |
ranges::views::transform([](int i) { return i; }) |
+ ranges::to<std::vector<double>>();
+
+ status = tdigest_->Add(*ctx_, test_digest_name, samples);
+ ASSERT_TRUE(status.ok()) << status.ToString();
+
+ std::vector<double> qs = {0.5, 0.9, 0.99};
+ redis::TDigestQuantitleResult result;
+ status = tdigest_->Quantile(*ctx_, test_digest_name, qs, &result);
+ ASSERT_TRUE(status.ok()) << status.ToString();
+ ASSERT_EQ(result.quantiles.size(), qs.size());
+ EXPECT_NEAR(result.quantiles[0], 50.5, 0.01);
+ EXPECT_NEAR(result.quantiles[1], 90.5, 0.01);
+ EXPECT_NEAR(result.quantiles[2], 100, 0.01);
+}
+
+TEST_F(RedisTDigestTest, PlentyQuantile_10000_144) {
+ std::string test_digest_name = "test_digest_quantile" +
std::to_string(util::GetTimeStampMS());
+ bool exists = false;
+ auto status = tdigest_->Create(*ctx_, test_digest_name, {100}, &exists);
+ ASSERT_FALSE(exists);
+ ASSERT_TRUE(status.ok());
+
+ int sample_count = 10000;
+ int quantile_count = 144;
+ double from = -100;
+ double to = 100;
+ auto error_double = (to - from) / sample_count;
+ auto samples = GenerateSamples(sample_count, -100, 100);
+ status = tdigest_->Add(*ctx_, test_digest_name, samples);
+ ASSERT_TRUE(status.ok()) << status.ToString();
+
+ auto qs = GenerateQuantiles(quantile_count);
+ auto result = QuantileOf(samples, qs);
+
+ redis::TDigestQuantitleResult tdigest_result;
+ status = tdigest_->Quantile(*ctx_, test_digest_name, qs, &tdigest_result);
+ ASSERT_TRUE(status.ok()) << status.ToString();
+
+ for (int i = 0; i < quantile_count; i++) {
+ EXPECT_NEAR(tdigest_result.quantiles[i], result[i], error_double) <<
"quantile is: " << qs[i];
+ }
+}
+
+TEST_F(RedisTDigestTest, Add_2_times) {
+ std::string test_digest_name = "test_digest_quantile" +
std::to_string(util::GetTimeStampMS());
+
+ bool exists = false;
+ auto status = tdigest_->Create(*ctx_, test_digest_name, {100}, &exists);
+ ASSERT_FALSE(exists);
+ ASSERT_TRUE(status.ok());
+
+ int sample_count = 17;
+ int quantile_count = 7;
+ auto samples = GenerateSamples(sample_count, -100, 100);
+ auto qs = GenerateQuantiles(quantile_count);
+ auto expect_result = QuantileIntervalOf(samples, qs);
+ std::shuffle(samples.begin(), samples.end(), std::mt19937(kSeed));
+
+ int group_count = 4;
+ auto samples_sub_group =
+ samples | ranges::views::chunk(sample_count / group_count) |
ranges::to<std::vector<std::vector<double>>>();
+
+ for (const auto &s : samples_sub_group) {
+ status = tdigest_->Add(*ctx_, test_digest_name, s);
+ ASSERT_TRUE(status.ok()) << status.ToString();
+ }
+
+ redis::TDigestQuantitleResult tdigest_result;
+ status = tdigest_->Quantile(*ctx_, test_digest_name, qs, &tdigest_result);
+ ASSERT_TRUE(status.ok()) << status.ToString();
+
+ for (int i = 0; i < quantile_count; i++) {
+ auto &[expect_down, expect_upper] = expect_result[i];
+ auto got = tdigest_result.quantiles[i];
+ EXPECT_GE(got, expect_down) << fmt::format("quantile is {}, should in
interval [{}, {}]", qs[i], expect_down,
+ expect_upper);
+ EXPECT_LE(got, expect_upper) << fmt::format("quantile is {}, should in
interval [{}, {}]", qs[i], expect_down,
+ expect_upper);
+ }
+}