tisonkun commented on code in PR #23: URL: https://github.com/apache/datasketches-rust/pull/23#discussion_r2625336137
########## src/tdigest/sketch.rs: ########## @@ -0,0 +1,1138 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use byteorder::{BE, LE, ReadBytesExt}; +use std::cmp::Ordering; +use std::convert::identity; +use std::io::Cursor; + +use crate::error::SerdeError; +use crate::tdigest::serialization::*; + +/// The default value of K if one is not specified. +const DEFAULT_K: u16 = 200; +/// Multiplier for buffer size relative to centroids capacity. +const BUFFER_MULTIPLIER: usize = 4; + +/// T-Digest sketch for estimating quantiles and ranks. +/// +/// See the [module documentation](super) for more details. +#[derive(Debug, Clone)] +pub struct TDigestMut { + k: u16, + + reverse_merge: bool, + min: f64, + max: f64, + + centroids: Vec<Centroid>, + centroids_weight: u64, + centroids_capacity: usize, + buffer: Vec<f64>, +} + +impl Default for TDigestMut { + fn default() -> Self { + TDigestMut::new(DEFAULT_K) + } +} + +impl TDigestMut { + /// Creates a tdigest instance with the given value of k. + /// + /// # Panics + /// + /// If k is less than 10 + pub fn new(k: u16) -> Self { + Self::make( + k, + false, + f64::INFINITY, + f64::NEG_INFINITY, + vec![], + 0, + vec![], + ) + } + + // for deserialization + fn make( + k: u16, + reverse_merge: bool, + min: f64, + max: f64, + mut centroids: Vec<Centroid>, + centroids_weight: u64, + mut buffer: Vec<f64>, + ) -> Self { + assert!(k >= 10, "k must be at least 10"); + + let fudge = if k < 30 { 30 } else { 10 }; + let centroids_capacity = (k as usize * 2) + fudge; + + centroids.reserve(centroids_capacity); + buffer.reserve(centroids_capacity * BUFFER_MULTIPLIER); + + TDigestMut { + k, + reverse_merge, + min, + max, + centroids, + centroids_weight, + centroids_capacity, + buffer, + } + } + + /// Update this TDigest with the given value (`NaN` values are ignored). + pub fn update(&mut self, value: f64) { + if value.is_nan() { + return; + } + + if self.buffer.len() == self.centroids_capacity * BUFFER_MULTIPLIER { + self.compress(); + } + + self.buffer.push(value); + self.min = self.min.min(value); + self.max = self.max.max(value); + } + + /// Returns parameter k (compression) that was used to configure this TDigest. + pub fn k(&self) -> u16 { + self.k + } + + /// Returns true if TDigest has not seen any data. + pub fn is_empty(&self) -> bool { + self.centroids.is_empty() && self.buffer.is_empty() + } + + /// Returns minimum value seen by TDigest; `None` if TDigest is empty. + pub fn min_value(&self) -> Option<f64> { + if self.is_empty() { + None + } else { + Some(self.min) + } + } + + /// Returns maximum value seen by TDigest; `None` if TDigest is empty. + pub fn max_value(&self) -> Option<f64> { + if self.is_empty() { + None + } else { + Some(self.max) + } + } + + /// Returns total weight. + pub fn total_weight(&self) -> u64 { + self.centroids_weight + (self.buffer.len() as u64) + } + + /// Merge the given TDigest into this one + pub fn merge(&mut self, other: &TDigestMut) { + if other.is_empty() { + return; + } + + let mut tmp = Vec::with_capacity( + self.centroids.len() + self.buffer.len() + other.centroids.len() + other.buffer.len(), + ); + for &v in &self.buffer { + tmp.push(Centroid { mean: v, weight: 1 }); + } + for &v in &other.buffer { + tmp.push(Centroid { mean: v, weight: 1 }); + } + for &c in &other.centroids { + tmp.push(c); + } + self.do_merge(tmp, self.buffer.len() as u64 + other.total_weight()) + } + + /// Freezes this TDigest into an immutable one. + pub fn freeze(mut self) -> TDigest { + self.compress(); + TDigest { + k: self.k, + reverse_merge: self.reverse_merge, + min: self.min, + max: self.max, + centroids: self.centroids, + centroids_weight: self.centroids_weight, + } + } + + fn view(&mut self) -> TDigestView<'_> { + self.compress(); // side effect + TDigestView { + min: self.min, + max: self.max, + centroids: &self.centroids, + centroids_weight: self.centroids_weight, + } + } + + /// Returns an approximation to the Cumulative Distribution Function (CDF), which is the + /// cumulative analog of the PMF, of the input stream given a set of split points. + /// + /// # Arguments + /// + /// * `split_points`: An array of _m_ unique, monotonically increasing values that divide the + /// input domain into _m+1_ consecutive disjoint intervals. + /// + /// # Returns + /// + /// An array of m+1 doubles, which are a consecutive approximation to the CDF of the input + /// stream given the split points. The value at array position j of the returned CDF array + /// is the sum of the returned values in positions 0 through j of the returned PMF array. + /// This can be viewed as array of ranks of the given split points plus one more value that + /// is always 1. + /// + /// Returns `None` if TDigest is empty. + /// + /// # Panics + /// + /// If `split_points` is not unique, not monotonically increasing, or contains `NaN` values. + pub fn cdf(&mut self, split_points: &[f64]) -> Option<Vec<f64>> { + check_split_points(split_points); + + if self.is_empty() { + return None; + } + + self.view().cdf(split_points) + } + + /// Returns an approximation to the Probability Mass Function (PMF) of the input stream + /// given a set of split points. + /// + /// # Arguments + /// + /// * `split_points`: An array of _m_ unique, monotonically increasing values that divide the + /// input domain into _m+1_ consecutive disjoint intervals (bins). + /// + /// # Returns + /// + /// An array of m+1 doubles each of which is an approximation to the fraction of the input + /// stream values (the mass) that fall into one of those intervals. + /// + /// Returns `None` if TDigest is empty. + /// + /// # Panics + /// + /// If `split_points` is not unique, not monotonically increasing, or contains `NaN` values. + pub fn pmf(&mut self, split_points: &[f64]) -> Option<Vec<f64>> { + check_split_points(split_points); + + if self.is_empty() { + return None; + } + + self.view().pmf(split_points) + } + + /// Compute approximate normalized rank (from 0 to 1 inclusive) of the given value. + /// + /// Returns `None` if TDigest is empty. + /// + /// # Panics + /// + /// If the value is `NaN`. + pub fn rank(&mut self, value: f64) -> Option<f64> { + assert!(!value.is_nan(), "value must not be NaN"); + + if self.is_empty() { + return None; + } + if value < self.min { + return Some(0.0); + } + if value > self.max { + return Some(1.0); + } + // one centroid and value == min == max + if self.centroids.len() + self.buffer.len() == 1 { + return Some(0.5); + } + + self.view().rank(value) + } + + /// Compute approximate quantile value corresponding to the given normalized rank. + /// + /// Returns `None` if TDigest is empty. + /// + /// # Panics + /// + /// If rank is not in [0.0, 1.0]. + pub fn quantile(&mut self, rank: f64) -> Option<f64> { + assert!((0.0..=1.0).contains(&rank), "rank must be in [0.0, 1.0]"); + + if self.is_empty() { + return None; + } + + self.view().quantile(rank) + } + + /// Serializes this TDigest to bytes. + pub fn serialize(&mut self) -> Vec<u8> { + self.compress(); + + let mut total_size = 0; + if self.is_empty() || self.is_single_value() { + // 1 byte preamble + // + 1 byte serial version + // + 1 byte family + // + 2 bytes k + // + 1 byte flags + // + 2 bytes unused + total_size += size_of::<u64>(); + } else { + // all of the above + // + 4 bytes num centroids + // + 4 bytes num buffered + total_size += size_of::<u64>() * 2; + } + if self.is_empty() { + // nothing more + } else if self.is_single_value() { + // + 8 bytes single value + total_size += size_of::<f64>(); + } else { + // + 8 bytes min + // + 8 bytes max + total_size += size_of::<f64>() * 2; + // + (8+8) bytes per centroid + total_size += self.centroids.len() * (size_of::<f64>() + size_of::<u64>()); + } + + let mut bytes = Vec::with_capacity(total_size); + bytes.push(match self.total_weight() { + 0 => PREAMBLE_LONGS_EMPTY_OR_SINGLE, + 1 => PREAMBLE_LONGS_EMPTY_OR_SINGLE, + _ => PREAMBLE_LONGS_MULTIPLE, + }); + bytes.push(SERIAL_VERSION); + bytes.push(TDIGEST_FAMILY_ID); + bytes.extend_from_slice(&self.k.to_le_bytes()); + bytes.push({ + let mut flags = 0; + if self.is_empty() { + flags |= FLAGS_IS_EMPTY; + } + if self.is_single_value() { + flags |= FLAGS_IS_SINGLE_VALUE; + } + if self.reverse_merge { + flags |= FLAGS_REVERSE_MERGE; + } + flags + }); + bytes.extend_from_slice(&0u16.to_le_bytes()); // unused + if self.is_empty() { + return bytes; + } + if self.is_single_value() { + bytes.extend_from_slice(&self.min.to_le_bytes()); + return bytes; + } + bytes.extend_from_slice(&(self.centroids.len() as u32).to_le_bytes()); + bytes.extend_from_slice(&0u32.to_le_bytes()); // unused + bytes.extend_from_slice(&self.min.to_le_bytes()); + bytes.extend_from_slice(&self.max.to_le_bytes()); + for centroid in &self.centroids { + bytes.extend_from_slice(¢roid.mean.to_le_bytes()); + bytes.extend_from_slice(¢roid.weight.to_le_bytes()); + } + bytes + } + + /// Deserializes a TDigest from bytes. + /// + /// Supports reading compact format with (float, int) centroids as opposed to (double, long) to + /// represent (mean, weight). [^1] + /// + /// Supports reading format of the reference implementation (auto-detected) [^2]. + /// + /// [^1]: This is to support reading the `tdigest<float>` format from the C++ implementation. + /// [^2]: <https://github.com/tdunning/t-digest> + pub fn deserialize(bytes: &[u8], is_f32: bool) -> Result<Self, SerdeError> { + fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> SerdeError { + move |_| SerdeError::InsufficientData(tag.to_string()) + } + + fn check_non_nan(value: f64, tag: &'static str) -> Result<(), SerdeError> { + if value.is_nan() { + return Err(SerdeError::MalformedData(format!("{tag} cannot be NaN"))); + } + Ok(()) + } + + fn check_nonzero(value: u64, tag: &'static str) -> Result<(), SerdeError> { + if value == 0 { + return Err(SerdeError::MalformedData(format!("{tag} cannot be zero"))); + } + Ok(()) + } + + let mut cursor = Cursor::new(bytes); + + let preamble_longs = cursor.read_u8().map_err(make_error("preamble_longs"))?; + let serial_version = cursor.read_u8().map_err(make_error("serial_version"))?; + let family_id = cursor.read_u8().map_err(make_error("family_id"))?; + if family_id != TDIGEST_FAMILY_ID { + if preamble_longs == 0 && serial_version == 0 && family_id == 0 { + return Self::deserialize_compat(bytes); + } + return Err(SerdeError::InvalidFamily(format!( + "expected {} (TDigest), got {}", + TDIGEST_FAMILY_ID, family_id + ))); + } + if serial_version != SERIAL_VERSION { + return Err(SerdeError::UnsupportedVersion(format!( + "expected {}, got {}", + SERIAL_VERSION, serial_version + ))); + } + let k = cursor.read_u16::<LE>().map_err(make_error("k"))?; + let flags = cursor.read_u8().map_err(make_error("flags"))?; + let is_empty = (flags & FLAGS_IS_EMPTY) != 0; + let is_single_value = (flags & FLAGS_IS_SINGLE_VALUE) != 0; + let expected_preamble_longs = if is_empty || is_single_value { + PREAMBLE_LONGS_EMPTY_OR_SINGLE + } else { + PREAMBLE_LONGS_MULTIPLE + }; + if preamble_longs != expected_preamble_longs { + return Err(SerdeError::MalformedData(format!( + "expected preamble_longs to be {}, got {}", + expected_preamble_longs, preamble_longs + ))); + } + cursor.read_u16::<LE>().map_err(make_error("<unused>"))?; // unused + if is_empty { + return Ok(TDigestMut::new(k)); + } + + let reverse_merge = (flags & FLAGS_REVERSE_MERGE) != 0; + if is_single_value { + let value = if is_f32 { + cursor + .read_f32::<LE>() + .map_err(make_error("single_value"))? as f64 + } else { + cursor + .read_f64::<LE>() + .map_err(make_error("single_value"))? + }; + check_non_nan(value, "single_value")?; + return Ok(TDigestMut::make( + k, + reverse_merge, + value, + value, + vec![Centroid { + mean: value, + weight: 1, + }], + 1, + vec![], + )); + } + let num_centroids = cursor + .read_u32::<LE>() + .map_err(make_error("num_centroids"))? as usize; + let num_buffered = cursor + .read_u32::<LE>() + .map_err(make_error("num_buffered"))? as usize; + let (min, max) = if is_f32 { + ( + cursor.read_f32::<LE>().map_err(make_error("min"))? as f64, + cursor.read_f32::<LE>().map_err(make_error("max"))? as f64, + ) + } else { + ( + cursor.read_f64::<LE>().map_err(make_error("min"))?, + cursor.read_f64::<LE>().map_err(make_error("max"))?, + ) + }; + check_non_nan(min, "min")?; + check_non_nan(max, "max")?; Review Comment: min & max should have other certia like min <= max and should be consistent with centroid + buffer. But I'd leave it later since the non_nan check is complex enough here. I don't make this PR too unreviewable ( -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
