tisonkun commented on code in PR #23: URL: https://github.com/apache/datasketches-rust/pull/23#discussion_r2625321627
########## src/tdigest/sketch.rs: ########## @@ -0,0 +1,673 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use byteorder::{ByteOrder, LE, ReadBytesExt}; +use std::cmp::Ordering; +use std::convert::identity; +use std::io::Cursor; + +use crate::error::SerdeError; +use crate::tdigest::serialization::*; + +const BUFFER_MULTIPLIER: usize = 4; + +/// T-Digest sketch for estimating quantiles and ranks. +/// +/// See the [module documentation](self) for more details. +#[derive(Debug, Clone, PartialEq)] +pub struct TDigest { + k: u16, + + reverse_merge: bool, + min: f64, + max: f64, + + centroids: Vec<Centroid>, + centroids_weight: u64, + centroids_capacity: usize, + buffer: Vec<f64>, +} + +impl Default for TDigest { + fn default() -> Self { + TDigest::new(Self::DEFAULT_K) + } +} + +impl TDigest { + /// The default value of K if one is not specified. + pub const DEFAULT_K: u16 = 200; + + /// Creates a tdigest instance with the given value of k. + /// + /// # Panics + /// + /// If k is less than 10 + pub fn new(k: u16) -> Self { + Self::make( + k, + false, + f64::INFINITY, + f64::NEG_INFINITY, + vec![], + 0, + vec![], + ) + } + + // for deserialization + fn make( + k: u16, + reverse_merge: bool, + min: f64, + max: f64, + mut centroids: Vec<Centroid>, + centroids_weight: u64, + mut buffer: Vec<f64>, + ) -> Self { + assert!(k >= 10, "k must be at least 10"); + + let fudge = if k < 30 { 30 } else { 10 }; + let centroids_capacity = (k as usize * 2) + fudge; + + centroids.reserve(centroids_capacity); + buffer.reserve(centroids_capacity * BUFFER_MULTIPLIER); + + TDigest { + k, + reverse_merge, + min, + max, + centroids, + centroids_weight, + centroids_capacity, + buffer, + } + } + + /// Update this TDigest with the given value (`NaN` values are ignored). + pub fn update(&mut self, value: f64) { + if value.is_nan() { + return; + } + + if self.buffer.len() == self.centroids_capacity * BUFFER_MULTIPLIER { + self.compress(); + } + + self.buffer.push(value); + self.min = self.min.min(value); + self.max = self.max.max(value); + } + + /// Returns parameter k (compression) that was used to configure this TDigest. + pub fn k(&self) -> u16 { + self.k + } + + /// Returns true if TDigest has not seen any data. + pub fn is_empty(&self) -> bool { + self.centroids.is_empty() && self.buffer.is_empty() + } + + /// Returns minimum value seen by TDigest; `None` if TDigest is empty. + pub fn min_value(&self) -> Option<f64> { + if self.is_empty() { + None + } else { + Some(self.min) + } + } + + /// Returns maximum value seen by TDigest; `None` if TDigest is empty. + pub fn max_value(&self) -> Option<f64> { + if self.is_empty() { + None + } else { + Some(self.max) + } + } + + /// Returns total weight. + pub fn total_weight(&self) -> u64 { + self.centroids_weight + (self.buffer.len() as u64) + } + + /// Merge the given t-Digest into this one + pub fn merge(&mut self, other: &TDigest) { + if other.is_empty() { + return; + } + + let mut tmp = Vec::with_capacity( + self.centroids.len() + self.buffer.len() + other.centroids.len() + other.buffer.len(), + ); + for &v in &self.buffer { + tmp.push(Centroid { mean: v, weight: 1 }); + } + for &v in &other.buffer { + tmp.push(Centroid { mean: v, weight: 1 }); + } + for &c in &other.centroids { + tmp.push(c); + } + self.do_merge(tmp, self.buffer.len() as u64 + other.total_weight()) + } + + /// Compute approximate normalized rank (from 0 to 1 inclusive) of the given value. + /// + /// Returns `None` if TDigest is empty. + /// + /// # Panics + /// + /// If the value is `NaN`. + pub fn get_rank(&mut self, value: f64) -> Option<f64> { + assert!(!value.is_nan(), "value must not be NaN"); + + if self.is_empty() { + return None; + } + if value < self.min { + return Some(0.0); + } + if value > self.max { + return Some(1.0); + } + // one centroid and value == min == max + if self.centroids.len() + self.buffer.len() == 1 { + return Some(0.5); + } + + self.compress(); // side effect + let centroids_weight = self.centroids_weight as f64; + let num_centroids = self.centroids.len(); + + // left tail + let first_mean = self.centroids[0].mean; + if value < first_mean { + if first_mean - self.min > 0. { + return Some(if value == self.min { + 0.5 / centroids_weight + } else { + 1. + (((value - self.min) / (first_mean - self.min)) + * ((self.centroids[0].weight as f64 / 2.) - 1.)) + }); + } + return Some(0.); // should never happen + } + + // right tail + let last_mean = self.centroids[num_centroids - 1].mean; + if value > last_mean { + if self.max - last_mean > 0. { + return Some(if value == self.max { + 1. - (0.5 / centroids_weight) + } else { + 1.0 - ((1.0 + + (((self.max - value) / (self.max - last_mean)) + * ((self.centroids[num_centroids - 1].weight as f64 / 2.) - 1.))) + / centroids_weight) + }); + } + return Some(1.); // should never happen + } + + let mut lower = self + .centroids + .binary_search_by(|c| centroid_lower_bound(c, value)) + .unwrap_or_else(identity); + debug_assert_ne!(lower, num_centroids, "get_rank: lower == end"); + let mut upper = self + .centroids + .binary_search_by(|c| centroid_upper_bound(c, value)) + .unwrap_or_else(identity); + debug_assert_ne!(upper, 0, "get_rank: upper == begin"); + if value < self.centroids[lower].mean { + lower -= 1; + } + if (upper == num_centroids) || (self.centroids[upper - 1].mean >= value) { + upper -= 1; + } + + let mut weight_below = 0.; + let mut i = 0; + while i < lower { + weight_below += self.centroids[i].weight as f64; + i += 1; + } + weight_below += self.centroids[lower].weight as f64 / 2.; + + let mut weight_delta = 0.; + while i < upper { + weight_delta += self.centroids[i].weight as f64; + i += 1; + } + weight_delta -= self.centroids[lower].weight as f64 / 2.; + weight_delta += self.centroids[upper].weight as f64 / 2.; + Some( + if self.centroids[upper].mean - self.centroids[lower].mean > 0. { + (weight_below + + (weight_delta * (value - self.centroids[lower].mean) + / (self.centroids[upper].mean - self.centroids[lower].mean))) + / centroids_weight + } else { + (weight_below + weight_delta / 2.) / centroids_weight + }, + ) + } + + /// Compute approximate quantile value corresponding to the given normalized rank. + /// + /// Returns `None` if TDigest is empty. + /// + /// # Panics + /// + /// If rank is not in [0.0, 1.0]. + pub fn get_quantile(&mut self, rank: f64) -> Option<f64> { + assert!((0.0..=1.0).contains(&rank), "rank must be in [0.0, 1.0]"); + + if self.is_empty() { + return None; + } + + self.compress(); // side effect + if self.centroids.len() == 1 { + return Some(self.centroids[0].mean); + } + + // at least 2 centroids + let centroids_weight = self.centroids_weight as f64; + let num_centroids = self.centroids.len(); + let weight = rank * centroids_weight; + if weight < 1. { + return Some(self.min); + } + if weight > centroids_weight - 1. { + return Some(self.max); + } + let first_weight = self.centroids[0].weight as f64; + if first_weight > 1. && weight < first_weight / 2. { + return Some( + self.min + + (((weight - 1.) / ((first_weight / 2.) - 1.)) + * (self.centroids[0].mean - self.min)), + ); + } + let last_weight = self.centroids[num_centroids - 1].weight as f64; + if last_weight > 1. && (centroids_weight - weight <= last_weight / 2.) { + return Some( + self.max + + (((centroids_weight - weight - 1.) / ((last_weight / 2.) - 1.)) + * (self.max - self.centroids[num_centroids - 1].mean)), + ); + } + + // interpolate between extremes + let mut weight_so_far = first_weight / 2.; + for i in 0..(num_centroids - 1) { + let dw = (self.centroids[i].weight + self.centroids[i + 1].weight) as f64 / 2.; + if weight_so_far + dw > weight { + // the target weight is between centroids i and i+1 + let mut left_weight = 0.; + if self.centroids[i].weight == 1 { + if weight - weight_so_far < 0.5 { + return Some(self.centroids[i].mean); + } + left_weight = 0.5; + } + let mut right_weight = 0.; + if self.centroids[i + 1].weight == 1 { + if weight_so_far + dw - weight < 0.5 { + return Some(self.centroids[i + 1].mean); + } + right_weight = 0.5; + } + let w1 = weight - weight_so_far - left_weight; + let w2 = weight_so_far + dw - weight - right_weight; + return Some(weighted_average( + self.centroids[i].mean, + w1, + self.centroids[i + 1].mean, + w2, + )); + } + weight_so_far += dw; + } + + let w1 = weight + - (self.centroids_weight as f64) + - ((self.centroids[num_centroids - 1].weight as f64) / 2.); + let w2 = (self.centroids[num_centroids - 1].weight as f64 / 2.) - w1; + Some(weighted_average( + self.centroids[num_centroids - 1].mean, + w1, + self.max, + w2, + )) + } + + /// Serializes this TDigest to bytes. + pub fn serialize(&mut self) -> Vec<u8> { + self.compress(); + + let mut bytes = vec![]; + bytes.push(match self.total_weight() { + 0 => PREAMBLE_LONGS_EMPTY_OR_SINGLE, + 1 => PREAMBLE_LONGS_EMPTY_OR_SINGLE, + _ => PREAMBLE_LONGS_MULTIPLE, + }); + bytes.push(SERIAL_VERSION); + bytes.push(TDIGEST_FAMILY_ID); + LE::write_u16(&mut bytes, self.k); + bytes.push({ + let mut flags = 0; + if self.is_empty() { + flags |= FLAGS_IS_EMPTY; + } + if self.is_single_value() { + flags |= FLAGS_IS_SINGLE_VALUE; + } + if self.reverse_merge { + flags |= FLAGS_REVERSE_MERGE; + } + flags + }); + LE::write_u16(&mut bytes, 0); // unused + if self.is_empty() { + return bytes; + } + if self.is_single_value() { + LE::write_f64(&mut bytes, self.min); + return bytes; + } + LE::write_u32(&mut bytes, self.centroids.len() as u32); + LE::write_u32(&mut bytes, 0); // unused + LE::write_f64(&mut bytes, self.min); + LE::write_f64(&mut bytes, self.max); + for centroid in &self.centroids { + LE::write_f64(&mut bytes, centroid.mean); + LE::write_u64(&mut bytes, centroid.weight); + } + bytes + } + + /// Deserializes a TDigest from bytes. + /// + /// Supports reading compact format with (float, int) centroids as opposed to (double, long) to + /// represent (mean, weight). [^1] + /// + /// [^1]: This is to support reading the `tdigest<float>` format from the C++ implementation. + pub fn deserialize(bytes: &[u8], is_float: bool) -> Result<Self, SerdeError> { + let make_error = |tag: &'static str| move |_| SerdeError::InsufficientData(tag.to_string()); + let mut cursor = Cursor::new(bytes); + + let preamble_longs = cursor.read_u8().map_err(make_error("preamble_longs"))?; + let serial_version = cursor.read_u8().map_err(make_error("serial_version"))?; + let family_id = cursor.read_u8().map_err(make_error("family_id"))?; + if family_id != TDIGEST_FAMILY_ID { + // TODO: Support reading format of the reference implementation + return Err(SerdeError::InvalidFamily(format!( + "expected {} (TDigest), got {}", + TDIGEST_FAMILY_ID, family_id + ))); + } + if serial_version != SERIAL_VERSION { + return Err(SerdeError::UnsupportedVersion(format!( + "expected {}, got {}", + SERIAL_VERSION, serial_version + ))); + } + let k = cursor.read_u16::<LE>().map_err(make_error("k"))?; + let flags = cursor.read_u8().map_err(make_error("flags"))?; + let is_empty = (flags & FLAGS_IS_EMPTY) != 0; + let is_single_value = (flags & FLAGS_IS_SINGLE_VALUE) != 0; + let expected_preamble_longs = if is_empty || is_single_value { + PREAMBLE_LONGS_EMPTY_OR_SINGLE + } else { + PREAMBLE_LONGS_MULTIPLE + }; + if preamble_longs != expected_preamble_longs { + return Err(SerdeError::MalformedData(format!( + "expected preamble_longs to be {}, got {}", + expected_preamble_longs, preamble_longs + ))); + } + cursor.read_u16::<LE>().map_err(make_error("<unused>"))?; // unused + if is_empty { + return Ok(TDigest::new(k)); + } + + let reverse_merge = (flags & FLAGS_REVERSE_MERGE) != 0; + if is_single_value { + let value = if is_float { + cursor + .read_f32::<LE>() + .map_err(make_error("single_value"))? as f64 + } else { + cursor + .read_f64::<LE>() + .map_err(make_error("single_value"))? + }; + return Ok(TDigest::make( + k, + reverse_merge, + value, + value, + vec![Centroid { + mean: value, + weight: 1, + }], + 1, + vec![], + )); + } + let num_centroids = cursor + .read_u32::<LE>() + .map_err(make_error("num_centroids"))? as usize; + let num_buffered = cursor + .read_u32::<LE>() + .map_err(make_error("num_buffered"))? as usize; + let (min, max) = if is_float { + ( + cursor.read_f32::<LE>().map_err(make_error("min"))? as f64, + cursor.read_f32::<LE>().map_err(make_error("max"))? as f64, + ) + } else { + ( + cursor.read_f64::<LE>().map_err(make_error("min"))?, + cursor.read_f64::<LE>().map_err(make_error("max"))?, + ) + }; + let mut centroids = Vec::with_capacity(num_centroids); + let mut centroids_weight = 0; + for _ in 0..num_centroids { + let (mean, weight) = if is_float { + ( + cursor.read_f32::<LE>().map_err(make_error("mean"))? as f64, + cursor.read_u32::<LE>().map_err(make_error("weight"))? as u64, + ) + } else { + ( + cursor.read_f64::<LE>().map_err(make_error("mean"))?, + cursor.read_u64::<LE>().map_err(make_error("weight"))?, + ) + }; + centroids_weight += weight; + centroids.push(Centroid { mean, weight }); + } + let mut buffer = Vec::with_capacity(num_buffered); + for _ in 0..num_buffered { + buffer.push(if is_float { + cursor + .read_f32::<LE>() + .map_err(make_error("buffered_value"))? as f64 + } else { + cursor + .read_f64::<LE>() + .map_err(make_error("buffered_value"))? + }) + } + Ok(TDigest::make( + k, + reverse_merge, + min, + max, + centroids, + centroids_weight, + buffer, + )) + } + + fn is_single_value(&self) -> bool { + self.total_weight() == 1 + } + + /// Process buffered values and merge centroids if needed. + fn compress(&mut self) { + if self.buffer.is_empty() { + return; + } + let mut tmp = Vec::with_capacity(self.buffer.len() + self.centroids.len()); + for &v in &self.buffer { + tmp.push(Centroid { mean: v, weight: 1 }); + } + self.do_merge(tmp, self.buffer.len() as u64) + } + + /// Merges the given buffer of centroids into this TDigest. + /// + /// # Contract + /// + /// * `buffer` must have at least one centroid. + /// * `buffer` is generated from `self.buffer`, and thus: + /// * No `NAN` values are present in `buffer`. + /// * We should clear `self.buffer` after merging. + fn do_merge(&mut self, mut buffer: Vec<Centroid>, weight: u64) { + buffer.extend(std::mem::take(&mut self.centroids)); + buffer.sort_by(centroid_cmp); + if self.reverse_merge { + buffer.reverse(); + } + self.centroids_weight += weight; + + let mut num_centroids = 0; + let len = buffer.len(); + self.centroids.push(buffer[0]); + num_centroids += 1; + let mut current = 1; + let mut weight_so_far = 0.; + while current < len { + let c = buffer[current]; + let proposed_weight = (self.centroids[num_centroids - 1].weight + c.weight) as f64; + let mut add_this = false; + if (current != 1) && (current != (len - 1)) { + let centroids_weight = self.centroids_weight as f64; + let q0 = weight_so_far / centroids_weight; + let q2 = (weight_so_far + proposed_weight) / centroids_weight; + let normalizer = scale_function::normalizer((2 * self.k) as f64, centroids_weight); + add_this = proposed_weight + <= (centroids_weight + * scale_function::max(q0, normalizer) + .min(scale_function::max(q2, normalizer))); + } + if add_this { + // merge into existing centroid + self.centroids[num_centroids - 1].add(c); + } else { + // copy to a new centroid + weight_so_far += self.centroids[num_centroids - 1].weight as f64; + self.centroids.push(c); + num_centroids += 1; + } + current += 1; + } + + if self.reverse_merge { + self.centroids.reverse(); + } + self.min = self.min.min(self.centroids[0].mean); + self.max = self.max.max(self.centroids[num_centroids - 1].mean); + self.reverse_merge = !self.reverse_merge; + self.buffer.clear(); + } +} + +fn centroid_cmp(a: &Centroid, b: &Centroid) -> Ordering { + match a.mean.partial_cmp(&b.mean) { + Some(order) => order, + None => unreachable!("NaN values should never be present in centroids"), Review Comment: Seems inevitable. ```rust let mut td = TDigestMut::new(10); for i in 0..10000 { if i % 2 == 0 { td.update(f64::INFINITY); } else { td.update(f64::NEG_INFINITY); } } assert_eq!(td.quantile(0.5), Some(f64::NEG_INFINITY)); ``` This would try to merge `-inf` and `inf` which is defined as NaN. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
