This is an automated email from the ASF dual-hosted git repository.
leerho pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datasketches-rust.git
The following commit(s) were added to refs/heads/main by this push:
new a21d7f6 feat: implement T-Digest (#23)
a21d7f6 is described below
commit a21d7f671d2956823c687ded55b5ad00ed3b890b
Author: tison <[email protected]>
AuthorDate: Sat Dec 20 03:32:26 2025 +0800
feat: implement T-Digest (#23)
* feat: implement T-Digest
Signed-off-by: tison <[email protected]>
* impl merge and compress
Signed-off-by: tison <[email protected]>
* impl get_rank
Signed-off-by: tison <[email protected]>
* impl merge and add tests
Signed-off-by: tison <[email protected]>
* demo iter
Signed-off-by: tison <[email protected]>
* impl ser
Signed-off-by: tison <[email protected]>
* impl de
Signed-off-by: tison <[email protected]>
* fine tune deserialize tags
Signed-off-by: tison <[email protected]>
* define code in one place
Signed-off-by: tison <[email protected]>
* centralize compare logics
Signed-off-by: tison <[email protected]>
* finish serde
Signed-off-by: tison <[email protected]>
* enable freeze TDigestMut
Signed-off-by: tison <[email protected]>
* add serde compat test files
Signed-off-by: tison <[email protected]>
* support deserialize_compat
Signed-off-by: tison <[email protected]>
* impl cdf and pmf
Signed-off-by: tison <[email protected]>
* fine tune docs
Signed-off-by: tison <[email protected]>
* naming and let to do the reserve
Signed-off-by: tison <[email protected]>
* further tidy
Signed-off-by: tison <[email protected]>
* best effort avoid NaN
Signed-off-by: tison <[email protected]>
* fixup! best effort avoid NaN
Signed-off-by: tison <[email protected]>
* concrete tag
Signed-off-by: tison <[email protected]>
* filter invalid inputs
Signed-off-by: tison <[email protected]>
* weight nonzero and should not overflow
Signed-off-by: tison <[email protected]>
* other_mean - self_mean may produce inf
Signed-off-by: tison <[email protected]>
* no need for checking in sk files now
Signed-off-by: tison <[email protected]>
* reuse test data loading logics
Signed-off-by: tison <[email protected]>
---------
Signed-off-by: tison <[email protected]>
---
Cargo.lock | 131 +++
Cargo.toml | 4 +
src/lib.rs | 1 +
src/tdigest/mod.rs | 55 +
src/{lib.rs => tdigest/serialization.rs} | 28 +-
src/tdigest/sketch.rs | 1170 +++++++++++++++++++++
tests/common.rs | 52 +
tests/hll_serialization_test.rs | 42 +-
tests/tdigest_serialization_test.rs | 189 ++++
tests/tdigest_test.rs | 229 ++++
tests/test_data/tdigest_ref_k100_n10000_double.sk | Bin 0 -> 976 bytes
tests/test_data/tdigest_ref_k100_n10000_float.sk | Bin 0 -> 502 bytes
12 files changed, 1852 insertions(+), 49 deletions(-)
diff --git a/Cargo.lock b/Cargo.lock
index 1a5fdbd..a2f2c69 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -2,15 +2,146 @@
# It is not intended for manual editing.
version = 4
+[[package]]
+name = "aho-corasick"
+version = "1.1.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301"
+dependencies = [
+ "memchr",
+]
+
+[[package]]
+name = "autocfg"
+version = "1.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
+
+[[package]]
+name = "byteorder"
+version = "1.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
+
[[package]]
name = "datasketches"
version = "0.1.0"
dependencies = [
+ "byteorder",
+ "googletest",
"mur3",
]
+[[package]]
+name = "googletest"
+version = "0.14.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "06597b7d02ee58b9a37f522785ac15b9e18c6b178747c4439a6c03fbb35ea753"
+dependencies = [
+ "googletest_macro",
+ "num-traits",
+ "regex",
+ "rustversion",
+]
+
+[[package]]
+name = "googletest_macro"
+version = "0.14.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c31d9f07c9c19b855faebf71637be3b43f8e13a518aece5d61a3beee7710b4ef"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
+]
+
+[[package]]
+name = "memchr"
+version = "2.7.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273"
+
[[package]]
name = "mur3"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97af489e1e21b68de4c390ecca6703318bc1aa16e9733bcb62c089b73c6fbb1b"
+
+[[package]]
+name = "num-traits"
+version = "0.2.19"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
+dependencies = [
+ "autocfg",
+]
+
+[[package]]
+name = "proc-macro2"
+version = "1.0.103"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8"
+dependencies = [
+ "unicode-ident",
+]
+
+[[package]]
+name = "quote"
+version = "1.0.42"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f"
+dependencies = [
+ "proc-macro2",
+]
+
+[[package]]
+name = "regex"
+version = "1.12.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "843bc0191f75f3e22651ae5f1e72939ab2f72a4bc30fa80a066bd66edefc24d4"
+dependencies = [
+ "aho-corasick",
+ "memchr",
+ "regex-automata",
+ "regex-syntax",
+]
+
+[[package]]
+name = "regex-automata"
+version = "0.4.13"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5276caf25ac86c8d810222b3dbb938e512c55c6831a10f3e6ed1c93b84041f1c"
+dependencies = [
+ "aho-corasick",
+ "memchr",
+ "regex-syntax",
+]
+
+[[package]]
+name = "regex-syntax"
+version = "0.8.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58"
+
+[[package]]
+name = "rustversion"
+version = "1.0.22"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
+
+[[package]]
+name = "syn"
+version = "2.0.111"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "unicode-ident",
+]
+
+[[package]]
+name = "unicode-ident"
+version = "1.0.22"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
diff --git a/Cargo.toml b/Cargo.toml
index ad4eeb6..90961f0 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -35,8 +35,12 @@ all-features = true
rustdoc-args = ["--cfg", "docsrs"]
[dependencies]
+byteorder = { version = "1.5.0" }
mur3 = { version = "0.1.0" }
+[dev-dependencies]
+googletest = { version = "0.14.2" }
+
[lints.rust]
unknown_lints = "deny"
unsafe_code = "deny"
diff --git a/src/lib.rs b/src/lib.rs
index 7e8afd2..d49faa0 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -32,3 +32,4 @@ compile_error!("datasketches does not support big-endian
targets");
pub mod error;
pub mod hll;
+pub mod tdigest;
diff --git a/src/tdigest/mod.rs b/src/tdigest/mod.rs
new file mode 100644
index 0000000..ad9ca42
--- /dev/null
+++ b/src/tdigest/mod.rs
@@ -0,0 +1,55 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! T-Digest implementation for estimating quantiles and ranks.
+//!
+//! The implementation in this library is based on the MergingDigest described
in
+//! [Computing Extremely Accurate Quantiles Using t-Digests][paper] by Ted
Dunning and Otmar Ertl.
+//!
+//! The implementation in this library has a few differences from the
reference implementation
+//! associated with that paper:
+//!
+//! * Merge does not modify the input
+//! * Deserialization similar to other sketches in this library, although
reading the reference
+//! implementation format is supported
+//!
+//! Unlike all other algorithms in the library, t-digest is empirical and has
no mathematical
+//! basis for estimating its error and its results are dependent on the input
data. However,
+//! for many common data distributions, it can produce excellent results.
t-digest also operates
+//! only on numeric data and, unlike the quantiles family algorithms in the
library which return
+//! quantile approximations from the input domain, t-digest interpolates
values and will hold and
+//! return data points not seen in the input.
+//!
+//! The closest alternative to t-digest in this library is REQ sketch. It
prioritizes one chosen
+//! side of the rank domain: either low rank accuracy or high rank accuracy.
t-digest (in this
+//! implementation) prioritizes both ends of the rank domain and has lower
accuracy towards the
+//! middle of the rank domain (median).
+//!
+//! Measurements show that t-digest is slightly biased (tends to underestimate
low ranks and
+//! overestimate high ranks), while still doing very well close to the
extremes. The effect seems
+//! to be more pronounced with more input values.
+//!
+//! For more information on the performance characteristics, see the
+//! [Datasketches page on
t-digest](https://datasketches.apache.org/docs/tdigest/tdigest.html).
+//!
+//! [paper]: https://arxiv.org/abs/1902.04023
+
+mod serialization;
+
+mod sketch;
+pub use self::sketch::TDigest;
+pub use self::sketch::TDigestMut;
diff --git a/src/lib.rs b/src/tdigest/serialization.rs
similarity index 54%
copy from src/lib.rs
copy to src/tdigest/serialization.rs
index 7e8afd2..e5b9788 100644
--- a/src/lib.rs
+++ b/src/tdigest/serialization.rs
@@ -15,20 +15,14 @@
// specific language governing permissions and limitations
// under the License.
-//! # Apache® DataSketches™ Core Rust Library Component
-//!
-//! The Sketching Core Library provides a range of stochastic streaming
algorithms and closely
-//! related Rust technologies that are particularly useful when integrating
this technology into
-//! systems that must deal with massive data.
-//!
-//! This library is divided into modules that constitute distinct groups of
functionality.
-
-#![cfg_attr(docsrs, feature(doc_cfg))]
-#![deny(missing_docs)]
-
-// See https://github.com/apache/datasketches-rust/issues/28 for more
information.
-#[cfg(target_endian = "big")]
-compile_error!("datasketches does not support big-endian targets");
-
-pub mod error;
-pub mod hll;
+pub(super) const PREAMBLE_LONGS_EMPTY_OR_SINGLE: u8 = 1;
+pub(super) const PREAMBLE_LONGS_MULTIPLE: u8 = 2;
+pub(super) const SERIAL_VERSION: u8 = 1;
+pub(super) const TDIGEST_FAMILY_ID: u8 = 20;
+pub(super) const FLAGS_IS_EMPTY: u8 = 1 << 0;
+pub(super) const FLAGS_IS_SINGLE_VALUE: u8 = 1 << 1;
+pub(super) const FLAGS_REVERSE_MERGE: u8 = 1 << 2;
+/// the format of the reference implementation is using double (f64) precision
+pub(super) const COMPAT_DOUBLE: u32 = 1;
+/// the format of the reference implementation is using float (f32) precision
+pub(super) const COMPAT_FLOAT: u32 = 2;
diff --git a/src/tdigest/sketch.rs b/src/tdigest/sketch.rs
new file mode 100644
index 0000000..7f125d9
--- /dev/null
+++ b/src/tdigest/sketch.rs
@@ -0,0 +1,1170 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::error::SerdeError;
+use crate::tdigest::serialization::*;
+use byteorder::{BE, LE, ReadBytesExt};
+use std::cmp::Ordering;
+use std::convert::identity;
+use std::io::Cursor;
+use std::num::NonZeroU64;
+
+/// The default value of K if one is not specified.
+const DEFAULT_K: u16 = 200;
+/// Multiplier for buffer size relative to centroids capacity.
+const BUFFER_MULTIPLIER: usize = 4;
+/// Default weight for single values.
+const DEFAULT_WEIGHT: NonZeroU64 = NonZeroU64::new(1).unwrap();
+
+/// T-Digest sketch for estimating quantiles and ranks.
+///
+/// See the [tdigest module level documentation](crate::tdigest) for more.
+#[derive(Debug, Clone)]
+pub struct TDigestMut {
+ k: u16,
+
+ reverse_merge: bool,
+ min: f64,
+ max: f64,
+
+ centroids: Vec<Centroid>,
+ centroids_weight: u64,
+ centroids_capacity: usize,
+ buffer: Vec<f64>,
+}
+
+impl Default for TDigestMut {
+ fn default() -> Self {
+ TDigestMut::new(DEFAULT_K)
+ }
+}
+
+impl TDigestMut {
+ /// Creates a tdigest instance with the given value of k.
+ ///
+ /// # Panics
+ ///
+ /// If k is less than 10
+ pub fn new(k: u16) -> Self {
+ Self::make(
+ k,
+ false,
+ f64::INFINITY,
+ f64::NEG_INFINITY,
+ vec![],
+ 0,
+ vec![],
+ )
+ }
+
+ // for deserialization
+ fn make(
+ k: u16,
+ reverse_merge: bool,
+ min: f64,
+ max: f64,
+ mut centroids: Vec<Centroid>,
+ centroids_weight: u64,
+ mut buffer: Vec<f64>,
+ ) -> Self {
+ assert!(k >= 10, "k must be at least 10");
+
+ let fudge = if k < 30 { 30 } else { 10 };
+ let centroids_capacity = (k as usize * 2) + fudge;
+
+ centroids.reserve(centroids_capacity);
+ buffer.reserve(centroids_capacity * BUFFER_MULTIPLIER);
+
+ TDigestMut {
+ k,
+ reverse_merge,
+ min,
+ max,
+ centroids,
+ centroids_weight,
+ centroids_capacity,
+ buffer,
+ }
+ }
+
+ /// Update this TDigest with the given value.
+ ///
+ /// [f64::NAN], [f64::INFINITY], and [f64::NEG_INFINITY] values are
ignored.
+ pub fn update(&mut self, value: f64) {
+ if value.is_nan() || value.is_infinite() {
+ return;
+ }
+
+ if self.buffer.len() == self.centroids_capacity * BUFFER_MULTIPLIER {
+ self.compress();
+ }
+
+ self.buffer.push(value);
+ self.min = self.min.min(value);
+ self.max = self.max.max(value);
+ }
+
+ /// Returns parameter k (compression) that was used to configure this
TDigest.
+ pub fn k(&self) -> u16 {
+ self.k
+ }
+
+ /// Returns true if TDigest has not seen any data.
+ pub fn is_empty(&self) -> bool {
+ self.centroids.is_empty() && self.buffer.is_empty()
+ }
+
+ /// Returns minimum value seen by TDigest; `None` if TDigest is empty.
+ pub fn min_value(&self) -> Option<f64> {
+ if self.is_empty() {
+ None
+ } else {
+ Some(self.min)
+ }
+ }
+
+ /// Returns maximum value seen by TDigest; `None` if TDigest is empty.
+ pub fn max_value(&self) -> Option<f64> {
+ if self.is_empty() {
+ None
+ } else {
+ Some(self.max)
+ }
+ }
+
+ /// Returns total weight.
+ pub fn total_weight(&self) -> u64 {
+ self.centroids_weight + self.buffer.len() as u64
+ }
+
+ /// Merge the given TDigest into this one
+ pub fn merge(&mut self, other: &TDigestMut) {
+ if other.is_empty() {
+ return;
+ }
+
+ let mut tmp = Vec::with_capacity(
+ self.centroids.len() + self.buffer.len() + other.centroids.len() +
other.buffer.len(),
+ );
+ for &v in &self.buffer {
+ tmp.push(Centroid {
+ mean: v,
+ weight: DEFAULT_WEIGHT,
+ });
+ }
+ for &v in &other.buffer {
+ tmp.push(Centroid {
+ mean: v,
+ weight: DEFAULT_WEIGHT,
+ });
+ }
+ for &c in &other.centroids {
+ tmp.push(c);
+ }
+ self.do_merge(tmp, self.buffer.len() as u64 + other.total_weight())
+ }
+
+ /// Freezes this TDigest into an immutable one.
+ pub fn freeze(mut self) -> TDigest {
+ self.compress();
+ TDigest {
+ k: self.k,
+ reverse_merge: self.reverse_merge,
+ min: self.min,
+ max: self.max,
+ centroids: self.centroids,
+ centroids_weight: self.centroids_weight,
+ }
+ }
+
+ fn view(&mut self) -> TDigestView<'_> {
+ self.compress(); // side effect
+ TDigestView {
+ min: self.min,
+ max: self.max,
+ centroids: &self.centroids,
+ centroids_weight: self.centroids_weight,
+ }
+ }
+
+ /// Returns an approximation to the Cumulative Distribution Function
(CDF), which is the
+ /// cumulative analog of the PMF, of the input stream given a set of split
points.
+ ///
+ /// # Arguments
+ ///
+ /// * `split_points`: An array of _m_ unique, monotonically increasing
values that divide the
+ /// input domain into _m+1_ consecutive disjoint intervals.
+ ///
+ /// # Returns
+ ///
+ /// An array of m+1 doubles, which are a consecutive approximation to the
CDF of the input
+ /// stream given the split points. The value at array position j of the
returned CDF array
+ /// is the sum of the returned values in positions 0 through j of the
returned PMF array.
+ /// This can be viewed as array of ranks of the given split points plus
one more value that
+ /// is always 1.
+ ///
+ /// Returns `None` if TDigest is empty.
+ ///
+ /// # Panics
+ ///
+ /// If `split_points` is not unique, not monotonically increasing, or
contains `NaN` values.
+ pub fn cdf(&mut self, split_points: &[f64]) -> Option<Vec<f64>> {
+ check_split_points(split_points);
+
+ if self.is_empty() {
+ return None;
+ }
+
+ self.view().cdf(split_points)
+ }
+
+ /// Returns an approximation to the Probability Mass Function (PMF) of the
input stream
+ /// given a set of split points.
+ ///
+ /// # Arguments
+ ///
+ /// * `split_points`: An array of _m_ unique, monotonically increasing
values that divide the
+ /// input domain into _m+1_ consecutive disjoint intervals (bins).
+ ///
+ /// # Returns
+ ///
+ /// An array of m+1 doubles each of which is an approximation to the
fraction of the input
+ /// stream values (the mass) that fall into one of those intervals.
+ ///
+ /// Returns `None` if TDigest is empty.
+ ///
+ /// # Panics
+ ///
+ /// If `split_points` is not unique, not monotonically increasing, or
contains `NaN` values.
+ pub fn pmf(&mut self, split_points: &[f64]) -> Option<Vec<f64>> {
+ check_split_points(split_points);
+
+ if self.is_empty() {
+ return None;
+ }
+
+ self.view().pmf(split_points)
+ }
+
+ /// Compute approximate normalized rank (from 0 to 1 inclusive) of the
given value.
+ ///
+ /// Returns `None` if TDigest is empty.
+ ///
+ /// # Panics
+ ///
+ /// If the value is `NaN`.
+ pub fn rank(&mut self, value: f64) -> Option<f64> {
+ assert!(!value.is_nan(), "value must not be NaN");
+
+ if self.is_empty() {
+ return None;
+ }
+ if value < self.min {
+ return Some(0.0);
+ }
+ if value > self.max {
+ return Some(1.0);
+ }
+ // one centroid and value == min == max
+ if self.centroids.len() + self.buffer.len() == 1 {
+ return Some(0.5);
+ }
+
+ self.view().rank(value)
+ }
+
+ /// Compute approximate quantile value corresponding to the given
normalized rank.
+ ///
+ /// Returns `None` if TDigest is empty.
+ ///
+ /// # Panics
+ ///
+ /// If rank is not in [0.0, 1.0].
+ pub fn quantile(&mut self, rank: f64) -> Option<f64> {
+ assert!((0.0..=1.0).contains(&rank), "rank must be in [0.0, 1.0]");
+
+ if self.is_empty() {
+ return None;
+ }
+
+ self.view().quantile(rank)
+ }
+
+ /// Serializes this TDigest to bytes.
+ pub fn serialize(&mut self) -> Vec<u8> {
+ self.compress();
+
+ let mut total_size = 0;
+ if self.is_empty() || self.is_single_value() {
+ // 1 byte preamble
+ // + 1 byte serial version
+ // + 1 byte family
+ // + 2 bytes k
+ // + 1 byte flags
+ // + 2 bytes unused
+ total_size += size_of::<u64>();
+ } else {
+ // all of the above
+ // + 4 bytes num centroids
+ // + 4 bytes num buffered
+ total_size += size_of::<u64>() * 2;
+ }
+ if self.is_empty() {
+ // nothing more
+ } else if self.is_single_value() {
+ // + 8 bytes single value
+ total_size += size_of::<f64>();
+ } else {
+ // + 8 bytes min
+ // + 8 bytes max
+ total_size += size_of::<f64>() * 2;
+ // + (8+8) bytes per centroid
+ total_size += self.centroids.len() * (size_of::<f64>() +
size_of::<u64>());
+ }
+
+ let mut bytes = Vec::with_capacity(total_size);
+ bytes.push(match self.total_weight() {
+ 0 => PREAMBLE_LONGS_EMPTY_OR_SINGLE,
+ 1 => PREAMBLE_LONGS_EMPTY_OR_SINGLE,
+ _ => PREAMBLE_LONGS_MULTIPLE,
+ });
+ bytes.push(SERIAL_VERSION);
+ bytes.push(TDIGEST_FAMILY_ID);
+ bytes.extend_from_slice(&self.k.to_le_bytes());
+ bytes.push({
+ let mut flags = 0;
+ if self.is_empty() {
+ flags |= FLAGS_IS_EMPTY;
+ }
+ if self.is_single_value() {
+ flags |= FLAGS_IS_SINGLE_VALUE;
+ }
+ if self.reverse_merge {
+ flags |= FLAGS_REVERSE_MERGE;
+ }
+ flags
+ });
+ bytes.extend_from_slice(&0u16.to_le_bytes()); // unused
+ if self.is_empty() {
+ return bytes;
+ }
+ if self.is_single_value() {
+ bytes.extend_from_slice(&self.min.to_le_bytes());
+ return bytes;
+ }
+ bytes.extend_from_slice(&(self.centroids.len() as u32).to_le_bytes());
+ bytes.extend_from_slice(&0u32.to_le_bytes()); // unused
+ bytes.extend_from_slice(&self.min.to_le_bytes());
+ bytes.extend_from_slice(&self.max.to_le_bytes());
+ for centroid in &self.centroids {
+ bytes.extend_from_slice(¢roid.mean.to_le_bytes());
+ bytes.extend_from_slice(¢roid.weight.get().to_le_bytes());
+ }
+ bytes
+ }
+
+ /// Deserializes a TDigest from bytes.
+ ///
+ /// Supports reading compact format with (float, int) centroids as opposed
to (double, long) to
+ /// represent (mean, weight). [^1]
+ ///
+ /// Supports reading format of the reference implementation
(auto-detected) [^2].
+ ///
+ /// [^1]: This is to support reading the `tdigest<float>` format from the
C++ implementation.
+ /// [^2]: <https://github.com/tdunning/t-digest>
+ pub fn deserialize(bytes: &[u8], is_f32: bool) -> Result<Self, SerdeError>
{
+ fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) ->
SerdeError {
+ move |_| SerdeError::InsufficientData(tag.to_string())
+ }
+
+ let mut cursor = Cursor::new(bytes);
+
+ let preamble_longs =
cursor.read_u8().map_err(make_error("preamble_longs"))?;
+ let serial_version =
cursor.read_u8().map_err(make_error("serial_version"))?;
+ let family_id = cursor.read_u8().map_err(make_error("family_id"))?;
+ if family_id != TDIGEST_FAMILY_ID {
+ if preamble_longs == 0 && serial_version == 0 && family_id == 0 {
+ return Self::deserialize_compat(bytes);
+ }
+ return Err(SerdeError::InvalidFamily(format!(
+ "expected {} (TDigest), got {}",
+ TDIGEST_FAMILY_ID, family_id
+ )));
+ }
+ if serial_version != SERIAL_VERSION {
+ return Err(SerdeError::UnsupportedVersion(format!(
+ "expected {}, got {}",
+ SERIAL_VERSION, serial_version
+ )));
+ }
+ let k = cursor.read_u16::<LE>().map_err(make_error("k"))?;
+ if k < 10 {
+ return Err(SerdeError::InvalidParameter(format!(
+ "k must be at least 10, got {k}"
+ )));
+ }
+ let flags = cursor.read_u8().map_err(make_error("flags"))?;
+ let is_empty = (flags & FLAGS_IS_EMPTY) != 0;
+ let is_single_value = (flags & FLAGS_IS_SINGLE_VALUE) != 0;
+ let expected_preamble_longs = if is_empty || is_single_value {
+ PREAMBLE_LONGS_EMPTY_OR_SINGLE
+ } else {
+ PREAMBLE_LONGS_MULTIPLE
+ };
+ if preamble_longs != expected_preamble_longs {
+ return Err(SerdeError::MalformedData(format!(
+ "expected preamble_longs to be {}, got {}",
+ expected_preamble_longs, preamble_longs
+ )));
+ }
+ cursor.read_u16::<LE>().map_err(make_error("<unused>"))?; // unused
+ if is_empty {
+ return Ok(TDigestMut::new(k));
+ }
+
+ let reverse_merge = (flags & FLAGS_REVERSE_MERGE) != 0;
+ if is_single_value {
+ let value = if is_f32 {
+ cursor
+ .read_f32::<LE>()
+ .map_err(make_error("single_value"))? as f64
+ } else {
+ cursor
+ .read_f64::<LE>()
+ .map_err(make_error("single_value"))?
+ };
+ check_non_nan(value, "single_value")?;
+ check_non_infinite(value, "single_value")?;
+ return Ok(TDigestMut::make(
+ k,
+ reverse_merge,
+ value,
+ value,
+ vec![Centroid {
+ mean: value,
+ weight: DEFAULT_WEIGHT,
+ }],
+ 1,
+ vec![],
+ ));
+ }
+ let num_centroids = cursor
+ .read_u32::<LE>()
+ .map_err(make_error("num_centroids"))? as usize;
+ let num_buffered = cursor
+ .read_u32::<LE>()
+ .map_err(make_error("num_buffered"))? as usize;
+ let (min, max) = if is_f32 {
+ (
+ cursor.read_f32::<LE>().map_err(make_error("min"))? as f64,
+ cursor.read_f32::<LE>().map_err(make_error("max"))? as f64,
+ )
+ } else {
+ (
+ cursor.read_f64::<LE>().map_err(make_error("min"))?,
+ cursor.read_f64::<LE>().map_err(make_error("max"))?,
+ )
+ };
+ check_non_nan(min, "min")?;
+ check_non_nan(max, "max")?;
+ let mut centroids = Vec::with_capacity(num_centroids);
+ let mut centroids_weight = 0u64;
+ for _ in 0..num_centroids {
+ let (mean, weight) = if is_f32 {
+ (
+ cursor.read_f32::<LE>().map_err(make_error("mean"))? as
f64,
+ cursor.read_u32::<LE>().map_err(make_error("weight"))? as
u64,
+ )
+ } else {
+ (
+ cursor.read_f64::<LE>().map_err(make_error("mean"))?,
+ cursor.read_u64::<LE>().map_err(make_error("weight"))?,
+ )
+ };
+ check_non_nan(mean, "centroid mean")?;
+ check_non_infinite(mean, "centroid")?;
+ let weight = check_nonzero(weight, "centroid weight")?;
+ centroids_weight += weight.get();
+ centroids.push(Centroid { mean, weight });
+ }
+ let mut buffer = Vec::with_capacity(num_buffered);
+ for _ in 0..num_buffered {
+ let value = if is_f32 {
+ cursor
+ .read_f32::<LE>()
+ .map_err(make_error("buffered_value"))? as f64
+ } else {
+ cursor
+ .read_f64::<LE>()
+ .map_err(make_error("buffered_value"))?
+ };
+ check_non_nan(value, "buffered_value mean")?;
+ check_non_infinite(value, "buffered_value mean")?;
+ buffer.push(value);
+ }
+ Ok(TDigestMut::make(
+ k,
+ reverse_merge,
+ min,
+ max,
+ centroids,
+ centroids_weight,
+ buffer,
+ ))
+ }
+
+ // compatibility with the format of the reference implementation
+ // default byte order of ByteBuffer is used there, which is big endian
+ fn deserialize_compat(bytes: &[u8]) -> Result<Self, SerdeError> {
+ fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) ->
SerdeError {
+ move |_| SerdeError::InsufficientData(format!("{tag} in compat
format"))
+ }
+
+ let mut cursor = Cursor::new(bytes);
+
+ let ty = cursor.read_u32::<BE>().map_err(make_error("type"))?;
+ match ty {
+ COMPAT_DOUBLE => {
+ fn make_error(tag: &'static str) -> impl
FnOnce(std::io::Error) -> SerdeError {
+ move |_| SerdeError::InsufficientData(format!("{tag} in
compat double format"))
+ }
+ // compatibility with asBytes()
+ let min = cursor.read_f64::<BE>().map_err(make_error("min"))?;
+ let max = cursor.read_f64::<BE>().map_err(make_error("max"))?;
+ check_non_nan(min, "min in compat double format")?;
+ check_non_nan(max, "max in compat double format")?;
+ let k = cursor.read_f64::<BE>().map_err(make_error("k"))? as
u16;
+ if k < 10 {
+ return Err(SerdeError::InvalidParameter(format!(
+ "k must be at least 10, got {k} in compat double
format"
+ )));
+ }
+ let num_centroids = cursor
+ .read_u32::<BE>()
+ .map_err(make_error("num_centroids"))?
+ as usize;
+ let mut total_weight = 0u64;
+ let mut centroids = Vec::with_capacity(num_centroids);
+ for _ in 0..num_centroids {
+ let weight =
cursor.read_f64::<BE>().map_err(make_error("weight"))? as u64;
+ let mean =
cursor.read_f64::<BE>().map_err(make_error("mean"))?;
+ let weight = check_nonzero(weight, "centroid weight in
compat double format")?;
+ check_non_nan(mean, "centroid mean in compat double
format")?;
+ check_non_infinite(mean, "centroid mean in compat double
format")?;
+ total_weight += weight.get();
+ centroids.push(Centroid { mean, weight });
+ }
+ Ok(TDigestMut::make(
+ k,
+ false,
+ min,
+ max,
+ centroids,
+ total_weight,
+ vec![],
+ ))
+ }
+ COMPAT_FLOAT => {
+ fn make_error(tag: &'static str) -> impl
FnOnce(std::io::Error) -> SerdeError {
+ move |_| SerdeError::InsufficientData(format!("{tag} in
compat float format"))
+ }
+ // COMPAT_FLOAT: compatibility with asSmallBytes()
+ // reference implementation uses doubles for min and max
+ let min = cursor.read_f64::<BE>().map_err(make_error("min"))?;
+ let max = cursor.read_f64::<BE>().map_err(make_error("max"))?;
+ check_non_nan(min, "min in compat float format")?;
+ check_non_nan(max, "max in compat float format")?;
+ let k = cursor.read_f32::<BE>().map_err(make_error("k"))? as
u16;
+ if k < 10 {
+ return Err(SerdeError::InvalidParameter(format!(
+ "k must be at least 10, got {k} in compat float format"
+ )));
+ }
+ // reference implementation stores capacities of the array of
centroids and the
+ // buffer as shorts they can be derived from k in the
constructor
+ cursor.read_u32::<BE>().map_err(make_error("<unused>"))?;
+ let num_centroids = cursor
+ .read_u16::<BE>()
+ .map_err(make_error("num_centroids"))?
+ as usize;
+ let mut total_weight = 0u64;
+ let mut centroids = Vec::with_capacity(num_centroids);
+ for _ in 0..num_centroids {
+ let weight =
cursor.read_f32::<BE>().map_err(make_error("weight"))? as u64;
+ let mean =
cursor.read_f32::<BE>().map_err(make_error("mean"))? as f64;
+ let weight = check_nonzero(weight, "centroid weight in
compat float format")?;
+ check_non_nan(mean, "centroid mean in compat float
format")?;
+ check_non_infinite(mean, "centroid mean in compat float
format")?;
+ total_weight += weight.get();
+ centroids.push(Centroid { mean, weight });
+ }
+ Ok(TDigestMut::make(
+ k,
+ false,
+ min,
+ max,
+ centroids,
+ total_weight,
+ vec![],
+ ))
+ }
+ ty => Err(SerdeError::InvalidParameter(format!(
+ "unknown TDigest compat type {ty}",
+ ))),
+ }
+ }
+
+ fn is_single_value(&self) -> bool {
+ self.total_weight() == 1
+ }
+
+ /// Process buffered values and merge centroids if needed.
+ fn compress(&mut self) {
+ if self.buffer.is_empty() {
+ return;
+ }
+ let mut tmp = Vec::with_capacity(self.buffer.len() +
self.centroids.len());
+ for &v in &self.buffer {
+ tmp.push(Centroid {
+ mean: v,
+ weight: DEFAULT_WEIGHT,
+ });
+ }
+ self.do_merge(tmp, self.buffer.len() as u64)
+ }
+
+ /// Merges the given buffer of centroids into this TDigest.
+ ///
+ /// # Contract
+ ///
+ /// * `buffer` must have at least one centroid.
+ /// * `buffer` is generated from `self.buffer`, and thus:
+ /// * No `NAN` values are present in `buffer`.
+ /// * We should clear `self.buffer` after merging.
+ fn do_merge(&mut self, mut buffer: Vec<Centroid>, weight: u64) {
+ buffer.extend(std::mem::take(&mut self.centroids));
+ buffer.sort_by(centroid_cmp);
+ if self.reverse_merge {
+ buffer.reverse();
+ }
+ self.centroids_weight += weight;
+
+ let mut num_centroids = 0;
+ let len = buffer.len();
+ self.centroids.push(buffer[0]);
+ num_centroids += 1;
+ let mut current = 1;
+ let mut weight_so_far = 0.;
+ while current < len {
+ let c = buffer[current];
+ let proposed_weight = self.centroids[num_centroids - 1].weight() +
c.weight();
+ let mut add_this = false;
+ if (current != 1) && (current != (len - 1)) {
+ let centroids_weight = self.centroids_weight as f64;
+ let q0 = weight_so_far / centroids_weight;
+ let q2 = (weight_so_far + proposed_weight) / centroids_weight;
+ let normalizer = scale_function::normalizer((2 * self.k) as
f64, centroids_weight);
+ add_this = proposed_weight
+ <= (centroids_weight
+ * scale_function::max(q0, normalizer)
+ .min(scale_function::max(q2, normalizer)));
+ }
+ if add_this {
+ // merge into existing centroid
+ self.centroids[num_centroids - 1].add(c);
+ } else {
+ // copy to a new centroid
+ weight_so_far += self.centroids[num_centroids - 1].weight();
+ self.centroids.push(c);
+ num_centroids += 1;
+ }
+ current += 1;
+ }
+
+ if self.reverse_merge {
+ self.centroids.reverse();
+ }
+ self.min = self.min.min(self.centroids[0].mean);
+ self.max = self.max.max(self.centroids[num_centroids - 1].mean);
+ self.reverse_merge = !self.reverse_merge;
+ self.buffer.clear();
+ }
+}
+
+/// Immutable (frozen) T-Digest sketch for estimating quantiles and ranks.
+///
+/// See the [module documentation](super) for more details.
+pub struct TDigest {
+ k: u16,
+
+ reverse_merge: bool,
+ min: f64,
+ max: f64,
+
+ centroids: Vec<Centroid>,
+ centroids_weight: u64,
+}
+
+impl TDigest {
+ /// Returns parameter k (compression) that was used to configure this
TDigest.
+ pub fn k(&self) -> u16 {
+ self.k
+ }
+
+ /// Returns true if TDigest has not seen any data.
+ pub fn is_empty(&self) -> bool {
+ self.centroids.is_empty()
+ }
+
+ /// Returns minimum value seen by TDigest; `None` if TDigest is empty.
+ pub fn min_value(&self) -> Option<f64> {
+ if self.is_empty() {
+ None
+ } else {
+ Some(self.min)
+ }
+ }
+
+ /// Returns maximum value seen by TDigest; `None` if TDigest is empty.
+ pub fn max_value(&self) -> Option<f64> {
+ if self.is_empty() {
+ None
+ } else {
+ Some(self.max)
+ }
+ }
+
+ /// Returns total weight.
+ pub fn total_weight(&self) -> u64 {
+ self.centroids_weight
+ }
+
+ fn view(&self) -> TDigestView<'_> {
+ TDigestView {
+ min: self.min,
+ max: self.max,
+ centroids: &self.centroids,
+ centroids_weight: self.centroids_weight,
+ }
+ }
+
+ /// Returns an approximation to the Cumulative Distribution Function
(CDF), which is the
+ /// cumulative analog of the PMF, of the input stream given a set of split
points.
+ ///
+ /// # Arguments
+ ///
+ /// * `split_points`: An array of _m_ unique, monotonically increasing
values that divide the
+ /// input domain into _m+1_ consecutive disjoint intervals.
+ ///
+ /// # Returns
+ ///
+ /// An array of m+1 doubles, which are a consecutive approximation to the
CDF of the input
+ /// stream given the split points. The value at array position j of the
returned CDF array
+ /// is the sum of the returned values in positions 0 through j of the
returned PMF array.
+ /// This can be viewed as array of ranks of the given split points plus
one more value that
+ /// is always 1.
+ ///
+ /// Returns `None` if TDigest is empty.
+ ///
+ /// # Panics
+ ///
+ /// If `split_points` is not unique, not monotonically increasing, or
contains `NaN` values.
+ pub fn cdf(&self, split_points: &[f64]) -> Option<Vec<f64>> {
+ self.view().cdf(split_points)
+ }
+
+ /// Returns an approximation to the Probability Mass Function (PMF) of the
input stream
+ /// given a set of split points.
+ ///
+ /// # Arguments
+ ///
+ /// * `split_points`: An array of _m_ unique, monotonically increasing
values that divide the
+ /// input domain into _m+1_ consecutive disjoint intervals (bins).
+ ///
+ /// # Returns
+ ///
+ /// An array of m+1 doubles each of which is an approximation to the
fraction of the input
+ /// stream values (the mass) that fall into one of those intervals.
+ ///
+ /// Returns `None` if TDigest is empty.
+ ///
+ /// # Panics
+ ///
+ /// If `split_points` is not unique, not monotonically increasing, or
contains `NaN` values.
+ pub fn pmf(&self, split_points: &[f64]) -> Option<Vec<f64>> {
+ self.view().pmf(split_points)
+ }
+
+ /// Compute approximate normalized rank (from 0 to 1 inclusive) of the
given value.
+ ///
+ /// Returns `None` if TDigest is empty.
+ ///
+ /// # Panics
+ ///
+ /// If the value is `NaN`.
+ pub fn rank(&self, value: f64) -> Option<f64> {
+ assert!(!value.is_nan(), "value must not be NaN");
+ self.view().rank(value)
+ }
+
+ /// Compute approximate quantile value corresponding to the given
normalized rank.
+ ///
+ /// Returns `None` if TDigest is empty.
+ ///
+ /// # Panics
+ ///
+ /// If rank is not in [0.0, 1.0].
+ pub fn quantile(&self, rank: f64) -> Option<f64> {
+ assert!((0.0..=1.0).contains(&rank), "rank must be in [0.0, 1.0]");
+ self.view().quantile(rank)
+ }
+
+ /// Converts this immutable TDigest into a mutable one.
+ pub fn unfreeze(self) -> TDigestMut {
+ TDigestMut::make(
+ self.k,
+ self.reverse_merge,
+ self.min,
+ self.max,
+ self.centroids,
+ self.centroids_weight,
+ vec![],
+ )
+ }
+}
+
+struct TDigestView<'a> {
+ min: f64,
+ max: f64,
+ centroids: &'a [Centroid],
+ centroids_weight: u64,
+}
+
+impl TDigestView<'_> {
+ fn pmf(&self, split_points: &[f64]) -> Option<Vec<f64>> {
+ let mut buckets = self.cdf(split_points)?;
+ for i in (1..buckets.len()).rev() {
+ buckets[i] -= buckets[i - 1];
+ }
+ Some(buckets)
+ }
+
+ fn cdf(&self, split_points: &[f64]) -> Option<Vec<f64>> {
+ check_split_points(split_points);
+
+ if self.centroids.is_empty() {
+ return None;
+ }
+
+ let mut ranks = Vec::with_capacity(split_points.len() + 1);
+ for &p in split_points {
+ match self.rank(p) {
+ Some(rank) => ranks.push(rank),
+ None => unreachable!("checked non-empty above"),
+ }
+ }
+ ranks.push(1.0);
+ Some(ranks)
+ }
+
+ fn rank(&self, value: f64) -> Option<f64> {
+ debug_assert!(!value.is_nan(), "value must not be NaN");
+
+ if self.centroids.is_empty() {
+ return None;
+ }
+ if value < self.min {
+ return Some(0.0);
+ }
+ if value > self.max {
+ return Some(1.0);
+ }
+ // one centroid and value == min == max
+ if self.centroids.len() == 1 {
+ return Some(0.5);
+ }
+
+ let centroids_weight = self.centroids_weight as f64;
+ let num_centroids = self.centroids.len();
+
+ // left tail
+ let first_mean = self.centroids[0].mean;
+ if value < first_mean {
+ if first_mean - self.min > 0. {
+ return Some(if value == self.min {
+ 0.5 / centroids_weight
+ } else {
+ 1. + (((value - self.min) / (first_mean - self.min))
+ * ((self.centroids[0].weight() / 2.) - 1.))
+ });
+ }
+ return Some(0.); // should never happen
+ }
+
+ // right tail
+ let last_mean = self.centroids[num_centroids - 1].mean;
+ if value > last_mean {
+ if self.max - last_mean > 0. {
+ return Some(if value == self.max {
+ 1. - (0.5 / centroids_weight)
+ } else {
+ 1.0 - ((1.0
+ + (((self.max - value) / (self.max - last_mean))
+ * ((self.centroids[num_centroids - 1].weight() /
2.) - 1.)))
+ / centroids_weight)
+ });
+ }
+ return Some(1.); // should never happen
+ }
+
+ let mut lower = self
+ .centroids
+ .binary_search_by(|c| centroid_lower_bound(c, value))
+ .unwrap_or_else(identity);
+ assert_ne!(lower, num_centroids, "get_rank: lower == end");
+ let mut upper = self
+ .centroids
+ .binary_search_by(|c| centroid_upper_bound(c, value))
+ .unwrap_or_else(identity);
+ assert_ne!(upper, 0, "get_rank: upper == begin");
+ if value < self.centroids[lower].mean {
+ lower -= 1;
+ }
+ if (upper == num_centroids) || (self.centroids[upper - 1].mean >=
value) {
+ upper -= 1;
+ }
+
+ let mut weight_below = 0.;
+ let mut i = 0;
+ while i < lower {
+ weight_below += self.centroids[i].weight();
+ i += 1;
+ }
+ weight_below += self.centroids[lower].weight() / 2.;
+
+ let mut weight_delta = 0.;
+ while i < upper {
+ weight_delta += self.centroids[i].weight();
+ i += 1;
+ }
+ weight_delta -= self.centroids[lower].weight() / 2.;
+ weight_delta += self.centroids[upper].weight() / 2.;
+ Some(
+ if self.centroids[upper].mean - self.centroids[lower].mean > 0. {
+ (weight_below
+ + (weight_delta * (value - self.centroids[lower].mean)
+ / (self.centroids[upper].mean -
self.centroids[lower].mean)))
+ / centroids_weight
+ } else {
+ (weight_below + weight_delta / 2.) / centroids_weight
+ },
+ )
+ }
+
+ fn quantile(&self, rank: f64) -> Option<f64> {
+ debug_assert!((0.0..=1.0).contains(&rank), "rank must be in [0.0,
1.0]");
+
+ if self.centroids.is_empty() {
+ return None;
+ }
+
+ if self.centroids.len() == 1 {
+ return Some(self.centroids[0].mean);
+ }
+
+ // at least 2 centroids
+ let centroids_weight = self.centroids_weight as f64;
+ let num_centroids = self.centroids.len();
+ let weight = rank * centroids_weight;
+ if weight < 1. {
+ return Some(self.min);
+ }
+ if weight > centroids_weight - 1. {
+ return Some(self.max);
+ }
+ let first_weight = self.centroids[0].weight();
+ if first_weight > 1. && weight < first_weight / 2. {
+ return Some(
+ self.min
+ + (((weight - 1.) / ((first_weight / 2.) - 1.))
+ * (self.centroids[0].mean - self.min)),
+ );
+ }
+ let last_weight = self.centroids[num_centroids - 1].weight();
+ if last_weight > 1. && (centroids_weight - weight <= last_weight / 2.)
{
+ return Some(
+ self.max
+ + (((centroids_weight - weight - 1.) / ((last_weight / 2.)
- 1.))
+ * (self.max - self.centroids[num_centroids - 1].mean)),
+ );
+ }
+
+ // interpolate between extremes
+ let mut weight_so_far = first_weight / 2.;
+ for i in 0..(num_centroids - 1) {
+ let dw = (self.centroids[i].weight() + self.centroids[i +
1].weight()) / 2.;
+ if weight_so_far + dw > weight {
+ // the target weight is between centroids i and i+1
+ let mut left_weight = 0.;
+ if self.centroids[i].weight.get() == 1 {
+ if weight - weight_so_far < 0.5 {
+ return Some(self.centroids[i].mean);
+ }
+ left_weight = 0.5;
+ }
+ let mut right_weight = 0.;
+ if self.centroids[i + 1].weight.get() == 1 {
+ if weight_so_far + dw - weight < 0.5 {
+ return Some(self.centroids[i + 1].mean);
+ }
+ right_weight = 0.5;
+ }
+ let w1 = weight - weight_so_far - left_weight;
+ let w2 = weight_so_far + dw - weight - right_weight;
+ return Some(weighted_average(
+ self.centroids[i].mean,
+ w1,
+ self.centroids[i + 1].mean,
+ w2,
+ ));
+ }
+ weight_so_far += dw;
+ }
+
+ let w1 = weight - (centroids_weight) - ((self.centroids[num_centroids
- 1].weight()) / 2.);
+ let w2 = (self.centroids[num_centroids - 1].weight() / 2.) - w1;
+ Some(weighted_average(
+ self.centroids[num_centroids - 1].mean,
+ w1,
+ self.max,
+ w2,
+ ))
+ }
+}
+
+/// Checks the sequential validity of the given array of double values.
+/// They must be unique, monotonically increasing and not NaN.
+#[track_caller]
+fn check_split_points(split_points: &[f64]) {
+ let len = split_points.len();
+ if len == 1 && split_points[0].is_nan() {
+ panic!("split_points must not contain NaN values: {split_points:?}");
+ }
+ for i in 0..len - 1 {
+ if split_points[i] < split_points[i + 1] {
+ // we must use this positive condition because NaN comparisons are
always false
+ continue;
+ }
+ panic!("split_points must be unique and monotonically increasing:
{split_points:?}");
+ }
+}
+
+fn centroid_cmp(a: &Centroid, b: &Centroid) -> Ordering {
+ match a.mean.partial_cmp(&b.mean) {
+ Some(order) => order,
+ None => unreachable!("NaN values should never be present in
centroids"),
+ }
+}
+
+fn centroid_lower_bound(c: &Centroid, value: f64) -> Ordering {
+ if c.mean < value {
+ Ordering::Less
+ } else {
+ Ordering::Greater
+ }
+}
+
+fn centroid_upper_bound(c: &Centroid, value: f64) -> Ordering {
+ if c.mean > value {
+ Ordering::Greater
+ } else {
+ Ordering::Less
+ }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq)]
+struct Centroid {
+ mean: f64,
+ weight: NonZeroU64,
+}
+
+impl Centroid {
+ fn add(&mut self, other: Centroid) {
+ let (self_weight, other_weight) = (self.weight(), other.weight());
+ let total_weight = self_weight + other_weight;
+ self.weight = self.weight.saturating_add(other.weight.get());
+
+ let (self_mean, other_mean) = (self.mean, other.mean);
+ let ratio_self = self_weight / total_weight;
+ let ratio_other = other_weight / total_weight;
+ self.mean = self_mean.mul_add(ratio_self, other_mean * ratio_other);
+ debug_assert!(
+ !self.mean.is_nan(),
+ "NaN values should never be present in centroids; self: {}, other:
{}",
+ self_mean,
+ other_mean
+ );
+ }
+
+ fn weight(&self) -> f64 {
+ self.weight.get() as f64
+ }
+}
+
+fn check_non_nan(value: f64, tag: &'static str) -> Result<(), SerdeError> {
+ if value.is_nan() {
+ return Err(SerdeError::MalformedData(format!("{tag} cannot be NaN")));
+ }
+ Ok(())
+}
+
+fn check_non_infinite(value: f64, tag: &'static str) -> Result<(), SerdeError>
{
+ if value.is_infinite() {
+ return Err(SerdeError::MalformedData(format!(
+ "{tag} cannot be is_infinite"
+ )));
+ }
+ Ok(())
+}
+
+fn check_nonzero(value: u64, tag: &'static str) -> Result<NonZeroU64,
SerdeError> {
+ NonZeroU64::new(value).ok_or_else(||
SerdeError::MalformedData(format!("{tag} cannot be zero")))
+}
+
+/// Generates cluster sizes proportional to `q*(1-q)`.
+///
+/// The use of a normalizing function results in a strictly bounded number of
clusters no matter
+/// how many samples.
+///
+/// Corresponds to K_2 in the reference implementation
+mod scale_function {
+ pub(super) fn max(q: f64, normalizer: f64) -> f64 {
+ q * (1. - q) / normalizer
+ }
+
+ pub(super) fn normalizer(compression: f64, n: f64) -> f64 {
+ compression / z(compression, n)
+ }
+
+ pub(super) fn z(compression: f64, n: f64) -> f64 {
+ 4. * (n / compression).ln() + 24.
+ }
+}
+
+const fn weighted_average(x1: f64, w1: f64, x2: f64, w2: f64) -> f64 {
+ (x1 * w1 + x2 * w2) / (w1 + w2)
+}
diff --git a/tests/common.rs b/tests/common.rs
new file mode 100644
index 0000000..e97b920
--- /dev/null
+++ b/tests/common.rs
@@ -0,0 +1,52 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::path::PathBuf;
+
+#[allow(dead_code)] // false-positive
+pub fn test_data(name: &str) -> PathBuf {
+ const TEST_DATA_DIR: &str = "tests/test_data";
+
+ PathBuf::from(env!("CARGO_MANIFEST_DIR"))
+ .join(TEST_DATA_DIR)
+ .join(name)
+}
+
+pub fn serialization_test_data(sub_dir: &str, name: &str) -> PathBuf {
+ const SERDE_TEST_DATA_DIR: &str = "tests/serialization_test_data";
+
+ let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
+ .join(SERDE_TEST_DATA_DIR)
+ .join(sub_dir)
+ .join(name);
+
+ if !path.exists() {
+ panic!(
+ r#"serialization test data file not found: {}
+
+ Please ensure test data files are present in the repository.
Generally, you can
+ run the following commands from the project root to regenerate the
test data files
+ if they are missing:
+
+ $ ./tools/generate_serialization_test_data.py
+ "#,
+ path.display(),
+ );
+ }
+
+ path
+}
diff --git a/tests/hll_serialization_test.rs b/tests/hll_serialization_test.rs
index a3c397b..fc1969c 100644
--- a/tests/hll_serialization_test.rs
+++ b/tests/hll_serialization_test.rs
@@ -24,36 +24,14 @@
//! Test data is generated by the reference implementations and stored in:
//! `tests/serialization_test_data/`
+mod common;
+
use std::fs;
use std::path::PathBuf;
+use common::serialization_test_data;
use datasketches::hll::HllSketch;
-const TEST_DATA_DIR: &str = "tests/serialization_test_data";
-
-fn get_test_data_path(sub_dir: &str, name: &str) -> PathBuf {
- let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
- .join(TEST_DATA_DIR)
- .join(sub_dir)
- .join(name);
-
- if !path.exists() {
- panic!(
- r#"serialization test data file not found: {}
-
- Please ensure test data files are present in the repository.
Generally, you can
- run the following commands from the project root to regenerate the
test data files
- if they are missing:
-
- $ ./tools/generate_serialization_test_data.py
- "#,
- path.display(),
- );
- }
-
- path
-}
-
fn test_sketch_file(path: PathBuf, expected_cardinality: usize, expected_lg_k:
u8) {
let expected = expected_cardinality as f64;
@@ -133,7 +111,7 @@ fn test_java_hll4_compatibility() {
for n in test_cases {
let filename = format!("hll4_n{}_java.sk", n);
- let path = get_test_data_path("java_generated_files", &filename);
+ let path = serialization_test_data("java_generated_files", &filename);
test_sketch_file(path, n, 12);
}
}
@@ -144,7 +122,7 @@ fn test_java_hll6_compatibility() {
for n in test_cases {
let filename = format!("hll6_n{}_java.sk", n);
- let path = get_test_data_path("java_generated_files", &filename);
+ let path = serialization_test_data("java_generated_files", &filename);
test_sketch_file(path, n, 12);
}
}
@@ -155,7 +133,7 @@ fn test_java_hll8_compatibility() {
for n in test_cases {
let filename = format!("hll8_n{}_java.sk", n);
- let path = get_test_data_path("java_generated_files", &filename);
+ let path = serialization_test_data("java_generated_files", &filename);
test_sketch_file(path, n, 12);
}
}
@@ -166,7 +144,7 @@ fn test_cpp_hll4_compatibility() {
for n in test_cases {
let filename = format!("hll4_n{}_cpp.sk", n);
- let path = get_test_data_path("cpp_generated_files", &filename);
+ let path = serialization_test_data("cpp_generated_files", &filename);
test_sketch_file(path, n, 12);
}
}
@@ -177,7 +155,7 @@ fn test_cpp_hll6_compatibility() {
for n in test_cases {
let filename = format!("hll6_n{}_cpp.sk", n);
- let path = get_test_data_path("cpp_generated_files", &filename);
+ let path = serialization_test_data("cpp_generated_files", &filename);
test_sketch_file(path, n, 12);
}
}
@@ -188,7 +166,7 @@ fn test_cpp_hll8_compatibility() {
for n in test_cases {
let filename = format!("hll8_n{}_cpp.sk", n);
- let path = get_test_data_path("cpp_generated_files", &filename);
+ let path = serialization_test_data("cpp_generated_files", &filename);
test_sketch_file(path, n, 12);
}
}
@@ -208,7 +186,7 @@ fn test_estimate_accuracy() {
println!("{:-<40}", "");
for (dir, file, expected) in test_cases {
- let path = get_test_data_path(dir, file);
+ let path = serialization_test_data(dir, file);
let bytes = fs::read(&path).unwrap();
let sketch = HllSketch::deserialize(&bytes).unwrap();
let estimate = sketch.estimate();
diff --git a/tests/tdigest_serialization_test.rs
b/tests/tdigest_serialization_test.rs
new file mode 100644
index 0000000..0ad68e4
--- /dev/null
+++ b/tests/tdigest_serialization_test.rs
@@ -0,0 +1,189 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+mod common;
+
+use std::fs;
+use std::path::PathBuf;
+
+use common::serialization_test_data;
+use common::test_data;
+use datasketches::tdigest::TDigestMut;
+use googletest::assert_that;
+use googletest::prelude::{eq, near};
+
+fn test_sketch_file(path: PathBuf, n: u64, with_buffer: bool, is_f32: bool) {
+ let bytes = fs::read(&path).unwrap();
+ let td = TDigestMut::deserialize(&bytes, is_f32).unwrap();
+ let td = td.freeze();
+
+ let path = path.display();
+ if n == 0 {
+ assert!(td.is_empty(), "filepath: {path}");
+ assert_eq!(td.total_weight(), 0, "filepath: {path}");
+ } else {
+ assert!(!td.is_empty(), "filepath: {path}");
+ assert_eq!(td.total_weight(), n, "filepath: {path}");
+ assert_eq!(td.min_value(), Some(1.0), "filepath: {path}");
+ assert_eq!(td.max_value(), Some(n as f64), "filepath: {path}");
+ assert_eq!(td.rank(0.0), Some(0.0), "filepath: {path}");
+ assert_eq!(td.rank((n + 1) as f64), Some(1.0), "filepath: {path}");
+ if n == 1 {
+ assert_eq!(td.rank(n as f64), Some(0.5), "filepath: {path}");
+ } else {
+ assert_that!(
+ td.rank(n as f64 / 2.).unwrap(),
+ near(0.5, 0.05),
+ "filepath: {path}",
+ );
+ }
+ }
+
+ if !with_buffer && !is_f32 {
+ let mut td = td.unfreeze();
+ let roundtrip_bytes = td.serialize();
+ assert_eq!(bytes, roundtrip_bytes, "filepath: {path}");
+ }
+}
+
+#[test]
+fn test_deserialize_from_cpp_snapshots() {
+ let ns = [0, 1, 10, 100, 1000, 10_000, 100_000, 1_000_000];
+ for n in ns {
+ let filename = format!("tdigest_double_n{}_cpp.sk", n);
+ let path = serialization_test_data("cpp_generated_files", &filename);
+ test_sketch_file(path, n, false, false);
+ }
+ for n in ns {
+ let filename = format!("tdigest_double_buf_n{}_cpp.sk", n);
+ let path = serialization_test_data("cpp_generated_files", &filename);
+ test_sketch_file(path, n, true, false);
+ }
+ for n in ns {
+ let filename = format!("tdigest_float_n{}_cpp.sk", n);
+ let path = serialization_test_data("cpp_generated_files", &filename);
+ test_sketch_file(path, n, false, true);
+ }
+ for n in ns {
+ let filename = format!("tdigest_float_buf_n{}_cpp.sk", n);
+ let path = serialization_test_data("cpp_generated_files", &filename);
+ test_sketch_file(path, n, true, true);
+ }
+}
+
+#[test]
+fn test_deserialize_from_reference_implementation() {
+ for filename in [
+ "tdigest_ref_k100_n10000_double.sk",
+ "tdigest_ref_k100_n10000_float.sk",
+ ] {
+ let path = test_data(filename);
+ let bytes = fs::read(&path).unwrap();
+ let td = TDigestMut::deserialize(&bytes, false).unwrap();
+ let td = td.freeze();
+
+ let n = 10000;
+ let path = path.display();
+ assert_eq!(td.k(), 100, "filepath: {path}");
+ assert_eq!(td.total_weight(), n, "filepath: {path}");
+ assert_eq!(td.min_value(), Some(0.0), "filepath: {path}");
+ assert_eq!(td.max_value(), Some((n - 1) as f64), "filepath: {path}");
+ assert_that!(td.rank(0.0).unwrap(), near(0.0, 0.0001), "filepath:
{path}");
+ assert_that!(
+ td.rank(n as f64 / 4.).unwrap(),
+ near(0.25, 0.0001),
+ "filepath: {path}"
+ );
+ assert_that!(
+ td.rank(n as f64 / 2.).unwrap(),
+ near(0.5, 0.0001),
+ "filepath: {path}"
+ );
+ assert_that!(
+ td.rank((n * 3) as f64 / 4.).unwrap(),
+ near(0.75, 0.0001),
+ "filepath: {path}"
+ );
+ assert_that!(td.rank(n as f64).unwrap(), eq(1.0), "filepath: {path}");
+ }
+}
+
+#[test]
+fn test_deserialize_from_java_snapshots() {
+ let ns = [0, 1, 10, 100, 1000, 10_000, 100_000, 1_000_000];
+ for n in ns {
+ let filename = format!("tdigest_double_n{}_java.sk", n);
+ let path = serialization_test_data("java_generated_files", &filename);
+ test_sketch_file(path, n, false, false);
+ }
+}
+
+#[test]
+fn test_empty() {
+ let mut td = TDigestMut::new(100);
+ assert!(td.is_empty());
+
+ let bytes = td.serialize();
+ assert_eq!(bytes.len(), 8);
+ let td = td.freeze();
+
+ let deserialized_td = TDigestMut::deserialize(&bytes, false).unwrap();
+ let deserialized_td = deserialized_td.freeze();
+ assert_eq!(td.k(), deserialized_td.k());
+ assert_eq!(td.total_weight(), deserialized_td.total_weight());
+ assert!(td.is_empty());
+ assert!(deserialized_td.is_empty());
+}
+
+#[test]
+fn test_single_value() {
+ let mut td = TDigestMut::default();
+ td.update(123.0);
+
+ let bytes = td.serialize();
+ assert_eq!(bytes.len(), 16);
+
+ let deserialized_td = TDigestMut::deserialize(&bytes, false).unwrap();
+ let deserialized_td = deserialized_td.freeze();
+ assert_eq!(deserialized_td.k(), 200);
+ assert_eq!(deserialized_td.total_weight(), 1);
+ assert!(!deserialized_td.is_empty());
+ assert_eq!(deserialized_td.min_value(), Some(123.0));
+ assert_eq!(deserialized_td.max_value(), Some(123.0));
+}
+
+#[test]
+fn test_many_values() {
+ let mut td = TDigestMut::new(100);
+ for i in 0..1000 {
+ td.update(i as f64);
+ }
+
+ let bytes = td.serialize();
+ assert_eq!(bytes.len(), 1584);
+ let td = td.freeze();
+
+ let deserialized_td = TDigestMut::deserialize(&bytes, false).unwrap();
+ let deserialized_td = deserialized_td.freeze();
+ assert_eq!(td.k(), deserialized_td.k());
+ assert_eq!(td.total_weight(), deserialized_td.total_weight());
+ assert_eq!(td.is_empty(), deserialized_td.is_empty());
+ assert_eq!(td.min_value(), deserialized_td.min_value());
+ assert_eq!(td.max_value(), deserialized_td.max_value());
+ assert_eq!(td.rank(500.0), deserialized_td.rank(500.0));
+ assert_eq!(td.quantile(0.5), deserialized_td.quantile(0.5));
+}
diff --git a/tests/tdigest_test.rs b/tests/tdigest_test.rs
new file mode 100644
index 0000000..1ae1ae3
--- /dev/null
+++ b/tests/tdigest_test.rs
@@ -0,0 +1,229 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use datasketches::tdigest::TDigestMut;
+use googletest::assert_that;
+use googletest::prelude::{eq, near};
+
+#[test]
+fn test_empty() {
+ let mut tdigest = TDigestMut::new(10);
+ assert!(tdigest.is_empty());
+ assert_eq!(tdigest.k(), 10);
+ assert_eq!(tdigest.total_weight(), 0);
+ assert_eq!(tdigest.min_value(), None);
+ assert_eq!(tdigest.max_value(), None);
+ assert_eq!(tdigest.rank(0.0), None);
+ assert_eq!(tdigest.quantile(0.5), None);
+
+ let split_points = [0.0];
+ assert_eq!(tdigest.pmf(&split_points), None);
+ assert_eq!(tdigest.cdf(&split_points), None);
+
+ let tdigest = TDigestMut::new(10).freeze();
+ assert!(tdigest.is_empty());
+ assert_eq!(tdigest.k(), 10);
+ assert_eq!(tdigest.total_weight(), 0);
+ assert_eq!(tdigest.min_value(), None);
+ assert_eq!(tdigest.max_value(), None);
+ assert_eq!(tdigest.rank(0.0), None);
+ assert_eq!(tdigest.quantile(0.5), None);
+
+ let split_points = [0.0];
+ assert_eq!(tdigest.pmf(&split_points), None);
+ assert_eq!(tdigest.cdf(&split_points), None);
+}
+
+#[test]
+fn test_one_value() {
+ let mut tdigest = TDigestMut::new(100);
+ tdigest.update(1.0);
+ assert_eq!(tdigest.k(), 100);
+ assert_eq!(tdigest.total_weight(), 1);
+ assert_eq!(tdigest.min_value(), Some(1.0));
+ assert_eq!(tdigest.max_value(), Some(1.0));
+ assert_eq!(tdigest.rank(0.99), Some(0.0));
+ assert_eq!(tdigest.rank(1.0), Some(0.5));
+ assert_eq!(tdigest.rank(1.01), Some(1.0));
+ assert_eq!(tdigest.quantile(0.0), Some(1.0));
+ assert_eq!(tdigest.quantile(0.5), Some(1.0));
+ assert_eq!(tdigest.quantile(1.0), Some(1.0));
+}
+
+#[test]
+fn test_many_values() {
+ let n = 10000;
+
+ let mut tdigest = TDigestMut::default();
+ for i in 0..n {
+ tdigest.update(i as f64);
+ }
+
+ assert!(!tdigest.is_empty());
+ assert_eq!(tdigest.total_weight(), n);
+ assert_eq!(tdigest.min_value(), Some(0.0));
+ assert_eq!(tdigest.max_value(), Some((n - 1) as f64));
+
+ assert_that!(tdigest.rank(0.0).unwrap(), near(0.0, 0.0001));
+ assert_that!(tdigest.rank((n / 4) as f64).unwrap(), near(0.25, 0.0001));
+ assert_that!(tdigest.rank((n / 2) as f64).unwrap(), near(0.5, 0.0001));
+ assert_that!(
+ tdigest.rank((n * 3 / 4) as f64).unwrap(),
+ near(0.75, 0.0001)
+ );
+ assert_that!(tdigest.rank(n as f64).unwrap(), eq(1.0));
+ assert_that!(tdigest.quantile(0.0).unwrap(), eq(0.0));
+ assert_that!(
+ tdigest.quantile(0.5).unwrap(),
+ near((n / 2) as f64, 0.03 * (n / 2) as f64)
+ );
+ assert_that!(
+ tdigest.quantile(0.9).unwrap(),
+ near((n as f64) * 0.9, 0.01 * (n as f64) * 0.9)
+ );
+ assert_that!(
+ tdigest.quantile(0.95).unwrap(),
+ near((n as f64) * 0.95, 0.01 * (n as f64) * 0.95)
+ );
+ assert_that!(tdigest.quantile(1.0).unwrap(), eq((n - 1) as f64));
+
+ let split_points = [n as f64 / 2.0];
+ let pmf = tdigest.pmf(&split_points).unwrap();
+ assert_eq!(pmf.len(), 2);
+ assert_that!(pmf[0], near(0.5, 0.0001));
+ assert_that!(pmf[1], near(0.5, 0.0001));
+ let cdf = tdigest.cdf(&split_points).unwrap();
+ assert_eq!(cdf.len(), 2);
+ assert_that!(cdf[0], near(0.5, 0.0001));
+ assert_that!(cdf[1], eq(1.0));
+}
+
+#[test]
+fn test_rank_two_values() {
+ let mut tdigest = TDigestMut::new(100);
+ tdigest.update(1.0);
+ tdigest.update(2.0);
+ assert_eq!(tdigest.rank(0.99), Some(0.0));
+ assert_eq!(tdigest.rank(1.0), Some(0.25));
+ assert_eq!(tdigest.rank(1.25), Some(0.375));
+ assert_eq!(tdigest.rank(1.5), Some(0.5));
+ assert_eq!(tdigest.rank(1.75), Some(0.625));
+ assert_eq!(tdigest.rank(2.0), Some(0.75));
+ assert_eq!(tdigest.rank(2.01), Some(1.0));
+}
+
+#[test]
+fn test_rank_repeated_values() {
+ let mut tdigest = TDigestMut::new(100);
+ tdigest.update(1.0);
+ tdigest.update(1.0);
+ tdigest.update(1.0);
+ tdigest.update(1.0);
+ assert_eq!(tdigest.rank(0.99), Some(0.0));
+ assert_eq!(tdigest.rank(1.0), Some(0.5));
+ assert_eq!(tdigest.rank(1.01), Some(1.0));
+}
+
+#[test]
+fn test_repeated_blocks() {
+ let mut tdigest = TDigestMut::new(100);
+ tdigest.update(1.0);
+ tdigest.update(2.0);
+ tdigest.update(2.0);
+ tdigest.update(3.0);
+ assert_eq!(tdigest.rank(0.99), Some(0.0));
+ assert_eq!(tdigest.rank(1.0), Some(0.125));
+ assert_eq!(tdigest.rank(2.0), Some(0.5));
+ assert_eq!(tdigest.rank(3.0), Some(0.875));
+ assert_eq!(tdigest.rank(3.01), Some(1.0));
+}
+
+#[test]
+fn test_merge_small() {
+ let mut td1 = TDigestMut::new(10);
+ td1.update(1.0);
+ td1.update(2.0);
+ let mut td2 = TDigestMut::new(10);
+ td2.update(2.0);
+ td2.update(3.0);
+ td1.merge(&td2);
+ assert_eq!(td1.min_value(), Some(1.0));
+ assert_eq!(td1.max_value(), Some(3.0));
+ assert_eq!(td1.total_weight(), 4);
+ assert_eq!(td1.rank(0.99), Some(0.0));
+ assert_eq!(td1.rank(1.0), Some(0.125));
+ assert_eq!(td1.rank(2.0), Some(0.5));
+ assert_eq!(td1.rank(3.0), Some(0.875));
+ assert_eq!(td1.rank(3.01), Some(1.0));
+}
+
+#[test]
+fn test_merge_large() {
+ let n = 10000;
+
+ let mut td1 = TDigestMut::new(10);
+ let mut td2 = TDigestMut::new(10);
+ let sup = n / 2;
+ for i in 0..sup {
+ td1.update(i as f64);
+ td2.update((sup + i) as f64);
+ }
+ td1.merge(&td2);
+
+ assert_eq!(td1.total_weight(), n);
+ assert_eq!(td1.min_value(), Some(0.0));
+ assert_eq!(td1.max_value(), Some((n - 1) as f64));
+
+ assert_that!(td1.rank(0.0).unwrap(), near(0.0, 0.0001));
+ assert_that!(td1.rank((n / 4) as f64).unwrap(), near(0.25, 0.0001));
+ assert_that!(td1.rank((n / 2) as f64).unwrap(), near(0.5, 0.0001));
+ assert_that!(td1.rank((n * 3 / 4) as f64).unwrap(), near(0.75, 0.0001));
+ assert_that!(td1.rank(n as f64).unwrap(), eq(1.0));
+}
+
+#[test]
+fn test_invalid_inputs() {
+ let n = 100;
+
+ let mut td = TDigestMut::new(10);
+ for _ in 0..n {
+ td.update(f64::NAN);
+ }
+ assert!(td.is_empty());
+
+ let mut td = TDigestMut::new(10);
+ for _ in 0..n {
+ td.update(f64::INFINITY);
+ }
+ assert!(td.is_empty());
+
+ let mut td = TDigestMut::new(10);
+ for _ in 0..n {
+ td.update(f64::NEG_INFINITY);
+ }
+ assert!(td.is_empty());
+
+ let mut td = TDigestMut::new(10);
+ for i in 0..n {
+ if i % 2 == 0 {
+ td.update(f64::INFINITY);
+ } else {
+ td.update(f64::NEG_INFINITY);
+ }
+ }
+ assert!(td.is_empty());
+}
diff --git a/tests/test_data/tdigest_ref_k100_n10000_double.sk
b/tests/test_data/tdigest_ref_k100_n10000_double.sk
new file mode 100644
index 0000000..f6f4510
Binary files /dev/null and b/tests/test_data/tdigest_ref_k100_n10000_double.sk
differ
diff --git a/tests/test_data/tdigest_ref_k100_n10000_float.sk
b/tests/test_data/tdigest_ref_k100_n10000_float.sk
new file mode 100644
index 0000000..16d7981
Binary files /dev/null and b/tests/test_data/tdigest_ref_k100_n10000_float.sk
differ
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]